kaggle_wrapper.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """Wrap Kaggle's environment
  2. Source: https://github.com/Kaggle/kaggle-environments
  3. """
  4. from copy import deepcopy
  5. from typing import Any, Dict, Optional, Tuple
  6. try:
  7. import kaggle_environments
  8. except (ImportError, ModuleNotFoundError):
  9. pass
  10. import numpy as np
  11. from gym.spaces import Box
  12. from gym.spaces import Dict as DictSpace
  13. from gym.spaces import Discrete, MultiBinary, MultiDiscrete, Space
  14. from gym.spaces import Tuple as TupleSpace
  15. from ray.rllib.env import MultiAgentEnv
  16. from ray.rllib.utils.typing import MultiAgentDict, AgentID
  17. class KaggleFootballMultiAgentEnv(MultiAgentEnv):
  18. """An interface to the kaggle's football environment.
  19. See: https://github.com/Kaggle/kaggle-environments
  20. """
  21. def __init__(self, configuration: Optional[Dict[str, Any]] = None) -> None:
  22. """Initializes a Kaggle football environment.
  23. Args:
  24. configuration (Optional[Dict[str, Any]]): configuration of the
  25. football environment. For detailed information, see:
  26. https://github.com/Kaggle/kaggle-environments/blob/master/kaggle_environments/envs/football/football.json
  27. """
  28. super().__init__()
  29. self.kaggle_env = kaggle_environments.make(
  30. "football", configuration=configuration or {})
  31. self.last_cumulative_reward = None
  32. def reset(self) -> MultiAgentDict:
  33. kaggle_state = self.kaggle_env.reset()
  34. self.last_cumulative_reward = None
  35. return {
  36. f"agent{idx}": self._convert_obs(agent_state["observation"])
  37. for idx, agent_state in enumerate(kaggle_state)
  38. if agent_state["status"] == "ACTIVE"
  39. }
  40. def step(
  41. self, action_dict: Dict[AgentID, int]
  42. ) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
  43. # Convert action_dict (used by RLlib) to a list of actions (used by
  44. # kaggle_environments)
  45. action_list = [None] * len(self.kaggle_env.state)
  46. for idx, agent_state in enumerate(self.kaggle_env.state):
  47. if agent_state["status"] == "ACTIVE":
  48. action = action_dict[f"agent{idx}"]
  49. action_list[idx] = [action]
  50. self.kaggle_env.step(action_list)
  51. # Parse (obs, reward, done, info) from kaggle's "state" representation
  52. obs = {}
  53. cumulative_reward = {}
  54. done = {"__all__": self.kaggle_env.done}
  55. info = {}
  56. for idx in range(len(self.kaggle_env.state)):
  57. agent_state = self.kaggle_env.state[idx]
  58. agent_name = f"agent{idx}"
  59. if agent_state["status"] == "ACTIVE":
  60. obs[agent_name] = self._convert_obs(agent_state["observation"])
  61. cumulative_reward[agent_name] = agent_state["reward"]
  62. done[agent_name] = agent_state["status"] != "ACTIVE"
  63. info[agent_name] = agent_state["info"]
  64. # Compute the step rewards from the cumulative rewards
  65. if self.last_cumulative_reward is not None:
  66. reward = {
  67. agent_id: agent_reward - self.last_cumulative_reward[agent_id]
  68. for agent_id, agent_reward in cumulative_reward.items()
  69. }
  70. else:
  71. reward = cumulative_reward
  72. self.last_cumulative_reward = cumulative_reward
  73. return obs, reward, done, info
  74. def _convert_obs(self, obs: Dict[str, Any]) -> Dict[str, Any]:
  75. """Convert raw observations
  76. These conversions are necessary to make the observations fall into the
  77. observation space defined below.
  78. """
  79. new_obs = deepcopy(obs)
  80. if new_obs["players_raw"][0]["ball_owned_team"] == -1:
  81. new_obs["players_raw"][0]["ball_owned_team"] = 2
  82. if new_obs["players_raw"][0]["ball_owned_player"] == -1:
  83. new_obs["players_raw"][0]["ball_owned_player"] = 11
  84. new_obs["players_raw"][0]["steps_left"] = [
  85. new_obs["players_raw"][0]["steps_left"]
  86. ]
  87. return new_obs
  88. def build_agent_spaces(self) -> Tuple[Space, Space]:
  89. """Construct the action and observation spaces
  90. Description of actions and observations:
  91. https://github.com/google-research/football/blob/master/gfootball/doc/observation.md
  92. """ # noqa: E501
  93. action_space = Discrete(19)
  94. # The football field's corners are [+-1., +-0.42]. However, the players
  95. # and balls may get out of the field. Thus we multiply those limits by
  96. # a factor of 2.
  97. xlim = 1. * 2
  98. ylim = 0.42 * 2
  99. num_players: int = 11
  100. xy_space = Box(
  101. np.array([-xlim, -ylim], dtype=np.float32),
  102. np.array([xlim, ylim], dtype=np.float32))
  103. xyz_space = Box(
  104. np.array([-xlim, -ylim, 0], dtype=np.float32),
  105. np.array([xlim, ylim, np.inf], dtype=np.float32))
  106. observation_space = DictSpace({
  107. "controlled_players": Discrete(2),
  108. "players_raw": TupleSpace([
  109. DictSpace({
  110. # ball information
  111. "ball": xyz_space,
  112. "ball_direction": Box(-np.inf, np.inf, (3, )),
  113. "ball_rotation": Box(-np.inf, np.inf, (3, )),
  114. "ball_owned_team": Discrete(3),
  115. "ball_owned_player": Discrete(num_players + 1),
  116. # left team
  117. "left_team": TupleSpace([xy_space] * num_players),
  118. "left_team_direction": TupleSpace(
  119. [xy_space] * num_players),
  120. "left_team_tired_factor": Box(0., 1., (num_players, )),
  121. "left_team_yellow_card": MultiBinary(num_players),
  122. "left_team_active": MultiBinary(num_players),
  123. "left_team_roles": MultiDiscrete([10] * num_players),
  124. # right team
  125. "right_team": TupleSpace([xy_space] * num_players),
  126. "right_team_direction": TupleSpace(
  127. [xy_space] * num_players),
  128. "right_team_tired_factor": Box(0., 1., (num_players, )),
  129. "right_team_yellow_card": MultiBinary(num_players),
  130. "right_team_active": MultiBinary(num_players),
  131. "right_team_roles": MultiDiscrete([10] * num_players),
  132. # controlled player information
  133. "active": Discrete(num_players),
  134. "designated": Discrete(num_players),
  135. "sticky_actions": MultiBinary(10),
  136. # match state
  137. "score": Box(-np.inf, np.inf, (2, )),
  138. "steps_left": Box(0, np.inf, (1, )),
  139. "game_mode": Discrete(7)
  140. })
  141. ])
  142. })
  143. return action_space, observation_space