vector_env.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. import logging
  2. import gym
  3. import numpy as np
  4. from typing import Callable, List, Optional, Tuple, Union
  5. from ray.rllib.env.base_env import BaseEnv
  6. from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
  7. from ray.rllib.utils.typing import EnvActionType, EnvID, EnvInfoDict, \
  8. EnvObsType, EnvType, MultiEnvDict
  9. logger = logging.getLogger(__name__)
  10. @PublicAPI
  11. class VectorEnv:
  12. """An environment that supports batch evaluation using clones of sub-envs.
  13. """
  14. def __init__(self, observation_space: gym.Space, action_space: gym.Space,
  15. num_envs: int):
  16. """Initializes a VectorEnv instance.
  17. Args:
  18. observation_space: The observation Space of a single
  19. sub-env.
  20. action_space: The action Space of a single sub-env.
  21. num_envs: The number of clones to make of the given sub-env.
  22. """
  23. self.observation_space = observation_space
  24. self.action_space = action_space
  25. self.num_envs = num_envs
  26. @staticmethod
  27. def vectorize_gym_envs(
  28. make_env: Optional[Callable[[int], EnvType]] = None,
  29. existing_envs: Optional[List[gym.Env]] = None,
  30. num_envs: int = 1,
  31. action_space: Optional[gym.Space] = None,
  32. observation_space: Optional[gym.Space] = None,
  33. # Deprecated. These seem to have never been used.
  34. env_config=None,
  35. policy_config=None) -> "_VectorizedGymEnv":
  36. """Translates any given gym.Env(s) into a VectorizedEnv object.
  37. Args:
  38. make_env: Factory that produces a new gym.Env taking the sub-env's
  39. vector index as only arg. Must be defined if the
  40. number of `existing_envs` is less than `num_envs`.
  41. existing_envs: Optional list of already instantiated sub
  42. environments.
  43. num_envs: Total number of sub environments in this VectorEnv.
  44. action_space: The action space. If None, use existing_envs[0]'s
  45. action space.
  46. observation_space: The observation space. If None, use
  47. existing_envs[0]'s action space.
  48. Returns:
  49. The resulting _VectorizedGymEnv object (subclass of VectorEnv).
  50. """
  51. return _VectorizedGymEnv(
  52. make_env=make_env,
  53. existing_envs=existing_envs or [],
  54. num_envs=num_envs,
  55. observation_space=observation_space,
  56. action_space=action_space,
  57. )
  58. @PublicAPI
  59. def vector_reset(self) -> List[EnvObsType]:
  60. """Resets all sub-environments.
  61. Returns:
  62. List of observations from each environment.
  63. """
  64. raise NotImplementedError
  65. @PublicAPI
  66. def reset_at(self, index: Optional[int] = None) -> EnvObsType:
  67. """Resets a single environment.
  68. Args:
  69. index: An optional sub-env index to reset.
  70. Returns:
  71. Observations from the reset sub environment.
  72. """
  73. raise NotImplementedError
  74. @PublicAPI
  75. def vector_step(
  76. self, actions: List[EnvActionType]
  77. ) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]:
  78. """Performs a vectorized step on all sub environments using `actions`.
  79. Args:
  80. actions: List of actions (one for each sub-env).
  81. Returns:
  82. A tuple consisting of
  83. 1) New observations for each sub-env.
  84. 2) Reward values for each sub-env.
  85. 3) Done values for each sub-env.
  86. 4) Info values for each sub-env.
  87. """
  88. raise NotImplementedError
  89. @PublicAPI
  90. def get_sub_environments(self) -> List[EnvType]:
  91. """Returns the underlying sub environments.
  92. Returns:
  93. List of all underlying sub environments.
  94. """
  95. return []
  96. # TODO: (sven) Experimental method. Make @PublicAPI at some point.
  97. def try_render_at(self, index: Optional[int] = None) -> \
  98. Optional[np.ndarray]:
  99. """Renders a single environment.
  100. Args:
  101. index: An optional sub-env index to render.
  102. Returns:
  103. Either a numpy RGB image (shape=(w x h x 3) dtype=uint8) or
  104. None in case rendering is handled directly by this method.
  105. """
  106. pass
  107. @Deprecated(new="vectorize_gym_envs", error=False)
  108. def wrap(self, *args, **kwargs) -> "_VectorizedGymEnv":
  109. return self.vectorize_gym_envs(*args, **kwargs)
  110. @Deprecated(new="get_sub_environments", error=False)
  111. def get_unwrapped(self) -> List[EnvType]:
  112. return self.get_sub_environments()
  113. @PublicAPI
  114. def to_base_env(
  115. self,
  116. make_env: Optional[Callable[[int], EnvType]] = None,
  117. num_envs: int = 1,
  118. remote_envs: bool = False,
  119. remote_env_batch_wait_ms: int = 0,
  120. ) -> "BaseEnv":
  121. """Converts an RLlib MultiAgentEnv into a BaseEnv object.
  122. The resulting BaseEnv is always vectorized (contains n
  123. sub-environments) to support batched forward passes, where n may
  124. also be 1. BaseEnv also supports async execution via the `poll` and
  125. `send_actions` methods and thus supports external simulators.
  126. Args:
  127. make_env: A callable taking an int as input (which indicates
  128. the number of individual sub-environments within the final
  129. vectorized BaseEnv) and returning one individual
  130. sub-environment.
  131. num_envs: The number of sub-environments to create in the
  132. resulting (vectorized) BaseEnv. The already existing `env`
  133. will be one of the `num_envs`.
  134. remote_envs: Whether each sub-env should be a @ray.remote
  135. actor. You can set this behavior in your config via the
  136. `remote_worker_envs=True` option.
  137. remote_env_batch_wait_ms: The wait time (in ms) to poll remote
  138. sub-environments for, if applicable. Only used if
  139. `remote_envs` is True.
  140. Returns:
  141. The resulting BaseEnv object.
  142. """
  143. del make_env, num_envs, remote_envs, remote_env_batch_wait_ms
  144. env = VectorEnvWrapper(self)
  145. return env
  146. class _VectorizedGymEnv(VectorEnv):
  147. """Internal wrapper to translate any gym.Envs into a VectorEnv object.
  148. """
  149. def __init__(
  150. self,
  151. make_env: Optional[Callable[[int], EnvType]] = None,
  152. existing_envs: Optional[List[gym.Env]] = None,
  153. num_envs: int = 1,
  154. *,
  155. observation_space: Optional[gym.Space] = None,
  156. action_space: Optional[gym.Space] = None,
  157. # Deprecated. These seem to have never been used.
  158. env_config=None,
  159. policy_config=None,
  160. ):
  161. """Initializes a _VectorizedGymEnv object.
  162. Args:
  163. make_env: Factory that produces a new gym.Env taking the sub-env's
  164. vector index as only arg. Must be defined if the
  165. number of `existing_envs` is less than `num_envs`.
  166. existing_envs: Optional list of already instantiated sub
  167. environments.
  168. num_envs: Total number of sub environments in this VectorEnv.
  169. action_space: The action space. If None, use existing_envs[0]'s
  170. action space.
  171. observation_space: The observation space. If None, use
  172. existing_envs[0]'s action space.
  173. """
  174. self.envs = existing_envs
  175. # Fill up missing envs (so we have exactly num_envs sub-envs in this
  176. # VectorEnv.
  177. while len(self.envs) < num_envs:
  178. self.envs.append(make_env(len(self.envs)))
  179. super().__init__(
  180. observation_space=observation_space
  181. or self.envs[0].observation_space,
  182. action_space=action_space or self.envs[0].action_space,
  183. num_envs=num_envs)
  184. @override(VectorEnv)
  185. def vector_reset(self):
  186. return [e.reset() for e in self.envs]
  187. @override(VectorEnv)
  188. def reset_at(self, index: Optional[int] = None) -> EnvObsType:
  189. if index is None:
  190. index = 0
  191. return self.envs[index].reset()
  192. @override(VectorEnv)
  193. def vector_step(self, actions):
  194. obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
  195. for i in range(self.num_envs):
  196. obs, r, done, info = self.envs[i].step(actions[i])
  197. if not np.isscalar(r) or not np.isreal(r) or not np.isfinite(r):
  198. raise ValueError(
  199. "Reward should be finite scalar, got {} ({}). "
  200. "Actions={}.".format(r, type(r), actions[i]))
  201. if not isinstance(info, dict):
  202. raise ValueError("Info should be a dict, got {} ({})".format(
  203. info, type(info)))
  204. obs_batch.append(obs)
  205. rew_batch.append(r)
  206. done_batch.append(done)
  207. info_batch.append(info)
  208. return obs_batch, rew_batch, done_batch, info_batch
  209. @override(VectorEnv)
  210. def get_sub_environments(self):
  211. return self.envs
  212. @override(VectorEnv)
  213. def try_render_at(self, index: Optional[int] = None):
  214. if index is None:
  215. index = 0
  216. return self.envs[index].render()
  217. class VectorEnvWrapper(BaseEnv):
  218. """Internal adapter of VectorEnv to BaseEnv.
  219. We assume the caller will always send the full vector of actions in each
  220. call to send_actions(), and that they call reset_at() on all completed
  221. environments before calling send_actions().
  222. """
  223. def __init__(self, vector_env: VectorEnv):
  224. self.vector_env = vector_env
  225. self.num_envs = vector_env.num_envs
  226. self.new_obs = None # lazily initialized
  227. self.cur_rewards = [None for _ in range(self.num_envs)]
  228. self.cur_dones = [False for _ in range(self.num_envs)]
  229. self.cur_infos = [None for _ in range(self.num_envs)]
  230. self._observation_space = vector_env.observation_space
  231. self._action_space = vector_env.action_space
  232. @override(BaseEnv)
  233. def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
  234. MultiEnvDict, MultiEnvDict]:
  235. from ray.rllib.env.base_env import with_dummy_agent_id
  236. if self.new_obs is None:
  237. self.new_obs = self.vector_env.vector_reset()
  238. new_obs = dict(enumerate(self.new_obs))
  239. rewards = dict(enumerate(self.cur_rewards))
  240. dones = dict(enumerate(self.cur_dones))
  241. infos = dict(enumerate(self.cur_infos))
  242. self.new_obs = []
  243. self.cur_rewards = []
  244. self.cur_dones = []
  245. self.cur_infos = []
  246. return with_dummy_agent_id(new_obs), \
  247. with_dummy_agent_id(rewards), \
  248. with_dummy_agent_id(dones, "__all__"), \
  249. with_dummy_agent_id(infos), {}
  250. @override(BaseEnv)
  251. def send_actions(self, action_dict: MultiEnvDict) -> None:
  252. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  253. action_vector = [None] * self.num_envs
  254. for i in range(self.num_envs):
  255. action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
  256. self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \
  257. self.vector_env.vector_step(action_vector)
  258. @override(BaseEnv)
  259. def try_reset(self, env_id: Optional[EnvID] = None) -> MultiEnvDict:
  260. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  261. assert env_id is None or isinstance(env_id, int)
  262. return {
  263. env_id if env_id is not None else 0: {
  264. _DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)
  265. }
  266. }
  267. @override(BaseEnv)
  268. def get_sub_environments(
  269. self, as_dict: bool = False) -> Union[List[EnvType], dict]:
  270. if not as_dict:
  271. return self.vector_env.get_sub_environments()
  272. else:
  273. return {
  274. _id: env
  275. for _id, env in enumerate(
  276. self.vector_env.get_sub_environments())
  277. }
  278. @override(BaseEnv)
  279. def try_render(self, env_id: Optional[EnvID] = None) -> None:
  280. assert env_id is None or isinstance(env_id, int)
  281. return self.vector_env.try_render_at(env_id)
  282. @property
  283. @override(BaseEnv)
  284. @PublicAPI
  285. def observation_space(self) -> gym.spaces.Dict:
  286. return self._observation_space
  287. @property
  288. @override(BaseEnv)
  289. @PublicAPI
  290. def action_space(self) -> gym.Space:
  291. return self._action_space