test_external_multi_agent_env.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import gym
  2. import numpy as np
  3. import unittest
  4. import ray
  5. from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
  6. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  7. from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
  8. from ray.rllib.examples.env.multi_agent import BasicMultiAgent
  9. from ray.rllib.policy.sample_batch import SampleBatch
  10. from ray.rllib.tests.test_external_env import make_simple_serving
  11. SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv)
  12. class TestExternalMultiAgentEnv(unittest.TestCase):
  13. @classmethod
  14. def setUpClass(cls) -> None:
  15. ray.init()
  16. @classmethod
  17. def tearDownClass(cls) -> None:
  18. ray.shutdown()
  19. def test_external_multi_agent_env_complete_episodes(self):
  20. agents = 4
  21. ev = RolloutWorker(
  22. env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
  23. policy_spec=MockPolicy,
  24. rollout_fragment_length=40,
  25. batch_mode="complete_episodes")
  26. for _ in range(3):
  27. batch = ev.sample()
  28. self.assertEqual(batch.count, 40)
  29. self.assertEqual(
  30. len(np.unique(batch[SampleBatch.AGENT_INDEX])), agents)
  31. def test_external_multi_agent_env_truncate_episodes(self):
  32. agents = 4
  33. ev = RolloutWorker(
  34. env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
  35. policy_spec=MockPolicy,
  36. rollout_fragment_length=40,
  37. batch_mode="truncate_episodes")
  38. for _ in range(3):
  39. batch = ev.sample()
  40. self.assertEqual(batch.count, 160)
  41. self.assertEqual(
  42. len(np.unique(batch[SampleBatch.AGENT_INDEX])), agents)
  43. def test_external_multi_agent_env_sample(self):
  44. agents = 2
  45. act_space = gym.spaces.Discrete(2)
  46. obs_space = gym.spaces.Discrete(2)
  47. ev = RolloutWorker(
  48. env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
  49. policy_spec={
  50. "p0": (MockPolicy, obs_space, act_space, {}),
  51. "p1": (MockPolicy, obs_space, act_space, {}),
  52. },
  53. policy_mapping_fn=lambda aid, **kwargs: "p{}".format(aid % 2),
  54. rollout_fragment_length=50)
  55. batch = ev.sample()
  56. self.assertEqual(batch.count, 50)
  57. if __name__ == "__main__":
  58. import pytest
  59. import sys
  60. sys.exit(pytest.main(["-v", __file__]))