preprocessors.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. from collections import OrderedDict
  2. import logging
  3. import numpy as np
  4. import gymnasium as gym
  5. from typing import Any, List
  6. from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
  7. from ray.rllib.utils.spaces.repeated import Repeated
  8. from ray.rllib.utils.typing import TensorType
  9. from ray.rllib.utils.images import resize
  10. from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
  11. ATARI_OBS_SHAPE = (210, 160, 3)
  12. ATARI_RAM_OBS_SHAPE = (128,)
  13. # Only validate env observations vs the observation space every n times in a
  14. # Preprocessor.
  15. OBS_VALIDATION_INTERVAL = 100
  16. logger = logging.getLogger(__name__)
  17. @PublicAPI
  18. class Preprocessor:
  19. """Defines an abstract observation preprocessor function.
  20. Attributes:
  21. shape (List[int]): Shape of the preprocessed output.
  22. """
  23. @PublicAPI
  24. def __init__(self, obs_space: gym.Space, options: dict = None):
  25. _legacy_patch_shapes(obs_space)
  26. self._obs_space = obs_space
  27. if not options:
  28. from ray.rllib.models.catalog import MODEL_DEFAULTS
  29. self._options = MODEL_DEFAULTS.copy()
  30. else:
  31. self._options = options
  32. self.shape = self._init_shape(obs_space, self._options)
  33. self._size = int(np.product(self.shape))
  34. self._i = 0
  35. self._obs_for_type_matching = self._obs_space.sample()
  36. @PublicAPI
  37. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  38. """Returns the shape after preprocessing."""
  39. raise NotImplementedError
  40. @PublicAPI
  41. def transform(self, observation: TensorType) -> np.ndarray:
  42. """Returns the preprocessed observation."""
  43. raise NotImplementedError
  44. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  45. """Alternative to transform for more efficient flattening."""
  46. array[offset : offset + self._size] = self.transform(observation)
  47. def check_shape(self, observation: Any) -> None:
  48. """Checks the shape of the given observation."""
  49. if self._i % OBS_VALIDATION_INTERVAL == 0:
  50. # Convert lists to np.ndarrays.
  51. if type(observation) is list and isinstance(
  52. self._obs_space, gym.spaces.Box
  53. ):
  54. observation = np.array(observation).astype(np.float32)
  55. if not self._obs_space.contains(observation):
  56. observation = convert_element_to_space_type(
  57. observation, self._obs_for_type_matching
  58. )
  59. try:
  60. if not self._obs_space.contains(observation):
  61. raise ValueError(
  62. "Observation ({} dtype={}) outside given space ({})!".format(
  63. observation,
  64. observation.dtype
  65. if isinstance(self._obs_space, gym.spaces.Box)
  66. else None,
  67. self._obs_space,
  68. )
  69. )
  70. except AttributeError as e:
  71. raise ValueError(
  72. "Observation for a Box/MultiBinary/MultiDiscrete space "
  73. "should be an np.array, not a Python list.",
  74. observation,
  75. ) from e
  76. self._i += 1
  77. @property
  78. @PublicAPI
  79. def size(self) -> int:
  80. return self._size
  81. @property
  82. @PublicAPI
  83. def observation_space(self) -> gym.Space:
  84. obs_space = gym.spaces.Box(-1.0, 1.0, self.shape, dtype=np.float32)
  85. # Stash the unwrapped space so that we can unwrap dict and tuple spaces
  86. # automatically in modelv2.py
  87. classes = (
  88. DictFlatteningPreprocessor,
  89. OneHotPreprocessor,
  90. RepeatedValuesPreprocessor,
  91. TupleFlatteningPreprocessor,
  92. AtariRamPreprocessor,
  93. GenericPixelPreprocessor,
  94. )
  95. if isinstance(self, classes):
  96. obs_space.original_space = self._obs_space
  97. return obs_space
  98. @DeveloperAPI
  99. class GenericPixelPreprocessor(Preprocessor):
  100. """Generic image preprocessor.
  101. Note: for Atari games, use config {"preprocessor_pref": "deepmind"}
  102. instead for deepmind-style Atari preprocessing.
  103. """
  104. @override(Preprocessor)
  105. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  106. self._grayscale = options.get("grayscale")
  107. self._zero_mean = options.get("zero_mean")
  108. self._dim = options.get("dim")
  109. if self._grayscale:
  110. shape = (self._dim, self._dim, 1)
  111. else:
  112. shape = (self._dim, self._dim, 3)
  113. return shape
  114. @override(Preprocessor)
  115. def transform(self, observation: TensorType) -> np.ndarray:
  116. """Downsamples images from (210, 160, 3) by the configured factor."""
  117. self.check_shape(observation)
  118. scaled = observation[25:-25, :, :]
  119. if self._dim < 84:
  120. scaled = resize(scaled, height=84, width=84)
  121. # OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
  122. # If we resize directly we lose pixels that, when mapped to 42x42,
  123. # aren't close enough to the pixel boundary.
  124. scaled = resize(scaled, height=self._dim, width=self._dim)
  125. if self._grayscale:
  126. scaled = scaled.mean(2)
  127. scaled = scaled.astype(np.float32)
  128. # Rescale needed for maintaining 1 channel
  129. scaled = np.reshape(scaled, [self._dim, self._dim, 1])
  130. if self._zero_mean:
  131. scaled = (scaled - 128) / 128
  132. else:
  133. scaled *= 1.0 / 255.0
  134. return scaled
  135. @DeveloperAPI
  136. class AtariRamPreprocessor(Preprocessor):
  137. @override(Preprocessor)
  138. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  139. return (128,)
  140. @override(Preprocessor)
  141. def transform(self, observation: TensorType) -> np.ndarray:
  142. self.check_shape(observation)
  143. return (observation.astype("float32") - 128) / 128
  144. @DeveloperAPI
  145. class OneHotPreprocessor(Preprocessor):
  146. """One-hot preprocessor for Discrete and MultiDiscrete spaces.
  147. Examples:
  148. >>> self.transform(Discrete(3).sample())
  149. ... np.array([0.0, 1.0, 0.0])
  150. >>> self.transform(MultiDiscrete([2, 3]).sample())
  151. ... np.array([0.0, 1.0, 0.0, 0.0, 1.0])
  152. """
  153. @override(Preprocessor)
  154. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  155. if isinstance(obs_space, gym.spaces.Discrete):
  156. return (self._obs_space.n,)
  157. else:
  158. return (np.sum(self._obs_space.nvec),)
  159. @override(Preprocessor)
  160. def transform(self, observation: TensorType) -> np.ndarray:
  161. self.check_shape(observation)
  162. return gym.spaces.utils.flatten(self._obs_space, observation).astype(np.float32)
  163. @override(Preprocessor)
  164. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  165. array[offset : offset + self.size] = self.transform(observation)
  166. @PublicAPI
  167. class NoPreprocessor(Preprocessor):
  168. @override(Preprocessor)
  169. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  170. return self._obs_space.shape
  171. @override(Preprocessor)
  172. def transform(self, observation: TensorType) -> np.ndarray:
  173. self.check_shape(observation)
  174. return observation
  175. @override(Preprocessor)
  176. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  177. array[offset : offset + self._size] = np.array(observation, copy=False).ravel()
  178. @property
  179. @override(Preprocessor)
  180. def observation_space(self) -> gym.Space:
  181. return self._obs_space
  182. @PublicAPI
  183. class MultiBinaryPreprocessor(Preprocessor):
  184. """Preprocessor that turns a MultiBinary space into a Box.
  185. Note: Before RLModules were introduced, RLlib's ModelCatalogV2 would produce
  186. ComplexInputNetworks that treat MultiBinary spaces as Boxes. This preprocessor is
  187. needed to get rid of the ComplexInputNetworks and use RLModules instead because
  188. RLModules lack the logic to handle MultiBinary or other non-Box spaces.
  189. """
  190. @override(Preprocessor)
  191. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  192. return self._obs_space.shape
  193. @override(Preprocessor)
  194. def transform(self, observation: TensorType) -> np.ndarray:
  195. # The shape stays the same, but the dtype changes.
  196. self.check_shape(observation)
  197. return observation.astype(np.float32)
  198. @override(Preprocessor)
  199. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  200. array[offset : offset + self._size] = np.array(observation, copy=False).ravel()
  201. @property
  202. @override(Preprocessor)
  203. def observation_space(self) -> gym.Space:
  204. obs_space = gym.spaces.Box(0.0, 1.0, self.shape, dtype=np.float32)
  205. obs_space.original_space = self._obs_space
  206. return obs_space
  207. @DeveloperAPI
  208. class TupleFlatteningPreprocessor(Preprocessor):
  209. """Preprocesses each tuple element, then flattens it all into a vector.
  210. RLlib models will unpack the flattened output before _build_layers_v2().
  211. """
  212. @override(Preprocessor)
  213. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  214. assert isinstance(self._obs_space, gym.spaces.Tuple)
  215. size = 0
  216. self.preprocessors = []
  217. for i in range(len(self._obs_space.spaces)):
  218. space = self._obs_space.spaces[i]
  219. logger.debug("Creating sub-preprocessor for {}".format(space))
  220. preprocessor_class = get_preprocessor(space)
  221. if preprocessor_class is not None:
  222. preprocessor = preprocessor_class(space, self._options)
  223. size += preprocessor.size
  224. else:
  225. preprocessor = None
  226. size += int(np.product(space.shape))
  227. self.preprocessors.append(preprocessor)
  228. return (size,)
  229. @override(Preprocessor)
  230. def transform(self, observation: TensorType) -> np.ndarray:
  231. self.check_shape(observation)
  232. array = np.zeros(self.shape, dtype=np.float32)
  233. self.write(observation, array, 0)
  234. return array
  235. @override(Preprocessor)
  236. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  237. assert len(observation) == len(self.preprocessors), observation
  238. for o, p in zip(observation, self.preprocessors):
  239. p.write(o, array, offset)
  240. offset += p.size
  241. @DeveloperAPI
  242. class DictFlatteningPreprocessor(Preprocessor):
  243. """Preprocesses each dict value, then flattens it all into a vector.
  244. RLlib models will unpack the flattened output before _build_layers_v2().
  245. """
  246. @override(Preprocessor)
  247. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  248. assert isinstance(self._obs_space, gym.spaces.Dict)
  249. size = 0
  250. self.preprocessors = []
  251. for space in self._obs_space.spaces.values():
  252. logger.debug("Creating sub-preprocessor for {}".format(space))
  253. preprocessor_class = get_preprocessor(space)
  254. if preprocessor_class is not None:
  255. preprocessor = preprocessor_class(space, self._options)
  256. size += preprocessor.size
  257. else:
  258. preprocessor = None
  259. size += int(np.product(space.shape))
  260. self.preprocessors.append(preprocessor)
  261. return (size,)
  262. @override(Preprocessor)
  263. def transform(self, observation: TensorType) -> np.ndarray:
  264. self.check_shape(observation)
  265. array = np.zeros(self.shape, dtype=np.float32)
  266. self.write(observation, array, 0)
  267. return array
  268. @override(Preprocessor)
  269. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  270. if not isinstance(observation, OrderedDict):
  271. observation = OrderedDict(sorted(observation.items()))
  272. assert len(observation) == len(self.preprocessors), (
  273. len(observation),
  274. len(self.preprocessors),
  275. )
  276. for o, p in zip(observation.values(), self.preprocessors):
  277. p.write(o, array, offset)
  278. offset += p.size
  279. @DeveloperAPI
  280. class RepeatedValuesPreprocessor(Preprocessor):
  281. """Pads and batches the variable-length list value."""
  282. @override(Preprocessor)
  283. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  284. assert isinstance(self._obs_space, Repeated)
  285. child_space = obs_space.child_space
  286. self.child_preprocessor = get_preprocessor(child_space)(
  287. child_space, self._options
  288. )
  289. # The first slot encodes the list length.
  290. size = 1 + self.child_preprocessor.size * obs_space.max_len
  291. return (size,)
  292. @override(Preprocessor)
  293. def transform(self, observation: TensorType) -> np.ndarray:
  294. array = np.zeros(self.shape)
  295. if isinstance(observation, list):
  296. for elem in observation:
  297. self.child_preprocessor.check_shape(elem)
  298. else:
  299. pass # ValueError will be raised in write() below.
  300. self.write(observation, array, 0)
  301. return array
  302. @override(Preprocessor)
  303. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  304. if not isinstance(observation, (list, np.ndarray)):
  305. raise ValueError(
  306. "Input for {} must be list type, got {}".format(self, observation)
  307. )
  308. elif len(observation) > self._obs_space.max_len:
  309. raise ValueError(
  310. "Input {} exceeds max len of space {}".format(
  311. observation, self._obs_space.max_len
  312. )
  313. )
  314. # The first slot encodes the list length.
  315. array[offset] = len(observation)
  316. for i, elem in enumerate(observation):
  317. offset_i = offset + 1 + i * self.child_preprocessor.size
  318. self.child_preprocessor.write(elem, array, offset_i)
  319. @PublicAPI
  320. def get_preprocessor(space: gym.Space, include_multi_binary=False) -> type:
  321. """Returns an appropriate preprocessor class for the given space."""
  322. _legacy_patch_shapes(space)
  323. obs_shape = space.shape
  324. if isinstance(space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)):
  325. preprocessor = OneHotPreprocessor
  326. elif obs_shape == ATARI_OBS_SHAPE:
  327. logger.debug(
  328. "Defaulting to RLlib's GenericPixelPreprocessor because input "
  329. "space has the atari-typical shape {}. Turn this behaviour off by setting "
  330. "`preprocessor_pref=None` or "
  331. "`preprocessor_pref='deepmind'` or disabling the preprocessing API "
  332. "altogether with `_disable_preprocessor_api=True`.".format(ATARI_OBS_SHAPE)
  333. )
  334. preprocessor = GenericPixelPreprocessor
  335. elif obs_shape == ATARI_RAM_OBS_SHAPE:
  336. logger.debug(
  337. "Defaulting to RLlib's AtariRamPreprocessor because input "
  338. "space has the atari-typical shape {}. Turn this behaviour off by setting "
  339. "`preprocessor_pref=None` or "
  340. "`preprocessor_pref='deepmind' or disabling the preprocessing API "
  341. "altogether with `_disable_preprocessor_api=True`."
  342. "`.".format(ATARI_OBS_SHAPE)
  343. )
  344. preprocessor = AtariRamPreprocessor
  345. elif isinstance(space, gym.spaces.Tuple):
  346. preprocessor = TupleFlatteningPreprocessor
  347. elif isinstance(space, gym.spaces.Dict):
  348. preprocessor = DictFlatteningPreprocessor
  349. elif isinstance(space, Repeated):
  350. preprocessor = RepeatedValuesPreprocessor
  351. # We usually only want to include this when using RLModules
  352. elif isinstance(space, gym.spaces.MultiBinary) and include_multi_binary:
  353. preprocessor = MultiBinaryPreprocessor
  354. else:
  355. preprocessor = NoPreprocessor
  356. return preprocessor
  357. def _legacy_patch_shapes(space: gym.Space) -> List[int]:
  358. """Assigns shapes to spaces that don't have shapes.
  359. This is only needed for older gym versions that don't set shapes properly
  360. for Tuple and Discrete spaces.
  361. """
  362. if not hasattr(space, "shape"):
  363. if isinstance(space, gym.spaces.Discrete):
  364. space.shape = ()
  365. elif isinstance(space, gym.spaces.Tuple):
  366. shapes = []
  367. for s in space.spaces:
  368. shape = _legacy_patch_shapes(s)
  369. shapes.append(shape)
  370. space.shape = tuple(shapes)
  371. return space.shape