vector_env.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import logging
  2. import gym
  3. import numpy as np
  4. from typing import Callable, List, Optional, Tuple
  5. from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
  6. from ray.rllib.utils.typing import EnvActionType, EnvInfoDict, \
  7. EnvObsType, EnvType
  8. logger = logging.getLogger(__name__)
  9. @PublicAPI
  10. class VectorEnv:
  11. """An environment that supports batch evaluation using clones of sub-envs.
  12. """
  13. def __init__(self, observation_space: gym.Space, action_space: gym.Space,
  14. num_envs: int):
  15. """Initializes a VectorEnv instance.
  16. Args:
  17. observation_space: The observation Space of a single
  18. sub-env.
  19. action_space: The action Space of a single sub-env.
  20. num_envs: The number of clones to make of the given sub-env.
  21. """
  22. self.observation_space = observation_space
  23. self.action_space = action_space
  24. self.num_envs = num_envs
  25. @staticmethod
  26. def vectorize_gym_envs(
  27. make_env: Optional[Callable[[int], EnvType]] = None,
  28. existing_envs: Optional[List[gym.Env]] = None,
  29. num_envs: int = 1,
  30. action_space: Optional[gym.Space] = None,
  31. observation_space: Optional[gym.Space] = None,
  32. # Deprecated. These seem to have never been used.
  33. env_config=None,
  34. policy_config=None) -> "_VectorizedGymEnv":
  35. """Translates any given gym.Env(s) into a VectorizedEnv object.
  36. Args:
  37. make_env: Factory that produces a new gym.Env taking the sub-env's
  38. vector index as only arg. Must be defined if the
  39. number of `existing_envs` is less than `num_envs`.
  40. existing_envs: Optional list of already instantiated sub
  41. environments.
  42. num_envs: Total number of sub environments in this VectorEnv.
  43. action_space: The action space. If None, use existing_envs[0]'s
  44. action space.
  45. observation_space: The observation space. If None, use
  46. existing_envs[0]'s action space.
  47. Returns:
  48. The resulting _VectorizedGymEnv object (subclass of VectorEnv).
  49. """
  50. return _VectorizedGymEnv(
  51. make_env=make_env,
  52. existing_envs=existing_envs or [],
  53. num_envs=num_envs,
  54. observation_space=observation_space,
  55. action_space=action_space,
  56. )
  57. @PublicAPI
  58. def vector_reset(self) -> List[EnvObsType]:
  59. """Resets all sub-environments.
  60. Returns:
  61. List of observations from each environment.
  62. """
  63. raise NotImplementedError
  64. @PublicAPI
  65. def reset_at(self, index: Optional[int] = None) -> EnvObsType:
  66. """Resets a single environment.
  67. Args:
  68. index: An optional sub-env index to reset.
  69. Returns:
  70. Observations from the reset sub environment.
  71. """
  72. raise NotImplementedError
  73. @PublicAPI
  74. def vector_step(
  75. self, actions: List[EnvActionType]
  76. ) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]:
  77. """Performs a vectorized step on all sub environments using `actions`.
  78. Args:
  79. actions: List of actions (one for each sub-env).
  80. Returns:
  81. A tuple consisting of
  82. 1) New observations for each sub-env.
  83. 2) Reward values for each sub-env.
  84. 3) Done values for each sub-env.
  85. 4) Info values for each sub-env.
  86. """
  87. raise NotImplementedError
  88. @PublicAPI
  89. def get_sub_environments(self) -> List[EnvType]:
  90. """Returns the underlying sub environments.
  91. Returns:
  92. List of all underlying sub environments.
  93. """
  94. return []
  95. # TODO: (sven) Experimental method. Make @PublicAPI at some point.
  96. def try_render_at(self, index: Optional[int] = None) -> \
  97. Optional[np.ndarray]:
  98. """Renders a single environment.
  99. Args:
  100. index: An optional sub-env index to render.
  101. Returns:
  102. Either a numpy RGB image (shape=(w x h x 3) dtype=uint8) or
  103. None in case rendering is handled directly by this method.
  104. """
  105. pass
  106. @Deprecated(new="vectorize_gym_envs", error=False)
  107. def wrap(self, *args, **kwargs) -> "_VectorizedGymEnv":
  108. return self.vectorize_gym_envs(*args, **kwargs)
  109. @Deprecated(new="get_sub_environments", error=False)
  110. def get_unwrapped(self) -> List[EnvType]:
  111. return self.get_sub_environments()
  112. class _VectorizedGymEnv(VectorEnv):
  113. """Internal wrapper to translate any gym.Envs into a VectorEnv object.
  114. """
  115. def __init__(
  116. self,
  117. make_env: Optional[Callable[[int], EnvType]] = None,
  118. existing_envs: Optional[List[gym.Env]] = None,
  119. num_envs: int = 1,
  120. *,
  121. observation_space: Optional[gym.Space] = None,
  122. action_space: Optional[gym.Space] = None,
  123. # Deprecated. These seem to have never been used.
  124. env_config=None,
  125. policy_config=None,
  126. ):
  127. """Initializes a _VectorizedGymEnv object.
  128. Args:
  129. make_env: Factory that produces a new gym.Env taking the sub-env's
  130. vector index as only arg. Must be defined if the
  131. number of `existing_envs` is less than `num_envs`.
  132. existing_envs: Optional list of already instantiated sub
  133. environments.
  134. num_envs: Total number of sub environments in this VectorEnv.
  135. action_space: The action space. If None, use existing_envs[0]'s
  136. action space.
  137. observation_space: The observation space. If None, use
  138. existing_envs[0]'s action space.
  139. """
  140. self.envs = existing_envs
  141. # Fill up missing envs (so we have exactly num_envs sub-envs in this
  142. # VectorEnv.
  143. while len(self.envs) < num_envs:
  144. self.envs.append(make_env(len(self.envs)))
  145. super().__init__(
  146. observation_space=observation_space
  147. or self.envs[0].observation_space,
  148. action_space=action_space or self.envs[0].action_space,
  149. num_envs=num_envs)
  150. @override(VectorEnv)
  151. def vector_reset(self):
  152. return [e.reset() for e in self.envs]
  153. @override(VectorEnv)
  154. def reset_at(self, index: Optional[int] = None) -> EnvObsType:
  155. if index is None:
  156. index = 0
  157. return self.envs[index].reset()
  158. @override(VectorEnv)
  159. def vector_step(self, actions):
  160. obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
  161. for i in range(self.num_envs):
  162. obs, r, done, info = self.envs[i].step(actions[i])
  163. if not np.isscalar(r) or not np.isreal(r) or not np.isfinite(r):
  164. raise ValueError(
  165. "Reward should be finite scalar, got {} ({}). "
  166. "Actions={}.".format(r, type(r), actions[i]))
  167. if not isinstance(info, dict):
  168. raise ValueError("Info should be a dict, got {} ({})".format(
  169. info, type(info)))
  170. obs_batch.append(obs)
  171. rew_batch.append(r)
  172. done_batch.append(done)
  173. info_batch.append(info)
  174. return obs_batch, rew_batch, done_batch, info_batch
  175. @override(VectorEnv)
  176. def get_sub_environments(self):
  177. return self.envs
  178. @override(VectorEnv)
  179. def try_render_at(self, index: Optional[int] = None):
  180. if index is None:
  181. index = 0
  182. return self.envs[index].render()