remote_base_env.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import logging
  2. from typing import Callable, Dict, List, Optional, Tuple
  3. import gym
  4. import ray
  5. from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
  6. from ray.rllib.utils.annotations import override, PublicAPI
  7. from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID
  8. logger = logging.getLogger(__name__)
  9. @PublicAPI
  10. class RemoteBaseEnv(BaseEnv):
  11. """BaseEnv that executes its sub environments as @ray.remote actors.
  12. This provides dynamic batching of inference as observations are returned
  13. from the remote simulator actors. Both single and multi-agent child envs
  14. are supported, and envs can be stepped synchronously or asynchronously.
  15. NOTE: This class implicitly assumes that the remote envs are gym.Env's
  16. You shouldn't need to instantiate this class directly. It's automatically
  17. inserted when you use the `remote_worker_envs=True` option in your
  18. Trainer's config.
  19. """
  20. def __init__(self,
  21. make_env: Callable[[int], EnvType],
  22. num_envs: int,
  23. multiagent: bool,
  24. remote_env_batch_wait_ms: int,
  25. existing_envs: Optional[List[ray.actor.ActorHandle]] = None):
  26. """Initializes a RemoteVectorEnv instance.
  27. Args:
  28. make_env: Callable that produces a single (non-vectorized) env,
  29. given the vector env index as only arg.
  30. num_envs: The number of sub-environments to create for the
  31. vectorization.
  32. multiagent: Whether this is a multiagent env or not.
  33. remote_env_batch_wait_ms: Time to wait for (ray.remote)
  34. sub-environments to have new observations available when
  35. polled. Only when none of the sub-environments is ready,
  36. repeat the `ray.wait()` call until at least one sub-env
  37. is ready. Then return only the observations of the ready
  38. sub-environment(s).
  39. existing_envs: Optional list of already created sub-environments.
  40. These will be used as-is and only as many new sub-envs as
  41. necessary (`num_envs - len(existing_envs)`) will be created.
  42. """
  43. # Could be creating local or remote envs.
  44. self.make_env = make_env
  45. # Whether the given `make_env` callable already returns ray.remote
  46. # objects or not.
  47. self.make_env_creates_actors = False
  48. # Already existing env objects (generated by the RolloutWorker).
  49. self.existing_envs = existing_envs or []
  50. self.num_envs = num_envs
  51. self.multiagent = multiagent
  52. self.poll_timeout = remote_env_batch_wait_ms / 1000
  53. # List of ray actor handles (each handle points to one @ray.remote
  54. # sub-environment).
  55. self.actors: Optional[List[ray.actor.ActorHandle]] = None
  56. self._observation_space = None
  57. self._action_space = None
  58. # Dict mapping object refs (return values of @ray.remote calls),
  59. # whose actual values we are waiting for (via ray.wait in
  60. # `self.poll()`) to their corresponding actor handles (the actors
  61. # that created these return values).
  62. self.pending: Optional[Dict[ray.actor.ActorHandle]] = None
  63. @override(BaseEnv)
  64. def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  65. MultiEnvDict, MultiEnvDict]:
  66. if self.actors is None:
  67. # `self.make_env` already produces Actors: Use it directly.
  68. if len(self.existing_envs) > 0 and isinstance(
  69. self.existing_envs[0], ray.actor.ActorHandle):
  70. self.make_env_creates_actors = True
  71. self.actors = []
  72. while len(self.actors) < self.num_envs:
  73. self.actors.append(self.make_env(len(self.actors)))
  74. # `self.make_env` produces gym.Envs (or children thereof, such
  75. # as MultiAgentEnv): Need to auto-wrap it here. The problem with
  76. # this is that custom methods wil get lost. If you would like to
  77. # keep your custom methods in your envs, you should provide the
  78. # env class directly in your config (w/o tune.register_env()),
  79. # such that your class will directly be made a @ray.remote
  80. # (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`).
  81. else:
  82. def make_remote_env(i):
  83. logger.info("Launching env {} in remote actor".format(i))
  84. if self.multiagent:
  85. return _RemoteMultiAgentEnv.remote(self.make_env, i)
  86. else:
  87. return _RemoteSingleAgentEnv.remote(self.make_env, i)
  88. self.actors = [
  89. make_remote_env(i) for i in range(self.num_envs)
  90. ]
  91. self._observation_space = ray.get(
  92. self.actors[0].observation_space.remote())
  93. self._action_space = ray.get(
  94. self.actors[0].action_space.remote())
  95. # Lazy initialization. Call `reset()` on all @ray.remote
  96. # sub-environment actors at the beginning.
  97. if self.pending is None:
  98. # Initialize our pending object ref -> actor handle mapping
  99. # dict.
  100. self.pending = {a.reset.remote(): a for a in self.actors}
  101. # each keyed by env_id in [0, num_remote_envs)
  102. obs, rewards, dones, infos = {}, {}, {}, {}
  103. ready = []
  104. # Wait for at least 1 env to be ready here.
  105. while not ready:
  106. ready, _ = ray.wait(
  107. list(self.pending),
  108. num_returns=len(self.pending),
  109. timeout=self.poll_timeout)
  110. # Get and return observations for each of the ready envs
  111. env_ids = set()
  112. for obj_ref in ready:
  113. # Get the corresponding actor handle from our dict and remove the
  114. # object ref (we will call `ray.get()` on it and it will no longer
  115. # be "pending").
  116. actor = self.pending.pop(obj_ref)
  117. env_id = self.actors.index(actor)
  118. env_ids.add(env_id)
  119. # Get the ready object ref (this may be return value(s) of
  120. # `reset()` or `step()`).
  121. ret = ray.get(obj_ref)
  122. # Our sub-envs are simple Actor-turned gym.Envs or MultiAgentEnvs.
  123. if self.make_env_creates_actors:
  124. rew, done, info = None, None, None
  125. if self.multiagent:
  126. if isinstance(ret, tuple) and len(ret) == 4:
  127. ob, rew, done, info = ret
  128. else:
  129. ob = ret
  130. else:
  131. if isinstance(ret, tuple) and len(ret) == 4:
  132. ob = {_DUMMY_AGENT_ID: ret[0]}
  133. rew = {_DUMMY_AGENT_ID: ret[1]}
  134. done = {_DUMMY_AGENT_ID: ret[2], "__all__": ret[2]}
  135. info = {_DUMMY_AGENT_ID: ret[3]}
  136. else:
  137. ob = {_DUMMY_AGENT_ID: ret}
  138. # If this is a `reset()` return value, we only have the initial
  139. # observations: Set rewards, dones, and infos to dummy values.
  140. if rew is None:
  141. rew = {agent_id: 0 for agent_id in ob.keys()}
  142. done = {"__all__": False}
  143. info = {agent_id: {} for agent_id in ob.keys()}
  144. # Our sub-envs are auto-wrapped (by `_RemoteSingleAgentEnv` or
  145. # `_RemoteMultiAgentEnv`) and already behave like multi-agent
  146. # envs.
  147. else:
  148. ob, rew, done, info = ret
  149. obs[env_id] = ob
  150. rewards[env_id] = rew
  151. dones[env_id] = done
  152. infos[env_id] = info
  153. logger.debug("Got obs batch for actors {}".format(env_ids))
  154. return obs, rewards, dones, infos, {}
  155. @override(BaseEnv)
  156. @PublicAPI
  157. def send_actions(self, action_dict: MultiEnvDict) -> None:
  158. for env_id, actions in action_dict.items():
  159. actor = self.actors[env_id]
  160. # `actor` is a simple single-agent (remote) env, e.g. a gym.Env
  161. # that was made a @ray.remote.
  162. if not self.multiagent and self.make_env_creates_actors:
  163. obj_ref = actor.step.remote(actions[_DUMMY_AGENT_ID])
  164. # `actor` is already a _RemoteSingleAgentEnv or
  165. # _RemoteMultiAgentEnv wrapper
  166. # (handles the multi-agent action_dict automatically).
  167. else:
  168. obj_ref = actor.step.remote(actions)
  169. self.pending[obj_ref] = actor
  170. @override(BaseEnv)
  171. @PublicAPI
  172. def try_reset(self,
  173. env_id: Optional[EnvID] = None) -> Optional[MultiEnvDict]:
  174. actor = self.actors[env_id]
  175. obj_ref = actor.reset.remote()
  176. self.pending[obj_ref] = actor
  177. return ASYNC_RESET_RETURN
  178. @override(BaseEnv)
  179. @PublicAPI
  180. def stop(self) -> None:
  181. if self.actors is not None:
  182. for actor in self.actors:
  183. actor.__ray_terminate__.remote()
  184. @override(BaseEnv)
  185. @PublicAPI
  186. def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
  187. if as_dict:
  188. return {env_id: actor for env_id, actor in enumerate(self.actors)}
  189. return self.actors
  190. @property
  191. @override(BaseEnv)
  192. @PublicAPI
  193. def observation_space(self) -> gym.spaces.Dict:
  194. return self._observation_space
  195. @property
  196. @override(BaseEnv)
  197. @PublicAPI
  198. def action_space(self) -> gym.Space:
  199. return self._action_space
  200. @ray.remote(num_cpus=0)
  201. class _RemoteMultiAgentEnv:
  202. """Wrapper class for making a multi-agent env a remote actor."""
  203. def __init__(self, make_env, i):
  204. self.env = make_env(i)
  205. def reset(self):
  206. obs = self.env.reset()
  207. # each keyed by agent_id in the env
  208. rew = {agent_id: 0 for agent_id in obs.keys()}
  209. info = {agent_id: {} for agent_id in obs.keys()}
  210. done = {"__all__": False}
  211. return obs, rew, done, info
  212. def step(self, action_dict):
  213. return self.env.step(action_dict)
  214. # defining these 2 functions that way this information can be queried
  215. # with a call to ray.get()
  216. def observation_space(self):
  217. return self.env.observation_space
  218. def action_space(self):
  219. return self.env.action_space
  220. @ray.remote(num_cpus=0)
  221. class _RemoteSingleAgentEnv:
  222. """Wrapper class for making a gym env a remote actor."""
  223. def __init__(self, make_env, i):
  224. self.env = make_env(i)
  225. def reset(self):
  226. obs = {_DUMMY_AGENT_ID: self.env.reset()}
  227. rew = {agent_id: 0 for agent_id in obs.keys()}
  228. done = {"__all__": False}
  229. info = {agent_id: {} for agent_id in obs.keys()}
  230. return obs, rew, done, info
  231. def step(self, action):
  232. obs, rew, done, info = self.env.step(action[_DUMMY_AGENT_ID])
  233. obs, rew, done, info = [{
  234. _DUMMY_AGENT_ID: x
  235. } for x in [obs, rew, done, info]]
  236. done["__all__"] = done[_DUMMY_AGENT_ID]
  237. return obs, rew, done, info
  238. # defining these 2 functions that way this information can be queried
  239. # with a call to ray.get()
  240. def observation_space(self):
  241. return self.env.observation_space
  242. def action_space(self):
  243. return self.env.action_space