modelv2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. from collections import OrderedDict
  2. import contextlib
  3. import gym
  4. from gym.spaces import Space
  5. import numpy as np
  6. from typing import Dict, List, Any, Union
  7. from ray.rllib.models.preprocessors import get_preprocessor, \
  8. RepeatedValuesPreprocessor
  9. from ray.rllib.models.repeated_values import RepeatedValues
  10. from ray.rllib.policy.sample_batch import SampleBatch
  11. from ray.rllib.policy.view_requirement import ViewRequirement
  12. from ray.rllib.utils import NullContextManager
  13. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
  14. from ray.rllib.utils.deprecation import Deprecated
  15. from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
  16. TensorType
  17. from ray.rllib.utils.spaces.repeated import Repeated
  18. from ray.rllib.utils.typing import ModelConfigDict, ModelInputDict, \
  19. TensorStructType
  20. tf1, tf, tfv = try_import_tf()
  21. torch, _ = try_import_torch()
  22. @PublicAPI
  23. class ModelV2:
  24. """Defines an abstract neural network model for use with RLlib.
  25. Custom models should extend either TFModelV2 or TorchModelV2 instead of
  26. this class directly.
  27. Data flow:
  28. obs -> forward() -> model_out
  29. \-> value_function() -> V(s)
  30. """
  31. def __init__(self, obs_space: Space, action_space: Space, num_outputs: int,
  32. model_config: ModelConfigDict, name: str, framework: str):
  33. """Initializes a ModelV2 instance.
  34. This method should create any variables used by the model.
  35. Args:
  36. obs_space: Observation space of the target gym
  37. env. This may have an `original_space` attribute that
  38. specifies how to unflatten the tensor into a ragged tensor.
  39. action_space: Action space of the target gym
  40. env.
  41. num_outputs: Number of output units of the model.
  42. model_config: Config for the model, documented
  43. in ModelCatalog.
  44. name: Name (scope) for the model.
  45. framework: Either "tf" or "torch".
  46. """
  47. self.obs_space: Space = obs_space
  48. self.action_space: Space = action_space
  49. self.num_outputs: int = num_outputs
  50. self.model_config: ModelConfigDict = model_config
  51. self.name: str = name or "default_model"
  52. self.framework: str = framework
  53. self._last_output = None
  54. self.time_major = self.model_config.get("_time_major")
  55. # Basic view requirement for all models: Use the observation as input.
  56. self.view_requirements = {
  57. SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
  58. }
  59. # TODO: (sven): Get rid of `get_initial_state` once Trajectory
  60. # View API is supported across all of RLlib.
  61. @PublicAPI
  62. def get_initial_state(self) -> List[np.ndarray]:
  63. """Get the initial recurrent state values for the model.
  64. Returns:
  65. List of np.array objects containing the initial hidden state
  66. of an RNN, if applicable.
  67. Examples:
  68. >>> def get_initial_state(self):
  69. >>> return [
  70. >>> np.zeros(self.cell_size, np.float32),
  71. >>> np.zeros(self.cell_size, np.float32),
  72. >>> ]
  73. """
  74. return []
  75. @PublicAPI
  76. def forward(self, input_dict: Dict[str, TensorType],
  77. state: List[TensorType],
  78. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  79. """Call the model with the given input tensors and state.
  80. Any complex observations (dicts, tuples, etc.) will be unpacked by
  81. __call__ before being passed to forward(). To access the flattened
  82. observation tensor, refer to input_dict["obs_flat"].
  83. This method can be called any number of times. In eager execution,
  84. each call to forward() will eagerly evaluate the model. In symbolic
  85. execution, each call to forward creates a computation graph that
  86. operates over the variables of this model (i.e., shares weights).
  87. Custom models should override this instead of __call__.
  88. Args:
  89. input_dict: dictionary of input tensors, including "obs",
  90. "obs_flat", "prev_action", "prev_reward", "is_training",
  91. "eps_id", "agent_id", "infos", and "t".
  92. state: list of state tensors with sizes matching those
  93. returned by get_initial_state + the batch dimension
  94. seq_lens: 1d tensor holding input sequence lengths
  95. Returns:
  96. A tuple consisting of the model output tensor of size
  97. [BATCH, num_outputs] and the list of new RNN state(s) if any.
  98. Examples:
  99. >>> def forward(self, input_dict, state, seq_lens):
  100. >>> model_out, self._value_out = self.base_model(
  101. ... input_dict["obs"])
  102. >>> return model_out, state
  103. """
  104. raise NotImplementedError
  105. @PublicAPI
  106. def value_function(self) -> TensorType:
  107. """Returns the value function output for the most recent forward pass.
  108. Note that a `forward` call has to be performed first, before this
  109. methods can return anything and thus that calling this method does not
  110. cause an extra forward pass through the network.
  111. Returns:
  112. Value estimate tensor of shape [BATCH].
  113. """
  114. raise NotImplementedError
  115. @PublicAPI
  116. def custom_loss(self, policy_loss: TensorType,
  117. loss_inputs: Dict[str, TensorType]) -> \
  118. Union[List[TensorType], TensorType]:
  119. """Override to customize the loss function used to optimize this model.
  120. This can be used to incorporate self-supervised losses (by defining
  121. a loss over existing input and output tensors of this model), and
  122. supervised losses (by defining losses over a variable-sharing copy of
  123. this model's layers).
  124. You can find an runnable example in examples/custom_loss.py.
  125. Args:
  126. policy_loss: List of or single policy loss(es) from the policy.
  127. loss_inputs: map of input placeholders for rollout data.
  128. Returns:
  129. List of or scalar tensor for the customized loss(es) for this
  130. model.
  131. """
  132. return policy_loss
  133. @PublicAPI
  134. def metrics(self) -> Dict[str, TensorType]:
  135. """Override to return custom metrics from your model.
  136. The stats will be reported as part of the learner stats, i.e.,
  137. info.learner.[policy_id, e.g. "default_policy"].model.key1=metric1
  138. Returns:
  139. The custom metrics for this model.
  140. """
  141. return {}
  142. def __call__(
  143. self,
  144. input_dict: Union[SampleBatch, ModelInputDict],
  145. state: List[Any] = None,
  146. seq_lens: TensorType = None) -> (TensorType, List[TensorType]):
  147. """Call the model with the given input tensors and state.
  148. This is the method used by RLlib to execute the forward pass. It calls
  149. forward() internally after unpacking nested observation tensors.
  150. Custom models should override forward() instead of __call__.
  151. Args:
  152. input_dict: Dictionary of input tensors.
  153. state: list of state tensors with sizes matching those
  154. returned by get_initial_state + the batch dimension
  155. seq_lens: 1D tensor holding input sequence lengths.
  156. Returns:
  157. A tuple consisting of the model output tensor of size
  158. [BATCH, output_spec.size] or a list of tensors corresponding to
  159. output_spec.shape_list, and a list of state tensors of
  160. [BATCH, state_size_i] if any.
  161. """
  162. # Original observations will be stored in "obs".
  163. # Flattened (preprocessed) obs will be stored in "obs_flat".
  164. # SampleBatch case: Models can now be called directly with a
  165. # SampleBatch (which also includes tracking-dict case (deprecated now),
  166. # where tensors get automatically converted).
  167. if isinstance(input_dict, SampleBatch):
  168. restored = input_dict.copy(shallow=True)
  169. else:
  170. restored = input_dict.copy()
  171. # Backward compatibility.
  172. if not state:
  173. state = []
  174. i = 0
  175. while "state_in_{}".format(i) in input_dict:
  176. state.append(input_dict["state_in_{}".format(i)])
  177. i += 1
  178. if seq_lens is None:
  179. seq_lens = input_dict.get(SampleBatch.SEQ_LENS)
  180. # No Preprocessor used: `config._disable_preprocessor_api`=True.
  181. # TODO: This is unnecessary for when no preprocessor is used.
  182. # Obs are not flat then anymore. However, we'll keep this
  183. # here for backward-compatibility until Preprocessors have
  184. # been fully deprecated.
  185. if self.model_config.get("_disable_preprocessor_api"):
  186. restored["obs_flat"] = input_dict["obs"]
  187. # Input to this Model went through a Preprocessor.
  188. # Generate extra keys: "obs_flat" (vs "obs", which will hold the
  189. # original obs).
  190. else:
  191. restored["obs"] = restore_original_dimensions(
  192. input_dict["obs"], self.obs_space, self.framework)
  193. try:
  194. if len(input_dict["obs"].shape) > 2:
  195. restored["obs_flat"] = flatten(input_dict["obs"],
  196. self.framework)
  197. else:
  198. restored["obs_flat"] = input_dict["obs"]
  199. except AttributeError:
  200. restored["obs_flat"] = input_dict["obs"]
  201. with self.context():
  202. res = self.forward(restored, state or [], seq_lens)
  203. if isinstance(input_dict, SampleBatch):
  204. input_dict.accessed_keys = restored.accessed_keys - {"obs_flat"}
  205. input_dict.deleted_keys = restored.deleted_keys
  206. input_dict.added_keys = restored.added_keys - {"obs_flat"}
  207. if ((not isinstance(res, list) and not isinstance(res, tuple))
  208. or len(res) != 2):
  209. raise ValueError(
  210. "forward() must return a tuple of (output, state) tensors, "
  211. "got {}".format(res))
  212. outputs, state_out = res
  213. if not isinstance(state_out, list):
  214. raise ValueError(
  215. "State output is not a list: {}".format(state_out))
  216. self._last_output = outputs
  217. return outputs, state_out if len(state_out) > 0 else (state or [])
  218. def import_from_h5(self, h5_file: str) -> None:
  219. """Imports weights from an h5 file.
  220. Args:
  221. h5_file: The h5 file name to import weights from.
  222. Example:
  223. >>> trainer = MyTrainer()
  224. >>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
  225. >>> for _ in range(10):
  226. >>> trainer.train()
  227. """
  228. raise NotImplementedError
  229. @PublicAPI
  230. def last_output(self) -> TensorType:
  231. """Returns the last output returned from calling the model."""
  232. return self._last_output
  233. @PublicAPI
  234. def context(self) -> contextlib.AbstractContextManager:
  235. """Returns a contextmanager for the current forward pass."""
  236. return NullContextManager()
  237. @PublicAPI
  238. def variables(self, as_dict: bool = False
  239. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  240. """Returns the list (or a dict) of variables for this model.
  241. Args:
  242. as_dict: Whether variables should be returned as dict-values
  243. (using descriptive str keys).
  244. Returns:
  245. The list (or dict if `as_dict` is True) of all variables of this
  246. ModelV2.
  247. """
  248. raise NotImplementedError
  249. @PublicAPI
  250. def trainable_variables(
  251. self, as_dict: bool = False
  252. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  253. """Returns the list of trainable variables for this model.
  254. Args:
  255. as_dict: Whether variables should be returned as dict-values
  256. (using descriptive keys).
  257. Returns:
  258. The list (or dict if `as_dict` is True) of all trainable
  259. (tf)/requires_grad (torch) variables of this ModelV2.
  260. """
  261. raise NotImplementedError
  262. @PublicAPI
  263. def is_time_major(self) -> bool:
  264. """If True, data for calling this ModelV2 must be in time-major format.
  265. Returns
  266. Whether this ModelV2 requires a time-major (TxBx...) data
  267. format.
  268. """
  269. return self.time_major is True
  270. @Deprecated(new="ModelV2.__call__()", error=False)
  271. def from_batch(self, train_batch: SampleBatch,
  272. is_training: bool = True) -> (TensorType, List[TensorType]):
  273. """Convenience function that calls this model with a tensor batch.
  274. All this does is unpack the tensor batch to call this model with the
  275. right input dict, state, and seq len arguments.
  276. """
  277. input_dict = train_batch.copy()
  278. input_dict.set_training(is_training)
  279. states = []
  280. i = 0
  281. while "state_in_{}".format(i) in input_dict:
  282. states.append(input_dict["state_in_{}".format(i)])
  283. i += 1
  284. ret = self.__call__(input_dict, states,
  285. input_dict.get(SampleBatch.SEQ_LENS))
  286. return ret
  287. @DeveloperAPI
  288. def flatten(obs: TensorType, framework: str) -> TensorType:
  289. """Flatten the given tensor."""
  290. if framework in ["tf2", "tf", "tfe"]:
  291. return tf1.keras.layers.Flatten()(obs)
  292. elif framework == "torch":
  293. assert torch is not None
  294. return torch.flatten(obs, start_dim=1)
  295. else:
  296. raise NotImplementedError("flatten", framework)
  297. @DeveloperAPI
  298. def restore_original_dimensions(obs: TensorType,
  299. obs_space: Space,
  300. tensorlib: Any = tf) -> TensorStructType:
  301. """Unpacks Dict and Tuple space observations into their original form.
  302. This is needed since we flatten Dict and Tuple observations in transit
  303. within a SampleBatch. Before sending them to the model though, we should
  304. unflatten them into Dicts or Tuples of tensors.
  305. Args:
  306. obs: The flattened observation tensor.
  307. obs_space: The flattened obs space. If this has the
  308. `original_space` attribute, we will unflatten the tensor to that
  309. shape.
  310. tensorlib: The library used to unflatten (reshape) the array/tensor.
  311. Returns:
  312. single tensor or dict / tuple of tensors matching the original
  313. observation space.
  314. """
  315. if tensorlib in ["tf", "tfe", "tf2"]:
  316. assert tf is not None
  317. tensorlib = tf
  318. elif tensorlib == "torch":
  319. assert torch is not None
  320. tensorlib = torch
  321. elif tensorlib == "numpy":
  322. assert np is not None
  323. tensorlib = np
  324. original_space = getattr(obs_space, "original_space", obs_space)
  325. return _unpack_obs(obs, original_space, tensorlib=tensorlib)
  326. # Cache of preprocessors, for if the user is calling unpack obs often.
  327. _cache = {}
  328. def _unpack_obs(obs: TensorType, space: Space,
  329. tensorlib: Any = tf) -> TensorStructType:
  330. """Unpack a flattened Dict or Tuple observation array/tensor.
  331. Args:
  332. obs: The flattened observation tensor, with last dimension equal to
  333. the flat size and any number of batch dimensions. For example, for
  334. Box(4,), the obs may have shape [B, 4], or [B, N, M, 4] in case
  335. the Box was nested under two Repeated spaces.
  336. space: The original space prior to flattening
  337. tensorlib: The library used to unflatten (reshape) the array/tensor
  338. """
  339. if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple, Repeated)):
  340. if id(space) in _cache:
  341. prep = _cache[id(space)]
  342. else:
  343. prep = get_preprocessor(space)(space)
  344. # Make an attempt to cache the result, if enough space left.
  345. if len(_cache) < 999:
  346. _cache[id(space)] = prep
  347. # Already unpacked?
  348. if (isinstance(space, gym.spaces.Tuple) and
  349. isinstance(obs, (list, tuple))) or \
  350. (isinstance(space, gym.spaces.Dict) and isinstance(obs, dict)):
  351. return obs
  352. elif len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]:
  353. raise ValueError(
  354. "Expected flattened obs shape of [..., {}], got {}".format(
  355. prep.shape[0], obs.shape))
  356. offset = 0
  357. if tensorlib == tf:
  358. batch_dims = [
  359. v if isinstance(v, int) else v.value for v in obs.shape[:-1]
  360. ]
  361. batch_dims = [-1 if v is None else v for v in batch_dims]
  362. else:
  363. batch_dims = list(obs.shape[:-1])
  364. if isinstance(space, gym.spaces.Tuple):
  365. assert len(prep.preprocessors) == len(space.spaces), \
  366. (len(prep.preprocessors) == len(space.spaces))
  367. u = []
  368. for p, v in zip(prep.preprocessors, space.spaces):
  369. obs_slice = obs[..., offset:offset + p.size]
  370. offset += p.size
  371. u.append(
  372. _unpack_obs(
  373. tensorlib.reshape(obs_slice,
  374. batch_dims + list(p.shape)),
  375. v,
  376. tensorlib=tensorlib))
  377. elif isinstance(space, gym.spaces.Dict):
  378. assert len(prep.preprocessors) == len(space.spaces), \
  379. (len(prep.preprocessors) == len(space.spaces))
  380. u = OrderedDict()
  381. for p, (k, v) in zip(prep.preprocessors, space.spaces.items()):
  382. obs_slice = obs[..., offset:offset + p.size]
  383. offset += p.size
  384. u[k] = _unpack_obs(
  385. tensorlib.reshape(obs_slice, batch_dims + list(p.shape)),
  386. v,
  387. tensorlib=tensorlib)
  388. # Repeated space.
  389. else:
  390. assert isinstance(prep, RepeatedValuesPreprocessor), prep
  391. child_size = prep.child_preprocessor.size
  392. # The list lengths are stored in the first slot of the flat obs.
  393. lengths = obs[..., 0]
  394. # [B, ..., 1 + max_len * child_sz] -> [B, ..., max_len, child_sz]
  395. with_repeat_dim = tensorlib.reshape(
  396. obs[..., 1:], batch_dims + [space.max_len, child_size])
  397. # Retry the unpack, dropping the List container space.
  398. u = _unpack_obs(
  399. with_repeat_dim, space.child_space, tensorlib=tensorlib)
  400. return RepeatedValues(
  401. u, lengths=lengths, max_len=prep._obs_space.max_len)
  402. return u
  403. else:
  404. return obs