open_spiel.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from gym.spaces import Box, Discrete
  2. import numpy as np
  3. import pyspiel
  4. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  5. class OpenSpielEnv(MultiAgentEnv):
  6. def __init__(self, env):
  7. super().__init__()
  8. self.env = env
  9. # Agent IDs are ints, starting from 0.
  10. self.num_agents = self.env.num_players()
  11. # Store the open-spiel game type.
  12. self.type = self.env.get_type()
  13. # Stores the current open-spiel game state.
  14. self.state = None
  15. # Extract observation- and action spaces from game.
  16. self.observation_space = Box(
  17. float("-inf"), float("inf"),
  18. (self.env.observation_tensor_size(), ))
  19. self.action_space = Discrete(self.env.num_distinct_actions())
  20. def reset(self):
  21. self.state = self.env.new_initial_state()
  22. return self._get_obs()
  23. def step(self, action):
  24. # Before applying action(s), there could be chance nodes.
  25. # E.g. if env has to figure out, which agent's action should get
  26. # resolved first in a simultaneous node.
  27. self._solve_chance_nodes()
  28. penalties = {}
  29. # Sequential game:
  30. if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
  31. curr_player = self.state.current_player()
  32. assert curr_player in action
  33. try:
  34. self.state.apply_action(action[curr_player])
  35. # TODO: (sven) resolve this hack by publishing legal actions
  36. # with each step.
  37. except pyspiel.SpielError:
  38. self.state.apply_action(
  39. np.random.choice(self.state.legal_actions()))
  40. penalties[curr_player] = -0.1
  41. # Compile rewards dict.
  42. rewards = {ag: r for ag, r in enumerate(self.state.returns())}
  43. # Simultaneous game.
  44. else:
  45. assert self.state.current_player() == -2
  46. # Apparently, this works, even if one or more actions are invalid.
  47. self.state.apply_actions(
  48. [action[ag] for ag in range(self.num_agents)])
  49. # Now that we have applied all actions, get the next obs.
  50. obs = self._get_obs()
  51. # Compile rewards dict and add the accumulated penalties
  52. # (for taking invalid actions).
  53. rewards = {ag: r for ag, r in enumerate(self.state.returns())}
  54. for ag, penalty in penalties.items():
  55. rewards[ag] += penalty
  56. # Are we done?
  57. is_done = self.state.is_terminal()
  58. dones = dict({ag: is_done
  59. for ag in range(self.num_agents)},
  60. **{"__all__": is_done})
  61. return obs, rewards, dones, {}
  62. def render(self, mode=None) -> None:
  63. if mode == "human":
  64. print(self.state)
  65. def _get_obs(self):
  66. # Before calculating an observation, there could be chance nodes
  67. # (that may have an effect on the actual observations).
  68. # E.g. After reset, figure out initial (random) positions of the
  69. # agents.
  70. self._solve_chance_nodes()
  71. if self.state.is_terminal():
  72. return {}
  73. # Sequential game:
  74. if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
  75. curr_player = self.state.current_player()
  76. return {
  77. curr_player: np.reshape(self.state.observation_tensor(), [-1])
  78. }
  79. # Simultaneous game.
  80. else:
  81. assert self.state.current_player() == -2
  82. return {
  83. ag: np.reshape(self.state.observation_tensor(ag), [-1])
  84. for ag in range(self.num_agents)
  85. }
  86. def _solve_chance_nodes(self):
  87. # Chance node(s): Sample a (non-player) action and apply.
  88. while self.state.is_chance_node():
  89. assert self.state.current_player() == -1
  90. actions, probs = zip(*self.state.chance_outcomes())
  91. action = np.random.choice(actions, p=probs)
  92. self.state.apply_action(action)