rllib-models.rst 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. RLlib Models, Preprocessors, and Action Distributions
  2. =====================================================
  3. The following diagram provides a conceptual overview of data flow between different components in RLlib.
  4. We start with an ``Environment``, which - given an action - produces an observation.
  5. The observation is preprocessed by a ``Preprocessor`` and ``Filter`` (e.g. for running mean normalization)
  6. before being sent to a neural network ``Model``. The model output is in turn
  7. interpreted by an ``ActionDistribution`` to determine the next action.
  8. .. image:: rllib-components.svg
  9. The components highlighted in green can be replaced with custom user-defined
  10. implementations, as described in the next sections. The purple components are
  11. RLlib internal, which means they can only be modified by changing the algorithm
  12. source code.
  13. Default Behaviors
  14. -----------------
  15. Built-in Preprocessors
  16. ~~~~~~~~~~~~~~~~~~~~~~
  17. RLlib tries to pick one of its built-in preprocessors based on the environment's
  18. observation space. Thereby, the following simple rules apply:
  19. - Discrete observations are one-hot encoded, e.g. ``Discrete(3) and value=1 -> [0, 1, 0]``.
  20. - MultiDiscrete observations are "multi" one-hot encoded,
  21. e.g. ``MultiDiscrete([3, 4]) and value=[1, 0] -> [0 1 0 1 0 0 0]``.
  22. - Tuple and Dict observations are flattened, thereby, Discrete and MultiDiscrete
  23. sub-spaces are handled as described above.
  24. Also, the original dict/tuple observations are still available inside a) the Model via the input
  25. dict's "obs" key (the flattened observations are in "obs_flat"), as well as b) the Policy
  26. via the following line of code (e.g. put this into your loss function to access the original
  27. observations: ``dict_or_tuple_obs = restore_original_dimensions(input_dict["obs"], self.obs_space, "tf|torch")``
  28. For Atari observation spaces, RLlib defaults to using the `DeepMind preprocessors <https://github.com/ray-project/ray/blob/master/rllib/env/atari_wrappers.py>`__
  29. (``preprocessor_pref=deepmind``). However, if the Trainer's config key ``preprocessor_pref`` is set to "rllib",
  30. the following mappings apply for Atari-type observation spaces:
  31. - Images of shape ``(210, 160, 3)`` are downscaled to ``dim x dim``, where
  32. ``dim`` is a model config key (see default Model config below). Also, you can set
  33. ``grayscale=True`` for reducing the color channel to 1, or ``zero_mean=True`` for
  34. producing -1.0 to 1.0 values (instead of 0.0 to 1.0 values by default).
  35. - Atari RAM observations (1D space of shape ``(128, )``) are zero-averaged
  36. (values between -1.0 and 1.0).
  37. In all other cases, no preprocessor will be used and the raw observations from the environment
  38. will be sent directly into your model.
  39. Default Model Config Settings
  40. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  41. In the following paragraphs, we will first describe RLlib's default behavior for automatically constructing
  42. models (if you don't setup a custom one), then dive into how you can customize your models by changing these
  43. settings or writing your own model classes.
  44. By default, RLlib will use the following config settings for your models.
  45. These include options for the ``FullyConnectedNetworks`` (``fcnet_hiddens`` and ``fcnet_activation``),
  46. ``VisionNetworks`` (``conv_filters`` and ``conv_activation``), auto-RNN wrapping, auto-Attention (`GTrXL <https://arxiv.org/abs/1910.06764>`__) wrapping,
  47. and some special options for Atari environments:
  48. .. literalinclude:: ../../rllib/models/catalog.py
  49. :language: python
  50. :start-after: __sphinx_doc_begin__
  51. :end-before: __sphinx_doc_end__
  52. The dict above (or an overriding sub-set) is handed to the Trainer via the ``model`` key within
  53. the main config dict like so:
  54. .. code-block:: python
  55. algo_config = {
  56. # All model-related settings go into this sub-dict.
  57. "model": {
  58. # By default, the MODEL_DEFAULTS dict above will be used.
  59. # Change individual keys in that dict by overriding them, e.g.
  60. "fcnet_hiddens": [512, 512, 512],
  61. "fcnet_activation": "relu",
  62. },
  63. # ... other Trainer config keys, e.g. "lr" ...
  64. "lr": 0.00001,
  65. }
  66. Built-in Models
  67. ~~~~~~~~~~~~~~~
  68. After preprocessing (if applicable) the raw environment outputs, the processed observations are fed through the policy's model.
  69. In case, no custom model is specified (see further below on how to customize models), RLlib will pick a default model
  70. based on simple heuristics:
  71. - A vision network (`TF <https://github.com/ray-project/ray/blob/master/rllib/models/tf/visionnet.py>`__ or `Torch <https://github.com/ray-project/ray/blob/master/rllib/models/torch/visionnet.py>`__)
  72. for observations that have a shape of length larger than 2, for example, ``(84 x 84 x 3)``.
  73. - A fully connected network (`TF <https://github.com/ray-project/ray/blob/master/rllib/models/tf/fcnet.py>`__ or `Torch <https://github.com/ray-project/ray/blob/master/rllib/models/torch/fcnet.py>`__)
  74. for everything else.
  75. These default model types can further be configured via the ``model`` config key inside your Trainer config (as discussed above).
  76. Available settings are `listed above <#default-model-config-settings>`__ and also documented in the `model catalog file <https://github.com/ray-project/ray/blob/master/rllib/models/catalog.py>`__.
  77. Note that for the vision network case, you'll probably have to configure ``conv_filters``, if your environment observations
  78. have custom sizes. For example, ``"model": {"dim": 42, "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], [512, [11, 11], 1]]}`` for 42x42 observations.
  79. Thereby, always make sure that the last Conv2D output has an output shape of ``[B, 1, 1, X]`` (``[B, X, 1, 1]`` for PyTorch), where B=batch and
  80. X=last Conv2D layer's number of filters, so that RLlib can flatten it. An informative error will be thrown if this is not the case.
  81. .. _auto_lstm_and_attention:
  82. Built-in auto-LSTM, and auto-Attention Wrappers
  83. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  84. In addition, if you set ``"use_lstm": True`` or ``"use_attention": True`` in your model config,
  85. your model's output will be further processed by an LSTM cell
  86. (`TF <https://github.com/ray-project/ray/blob/master/rllib/models/tf/recurrent_net.py>`__ or `Torch <https://github.com/ray-project/ray/blob/master/rllib/models/torch/recurrent_net.py>`__),
  87. or an attention (`GTrXL <https://arxiv.org/abs/1910.06764>`__) network
  88. (`TF <https://github.com/ray-project/ray/blob/master/rllib/models/tf/attention_net.py>`__ or
  89. `Torch <https://github.com/ray-project/ray/blob/master/rllib/models/torch/attention_net.py>`__), respectively.
  90. More generally, RLlib supports the use of recurrent/attention models for all
  91. its policy-gradient algorithms (A3C, PPO, PG, IMPALA), and the necessary sequence processing support
  92. is built into its policy evaluation utilities.
  93. See above for which additional config keys to use to configure in more detail these two auto-wrappers
  94. (e.g. you can specify the size of the LSTM layer by ``lstm_cell_size`` or the attention dim by ``attention_dim``).
  95. For fully customized RNN/LSTM/Attention-Net setups see the `Recurrent Models <#rnns>`_ and
  96. `Attention Networks/Transformers <#attention>`_ sections below.
  97. .. note::
  98. It is not possible to use both auto-wrappers (lstm and attention) at the same time. Doing so will create an error.
  99. Customizing Preprocessors and Models
  100. ------------------------------------
  101. Custom Preprocessors and Environment Filters
  102. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  103. .. warning::
  104. Custom preprocessors are deprecated, since they sometimes conflict with the built-in preprocessors for handling complex observation spaces.
  105. Please use `wrapper classes <https://github.com/openai/gym/tree/master/gym/wrappers>`__ around your environment instead of preprocessors.
  106. Note that the built-in **default** Preprocessors described above will still be used and won't be deprecated.
  107. Instead of using the deprecated custom Preprocessors, you should use ``gym.Wrappers`` to preprocess your environment's output (observations and rewards),
  108. but also your Model's computed actions before sending them back to the environment.
  109. For example, for manipulating your env's observations or rewards, do:
  110. .. code-block:: python
  111. import gym
  112. from ray.rllib.utils.numpy import one_hot
  113. class OneHotEnv(gym.core.ObservationWrapper):
  114. # Override `observation` to custom process the original observation
  115. # coming from the env.
  116. def observation(self, observation):
  117. # E.g. one-hotting a float obs [0.0, 5.0[.
  118. return one_hot(observation, depth=5)
  119. class ClipRewardEnv(gym.core.RewardWrapper):
  120. def __init__(self, env, min_, max_):
  121. super().__init__(env)
  122. self.min = min_
  123. self.max = max_
  124. # Override `reward` to custom process the original reward coming
  125. # from the env.
  126. def reward(self, reward):
  127. # E.g. simple clipping between min and max.
  128. return np.clip(reward, self.min, self.max)
  129. Custom Models: Implementing your own Forward Logic
  130. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  131. If you would like to provide your own model logic (instead of using RLlib's built-in defaults), you
  132. can sub-class either ``TFModelV2`` (for TensorFlow) or ``TorchModelV2`` (for PyTorch) and then
  133. register and specify your sub-class in the config as follows:
  134. Custom TensorFlow Models
  135. ````````````````````````
  136. Custom TensorFlow models should subclass `TFModelV2 <https://github.com/ray-project/ray/blob/master/rllib/models/tf/tf_modelv2.py>`__ and implement the ``__init__()`` and ``forward()`` methods.
  137. ``forward()`` takes a dict of tensor inputs (mapping str to Tensor types), whose keys and values depend on
  138. the `view requirements <rllib-sample-collection.html>`__ of the model.
  139. Normally, this input dict contains only the current observation ``obs`` and an ``is_training`` boolean flag, as well as an optional list of RNN states.
  140. ``forward()`` should return the model output (of size ``self.num_outputs``) and - if applicable - a new list of internal
  141. states (in case of RNNs or attention nets). You can also override extra methods of the model such as ``value_function`` to implement
  142. a custom value branch.
  143. Additional supervised/self-supervised losses can be added via the ``TFModelV2.custom_loss`` method:
  144. .. autoclass:: ray.rllib.models.tf.tf_modelv2.TFModelV2
  145. .. automethod:: __init__
  146. .. automethod:: forward
  147. .. automethod:: value_function
  148. .. automethod:: custom_loss
  149. .. automethod:: metrics
  150. .. automethod:: update_ops
  151. .. automethod:: register_variables
  152. .. automethod:: variables
  153. .. automethod:: trainable_variables
  154. Once implemented, your TF model can then be registered and used in place of a built-in default one:
  155. .. code-block:: python
  156. import ray
  157. import ray.rllib.agents.ppo as ppo
  158. from ray.rllib.models import ModelCatalog
  159. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  160. class MyModelClass(TFModelV2):
  161. def __init__(self, obs_space, action_space, num_outputs, model_config, name): ...
  162. def forward(self, input_dict, state, seq_lens): ...
  163. def value_function(self): ...
  164. ModelCatalog.register_custom_model("my_tf_model", MyModelClass)
  165. ray.init()
  166. trainer = ppo.PPOTrainer(env="CartPole-v0", config={
  167. "model": {
  168. "custom_model": "my_tf_model",
  169. # Extra kwargs to be passed to your model's c'tor.
  170. "custom_model_config": {},
  171. },
  172. })
  173. See the `keras model example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_model.py>`__ for a full example of a TF custom model.
  174. More examples and explanations on how to implement custom Tuple/Dict processing models
  175. (also check out `this test case here <https://github.com/ray-project/ray/blob/master/rllib/tests/test_nested_observation_spaces.py>`__),
  176. custom RNNs, custom model APIs (on top of default models) follow further below.
  177. Custom PyTorch Models
  178. `````````````````````
  179. Similarly, you can create and register custom PyTorch models by subclassing
  180. `TorchModelV2 <https://github.com/ray-project/ray/blob/master/rllib/models/torch/torch_modelv2.py>`__ and implement the ``__init__()`` and ``forward()`` methods.
  181. ``forward()`` takes a dict of tensor inputs (mapping str to PyTorch tensor types), whose keys and values depend on
  182. the `view requirements <rllib-sample-collection.html>`__ of the model.
  183. Usually, the dict contains only the current observation ``obs`` and an ``is_training`` boolean flag, as well as an optional list of RNN states.
  184. ``forward()`` should return the model output (of size ``self.num_outputs``) and - if applicable - a new list of internal
  185. states (in case of RNNs or attention nets). You can also override extra methods of the model such as ``value_function`` to implement
  186. a custom value branch.
  187. Additional supervised/self-supervised losses can be added via the ``TorchModelV2.custom_loss`` method:
  188. See these examples of `fully connected <https://github.com/ray-project/ray/blob/master/rllib/models/torch/fcnet.py>`__, `convolutional <https://github.com/ray-project/ray/blob/master/rllib/models/torch/visionnet.py>`__, and `recurrent <https://github.com/ray-project/ray/blob/master/rllib/models/torch/recurrent_net.py>`__ torch models.
  189. .. autoclass:: ray.rllib.models.torch.torch_modelv2.TorchModelV2
  190. .. automethod:: __init__
  191. .. automethod:: forward
  192. .. automethod:: value_function
  193. .. automethod:: custom_loss
  194. .. automethod:: metrics
  195. .. automethod:: get_initial_state
  196. .. automethod:: variables
  197. .. automethod:: trainable_variables
  198. Once implemented, your PyTorch model can then be registered and used in place of a built-in model:
  199. .. code-block:: python
  200. import torch.nn as nn
  201. import ray
  202. from ray.rllib.agents import ppo
  203. from ray.rllib.models import ModelCatalog
  204. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  205. class CustomTorchModel(TorchModelV2):
  206. def __init__(self, obs_space, action_space, num_outputs, model_config, name): ...
  207. def forward(self, input_dict, state, seq_lens): ...
  208. def value_function(self): ...
  209. ModelCatalog.register_custom_model("my_torch_model", CustomTorchModel)
  210. ray.init()
  211. trainer = ppo.PPOTrainer(env="CartPole-v0", config={
  212. "framework": "torch",
  213. "model": {
  214. "custom_model": "my_torch_model",
  215. # Extra kwargs to be passed to your model's c'tor.
  216. "custom_model_config": {},
  217. },
  218. })
  219. See the `torch model examples <https://github.com/ray-project/ray/blob/master/rllib/examples/models/>`__ for various examples on how to build a custom
  220. PyTorch model (including recurrent ones).
  221. More examples and explanations on how to implement custom Tuple/Dict processing models (also check out `this test case here <https://github.com/ray-project/ray/blob/master/rllib/tests/test_nested_observation_spaces.py>`__),
  222. custom RNNs, custom model APIs (on top of default models) follow further below.
  223. Wrapping a Custom Model (TF and PyTorch) with an LSTM- or Attention Net
  224. ```````````````````````````````````````````````````````````````````````
  225. You can also use a custom (TF or PyTorch) model with our auto-wrappers for LSTMs (``use_lstm=True``) or Attention networks (``use_attention=True``).
  226. For example, if you would like to wrap some non-default model logic with an LSTM, simply do:
  227. .. literalinclude:: ../../rllib/examples/lstm_auto_wrapping.py
  228. :language: python
  229. :start-after: __sphinx_doc_begin__
  230. :end-before: __sphinx_doc_end__
  231. .. _rnns:
  232. Implementing custom Recurrent Networks
  233. ``````````````````````````````````````
  234. Instead of using the ``use_lstm: True`` option, it may be preferable to use a custom recurrent model.
  235. This provides more control over postprocessing the LSTM's output and can also allow the use of multiple LSTM cells to process different portions of the input.
  236. For an RNN model it is recommended to subclass ``RecurrentNetwork`` (either the `TF <https://github.com/ray-project/ray/blob/master/rllib/models/tf/recurrent_net.py>`__
  237. or `PyTorch <https://github.com/ray-project/ray/blob/master/rllib/models/torch/recurrent_net.py>`__ versions) and then implement ``__init__()``,
  238. ``get_initial_state()``, and ``forward_rnn()``.
  239. .. autoclass:: ray.rllib.models.tf.recurrent_net.RecurrentNetwork
  240. .. automethod:: __init__
  241. .. automethod:: get_initial_state
  242. .. automethod:: forward_rnn
  243. Note that the ``inputs`` arg entering ``forward_rnn`` is already a time-ranked single tensor (not an ``input_dict``!) with shape ``(B x T x ...)``.
  244. If you further want to customize and need more direct access to the complete (non time-ranked) ``input_dict``, you can also override
  245. your Model's ``forward`` method directly (as you would do with a non-RNN ModelV2). In that case, though, you are responsible for changing your inputs
  246. and add the time rank to the incoming data (usually you just have to reshape).
  247. You can check out the `rnn_model.py <https://github.com/ray-project/ray/blob/master/rllib/examples/models/rnn_model.py>`__ models as examples to implement
  248. your own (either TF or Torch).
  249. .. _attention:
  250. Implementing custom Attention Networks
  251. ``````````````````````````````````````
  252. Similar to the RNN case described above, you could also implement your own attention-based networks, instead of using the
  253. ``use_attention: True`` flag in your model config.
  254. Check out RLlib's `GTrXL (Attention Net) <https://arxiv.org/abs/1910.06764>`__ implementations
  255. (for `TF <https://github.com/ray-project/ray/blob/master/rllib/models/tf/attention_net.py>`__ and `PyTorch <https://github.com/ray-project/ray/blob/master/rllib/models/torch/attention_net.py>`__)
  256. to get a better idea on how to write your own models of this type. These are the models we use
  257. as wrappers when ``use_attention=True``.
  258. You can run `this example script <https://github.com/ray-project/ray/blob/master/rllib/examples/attention_net.py>`__ to run these nets within some of our algorithms.
  259. `There is also a test case <https://github.com/ray-project/ray/blob/master/rllib/tests/test_attention_net_learning.py>`__, which confirms their learning capabilities in PPO and IMPALA.
  260. Batch Normalization
  261. ```````````````````
  262. You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model
  263. (see a `code example here <https://github.com/ray-project/ray/blob/master/rllib/examples/batch_norm_model.py>`__).
  264. RLlib will automatically run the update ops for the batch norm layers during optimization
  265. (see `tf_policy.py <https://github.com/ray-project/ray/blob/master/rllib/policy/tf_policy.py>`__ and
  266. `multi_gpu_impl.py <https://github.com/ray-project/ray/blob/master/rllib/execution/multi_gpu_impl.py>`__ for the exact handling of these updates).
  267. In case RLlib does not properly detect the update ops for your custom model, you can override the ``update_ops()`` method to return the list of ops to run for updates.
  268. Custom Model APIs (on Top of Default- or Custom Models)
  269. ```````````````````````````````````````````````````````
  270. So far we talked about a) the default models that are built into RLlib and are being provided
  271. automatically if you don't specify anything in your Trainer's config and b) custom Models through
  272. which you can define any arbitrary forward passes.
  273. Another typical situation in which you would have to customize a model would be to
  274. add a new API that your algorithm needs in order to learn, for example a Q-value
  275. calculating head on top of your policy model. In order to expand a Model's API, simply
  276. define and implement a new method (e.g. ``get_q_values()``) in your TF- or TorchModelV2 sub-class.
  277. You can now wrap this new API either around RLlib's default models or around
  278. your custom (``forward()``-overriding) model classes. Here are two examples that illustrate how to do this:
  279. **The Q-head API: Adding a dueling layer on top of a default RLlib model**.
  280. The following code adds a ``get_q_values()`` method to the automatically chosen
  281. default Model (e.g. a ``FullyConnectedNetwork`` if the observation space is a 1D Box
  282. or Discrete):
  283. .. literalinclude:: ../../rllib/examples/models/custom_model_api.py
  284. :language: python
  285. :start-after: __sphinx_doc_model_api_1_begin__
  286. :end-before: __sphinx_doc_model_api_1_end__
  287. Now, for your algorithm that needs to have this model API to work properly (e.g. DQN),
  288. you use this following code to construct the complete final Model using the
  289. ``ModelCatalog.get_model_v2`` factory function (`code here <https://github.com/ray-project/ray/blob/master/rllib/models/catalog.py>`__):
  290. .. literalinclude:: ../../rllib/examples/custom_model_api.py
  291. :language: python
  292. :start-after: __sphinx_doc_model_construct_1_begin__
  293. :end-before: __sphinx_doc_model_construct_1_end__
  294. With the model object constructed above, you can get the underlying intermediate output (before the dueling head)
  295. by calling ``my_dueling_model`` directly (``out = my_dueling_model([input_dict])``), and then passing ``out`` into
  296. your custom ``get_q_values`` method: ``q_values = my_dueling_model.get_q_values(out)``.
  297. **The single Q-value API for SAC**.
  298. Our DQN model from above takes an observation and outputs one Q-value per (discrete) action.
  299. Continuous SAC - on the other hand - uses Models that calculate one Q-value only
  300. for a single (**continuous**) action, given an observation and that particular action.
  301. Let's take a look at how we would construct this API and wrap it around a custom model:
  302. .. literalinclude:: ../../rllib/examples/models/custom_model_api.py
  303. :language: python
  304. :start-after: __sphinx_doc_model_api_2_begin__
  305. :end-before: __sphinx_doc_model_api_2_end__
  306. Now, for your algorithm that needs to have this model API to work properly (e.g. SAC),
  307. you use this following code to construct the complete final Model using the
  308. ``ModelCatalog.get_model_v2`` factory function (`code here <https://github.com/ray-project/ray/blob/master/rllib/models/catalog.py>`__):
  309. .. literalinclude:: ../../rllib/examples/custom_model_api.py
  310. :language: python
  311. :start-after: __sphinx_doc_model_construct_2_begin__
  312. :end-before: __sphinx_doc_model_construct_2_end__
  313. With the model object constructed above, you can get the underlying intermediate output (before the q-head)
  314. by calling ``my_cont_action_q_model`` directly (``out = my_cont_action_q_model([input_dict])``), and then passing ``out``
  315. and some action into your custom ``get_single_q_value`` method:
  316. ``q_value = my_cont_action_q_model.get_signle_q_value(out, action)``.
  317. More examples for Building Custom Models
  318. ````````````````````````````````````````
  319. **A multi-input capable model for Tuple observation spaces (for PPO)**
  320. RLlib's default preprocessor for Tuple and Dict spaces is to flatten incoming observations
  321. into one flat **1D** array, and then pick a fully connected network (by default) to
  322. process this flattened vector. This is usually ok, if you have only 1D Box or
  323. Discrete/MultiDiscrete sub-spaces in your observations.
  324. However, what if you had a complex observation space with one or more image components in
  325. it (besides 1D Boxes and discrete spaces). You would probably want to preprocess each of the
  326. image components using some convolutional network, and then concatenate their outputs
  327. with the remaining non-image (flat) inputs (the 1D Box and discrete/one-hot components).
  328. Take a look at this model example that does exactly that:
  329. .. literalinclude:: ../../rllib/models/tf/complex_input_net.py
  330. :language: python
  331. :start-after: __sphinx_doc_begin__
  332. :end-before: __sphinx_doc_end__
  333. **Using the Trajectory View API: Passing in the last n actions (or rewards or observations) as inputs to a custom Model**
  334. It is sometimes helpful for learning not only to look at the current observation
  335. in order to calculate the next action, but also at the past n observations.
  336. In other cases, you may want to provide the most recent rewards or actions to the model as well
  337. (like our LSTM wrapper does if you specify: ``use_lstm=True`` and ``lstm_use_prev_action/reward=True``).
  338. All this may even be useful when not working with partially observable environments (PO-MDPs)
  339. and/or RNN/Attention models, as for example in classic Atari runs, where we usually use framestacking of
  340. the last four observed images.
  341. The `trajectory view API <rllib-sample-collection.html#trajectory-view-api>`__ allows your models
  342. to specify these more complex "view requirements".
  343. Here is a simple (non-RNN/Attention) example of a Model that takes as input
  344. the last 3 observations (very similar to the recommended "framestacking" for
  345. learning in Atari environments):
  346. .. literalinclude:: ../../rllib/examples/models/trajectory_view_utilizing_models.py
  347. :language: python
  348. :start-after: __sphinx_doc_begin__
  349. :end-before: __sphinx_doc_end__
  350. A PyTorch version of the above model is also `given in the same file <https://github.com/ray-project/ray/blob/master/rllib/examples/models/trajectory_view_utilizing_models.py>`__.
  351. Custom Action Distributions
  352. ---------------------------
  353. Similar to custom models and preprocessors, you can also specify a custom action distribution class as follows. The action dist class is passed a reference to the ``model``, which you can use to access ``model.model_config`` or other attributes of the model. This is commonly used to implement `autoregressive action outputs <#autoregressive-action-distributions>`__.
  354. .. code-block:: python
  355. import ray
  356. import ray.rllib.agents.ppo as ppo
  357. from ray.rllib.models import ModelCatalog
  358. from ray.rllib.models.preprocessors import Preprocessor
  359. class MyActionDist(ActionDistribution):
  360. @staticmethod
  361. def required_model_output_shape(action_space, model_config):
  362. return 7 # controls model output feature vector size
  363. def __init__(self, inputs, model):
  364. super(MyActionDist, self).__init__(inputs, model)
  365. assert model.num_outputs == 7
  366. def sample(self): ...
  367. def logp(self, actions): ...
  368. def entropy(self): ...
  369. ModelCatalog.register_custom_action_dist("my_dist", MyActionDist)
  370. ray.init()
  371. trainer = ppo.PPOTrainer(env="CartPole-v0", config={
  372. "model": {
  373. "custom_action_dist": "my_dist",
  374. },
  375. })
  376. Supervised Model Losses
  377. -----------------------
  378. You can mix supervised losses into any RLlib algorithm through custom models. For example, you can add an imitation learning loss on expert experiences, or a self-supervised autoencoder loss within the model. These losses can be defined over either policy evaluation inputs, or data read from `offline storage <rllib-offline.html#input-pipeline-for-supervised-losses>`__.
  379. **TensorFlow**: To add a supervised loss to a custom TF model, you need to override the ``custom_loss()`` method. This method takes in the existing policy loss for the algorithm, which you can add your own supervised loss to before returning. For debugging, you can also return a dictionary of scalar tensors in the ``metrics()`` method. Here is a `runnable example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_loss.py>`__ of adding an imitation loss to CartPole training that is defined over a `offline dataset <rllib-offline.html#input-pipeline-for-supervised-losses>`__.
  380. **PyTorch**: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling ``reader.next()`` in the loss forward pass.
  381. Self-Supervised Model Losses
  382. ----------------------------
  383. You can also use the ``custom_loss()`` API to add in self-supervised losses such as VAE reconstruction loss and L2-regularization.
  384. Variable-length / Complex Observation Spaces
  385. --------------------------------------------
  386. RLlib supports complex and variable-length observation spaces, including ``gym.spaces.Tuple``, ``gym.spaces.Dict``, and ``rllib.utils.spaces.Repeated``. The handling of these spaces is transparent to the user. RLlib internally will insert preprocessors to insert padding for repeated elements, flatten complex observations into a fixed-size vector during transit, and unpack the vector into the structured tensor before sending it to the model. The flattened observation is available to the model as ``input_dict["obs_flat"]``, and the unpacked observation as ``input_dict["obs"]``.
  387. To enable batching of struct observations, RLlib unpacks them in a `StructTensor-like format <https://github.com/tensorflow/community/blob/master/rfcs/20190910-struct-tensor.md>`__. In summary, repeated fields are "pushed down" and become the outer dimensions of tensor batches, as illustrated in this figure from the StructTensor RFC.
  388. .. image:: struct-tensor.png
  389. For further information about complex observation spaces, see:
  390. * A custom environment and model that uses `repeated struct fields <https://github.com/ray-project/ray/blob/master/rllib/examples/complex_struct_space.py>`__.
  391. * The pydoc of the `Repeated space <https://github.com/ray-project/ray/blob/master/rllib/utils/spaces/repeated.py>`__.
  392. * The pydoc of the batched `repeated values tensor <https://github.com/ray-project/ray/blob/master/rllib/models/repeated_values.py>`__.
  393. * The `unit tests <https://github.com/ray-project/ray/blob/master/rllib/tests/test_nested_observation_spaces.py>`__ for Tuple and Dict spaces.
  394. Variable-length / Parametric Action Spaces
  395. ------------------------------------------
  396. Custom models can be used to work with environments where (1) the set of valid actions `varies per step <https://neuro.cs.ut.ee/the-use-of-embeddings-in-openai-five>`__, and/or (2) the number of valid actions is `very large <https://arxiv.org/abs/1811.00260>`__. The general idea is that the meaning of actions can be completely conditioned on the observation, i.e., the ``a`` in ``Q(s, a)`` becomes just a token in ``[0, MAX_AVAIL_ACTIONS)`` that only has meaning in the context of ``s``. This works with algorithms in the `DQN and policy-gradient families <rllib-env.html>`__ and can be implemented as follows:
  397. 1. The environment should return a mask and/or list of valid action embeddings as part of the observation for each step. To enable batching, the number of actions can be allowed to vary from 1 to some max number:
  398. .. code-block:: python
  399. class MyParamActionEnv(gym.Env):
  400. def __init__(self, max_avail_actions):
  401. self.action_space = Discrete(max_avail_actions)
  402. self.observation_space = Dict({
  403. "action_mask": Box(0, 1, shape=(max_avail_actions, )),
  404. "avail_actions": Box(-1, 1, shape=(max_avail_actions, action_embedding_sz)),
  405. "real_obs": ...,
  406. })
  407. 2. A custom model can be defined that can interpret the ``action_mask`` and ``avail_actions`` portions of the observation. Here the model computes the action logits via the dot product of some network output and each action embedding. Invalid actions can be masked out of the softmax by scaling the probability to zero:
  408. .. code-block:: python
  409. class ParametricActionsModel(TFModelV2):
  410. def __init__(self,
  411. obs_space,
  412. action_space,
  413. num_outputs,
  414. model_config,
  415. name,
  416. true_obs_shape=(4,),
  417. action_embed_size=2):
  418. super(ParametricActionsModel, self).__init__(
  419. obs_space, action_space, num_outputs, model_config, name)
  420. self.action_embed_model = FullyConnectedNetwork(...)
  421. def forward(self, input_dict, state, seq_lens):
  422. # Extract the available actions tensor from the observation.
  423. avail_actions = input_dict["obs"]["avail_actions"]
  424. action_mask = input_dict["obs"]["action_mask"]
  425. # Compute the predicted action embedding
  426. action_embed, _ = self.action_embed_model({
  427. "obs": input_dict["obs"]["cart"]
  428. })
  429. # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
  430. # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
  431. intent_vector = tf.expand_dims(action_embed, 1)
  432. # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
  433. action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2)
  434. # Mask out invalid actions (use tf.float32.min for stability)
  435. inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
  436. return action_logits + inf_mask, state
  437. Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_actions_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_actions_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``model.vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `algorithm overview <rllib-algorithms.html#available-algorithms-overview>`__.
  438. Autoregressive Action Distributions
  439. -----------------------------------
  440. In an action space with multiple components (e.g., ``Tuple(a1, a2)``), you might want ``a2`` to be conditioned on the sampled value of ``a1``, i.e., ``a2_sampled ~ P(a2 | a1_sampled, obs)``. Normally, ``a1`` and ``a2`` would be sampled independently, reducing the expressivity of the policy.
  441. To do this, you need both a custom model that implements the autoregressive pattern, and a custom action distribution class that leverages that model. The `autoregressive_action_dist.py <https://github.com/ray-project/ray/blob/master/rllib/examples/autoregressive_action_dist.py>`__ example shows how this can be implemented for a simple binary action space. For a more complex space, a more efficient architecture such as a `MADE <https://arxiv.org/abs/1502.03509>`__ is recommended. Note that sampling a `N-part` action requires `N` forward passes through the model, however computing the log probability of an action can be done in one pass:
  442. .. code-block:: python
  443. class BinaryAutoregressiveOutput(ActionDistribution):
  444. """Action distribution P(a1, a2) = P(a1) * P(a2 | a1)"""
  445. @staticmethod
  446. def required_model_output_shape(self, model_config):
  447. return 16 # controls model output feature vector size
  448. def sample(self):
  449. # first, sample a1
  450. a1_dist = self._a1_distribution()
  451. a1 = a1_dist.sample()
  452. # sample a2 conditioned on a1
  453. a2_dist = self._a2_distribution(a1)
  454. a2 = a2_dist.sample()
  455. # return the action tuple
  456. return TupleActions([a1, a2])
  457. def logp(self, actions):
  458. a1, a2 = actions[:, 0], actions[:, 1]
  459. a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
  460. a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec])
  461. return (Categorical(a1_logits, None).logp(a1) + Categorical(
  462. a2_logits, None).logp(a2))
  463. def _a1_distribution(self):
  464. BATCH = tf.shape(self.inputs)[0]
  465. a1_logits, _ = self.model.action_model(
  466. [self.inputs, tf.zeros((BATCH, 1))])
  467. a1_dist = Categorical(a1_logits, None)
  468. return a1_dist
  469. def _a2_distribution(self, a1):
  470. a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
  471. _, a2_logits = self.model.action_model([self.inputs, a1_vec])
  472. a2_dist = Categorical(a2_logits, None)
  473. return a2_dist
  474. class AutoregressiveActionsModel(TFModelV2):
  475. """Implements the `.action_model` branch required above."""
  476. def __init__(self, obs_space, action_space, num_outputs, model_config,
  477. name):
  478. super(AutoregressiveActionsModel, self).__init__(
  479. obs_space, action_space, num_outputs, model_config, name)
  480. if action_space != Tuple([Discrete(2), Discrete(2)]):
  481. raise ValueError(
  482. "This model only supports the [2, 2] action space")
  483. # Inputs
  484. obs_input = tf.keras.layers.Input(
  485. shape=obs_space.shape, name="obs_input")
  486. a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input")
  487. ctx_input = tf.keras.layers.Input(
  488. shape=(num_outputs, ), name="ctx_input")
  489. # Output of the model (normally 'logits', but for an autoregressive
  490. # dist this is more like a context/feature layer encoding the obs)
  491. context = tf.keras.layers.Dense(
  492. num_outputs,
  493. name="hidden",
  494. activation=tf.nn.tanh,
  495. kernel_initializer=normc_initializer(1.0))(obs_input)
  496. # P(a1 | obs)
  497. a1_logits = tf.keras.layers.Dense(
  498. 2,
  499. name="a1_logits",
  500. activation=None,
  501. kernel_initializer=normc_initializer(0.01))(ctx_input)
  502. # P(a2 | a1)
  503. # --note: typically you'd want to implement P(a2 | a1, obs) as follows:
  504. # a2_context = tf.keras.layers.Concatenate(axis=1)(
  505. # [ctx_input, a1_input])
  506. a2_context = a1_input
  507. a2_hidden = tf.keras.layers.Dense(
  508. 16,
  509. name="a2_hidden",
  510. activation=tf.nn.tanh,
  511. kernel_initializer=normc_initializer(1.0))(a2_context)
  512. a2_logits = tf.keras.layers.Dense(
  513. 2,
  514. name="a2_logits",
  515. activation=None,
  516. kernel_initializer=normc_initializer(0.01))(a2_hidden)
  517. # Base layers
  518. self.base_model = tf.keras.Model(obs_input, context)
  519. self.register_variables(self.base_model.variables)
  520. self.base_model.summary()
  521. # Autoregressive action sampler
  522. self.action_model = tf.keras.Model([ctx_input, a1_input],
  523. [a1_logits, a2_logits])
  524. self.action_model.summary()
  525. self.register_variables(self.action_model.variables)
  526. .. note::
  527. Not all algorithms support autoregressive action distributions; see the `algorithm overview table <rllib-algorithms.html#available-algorithms-overview>`__ for more information.