group_agents_wrapper.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from collections import OrderedDict
  2. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  3. # info key for the individual rewards of an agent, for example:
  4. # info: {
  5. # group_1: {
  6. # _group_rewards: [5, -1, 1], # 3 agents in this group
  7. # }
  8. # }
  9. GROUP_REWARDS = "_group_rewards"
  10. # info key for the individual infos of an agent, for example:
  11. # info: {
  12. # group_1: {
  13. # _group_infos: [{"foo": ...}, {}], # 2 agents in this group
  14. # }
  15. # }
  16. GROUP_INFO = "_group_info"
  17. class GroupAgentsWrapper(MultiAgentEnv):
  18. """Wraps a MultiAgentEnv environment with agents grouped as specified.
  19. See multi_agent_env.py for the specification of groups.
  20. This API is experimental.
  21. """
  22. def __init__(self, env, groups, obs_space=None, act_space=None):
  23. """Wrap an existing multi-agent env to group agents together.
  24. See MultiAgentEnv.with_agent_groups() for usage info.
  25. Args:
  26. env (MultiAgentEnv): env to wrap
  27. groups (dict): Grouping spec as documented in MultiAgentEnv.
  28. obs_space (Space): Optional observation space for the grouped
  29. env. Must be a tuple space.
  30. act_space (Space): Optional action space for the grouped env.
  31. Must be a tuple space.
  32. """
  33. super().__init__()
  34. self.env = env
  35. self.groups = groups
  36. self.agent_id_to_group = {}
  37. for group_id, agent_ids in groups.items():
  38. for agent_id in agent_ids:
  39. if agent_id in self.agent_id_to_group:
  40. raise ValueError(
  41. "Agent id {} is in multiple groups".format(agent_id))
  42. self.agent_id_to_group[agent_id] = group_id
  43. if obs_space is not None:
  44. self.observation_space = obs_space
  45. if act_space is not None:
  46. self.action_space = act_space
  47. def seed(self, seed=None):
  48. if not hasattr(self.env, "seed"):
  49. # This is a silent fail. However, OpenAI gyms also silently fail
  50. # here.
  51. return
  52. self.env.seed(seed)
  53. def reset(self):
  54. obs = self.env.reset()
  55. return self._group_items(obs)
  56. def step(self, action_dict):
  57. # Ungroup and send actions
  58. action_dict = self._ungroup_items(action_dict)
  59. obs, rewards, dones, infos = self.env.step(action_dict)
  60. # Apply grouping transforms to the env outputs
  61. obs = self._group_items(obs)
  62. rewards = self._group_items(
  63. rewards, agg_fn=lambda gvals: list(gvals.values()))
  64. dones = self._group_items(
  65. dones, agg_fn=lambda gvals: all(gvals.values()))
  66. infos = self._group_items(
  67. infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())})
  68. # Aggregate rewards, but preserve the original values in infos
  69. for agent_id, rew in rewards.items():
  70. if isinstance(rew, list):
  71. rewards[agent_id] = sum(rew)
  72. if agent_id not in infos:
  73. infos[agent_id] = {}
  74. infos[agent_id][GROUP_REWARDS] = rew
  75. return obs, rewards, dones, infos
  76. def _ungroup_items(self, items):
  77. out = {}
  78. for agent_id, value in items.items():
  79. if agent_id in self.groups:
  80. assert len(value) == len(self.groups[agent_id]), \
  81. (agent_id, value, self.groups)
  82. for a, v in zip(self.groups[agent_id], value):
  83. out[a] = v
  84. else:
  85. out[agent_id] = value
  86. return out
  87. def _group_items(self, items, agg_fn=lambda gvals: list(gvals.values())):
  88. grouped_items = {}
  89. for agent_id, item in items.items():
  90. if agent_id in self.agent_id_to_group:
  91. group_id = self.agent_id_to_group[agent_id]
  92. if group_id in grouped_items:
  93. continue # already added
  94. group_out = OrderedDict()
  95. for a in self.groups[group_id]:
  96. if a in items:
  97. group_out[a] = items[a]
  98. else:
  99. raise ValueError(
  100. "Missing member of group {}: {}: {}".format(
  101. group_id, a, items))
  102. grouped_items[group_id] = agg_fn(group_out)
  103. else:
  104. grouped_items[agent_id] = item
  105. return grouped_items