modelv2.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. from collections import OrderedDict
  2. import contextlib
  3. import gym
  4. import numpy as np
  5. from typing import Dict, List, Any, Union
  6. from ray.rllib.models.preprocessors import get_preprocessor, \
  7. RepeatedValuesPreprocessor
  8. from ray.rllib.models.repeated_values import RepeatedValues
  9. from ray.rllib.policy.sample_batch import SampleBatch
  10. from ray.rllib.policy.view_requirement import ViewRequirement
  11. from ray.rllib.utils import NullContextManager
  12. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
  13. from ray.rllib.utils.deprecation import Deprecated
  14. from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
  15. TensorType
  16. from ray.rllib.utils.spaces.repeated import Repeated
  17. from ray.rllib.utils.typing import ModelConfigDict, ModelInputDict, \
  18. TensorStructType
  19. tf1, tf, tfv = try_import_tf()
  20. torch, _ = try_import_torch()
  21. @PublicAPI
  22. class ModelV2:
  23. """Defines an abstract neural network model for use with RLlib.
  24. Custom models should extend either TFModelV2 or TorchModelV2 instead of
  25. this class directly.
  26. Data flow:
  27. obs -> forward() -> model_out
  28. value_function() -> V(s)
  29. """
  30. def __init__(self, obs_space: gym.spaces.Space,
  31. action_space: gym.spaces.Space, num_outputs: int,
  32. model_config: ModelConfigDict, name: str, framework: str):
  33. """Initializes a ModelV2 object.
  34. This method should create any variables used by the model.
  35. Args:
  36. obs_space (gym.spaces.Space): Observation space of the target gym
  37. env. This may have an `original_space` attribute that
  38. specifies how to unflatten the tensor into a ragged tensor.
  39. action_space (gym.spaces.Space): Action space of the target gym
  40. env.
  41. num_outputs (int): Number of output units of the model.
  42. model_config (ModelConfigDict): Config for the model, documented
  43. in ModelCatalog.
  44. name (str): Name (scope) for the model.
  45. framework (str): Either "tf" or "torch".
  46. """
  47. self.obs_space: gym.spaces.Space = obs_space
  48. self.action_space: gym.spaces.Space = action_space
  49. self.num_outputs: int = num_outputs
  50. self.model_config: ModelConfigDict = model_config
  51. self.name: str = name or "default_model"
  52. self.framework: str = framework
  53. self._last_output = None
  54. self.time_major = self.model_config.get("_time_major")
  55. # Basic view requirement for all models: Use the observation as input.
  56. self.view_requirements = {
  57. SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
  58. }
  59. # TODO: (sven): Get rid of `get_initial_state` once Trajectory
  60. # View API is supported across all of RLlib.
  61. @PublicAPI
  62. def get_initial_state(self) -> List[np.ndarray]:
  63. """Get the initial recurrent state values for the model.
  64. Returns:
  65. List[np.ndarray]: List of np.array objects containing the initial
  66. hidden state of an RNN, if applicable.
  67. Examples:
  68. >>> def get_initial_state(self):
  69. >>> return [
  70. >>> np.zeros(self.cell_size, np.float32),
  71. >>> np.zeros(self.cell_size, np.float32),
  72. >>> ]
  73. """
  74. return []
  75. @PublicAPI
  76. def forward(self, input_dict: Dict[str, TensorType],
  77. state: List[TensorType],
  78. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  79. """Call the model with the given input tensors and state.
  80. Any complex observations (dicts, tuples, etc.) will be unpacked by
  81. __call__ before being passed to forward(). To access the flattened
  82. observation tensor, refer to input_dict["obs_flat"].
  83. This method can be called any number of times. In eager execution,
  84. each call to forward() will eagerly evaluate the model. In symbolic
  85. execution, each call to forward creates a computation graph that
  86. operates over the variables of this model (i.e., shares weights).
  87. Custom models should override this instead of __call__.
  88. Args:
  89. input_dict (dict): dictionary of input tensors, including "obs",
  90. "obs_flat", "prev_action", "prev_reward", "is_training",
  91. "eps_id", "agent_id", "infos", and "t".
  92. state (list): list of state tensors with sizes matching those
  93. returned by get_initial_state + the batch dimension
  94. seq_lens (Tensor): 1d tensor holding input sequence lengths
  95. Returns:
  96. (outputs, state): The model output tensor of size
  97. [BATCH, num_outputs], and the new RNN state.
  98. Examples:
  99. >>> def forward(self, input_dict, state, seq_lens):
  100. >>> model_out, self._value_out = self.base_model(
  101. ... input_dict["obs"])
  102. >>> return model_out, state
  103. """
  104. raise NotImplementedError
  105. @PublicAPI
  106. def value_function(self) -> TensorType:
  107. """Returns the value function output for the most recent forward pass.
  108. Note that a `forward` call has to be performed first, before this
  109. methods can return anything and thus that calling this method does not
  110. cause an extra forward pass through the network.
  111. Returns:
  112. value estimate tensor of shape [BATCH].
  113. """
  114. raise NotImplementedError
  115. @PublicAPI
  116. def custom_loss(self, policy_loss: TensorType,
  117. loss_inputs: Dict[str, TensorType]) -> TensorType:
  118. """Override to customize the loss function used to optimize this model.
  119. This can be used to incorporate self-supervised losses (by defining
  120. a loss over existing input and output tensors of this model), and
  121. supervised losses (by defining losses over a variable-sharing copy of
  122. this model's layers).
  123. You can find an runnable example in examples/custom_loss.py.
  124. Args:
  125. policy_loss (Union[List[Tensor],Tensor]): List of or single policy
  126. loss(es) from the policy.
  127. loss_inputs (dict): map of input placeholders for rollout data.
  128. Returns:
  129. Union[List[Tensor],Tensor]: List of or scalar tensor for the
  130. customized loss(es) for this model.
  131. """
  132. return policy_loss
  133. @PublicAPI
  134. def metrics(self) -> Dict[str, TensorType]:
  135. """Override to return custom metrics from your model.
  136. The stats will be reported as part of the learner stats, i.e.,
  137. info.learner.[policy_id, e.g. "default_policy"].model.key1=metric1
  138. Returns:
  139. Dict[str, TensorType]: The custom metrics for this model.
  140. """
  141. return {}
  142. def __call__(
  143. self,
  144. input_dict: Union[SampleBatch, ModelInputDict],
  145. state: List[Any] = None,
  146. seq_lens: TensorType = None) -> (TensorType, List[TensorType]):
  147. """Call the model with the given input tensors and state.
  148. This is the method used by RLlib to execute the forward pass. It calls
  149. forward() internally after unpacking nested observation tensors.
  150. Custom models should override forward() instead of __call__.
  151. Args:
  152. input_dict (Union[SampleBatch, ModelInputDict]): Dictionary of
  153. input tensors.
  154. state (list): list of state tensors with sizes matching those
  155. returned by get_initial_state + the batch dimension
  156. seq_lens (Tensor): 1D tensor holding input sequence lengths.
  157. Returns:
  158. (outputs, state): The model output tensor of size
  159. [BATCH, output_spec.size] or a list of tensors corresponding to
  160. output_spec.shape_list, and a list of state tensors of
  161. [BATCH, state_size_i].
  162. """
  163. # Original observations will be stored in "obs".
  164. # Flattened (preprocessed) obs will be stored in "obs_flat".
  165. # SampleBatch case: Models can now be called directly with a
  166. # SampleBatch (which also includes tracking-dict case (deprecated now),
  167. # where tensors get automatically converted).
  168. if isinstance(input_dict, SampleBatch):
  169. restored = input_dict.copy(shallow=True)
  170. else:
  171. restored = input_dict.copy()
  172. # Backward compatibility.
  173. if not state:
  174. state = []
  175. i = 0
  176. while "state_in_{}".format(i) in input_dict:
  177. state.append(input_dict["state_in_{}".format(i)])
  178. i += 1
  179. if seq_lens is None:
  180. seq_lens = input_dict.get(SampleBatch.SEQ_LENS)
  181. # No Preprocessor used: `config._disable_preprocessor_api`=True.
  182. # TODO: This is unnecessary for when no preprocessor is used.
  183. # Obs are not flat then anymore. However, we'll keep this
  184. # here for backward-compatibility until Preprocessors have
  185. # been fully deprecated.
  186. if self.model_config.get("_disable_preprocessor_api"):
  187. restored["obs_flat"] = input_dict["obs"]
  188. # Input to this Model went through a Preprocessor.
  189. # Generate extra keys: "obs_flat" (vs "obs", which will hold the
  190. # original obs).
  191. else:
  192. restored["obs"] = restore_original_dimensions(
  193. input_dict["obs"], self.obs_space, self.framework)
  194. try:
  195. if len(input_dict["obs"].shape) > 2:
  196. restored["obs_flat"] = flatten(input_dict["obs"],
  197. self.framework)
  198. else:
  199. restored["obs_flat"] = input_dict["obs"]
  200. except AttributeError:
  201. restored["obs_flat"] = input_dict["obs"]
  202. with self.context():
  203. res = self.forward(restored, state or [], seq_lens)
  204. if isinstance(input_dict, SampleBatch):
  205. input_dict.accessed_keys = restored.accessed_keys - {"obs_flat"}
  206. input_dict.deleted_keys = restored.deleted_keys
  207. input_dict.added_keys = restored.added_keys - {"obs_flat"}
  208. if ((not isinstance(res, list) and not isinstance(res, tuple))
  209. or len(res) != 2):
  210. raise ValueError(
  211. "forward() must return a tuple of (output, state) tensors, "
  212. "got {}".format(res))
  213. outputs, state_out = res
  214. if not isinstance(state_out, list):
  215. raise ValueError(
  216. "State output is not a list: {}".format(state_out))
  217. self._last_output = outputs
  218. return outputs, state_out if len(state_out) > 0 else (state or [])
  219. @Deprecated(new="ModelV2.__call__()", error=False)
  220. def from_batch(self, train_batch: SampleBatch,
  221. is_training: bool = True) -> (TensorType, List[TensorType]):
  222. """Convenience function that calls this model with a tensor batch.
  223. All this does is unpack the tensor batch to call this model with the
  224. right input dict, state, and seq len arguments.
  225. """
  226. input_dict = train_batch.copy()
  227. input_dict["is_training"] = is_training
  228. states = []
  229. i = 0
  230. while "state_in_{}".format(i) in input_dict:
  231. states.append(input_dict["state_in_{}".format(i)])
  232. i += 1
  233. ret = self.__call__(input_dict, states,
  234. input_dict.get(SampleBatch.SEQ_LENS))
  235. return ret
  236. def import_from_h5(self, h5_file: str) -> None:
  237. """Imports weights from an h5 file.
  238. Args:
  239. h5_file (str): The h5 file name to import weights from.
  240. Example:
  241. >>> trainer = MyTrainer()
  242. >>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
  243. >>> for _ in range(10):
  244. >>> trainer.train()
  245. """
  246. raise NotImplementedError
  247. @PublicAPI
  248. def last_output(self) -> TensorType:
  249. """Returns the last output returned from calling the model."""
  250. return self._last_output
  251. @PublicAPI
  252. def context(self) -> contextlib.AbstractContextManager:
  253. """Returns a contextmanager for the current forward pass."""
  254. return NullContextManager()
  255. @PublicAPI
  256. def variables(self, as_dict: bool = False
  257. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  258. """Returns the list (or a dict) of variables for this model.
  259. Args:
  260. as_dict(bool): Whether variables should be returned as dict-values
  261. (using descriptive str keys).
  262. Returns:
  263. Union[List[any],Dict[str,any]]: The list (or dict if `as_dict` is
  264. True) of all variables of this ModelV2.
  265. """
  266. raise NotImplementedError
  267. @PublicAPI
  268. def trainable_variables(
  269. self, as_dict: bool = False
  270. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  271. """Returns the list of trainable variables for this model.
  272. Args:
  273. as_dict(bool): Whether variables should be returned as dict-values
  274. (using descriptive keys).
  275. Returns:
  276. Union[List[any],Dict[str,any]]: The list (or dict if `as_dict` is
  277. True) of all trainable (tf)/requires_grad (torch) variables
  278. of this ModelV2.
  279. """
  280. raise NotImplementedError
  281. @PublicAPI
  282. def is_time_major(self) -> bool:
  283. """If True, data for calling this ModelV2 must be in time-major format.
  284. Returns
  285. bool: Whether this ModelV2 requires a time-major (TxBx...) data
  286. format.
  287. """
  288. return self.time_major is True
  289. @DeveloperAPI
  290. def flatten(obs: TensorType, framework: str) -> TensorType:
  291. """Flatten the given tensor."""
  292. if framework in ["tf2", "tf", "tfe"]:
  293. return tf1.keras.layers.Flatten()(obs)
  294. elif framework == "torch":
  295. assert torch is not None
  296. return torch.flatten(obs, start_dim=1)
  297. else:
  298. raise NotImplementedError("flatten", framework)
  299. @DeveloperAPI
  300. def restore_original_dimensions(obs: TensorType,
  301. obs_space: gym.spaces.Space,
  302. tensorlib: Any = tf) -> TensorStructType:
  303. """Unpacks Dict and Tuple space observations into their original form.
  304. This is needed since we flatten Dict and Tuple observations in transit
  305. within a SampleBatch. Before sending them to the model though, we should
  306. unflatten them into Dicts or Tuples of tensors.
  307. Args:
  308. obs (TensorType): The flattened observation tensor.
  309. obs_space (gym.spaces.Space): The flattened obs space. If this has the
  310. `original_space` attribute, we will unflatten the tensor to that
  311. shape.
  312. tensorlib: The library used to unflatten (reshape) the array/tensor.
  313. Returns:
  314. single tensor or dict / tuple of tensors matching the original
  315. observation space.
  316. """
  317. if tensorlib in ["tf", "tfe", "tf2"]:
  318. assert tf is not None
  319. tensorlib = tf
  320. elif tensorlib == "torch":
  321. assert torch is not None
  322. tensorlib = torch
  323. original_space = getattr(obs_space, "original_space", obs_space)
  324. return _unpack_obs(obs, original_space, tensorlib=tensorlib)
  325. # Cache of preprocessors, for if the user is calling unpack obs often.
  326. _cache = {}
  327. def _unpack_obs(obs: TensorType, space: gym.Space,
  328. tensorlib: Any = tf) -> TensorStructType:
  329. """Unpack a flattened Dict or Tuple observation array/tensor.
  330. Args:
  331. obs: The flattened observation tensor, with last dimension equal to
  332. the flat size and any number of batch dimensions. For example, for
  333. Box(4,), the obs may have shape [B, 4], or [B, N, M, 4] in case
  334. the Box was nested under two Repeated spaces.
  335. space: The original space prior to flattening
  336. tensorlib: The library used to unflatten (reshape) the array/tensor
  337. """
  338. if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple, Repeated)):
  339. if id(space) in _cache:
  340. prep = _cache[id(space)]
  341. else:
  342. prep = get_preprocessor(space)(space)
  343. # Make an attempt to cache the result, if enough space left.
  344. if len(_cache) < 999:
  345. _cache[id(space)] = prep
  346. # Already unpacked?
  347. if (isinstance(space, gym.spaces.Tuple) and
  348. isinstance(obs, (list, tuple))) or \
  349. (isinstance(space, gym.spaces.Dict) and isinstance(obs, dict)):
  350. return obs
  351. elif len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]:
  352. raise ValueError(
  353. "Expected flattened obs shape of [..., {}], got {}".format(
  354. prep.shape[0], obs.shape))
  355. offset = 0
  356. if tensorlib == tf:
  357. batch_dims = [
  358. v if isinstance(v, int) else v.value for v in obs.shape[:-1]
  359. ]
  360. batch_dims = [-1 if v is None else v for v in batch_dims]
  361. else:
  362. batch_dims = list(obs.shape[:-1])
  363. if isinstance(space, gym.spaces.Tuple):
  364. assert len(prep.preprocessors) == len(space.spaces), \
  365. (len(prep.preprocessors) == len(space.spaces))
  366. u = []
  367. for p, v in zip(prep.preprocessors, space.spaces):
  368. obs_slice = obs[..., offset:offset + p.size]
  369. offset += p.size
  370. u.append(
  371. _unpack_obs(
  372. tensorlib.reshape(obs_slice,
  373. batch_dims + list(p.shape)),
  374. v,
  375. tensorlib=tensorlib))
  376. elif isinstance(space, gym.spaces.Dict):
  377. assert len(prep.preprocessors) == len(space.spaces), \
  378. (len(prep.preprocessors) == len(space.spaces))
  379. u = OrderedDict()
  380. for p, (k, v) in zip(prep.preprocessors, space.spaces.items()):
  381. obs_slice = obs[..., offset:offset + p.size]
  382. offset += p.size
  383. u[k] = _unpack_obs(
  384. tensorlib.reshape(obs_slice, batch_dims + list(p.shape)),
  385. v,
  386. tensorlib=tensorlib)
  387. # Repeated space.
  388. else:
  389. assert isinstance(prep, RepeatedValuesPreprocessor), prep
  390. child_size = prep.child_preprocessor.size
  391. # The list lengths are stored in the first slot of the flat obs.
  392. lengths = obs[..., 0]
  393. # [B, ..., 1 + max_len * child_sz] -> [B, ..., max_len, child_sz]
  394. with_repeat_dim = tensorlib.reshape(
  395. obs[..., 1:], batch_dims + [space.max_len, child_size])
  396. # Retry the unpack, dropping the List container space.
  397. u = _unpack_obs(
  398. with_repeat_dim, space.child_space, tensorlib=tensorlib)
  399. return RepeatedValues(
  400. u, lengths=lengths, max_len=prep._obs_space.max_len)
  401. return u
  402. else:
  403. return obs