open_spiel.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from gym.spaces import Box, Discrete
  2. import numpy as np
  3. import pyspiel
  4. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  5. class OpenSpielEnv(MultiAgentEnv):
  6. def __init__(self, env):
  7. self.env = env
  8. # Agent IDs are ints, starting from 0.
  9. self.num_agents = self.env.num_players()
  10. # Store the open-spiel game type.
  11. self.type = self.env.get_type()
  12. # Stores the current open-spiel game state.
  13. self.state = None
  14. # Extract observation- and action spaces from game.
  15. self.observation_space = Box(
  16. float("-inf"), float("inf"),
  17. (self.env.observation_tensor_size(), ))
  18. self.action_space = Discrete(self.env.num_distinct_actions())
  19. def reset(self):
  20. self.state = self.env.new_initial_state()
  21. return self._get_obs()
  22. def step(self, action):
  23. # Before applying action(s), there could be chance nodes.
  24. # E.g. if env has to figure out, which agent's action should get
  25. # resolved first in a simultaneous node.
  26. self._solve_chance_nodes()
  27. penalties = {}
  28. # Sequential game:
  29. if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
  30. curr_player = self.state.current_player()
  31. assert curr_player in action
  32. try:
  33. self.state.apply_action(action[curr_player])
  34. # TODO: (sven) resolve this hack by publishing legal actions
  35. # with each step.
  36. except pyspiel.SpielError:
  37. self.state.apply_action(
  38. np.random.choice(self.state.legal_actions()))
  39. penalties[curr_player] = -0.1
  40. # Compile rewards dict.
  41. rewards = {ag: r for ag, r in enumerate(self.state.returns())}
  42. # Simultaneous game.
  43. else:
  44. assert self.state.current_player() == -2
  45. # Apparently, this works, even if one or more actions are invalid.
  46. self.state.apply_actions(
  47. [action[ag] for ag in range(self.num_agents)])
  48. # Now that we have applied all actions, get the next obs.
  49. obs = self._get_obs()
  50. # Compile rewards dict and add the accumulated penalties
  51. # (for taking invalid actions).
  52. rewards = {ag: r for ag, r in enumerate(self.state.returns())}
  53. for ag, penalty in penalties.items():
  54. rewards[ag] += penalty
  55. # Are we done?
  56. is_done = self.state.is_terminal()
  57. dones = dict({ag: is_done
  58. for ag in range(self.num_agents)},
  59. **{"__all__": is_done})
  60. return obs, rewards, dones, {}
  61. def render(self, mode=None) -> None:
  62. if mode == "human":
  63. print(self.state)
  64. def _get_obs(self):
  65. # Before calculating an observation, there could be chance nodes
  66. # (that may have an effect on the actual observations).
  67. # E.g. After reset, figure out initial (random) positions of the
  68. # agents.
  69. self._solve_chance_nodes()
  70. if self.state.is_terminal():
  71. return {}
  72. # Sequential game:
  73. if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
  74. curr_player = self.state.current_player()
  75. return {
  76. curr_player: np.reshape(self.state.observation_tensor(), [-1])
  77. }
  78. # Simultaneous game.
  79. else:
  80. assert self.state.current_player() == -2
  81. return {
  82. ag: np.reshape(self.state.observation_tensor(ag), [-1])
  83. for ag in range(self.num_agents)
  84. }
  85. def _solve_chance_nodes(self):
  86. # Chance node(s): Sample a (non-player) action and apply.
  87. while self.state.is_chance_node():
  88. assert self.state.current_player() == -1
  89. actions, probs = zip(*self.state.chance_outcomes())
  90. action = np.random.choice(actions, p=probs)
  91. self.state.apply_action(action)