multi_agent.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import gym
  2. import random
  3. from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent
  4. from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
  5. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  6. from ray.rllib.utils.deprecation import Deprecated
  7. @Deprecated(
  8. old="ray.rllib.examples.env.multi_agent.make_multiagent",
  9. new="ray.rllib.env.multi_agent_env.make_multi_agent",
  10. error=False)
  11. def make_multiagent(env_name_or_creator):
  12. return make_multi_agent(env_name_or_creator)
  13. class BasicMultiAgent(MultiAgentEnv):
  14. """Env of N independent agents, each of which exits after 25 steps."""
  15. def __init__(self, num):
  16. self.agents = [MockEnv(25) for _ in range(num)]
  17. self.dones = set()
  18. self.observation_space = gym.spaces.Discrete(2)
  19. self.action_space = gym.spaces.Discrete(2)
  20. self.resetted = False
  21. def reset(self):
  22. self.resetted = True
  23. self.dones = set()
  24. return {i: a.reset() for i, a in enumerate(self.agents)}
  25. def step(self, action_dict):
  26. obs, rew, done, info = {}, {}, {}, {}
  27. for i, action in action_dict.items():
  28. obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
  29. if done[i]:
  30. self.dones.add(i)
  31. done["__all__"] = len(self.dones) == len(self.agents)
  32. return obs, rew, done, info
  33. class EarlyDoneMultiAgent(MultiAgentEnv):
  34. """Env for testing when the env terminates (after agent 0 does)."""
  35. def __init__(self):
  36. self.agents = [MockEnv(3), MockEnv(5)]
  37. self.dones = set()
  38. self.last_obs = {}
  39. self.last_rew = {}
  40. self.last_done = {}
  41. self.last_info = {}
  42. self.i = 0
  43. self.observation_space = gym.spaces.Discrete(10)
  44. self.action_space = gym.spaces.Discrete(2)
  45. def reset(self):
  46. self.dones = set()
  47. self.last_obs = {}
  48. self.last_rew = {}
  49. self.last_done = {}
  50. self.last_info = {}
  51. self.i = 0
  52. for i, a in enumerate(self.agents):
  53. self.last_obs[i] = a.reset()
  54. self.last_rew[i] = None
  55. self.last_done[i] = False
  56. self.last_info[i] = {}
  57. obs_dict = {self.i: self.last_obs[self.i]}
  58. self.i = (self.i + 1) % len(self.agents)
  59. return obs_dict
  60. def step(self, action_dict):
  61. assert len(self.dones) != len(self.agents)
  62. for i, action in action_dict.items():
  63. (self.last_obs[i], self.last_rew[i], self.last_done[i],
  64. self.last_info[i]) = self.agents[i].step(action)
  65. obs = {self.i: self.last_obs[self.i]}
  66. rew = {self.i: self.last_rew[self.i]}
  67. done = {self.i: self.last_done[self.i]}
  68. info = {self.i: self.last_info[self.i]}
  69. if done[self.i]:
  70. rew[self.i] = 0
  71. self.dones.add(self.i)
  72. self.i = (self.i + 1) % len(self.agents)
  73. done["__all__"] = len(self.dones) == len(self.agents) - 1
  74. return obs, rew, done, info
  75. class FlexAgentsMultiAgent(MultiAgentEnv):
  76. """Env of independent agents, each of which exits after n steps."""
  77. def __init__(self):
  78. self.agents = {}
  79. self.agentID = 0
  80. self.dones = set()
  81. self.observation_space = gym.spaces.Discrete(2)
  82. self.action_space = gym.spaces.Discrete(2)
  83. self.resetted = False
  84. def spawn(self):
  85. # Spawn a new agent into the current episode.
  86. agentID = self.agentID
  87. self.agents[agentID] = MockEnv(25)
  88. self.agentID += 1
  89. return agentID
  90. def reset(self):
  91. self.agents = {}
  92. self.spawn()
  93. self.resetted = True
  94. self.dones = set()
  95. obs = {}
  96. for i, a in self.agents.items():
  97. obs[i] = a.reset()
  98. return obs
  99. def step(self, action_dict):
  100. obs, rew, done, info = {}, {}, {}, {}
  101. # Apply the actions.
  102. for i, action in action_dict.items():
  103. obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
  104. if done[i]:
  105. self.dones.add(i)
  106. # Sometimes, add a new agent to the episode.
  107. if random.random() > 0.75:
  108. i = self.spawn()
  109. obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
  110. if done[i]:
  111. self.dones.add(i)
  112. # Sometimes, kill an existing agent.
  113. if len(self.agents) > 1 and random.random() > 0.25:
  114. keys = list(self.agents.keys())
  115. key = random.choice(keys)
  116. done[key] = True
  117. del self.agents[key]
  118. done["__all__"] = len(self.dones) == len(self.agents)
  119. return obs, rew, done, info
  120. class RoundRobinMultiAgent(MultiAgentEnv):
  121. """Env of N independent agents, each of which exits after 5 steps.
  122. On each step() of the env, only one agent takes an action."""
  123. def __init__(self, num, increment_obs=False):
  124. if increment_obs:
  125. # Observations are 0, 1, 2, 3... etc. as time advances
  126. self.agents = [MockEnv2(5) for _ in range(num)]
  127. else:
  128. # Observations are all zeros
  129. self.agents = [MockEnv(5) for _ in range(num)]
  130. self.dones = set()
  131. self.last_obs = {}
  132. self.last_rew = {}
  133. self.last_done = {}
  134. self.last_info = {}
  135. self.i = 0
  136. self.num = num
  137. self.observation_space = gym.spaces.Discrete(10)
  138. self.action_space = gym.spaces.Discrete(2)
  139. def reset(self):
  140. self.dones = set()
  141. self.last_obs = {}
  142. self.last_rew = {}
  143. self.last_done = {}
  144. self.last_info = {}
  145. self.i = 0
  146. for i, a in enumerate(self.agents):
  147. self.last_obs[i] = a.reset()
  148. self.last_rew[i] = None
  149. self.last_done[i] = False
  150. self.last_info[i] = {}
  151. obs_dict = {self.i: self.last_obs[self.i]}
  152. self.i = (self.i + 1) % self.num
  153. return obs_dict
  154. def step(self, action_dict):
  155. assert len(self.dones) != len(self.agents)
  156. for i, action in action_dict.items():
  157. (self.last_obs[i], self.last_rew[i], self.last_done[i],
  158. self.last_info[i]) = self.agents[i].step(action)
  159. obs = {self.i: self.last_obs[self.i]}
  160. rew = {self.i: self.last_rew[self.i]}
  161. done = {self.i: self.last_done[self.i]}
  162. info = {self.i: self.last_info[self.i]}
  163. if done[self.i]:
  164. rew[self.i] = 0
  165. self.dones.add(self.i)
  166. self.i = (self.i + 1) % self.num
  167. done["__all__"] = len(self.dones) == len(self.agents)
  168. return obs, rew, done, info
  169. MultiAgentCartPole = make_multi_agent("CartPole-v0")
  170. MultiAgentMountainCar = make_multi_agent("MountainCarContinuous-v0")
  171. MultiAgentPendulum = make_multi_agent("Pendulum-v1")
  172. MultiAgentStatelessCartPole = make_multi_agent(
  173. lambda config: StatelessCartPole(config))