test_episode.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import ray
  2. import unittest
  3. from typing import Dict, List, Optional, Union, Tuple
  4. import numpy as np
  5. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  6. from ray.rllib.algorithms.callbacks import DefaultCallbacks
  7. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  8. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  9. from ray.rllib.evaluation.episode import Episode
  10. from ray.rllib.examples.env.mock_env import MockEnv3
  11. from ray.rllib.policy import Policy
  12. from ray.rllib.utils import override
  13. from ray.rllib.utils.typing import TensorStructType, TensorType
  14. NUM_STEPS = 25
  15. NUM_AGENTS = 4
  16. class LastInfoCallback(DefaultCallbacks):
  17. def __init__(self):
  18. super(LastInfoCallback, self).__init__()
  19. self.tc = unittest.TestCase()
  20. self.step = 0
  21. def on_episode_start(
  22. self, worker, base_env, policies, episode, env_index, **kwargs
  23. ):
  24. self.step = 0
  25. self._check_last_values(episode)
  26. def on_episode_step(self, worker, base_env, episode, env_index=None, **kwargs):
  27. self.step += 1
  28. self._check_last_values(episode)
  29. def on_episode_end(self, worker, base_env, policies, episode, **kwargs):
  30. self._check_last_values(episode)
  31. def _check_last_values(self, episode):
  32. last_obs = {
  33. k: np.where(v)[0].item() for k, v in episode._agent_to_last_obs.items()
  34. }
  35. last_raw_obs = episode._agent_to_last_raw_obs
  36. last_info = episode._agent_to_last_info
  37. last_terminated = episode._agent_to_last_terminated
  38. last_truncated = episode._agent_to_last_truncated
  39. last_action = episode._agent_to_last_action
  40. last_reward = {k: v[-1] for k, v in episode._agent_reward_history.items()}
  41. if self.step == 0:
  42. for last in [
  43. last_obs,
  44. last_terminated,
  45. last_truncated,
  46. last_action,
  47. last_reward,
  48. ]:
  49. self.tc.assertEqual(last, {})
  50. self.tc.assertTrue("__common__" in last_info)
  51. self.tc.assertTrue(len(last_raw_obs) > 0)
  52. for agent in last_raw_obs.keys():
  53. index = int(str(agent).replace("agent", ""))
  54. self.tc.assertEqual(last_raw_obs[agent], 0)
  55. self.tc.assertEqual(last_info[agent]["timestep"], self.step + index)
  56. else:
  57. for agent in last_obs.keys():
  58. index = int(str(agent).replace("agent", ""))
  59. self.tc.assertEqual(last_obs[agent], self.step + index)
  60. self.tc.assertEqual(last_reward[agent], self.step + index)
  61. self.tc.assertEqual(last_terminated[agent], self.step == NUM_STEPS)
  62. self.tc.assertEqual(last_truncated[agent], self.step == NUM_STEPS)
  63. if self.step == 1:
  64. self.tc.assertEqual(last_action[agent], 0)
  65. else:
  66. self.tc.assertEqual(last_action[agent], self.step + index - 1)
  67. self.tc.assertEqual(last_info[agent]["timestep"], self.step + index)
  68. class EchoPolicy(Policy):
  69. @override(Policy)
  70. def compute_actions(
  71. self,
  72. obs_batch: Union[List[TensorStructType], TensorStructType],
  73. state_batches: Optional[List[TensorType]] = None,
  74. prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
  75. prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
  76. info_batch: Optional[Dict[str, list]] = None,
  77. episodes: Optional[List["Episode"]] = None,
  78. explore: Optional[bool] = None,
  79. timestep: Optional[int] = None,
  80. **kwargs,
  81. ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  82. return obs_batch.argmax(axis=1), [], {}
  83. class EpisodeEnv(MultiAgentEnv):
  84. def __init__(self, episode_length, num):
  85. super().__init__()
  86. self._skip_env_checking = True
  87. self.agents = [MockEnv3(episode_length) for _ in range(num)]
  88. self.terminateds = set()
  89. self.truncateds = set()
  90. self.observation_space = self.agents[0].observation_space
  91. self.action_space = self.agents[0].action_space
  92. def reset(self, *, seed=None, options=None):
  93. self.terminateds = set()
  94. self.truncateds = set()
  95. obs_and_infos = [a.reset() for a in self.agents]
  96. return (
  97. {i: oi[0] for i, oi in enumerate(obs_and_infos)},
  98. {i: dict(oi[1], **{"timestep": i}) for i, oi in enumerate(obs_and_infos)},
  99. )
  100. def step(self, action_dict):
  101. obs, rew, terminated, truncated, info = {}, {}, {}, {}, {}
  102. for i, action in action_dict.items():
  103. obs[i], rew[i], terminated[i], truncated[i], info[i] = self.agents[i].step(
  104. action
  105. )
  106. obs[i] = obs[i] + i
  107. rew[i] = rew[i] + i
  108. info[i]["timestep"] = info[i]["timestep"] + i
  109. if terminated[i]:
  110. self.terminateds.add(i)
  111. if truncated[i]:
  112. self.truncateds.add(i)
  113. terminated["__all__"] = len(self.terminateds) == len(self.agents)
  114. truncated["__all__"] = len(self.truncateds) == len(self.agents)
  115. return obs, rew, terminated, truncated, info
  116. class TestEpisodeLastValues(unittest.TestCase):
  117. @classmethod
  118. def setUpClass(cls):
  119. ray.init(num_cpus=1)
  120. @classmethod
  121. def tearDownClass(cls):
  122. ray.shutdown()
  123. def test_single_agent_env(self):
  124. ev = RolloutWorker(
  125. env_creator=lambda _: MockEnv3(NUM_STEPS),
  126. default_policy_class=EchoPolicy,
  127. # Episode only works with env runner v1.
  128. config=AlgorithmConfig()
  129. .rollouts(enable_connectors=False)
  130. .rollouts(num_rollout_workers=0)
  131. .callbacks(LastInfoCallback),
  132. )
  133. ev.sample()
  134. def test_multi_agent_env(self):
  135. ev = RolloutWorker(
  136. env_creator=lambda _: EpisodeEnv(NUM_STEPS, NUM_AGENTS),
  137. default_policy_class=EchoPolicy,
  138. # Episode only works with env runner v1.
  139. config=AlgorithmConfig()
  140. .rollouts(enable_connectors=False)
  141. .rollouts(num_rollout_workers=0)
  142. .callbacks(LastInfoCallback)
  143. .multi_agent(
  144. policies={str(agent_id) for agent_id in range(NUM_AGENTS)},
  145. policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
  146. str(agent_id)
  147. ),
  148. ),
  149. )
  150. ev.sample()
  151. if __name__ == "__main__":
  152. import pytest
  153. import sys
  154. sys.exit(pytest.main(["-v", __file__]))