test_episode_v2.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import unittest
  2. import ray
  3. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  4. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  5. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  6. from ray.rllib.examples.env.mock_env import MockEnv3
  7. from ray.rllib.policy import Policy
  8. from ray.rllib.utils import override
  9. NUM_STEPS = 25
  10. NUM_AGENTS = 4
  11. class EchoPolicy(Policy):
  12. @override(Policy)
  13. def compute_actions(
  14. self,
  15. obs_batch,
  16. state_batches=None,
  17. prev_action_batch=None,
  18. prev_reward_batch=None,
  19. episodes=None,
  20. explore=None,
  21. timestep=None,
  22. **kwargs
  23. ):
  24. return obs_batch.argmax(axis=1), [], {}
  25. class EpisodeEnv(MultiAgentEnv):
  26. def __init__(self, episode_length, num):
  27. super().__init__()
  28. self._skip_env_checking = True
  29. self.agents = [MockEnv3(episode_length) for _ in range(num)]
  30. self.terminateds = set()
  31. self.truncateds = set()
  32. self.observation_space = self.agents[0].observation_space
  33. self.action_space = self.agents[0].action_space
  34. def reset(self, *, seed=None, options=None):
  35. self.terminateds = set()
  36. self.truncateds = set()
  37. obs_and_infos = [a.reset() for a in self.agents]
  38. return (
  39. {i: oi[0] for i, oi in enumerate(obs_and_infos)},
  40. {i: oi[1] for i, oi in enumerate(obs_and_infos)},
  41. )
  42. def step(self, action_dict):
  43. obs, rew, terminated, truncated, info = {}, {}, {}, {}, {}
  44. for i, action in action_dict.items():
  45. obs[i], rew[i], terminated[i], truncated[i], info[i] = self.agents[i].step(
  46. action
  47. )
  48. obs[i] = obs[i] + i
  49. rew[i] = rew[i] + i
  50. info[i]["timestep"] = info[i]["timestep"] + i
  51. if terminated[i]:
  52. self.terminateds.add(i)
  53. if truncated[i]:
  54. self.truncateds.add(i)
  55. terminated["__all__"] = len(self.terminateds) == len(self.agents)
  56. truncated["__all__"] = len(self.truncateds) == len(self.agents)
  57. return obs, rew, terminated, truncated, info
  58. class TestEpisodeV2(unittest.TestCase):
  59. @classmethod
  60. def setUpClass(cls):
  61. ray.init(num_cpus=1)
  62. @classmethod
  63. def tearDownClass(cls):
  64. ray.shutdown()
  65. def test_single_agent_env(self):
  66. ev = RolloutWorker(
  67. env_creator=lambda _: MockEnv3(NUM_STEPS),
  68. default_policy_class=EchoPolicy,
  69. config=AlgorithmConfig().rollouts(
  70. enable_connectors=True,
  71. num_rollout_workers=0,
  72. ),
  73. )
  74. sample_batch = ev.sample()
  75. self.assertEqual(sample_batch.count, 200)
  76. # EnvRunnerV2 always returns MultiAgentBatch, even for single-agent envs.
  77. for agent_id, sample_batch in sample_batch.policy_batches.items():
  78. # A batch of 100. 4 episodes, each 25.
  79. self.assertEqual(len(set(sample_batch["eps_id"])), 8)
  80. def test_multi_agent_env(self):
  81. temp_env = EpisodeEnv(NUM_STEPS, NUM_AGENTS)
  82. ev = RolloutWorker(
  83. env_creator=lambda _: temp_env,
  84. default_policy_class=EchoPolicy,
  85. config=AlgorithmConfig()
  86. .multi_agent(
  87. policies={str(agent_id) for agent_id in range(NUM_AGENTS)},
  88. policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
  89. str(agent_id)
  90. ),
  91. )
  92. .rollouts(enable_connectors=True, num_rollout_workers=0),
  93. )
  94. sample_batches = ev.sample()
  95. self.assertEqual(len(sample_batches.policy_batches), 4)
  96. for agent_id, sample_batch in sample_batches.policy_batches.items():
  97. self.assertEqual(sample_batch.count, 200)
  98. # A batch of 100. 4 episodes, each 25.
  99. self.assertEqual(len(set(sample_batch["eps_id"])), 8)
  100. if __name__ == "__main__":
  101. import sys
  102. import pytest
  103. sys.exit(pytest.main(["-v", __file__]))