test_placement_groups.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import unittest
  3. import ray
  4. from ray import tune
  5. from ray.tune import Callback
  6. from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
  7. from ray.tune.ray_trial_executor import RayTrialExecutor
  8. from ray.tune.trial import Trial
  9. from ray.tune.utils.placement_groups import PlacementGroupFactory
  10. from ray.util import placement_group_table
  11. trial_executor = None
  12. class _TestCallback(Callback):
  13. def on_step_end(self, iteration, trials, **info):
  14. num_finished = len([
  15. t for t in trials
  16. if t.status == Trial.TERMINATED or t.status == Trial.ERROR
  17. ])
  18. num_running = len([t for t in trials if t.status == Trial.RUNNING])
  19. num_staging = sum(
  20. len(s) for s in trial_executor._pg_manager._staging.values())
  21. num_ready = sum(
  22. len(s) for s in trial_executor._pg_manager._ready.values())
  23. num_in_use = len(trial_executor._pg_manager._in_use_pgs)
  24. num_cached = len(trial_executor._pg_manager._cached_pgs)
  25. total_num_tracked = num_staging + num_ready + \
  26. num_in_use + num_cached
  27. num_non_removed_pgs = len([
  28. p for pid, p in placement_group_table().items()
  29. if p["state"] != "REMOVED"
  30. ])
  31. num_removal_scheduled_pgs = len(
  32. trial_executor._pg_manager._pgs_for_removal)
  33. # All 3 trials (3 different learning rates) should be scheduled.
  34. assert 3 == min(3, len(trials))
  35. # Cannot run more than 2 at a time
  36. # (due to different resource restrictions in the test cases).
  37. assert num_running <= 2
  38. # The number of placement groups should decrease
  39. # when trials finish.
  40. assert max(3, len(trials)) - num_finished == total_num_tracked
  41. # The number of actual placement groups should match this.
  42. assert max(3, len(trials)) - num_finished == \
  43. num_non_removed_pgs - num_removal_scheduled_pgs
  44. class TestPlacementGroups(unittest.TestCase):
  45. def setUp(self) -> None:
  46. os.environ["TUNE_PLACEMENT_GROUP_RECON_INTERVAL"] = "0"
  47. ray.init(num_cpus=6)
  48. def tearDown(self) -> None:
  49. ray.shutdown()
  50. def test_overriding_default_resource_request(self):
  51. config = DEFAULT_CONFIG.copy()
  52. config["model"]["fcnet_hiddens"] = [10]
  53. config["num_workers"] = 2
  54. # 3 Trials: Can only run 2 at a time (num_cpus=6; needed: 3).
  55. config["lr"] = tune.grid_search([0.1, 0.01, 0.001])
  56. config["env"] = "CartPole-v0"
  57. config["framework"] = "tf"
  58. class DefaultResourceRequest:
  59. @classmethod
  60. def default_resource_request(cls, config):
  61. head_bundle = {"CPU": 1, "GPU": 0}
  62. child_bundle = {"CPU": 1}
  63. return PlacementGroupFactory(
  64. [head_bundle, child_bundle, child_bundle],
  65. strategy=config["placement_strategy"])
  66. # Create a trainer with an overridden default_resource_request
  67. # method that returns a PlacementGroupFactory.
  68. MyTrainer = PGTrainer.with_updates(mixins=[DefaultResourceRequest])
  69. tune.register_trainable("my_trainable", MyTrainer)
  70. global trial_executor
  71. trial_executor = RayTrialExecutor(reuse_actors=False)
  72. tune.run(
  73. "my_trainable",
  74. config=config,
  75. stop={"training_iteration": 2},
  76. trial_executor=trial_executor,
  77. callbacks=[_TestCallback()],
  78. verbose=2,
  79. )
  80. def test_default_resource_request(self):
  81. config = DEFAULT_CONFIG.copy()
  82. config["model"]["fcnet_hiddens"] = [10]
  83. config["num_workers"] = 2
  84. config["num_cpus_per_worker"] = 2
  85. # 3 Trials: Can only run 1 at a time (num_cpus=6; needed: 5).
  86. config["lr"] = tune.grid_search([0.1, 0.01, 0.001])
  87. config["env"] = "CartPole-v0"
  88. config["framework"] = "torch"
  89. config["placement_strategy"] = "SPREAD"
  90. global trial_executor
  91. trial_executor = RayTrialExecutor(reuse_actors=False)
  92. tune.run(
  93. "PG",
  94. config=config,
  95. stop={"training_iteration": 2},
  96. trial_executor=trial_executor,
  97. callbacks=[_TestCallback()],
  98. verbose=2,
  99. )
  100. def test_default_resource_request_plus_manual_leads_to_error(self):
  101. config = DEFAULT_CONFIG.copy()
  102. config["model"]["fcnet_hiddens"] = [10]
  103. config["num_workers"] = 0
  104. config["env"] = "CartPole-v0"
  105. try:
  106. tune.run(
  107. "PG",
  108. config=config,
  109. stop={"training_iteration": 2},
  110. resources_per_trial=PlacementGroupFactory([{
  111. "CPU": 1
  112. }]),
  113. verbose=2,
  114. )
  115. except ValueError as e:
  116. assert "have been automatically set to" in e.args[0]
  117. if __name__ == "__main__":
  118. import pytest
  119. import sys
  120. sys.exit(pytest.main(["-v", __file__]))