test_worker_set.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import gymnasium as gym
  2. import unittest
  3. import ray
  4. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  5. from ray.rllib.evaluation.worker_set import WorkerSet
  6. from ray.rllib.examples.policy.random_policy import RandomPolicy
  7. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  8. class TestWorkerSet(unittest.TestCase):
  9. @classmethod
  10. def setUpClass(cls):
  11. ray.init()
  12. @classmethod
  13. def tearDownClass(cls):
  14. ray.shutdown()
  15. def test_foreach_worker(self):
  16. """Test to make sure basic sychronous calls to remote workers work."""
  17. ws = WorkerSet(
  18. env_creator=lambda _: gym.make("CartPole-v1"),
  19. default_policy_class=RandomPolicy,
  20. config=AlgorithmConfig().rollouts(num_rollout_workers=2),
  21. num_workers=2,
  22. )
  23. policies = ws.foreach_worker(
  24. lambda w: w.get_policy(DEFAULT_POLICY_ID),
  25. local_worker=True,
  26. )
  27. # 3 policies including the one from the local worker.
  28. self.assertEqual(len(policies), 3)
  29. for p in policies:
  30. self.assertIsInstance(p, RandomPolicy)
  31. policies = ws.foreach_worker(
  32. lambda w: w.get_policy(DEFAULT_POLICY_ID),
  33. local_worker=False,
  34. )
  35. # 2 policies from only the remote workers.
  36. self.assertEqual(len(policies), 2)
  37. ws.stop()
  38. def test_foreach_worker_return_obj_refss(self):
  39. """Test to make sure return_obj_refs parameter works."""
  40. ws = WorkerSet(
  41. env_creator=lambda _: gym.make("CartPole-v1"),
  42. default_policy_class=RandomPolicy,
  43. config=AlgorithmConfig().rollouts(num_rollout_workers=2),
  44. num_workers=2,
  45. )
  46. policy_refs = ws.foreach_worker(
  47. lambda w: w.get_policy(DEFAULT_POLICY_ID),
  48. local_worker=False,
  49. return_obj_refs=True,
  50. )
  51. # 2 policy references from remote workers.
  52. self.assertEqual(len(policy_refs), 2)
  53. self.assertTrue(isinstance(policy_refs[0], ray.ObjectRef))
  54. self.assertTrue(isinstance(policy_refs[1], ray.ObjectRef))
  55. ws.stop()
  56. def test_foreach_worker_async(self):
  57. """Test to make sure basic asychronous calls to remote workers work."""
  58. ws = WorkerSet(
  59. env_creator=lambda _: gym.make("CartPole-v1"),
  60. default_policy_class=RandomPolicy,
  61. config=AlgorithmConfig().rollouts(num_rollout_workers=2),
  62. num_workers=2,
  63. )
  64. # Fired async request against both remote workers.
  65. self.assertEqual(
  66. ws.foreach_worker_async(
  67. lambda w: w.get_policy(DEFAULT_POLICY_ID),
  68. ),
  69. 2,
  70. )
  71. remote_results = ws.fetch_ready_async_reqs(timeout_seconds=None)
  72. self.assertEqual(len(remote_results), 2)
  73. for p in remote_results:
  74. # p is in the format of (worker_id, result).
  75. # First is the id of the remote worker.
  76. self.assertTrue(p[0] in [1, 2])
  77. # Next is the actual policy.
  78. self.assertIsInstance(p[1], RandomPolicy)
  79. ws.stop()
  80. if __name__ == "__main__":
  81. import pytest
  82. import sys
  83. sys.exit(pytest.main(["-v", __file__]))