external_multi_agent_env.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import uuid
  2. import gym
  3. from typing import Optional
  4. from ray.rllib.utils.annotations import override, PublicAPI
  5. from ray.rllib.env.external_env import ExternalEnv, _ExternalEnvEpisode
  6. from ray.rllib.utils.typing import MultiAgentDict
  7. @PublicAPI
  8. class ExternalMultiAgentEnv(ExternalEnv):
  9. """This is the multi-agent version of ExternalEnv."""
  10. @PublicAPI
  11. def __init__(self,
  12. action_space: gym.Space,
  13. observation_space: gym.Space,
  14. max_concurrent: int = 100):
  15. """Initializes an ExternalMultiAgentEnv instance.
  16. Args:
  17. action_space: Action space of the env.
  18. observation_space: Observation space of the env.
  19. max_concurrent: Max number of active episodes to allow at
  20. once. Exceeding this limit raises an error.
  21. """
  22. ExternalEnv.__init__(self, action_space, observation_space,
  23. max_concurrent)
  24. # We require to know all agents' spaces.
  25. if isinstance(self.action_space, dict) or isinstance(
  26. self.observation_space, dict):
  27. if not (self.action_space.keys() == self.observation_space.keys()):
  28. raise ValueError("Agent ids disagree for action space and obs "
  29. "space dict: {} {}".format(
  30. self.action_space.keys(),
  31. self.observation_space.keys()))
  32. @PublicAPI
  33. def run(self):
  34. """Override this to implement the multi-agent run loop.
  35. Your loop should continuously:
  36. 1. Call self.start_episode(episode_id)
  37. 2. Call self.get_action(episode_id, obs_dict)
  38. -or-
  39. self.log_action(episode_id, obs_dict, action_dict)
  40. 3. Call self.log_returns(episode_id, reward_dict)
  41. 4. Call self.end_episode(episode_id, obs_dict)
  42. 5. Wait if nothing to do.
  43. Multiple episodes may be started at the same time.
  44. """
  45. raise NotImplementedError
  46. @PublicAPI
  47. @override(ExternalEnv)
  48. def start_episode(self,
  49. episode_id: Optional[str] = None,
  50. training_enabled: bool = True) -> str:
  51. if episode_id is None:
  52. episode_id = uuid.uuid4().hex
  53. if episode_id in self._finished:
  54. raise ValueError(
  55. "Episode {} has already completed.".format(episode_id))
  56. if episode_id in self._episodes:
  57. raise ValueError(
  58. "Episode {} is already started".format(episode_id))
  59. self._episodes[episode_id] = _ExternalEnvEpisode(
  60. episode_id,
  61. self._results_avail_condition,
  62. training_enabled,
  63. multiagent=True)
  64. return episode_id
  65. @PublicAPI
  66. @override(ExternalEnv)
  67. def get_action(self, episode_id: str,
  68. observation_dict: MultiAgentDict) -> MultiAgentDict:
  69. """Record an observation and get the on-policy action.
  70. Thereby, observation_dict is expected to contain the observation
  71. of all agents acting in this episode step.
  72. Args:
  73. episode_id: Episode id returned from start_episode().
  74. observation_dict: Current environment observation.
  75. Returns:
  76. action: Action from the env action space.
  77. """
  78. episode = self._get(episode_id)
  79. return episode.wait_for_action(observation_dict)
  80. @PublicAPI
  81. @override(ExternalEnv)
  82. def log_action(self, episode_id: str, observation_dict: MultiAgentDict,
  83. action_dict: MultiAgentDict) -> None:
  84. """Record an observation and (off-policy) action taken.
  85. Args:
  86. episode_id: Episode id returned from start_episode().
  87. observation_dict: Current environment observation.
  88. action_dict: Action for the observation.
  89. """
  90. episode = self._get(episode_id)
  91. episode.log_action(observation_dict, action_dict)
  92. @PublicAPI
  93. @override(ExternalEnv)
  94. def log_returns(self,
  95. episode_id: str,
  96. reward_dict: MultiAgentDict,
  97. info_dict: MultiAgentDict = None,
  98. multiagent_done_dict: MultiAgentDict = None) -> None:
  99. """Record returns from the environment.
  100. The reward will be attributed to the previous action taken by the
  101. episode. Rewards accumulate until the next action. If no reward is
  102. logged before the next action, a reward of 0.0 is assumed.
  103. Args:
  104. episode_id: Episode id returned from start_episode().
  105. reward_dict: Reward from the environment agents.
  106. info_dict: Optional info dict.
  107. multiagent_done_dict: Optional done dict for agents.
  108. """
  109. episode = self._get(episode_id)
  110. # Accumulate reward by agent.
  111. # For existing agents, we want to add the reward up.
  112. for agent, rew in reward_dict.items():
  113. if agent in episode.cur_reward_dict:
  114. episode.cur_reward_dict[agent] += rew
  115. else:
  116. episode.cur_reward_dict[agent] = rew
  117. if multiagent_done_dict:
  118. for agent, done in multiagent_done_dict.items():
  119. episode.cur_done_dict[agent] = done
  120. if info_dict:
  121. episode.cur_info_dict = info_dict or {}
  122. @PublicAPI
  123. @override(ExternalEnv)
  124. def end_episode(self, episode_id: str,
  125. observation_dict: MultiAgentDict) -> None:
  126. """Record the end of an episode.
  127. Args:
  128. episode_id: Episode id returned from start_episode().
  129. observation_dict: Current environment observation.
  130. """
  131. episode = self._get(episode_id)
  132. self._finished.add(episode.episode_id)
  133. episode.done(observation_dict)