123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- import gym
- import random
- from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent
- from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
- from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
- from ray.rllib.utils.deprecation import Deprecated
- @Deprecated(
- old="ray.rllib.examples.env.multi_agent.make_multiagent",
- new="ray.rllib.env.multi_agent_env.make_multi_agent",
- error=False)
- def make_multiagent(env_name_or_creator):
- return make_multi_agent(env_name_or_creator)
- class BasicMultiAgent(MultiAgentEnv):
- """Env of N independent agents, each of which exits after 25 steps."""
- def __init__(self, num):
- self.agents = [MockEnv(25) for _ in range(num)]
- self.dones = set()
- self.observation_space = gym.spaces.Discrete(2)
- self.action_space = gym.spaces.Discrete(2)
- self.resetted = False
- def reset(self):
- self.resetted = True
- self.dones = set()
- return {i: a.reset() for i, a in enumerate(self.agents)}
- def step(self, action_dict):
- obs, rew, done, info = {}, {}, {}, {}
- for i, action in action_dict.items():
- obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
- if done[i]:
- self.dones.add(i)
- done["__all__"] = len(self.dones) == len(self.agents)
- return obs, rew, done, info
- class EarlyDoneMultiAgent(MultiAgentEnv):
- """Env for testing when the env terminates (after agent 0 does)."""
- def __init__(self):
- self.agents = [MockEnv(3), MockEnv(5)]
- self.dones = set()
- self.last_obs = {}
- self.last_rew = {}
- self.last_done = {}
- self.last_info = {}
- self.i = 0
- self.observation_space = gym.spaces.Discrete(10)
- self.action_space = gym.spaces.Discrete(2)
- def reset(self):
- self.dones = set()
- self.last_obs = {}
- self.last_rew = {}
- self.last_done = {}
- self.last_info = {}
- self.i = 0
- for i, a in enumerate(self.agents):
- self.last_obs[i] = a.reset()
- self.last_rew[i] = None
- self.last_done[i] = False
- self.last_info[i] = {}
- obs_dict = {self.i: self.last_obs[self.i]}
- self.i = (self.i + 1) % len(self.agents)
- return obs_dict
- def step(self, action_dict):
- assert len(self.dones) != len(self.agents)
- for i, action in action_dict.items():
- (self.last_obs[i], self.last_rew[i], self.last_done[i],
- self.last_info[i]) = self.agents[i].step(action)
- obs = {self.i: self.last_obs[self.i]}
- rew = {self.i: self.last_rew[self.i]}
- done = {self.i: self.last_done[self.i]}
- info = {self.i: self.last_info[self.i]}
- if done[self.i]:
- rew[self.i] = 0
- self.dones.add(self.i)
- self.i = (self.i + 1) % len(self.agents)
- done["__all__"] = len(self.dones) == len(self.agents) - 1
- return obs, rew, done, info
- class FlexAgentsMultiAgent(MultiAgentEnv):
- """Env of independent agents, each of which exits after n steps."""
- def __init__(self):
- self.agents = {}
- self.agentID = 0
- self.dones = set()
- self.observation_space = gym.spaces.Discrete(2)
- self.action_space = gym.spaces.Discrete(2)
- self.resetted = False
- def spawn(self):
- # Spawn a new agent into the current episode.
- agentID = self.agentID
- self.agents[agentID] = MockEnv(25)
- self.agentID += 1
- return agentID
- def reset(self):
- self.agents = {}
- self.spawn()
- self.resetted = True
- self.dones = set()
- obs = {}
- for i, a in self.agents.items():
- obs[i] = a.reset()
- return obs
- def step(self, action_dict):
- obs, rew, done, info = {}, {}, {}, {}
- # Apply the actions.
- for i, action in action_dict.items():
- obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
- if done[i]:
- self.dones.add(i)
- # Sometimes, add a new agent to the episode.
- if random.random() > 0.75:
- i = self.spawn()
- obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
- if done[i]:
- self.dones.add(i)
- # Sometimes, kill an existing agent.
- if len(self.agents) > 1 and random.random() > 0.25:
- keys = list(self.agents.keys())
- key = random.choice(keys)
- done[key] = True
- del self.agents[key]
- done["__all__"] = len(self.dones) == len(self.agents)
- return obs, rew, done, info
- class RoundRobinMultiAgent(MultiAgentEnv):
- """Env of N independent agents, each of which exits after 5 steps.
- On each step() of the env, only one agent takes an action."""
- def __init__(self, num, increment_obs=False):
- if increment_obs:
- # Observations are 0, 1, 2, 3... etc. as time advances
- self.agents = [MockEnv2(5) for _ in range(num)]
- else:
- # Observations are all zeros
- self.agents = [MockEnv(5) for _ in range(num)]
- self.dones = set()
- self.last_obs = {}
- self.last_rew = {}
- self.last_done = {}
- self.last_info = {}
- self.i = 0
- self.num = num
- self.observation_space = gym.spaces.Discrete(10)
- self.action_space = gym.spaces.Discrete(2)
- def reset(self):
- self.dones = set()
- self.last_obs = {}
- self.last_rew = {}
- self.last_done = {}
- self.last_info = {}
- self.i = 0
- for i, a in enumerate(self.agents):
- self.last_obs[i] = a.reset()
- self.last_rew[i] = None
- self.last_done[i] = False
- self.last_info[i] = {}
- obs_dict = {self.i: self.last_obs[self.i]}
- self.i = (self.i + 1) % self.num
- return obs_dict
- def step(self, action_dict):
- assert len(self.dones) != len(self.agents)
- for i, action in action_dict.items():
- (self.last_obs[i], self.last_rew[i], self.last_done[i],
- self.last_info[i]) = self.agents[i].step(action)
- obs = {self.i: self.last_obs[self.i]}
- rew = {self.i: self.last_rew[self.i]}
- done = {self.i: self.last_done[self.i]}
- info = {self.i: self.last_info[self.i]}
- if done[self.i]:
- rew[self.i] = 0
- self.dones.add(self.i)
- self.i = (self.i + 1) % self.num
- done["__all__"] = len(self.dones) == len(self.agents)
- return obs, rew, done, info
- MultiAgentCartPole = make_multi_agent("CartPole-v0")
- MultiAgentMountainCar = make_multi_agent("MountainCarContinuous-v0")
- MultiAgentPendulum = make_multi_agent("Pendulum-v1")
- MultiAgentStatelessCartPole = make_multi_agent(
- lambda config: StatelessCartPole(config))
|