unity3d_env.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. from gym.spaces import Box, MultiDiscrete, Tuple as TupleSpace
  2. import logging
  3. import numpy as np
  4. import random
  5. import time
  6. from typing import Callable, Optional, Tuple
  7. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  8. from ray.rllib.policy.policy import PolicySpec
  9. from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID
  10. logger = logging.getLogger(__name__)
  11. class Unity3DEnv(MultiAgentEnv):
  12. """A MultiAgentEnv representing a single Unity3D game instance.
  13. For an example on how to use this Env with a running Unity3D editor
  14. or with a compiled game, see:
  15. `rllib/examples/unity3d_env_local.py`
  16. For an example on how to use it inside a Unity game client, which
  17. connects to an RLlib Policy server, see:
  18. `rllib/examples/serving/unity3d_[client|server].py`
  19. Supports all Unity3D (MLAgents) examples, multi- or single-agent and
  20. gets converted automatically into an ExternalMultiAgentEnv, when used
  21. inside an RLlib PolicyClient for cloud/distributed training of Unity games.
  22. """
  23. # Default base port when connecting directly to the Editor
  24. _BASE_PORT_EDITOR = 5004
  25. # Default base port when connecting to a compiled environment
  26. _BASE_PORT_ENVIRONMENT = 5005
  27. # The worker_id for each environment instance
  28. _WORKER_ID = 0
  29. def __init__(self,
  30. file_name: str = None,
  31. port: Optional[int] = None,
  32. seed: int = 0,
  33. no_graphics: bool = False,
  34. timeout_wait: int = 300,
  35. episode_horizon: int = 1000):
  36. """Initializes a Unity3DEnv object.
  37. Args:
  38. file_name (Optional[str]): Name of the Unity game binary.
  39. If None, will assume a locally running Unity3D editor
  40. to be used, instead.
  41. port (Optional[int]): Port number to connect to Unity environment.
  42. seed (int): A random seed value to use for the Unity3D game.
  43. no_graphics (bool): Whether to run the Unity3D simulator in
  44. no-graphics mode. Default: False.
  45. timeout_wait (int): Time (in seconds) to wait for connection from
  46. the Unity3D instance.
  47. episode_horizon (int): A hard horizon to abide to. After at most
  48. this many steps (per-agent episode `step()` calls), the
  49. Unity3D game is reset and will start again (finishing the
  50. multi-agent episode that the game represents).
  51. Note: The game itself may contain its own episode length
  52. limits, which are always obeyed (on top of this value here).
  53. """
  54. super().__init__()
  55. if file_name is None:
  56. print(
  57. "No game binary provided, will use a running Unity editor "
  58. "instead.\nMake sure you are pressing the Play (|>) button in "
  59. "your editor to start.")
  60. import mlagents_envs
  61. from mlagents_envs.environment import UnityEnvironment
  62. # Try connecting to the Unity3D game instance. If a port is blocked
  63. port_ = None
  64. while True:
  65. # Sleep for random time to allow for concurrent startup of many
  66. # environments (num_workers >> 1). Otherwise, would lead to port
  67. # conflicts sometimes.
  68. if port_ is not None:
  69. time.sleep(random.randint(1, 10))
  70. port_ = port or (self._BASE_PORT_ENVIRONMENT
  71. if file_name else self._BASE_PORT_EDITOR)
  72. # cache the worker_id and
  73. # increase it for the next environment
  74. worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0
  75. Unity3DEnv._WORKER_ID += 1
  76. try:
  77. self.unity_env = UnityEnvironment(
  78. file_name=file_name,
  79. worker_id=worker_id_,
  80. base_port=port_,
  81. seed=seed,
  82. no_graphics=no_graphics,
  83. timeout_wait=timeout_wait,
  84. )
  85. print(
  86. "Created UnityEnvironment for port {}".format(port_ +
  87. worker_id_))
  88. except mlagents_envs.exception.UnityWorkerInUseException:
  89. pass
  90. else:
  91. break
  92. # ML-Agents API version.
  93. self.api_version = self.unity_env.API_VERSION.split(".")
  94. self.api_version = [int(s) for s in self.api_version]
  95. # Reset entire env every this number of step calls.
  96. self.episode_horizon = episode_horizon
  97. # Keep track of how many times we have called `step` so far.
  98. self.episode_timesteps = 0
  99. def step(
  100. self, action_dict: MultiAgentDict
  101. ) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
  102. """Performs one multi-agent step through the game.
  103. Args:
  104. action_dict (dict): Multi-agent action dict with:
  105. keys=agent identifier consisting of
  106. [MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
  107. [Agent index, a unique MLAgent-assigned index per single agent]
  108. Returns:
  109. tuple:
  110. - obs: Multi-agent observation dict.
  111. Only those observations for which to get new actions are
  112. returned.
  113. - rewards: Rewards dict matching `obs`.
  114. - dones: Done dict with only an __all__ multi-agent entry in
  115. it. __all__=True, if episode is done for all agents.
  116. - infos: An (empty) info dict.
  117. """
  118. from mlagents_envs.base_env import ActionTuple
  119. # Set only the required actions (from the DecisionSteps) in Unity3D.
  120. all_agents = []
  121. for behavior_name in self.unity_env.behavior_specs:
  122. # New ML-Agents API: Set all agents actions at the same time
  123. # via an ActionTuple. Since API v1.4.0.
  124. if self.api_version[0] > 1 or (self.api_version[0] == 1
  125. and self.api_version[1] >= 4):
  126. actions = []
  127. for agent_id in self.unity_env.get_steps(behavior_name)[
  128. 0].agent_id:
  129. key = behavior_name + "_{}".format(agent_id)
  130. all_agents.append(key)
  131. actions.append(action_dict[key])
  132. if actions:
  133. if actions[0].dtype == np.float32:
  134. action_tuple = ActionTuple(
  135. continuous=np.array(actions))
  136. else:
  137. action_tuple = ActionTuple(discrete=np.array(actions))
  138. self.unity_env.set_actions(behavior_name, action_tuple)
  139. # Old behavior: Do not use an ActionTuple and set each agent's
  140. # action individually.
  141. else:
  142. for agent_id in self.unity_env.get_steps(behavior_name)[
  143. 0].agent_id_to_index.keys():
  144. key = behavior_name + "_{}".format(agent_id)
  145. all_agents.append(key)
  146. self.unity_env.set_action_for_agent(
  147. behavior_name, agent_id, action_dict[key])
  148. # Do the step.
  149. self.unity_env.step()
  150. obs, rewards, dones, infos = self._get_step_results()
  151. # Global horizon reached? -> Return __all__ done=True, so user
  152. # can reset. Set all agents' individual `done` to True as well.
  153. self.episode_timesteps += 1
  154. if self.episode_timesteps > self.episode_horizon:
  155. return obs, rewards, dict({
  156. "__all__": True
  157. }, **{agent_id: True
  158. for agent_id in all_agents}), infos
  159. return obs, rewards, dones, infos
  160. def reset(self) -> MultiAgentDict:
  161. """Resets the entire Unity3D scene (a single multi-agent episode)."""
  162. self.episode_timesteps = 0
  163. self.unity_env.reset()
  164. obs, _, _, _ = self._get_step_results()
  165. return obs
  166. def _get_step_results(self):
  167. """Collects those agents' obs/rewards that have to act in next `step`.
  168. Returns:
  169. Tuple:
  170. obs: Multi-agent observation dict.
  171. Only those observations for which to get new actions are
  172. returned.
  173. rewards: Rewards dict matching `obs`.
  174. dones: Done dict with only an __all__ multi-agent entry in it.
  175. __all__=True, if episode is done for all agents.
  176. infos: An (empty) info dict.
  177. """
  178. obs = {}
  179. rewards = {}
  180. infos = {}
  181. for behavior_name in self.unity_env.behavior_specs:
  182. decision_steps, terminal_steps = self.unity_env.get_steps(
  183. behavior_name)
  184. # Important: Only update those sub-envs that are currently
  185. # available within _env_state.
  186. # Loop through all envs ("agents") and fill in, whatever
  187. # information we have.
  188. for agent_id, idx in decision_steps.agent_id_to_index.items():
  189. key = behavior_name + "_{}".format(agent_id)
  190. os = tuple(o[idx] for o in decision_steps.obs)
  191. os = os[0] if len(os) == 1 else os
  192. obs[key] = os
  193. rewards[key] = decision_steps.reward[idx] # rewards vector
  194. for agent_id, idx in terminal_steps.agent_id_to_index.items():
  195. key = behavior_name + "_{}".format(agent_id)
  196. # Only overwrite rewards (last reward in episode), b/c obs
  197. # here is the last obs (which doesn't matter anyways).
  198. # Unless key does not exist in obs.
  199. if key not in obs:
  200. os = tuple(o[idx] for o in terminal_steps.obs)
  201. obs[key] = os = os[0] if len(os) == 1 else os
  202. rewards[key] = terminal_steps.reward[idx] # rewards vector
  203. # Only use dones if all agents are done, then we should do a reset.
  204. return obs, rewards, {"__all__": False}, infos
  205. @staticmethod
  206. def get_policy_configs_for_game(
  207. game_name: str) -> Tuple[dict, Callable[[AgentID], PolicyID]]:
  208. # The RLlib server must know about the Spaces that the Client will be
  209. # using inside Unity3D, up-front.
  210. obs_spaces = {
  211. # 3DBall.
  212. "3DBall": Box(float("-inf"), float("inf"), (8, )),
  213. # 3DBallHard.
  214. "3DBallHard": Box(float("-inf"), float("inf"), (45, )),
  215. # GridFoodCollector
  216. "GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)),
  217. # Pyramids.
  218. "Pyramids": TupleSpace([
  219. Box(float("-inf"), float("inf"), (56, )),
  220. Box(float("-inf"), float("inf"), (56, )),
  221. Box(float("-inf"), float("inf"), (56, )),
  222. Box(float("-inf"), float("inf"), (4, )),
  223. ]),
  224. # SoccerStrikersVsGoalie.
  225. "Goalie": Box(float("-inf"), float("inf"), (738, )),
  226. "Striker": TupleSpace([
  227. Box(float("-inf"), float("inf"), (231, )),
  228. Box(float("-inf"), float("inf"), (63, )),
  229. ]),
  230. # Sorter.
  231. "Sorter": TupleSpace([
  232. Box(float("-inf"), float("inf"), (
  233. 20,
  234. 23,
  235. )),
  236. Box(float("-inf"), float("inf"), (10, )),
  237. Box(float("-inf"), float("inf"), (8, )),
  238. ]),
  239. # Tennis.
  240. "Tennis": Box(float("-inf"), float("inf"), (27, )),
  241. # VisualHallway.
  242. "VisualHallway": Box(float("-inf"), float("inf"), (84, 84, 3)),
  243. # Walker.
  244. "Walker": Box(float("-inf"), float("inf"), (212, )),
  245. # FoodCollector.
  246. "FoodCollector": TupleSpace([
  247. Box(float("-inf"), float("inf"), (49, )),
  248. Box(float("-inf"), float("inf"), (4, )),
  249. ]),
  250. }
  251. action_spaces = {
  252. # 3DBall.
  253. "3DBall": Box(
  254. float("-inf"), float("inf"), (2, ), dtype=np.float32),
  255. # 3DBallHard.
  256. "3DBallHard": Box(
  257. float("-inf"), float("inf"), (2, ), dtype=np.float32),
  258. # GridFoodCollector.
  259. "GridFoodCollector": MultiDiscrete([3, 3, 3, 2]),
  260. # Pyramids.
  261. "Pyramids": MultiDiscrete([5]),
  262. # SoccerStrikersVsGoalie.
  263. "Goalie": MultiDiscrete([3, 3, 3]),
  264. "Striker": MultiDiscrete([3, 3, 3]),
  265. # Sorter.
  266. "Sorter": MultiDiscrete([3, 3, 3]),
  267. # Tennis.
  268. "Tennis": Box(float("-inf"), float("inf"), (3, )),
  269. # VisualHallway.
  270. "VisualHallway": MultiDiscrete([5]),
  271. # Walker.
  272. "Walker": Box(float("-inf"), float("inf"), (39, )),
  273. # FoodCollector.
  274. "FoodCollector": MultiDiscrete([3, 3, 3, 2]),
  275. }
  276. # Policies (Unity: "behaviors") and agent-to-policy mapping fns.
  277. if game_name == "SoccerStrikersVsGoalie":
  278. policies = {
  279. "Goalie": PolicySpec(
  280. observation_space=obs_spaces["Goalie"],
  281. action_space=action_spaces["Goalie"]),
  282. "Striker": PolicySpec(
  283. observation_space=obs_spaces["Striker"],
  284. action_space=action_spaces["Striker"]),
  285. }
  286. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  287. return "Striker" if "Striker" in agent_id else "Goalie"
  288. else:
  289. policies = {
  290. game_name: PolicySpec(
  291. observation_space=obs_spaces[game_name],
  292. action_space=action_spaces[game_name]),
  293. }
  294. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  295. return game_name
  296. return policies, policy_mapping_fn