rllib-rlmodule.rst 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. .. include:: /_includes/rllib/announcement.rst
  2. .. include:: /_includes/rllib/we_are_hiring.rst
  3. .. |tensorflow| image:: images/tensorflow.png
  4. :class: inline-figure
  5. :width: 16
  6. .. |pytorch| image:: images/pytorch.png
  7. :class: inline-figure
  8. :width: 16
  9. RL Modules (Alpha)
  10. ==================
  11. .. note::
  12. This is an experimental module that serves as a general replacement for ModelV2, and is subject to change. It will eventually match the functionality of the previous stack. If you only use high-level RLlib APIs such as :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` you should not experience siginficant changes, except for a few new parameters to the configuration object. If you've used custom models or policies before, you'll need to migrate them to the new modules. Check the Migration guide for more information.
  13. The table below shows the list of migrated algorithms and their current supported features, which will be updated as we progress.
  14. .. list-table::
  15. :header-rows: 1
  16. :widths: 20 20 20 20 20 20
  17. * - Algorithm
  18. - Independent MARL
  19. - Fully-connected
  20. - Image inputs (CNN)
  21. - RNN support (LSTM)
  22. - Complex observations (ComplexNet)
  23. * - **PPO**
  24. - |pytorch| |tensorflow|
  25. - |pytorch| |tensorflow|
  26. - |pytorch|
  27. -
  28. - |pytorch|
  29. * - **Impala**
  30. - |pytorch| |tensorflow|
  31. - |pytorch| |tensorflow|
  32. - |pytorch|
  33. -
  34. - |pytorch|
  35. * - **APPO**
  36. - |tensorflow|
  37. - |tensorflow|
  38. -
  39. -
  40. -
  41. RL Module is a neural network container that implements three public methods: :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_train`, :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration`, and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference`. Each method corresponds to a distinct reinforcement learning phase.
  42. :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` handles acting and data collection, balancing exploration and exploitation. On the other hand, the :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference` serves the learned model during evaluation, often being less stochastic.
  43. :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_train` manages the training phase, handling calculations exclusive to computing losses, such as learning Q values in a DQN model.
  44. Enabling RL Modules in the Configuration
  45. ----------------------------------------
  46. Enable RL Modules by setting the ``_enable_rl_module_api`` flag to ``True`` in the configuration object.
  47. .. literalinclude:: doc_code/rlmodule_guide.py
  48. :language: python
  49. :start-after: __enabling-rlmodules-in-configs-begin__
  50. :end-before: __enabling-rlmodules-in-configs-end__
  51. Constructing RL Modules
  52. -----------------------
  53. The RLModule API provides a unified way to define custom reinforcement learning models in RLlib. This API enables you to design and implement your own models to suit specific needs.
  54. To maintain consistency and usability, RLlib offers a standardized approach for defining module objects for both single-agent and multi-agent reinforcement learning environments. This is achieved through the :py:class:`~ray.rllib.core.rl_module.rl_module.SingleAgentRLModuleSpec` and :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModuleSpec` classes. The built-in RLModules in RLlib follow this consistent design pattern, making it easier for you to understand and utilize these modules.
  55. .. tab-set::
  56. .. tab-item:: Single Agent
  57. .. literalinclude:: doc_code/rlmodule_guide.py
  58. :language: python
  59. :start-after: __constructing-rlmodules-sa-begin__
  60. :end-before: __constructing-rlmodules-sa-end__
  61. .. tab-item:: Multi Agent
  62. .. literalinclude:: doc_code/rlmodule_guide.py
  63. :language: python
  64. :start-after: __constructing-rlmodules-ma-begin__
  65. :end-before: __constructing-rlmodules-ma-end__
  66. You can pass RL Module specs to the algorithm configuration to be used by the algorithm.
  67. .. tab-set::
  68. .. tab-item:: Single Agent
  69. .. literalinclude:: doc_code/rlmodule_guide.py
  70. :language: python
  71. :start-after: __pass-specs-to-configs-sa-begin__
  72. :end-before: __pass-specs-to-configs-sa-end__
  73. .. note::
  74. For passing RL Module specs, all fields do not have to be filled as they are filled based on the described environment or other algorithm configuration parameters (i.e. ,``observation_space``, ``action_space``, ``model_config_dict`` are not required fields when passing a custom RL Module spec to the algorithm config.)
  75. .. tab-item:: Multi Agent
  76. .. literalinclude:: doc_code/rlmodule_guide.py
  77. :language: python
  78. :start-after: __pass-specs-to-configs-ma-begin__
  79. :end-before: __pass-specs-to-configs-ma-end__
  80. Writing Custom Single Agent RL Modules
  81. --------------------------------------
  82. For single-agent algorithms (e.g., PPO, DQN) or independent multi-agent algorithms (e.g., PPO-MultiAgent), use :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`. For more advanced multi-agent use cases with a shared communication between agents, extend the :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule` class.
  83. RLlib treats single-agent modules as a special case of :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule` with only one module. Create the multi-agent representation of all RLModules by calling :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.as_multi_agent`. For example:
  84. .. literalinclude:: doc_code/rlmodule_guide.py
  85. :language: python
  86. :start-after: __convert-sa-to-ma-begin__
  87. :end-before: __convert-sa-to-ma-end__
  88. RLlib implements the following abstract framework specific base classes:
  89. - :class:`TorchRLModule <ray.rllib.core.rl_module.torch_rl_module.TorchRLModule>`: For PyTorch-based RL Modules.
  90. - :class:`TfRLModule <ray.rllib.core.rl_module.tf.tf_rl_module.TfRLModule>`: For TensorFlow-based RL Modules.
  91. The minimum requirement is for sub-classes of :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` is to implement the following methods:
  92. - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule._forward_train`: Forward pass for training.
  93. - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule._forward_inference`: Forward pass for inference.
  94. - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule._forward_exploration`: Forward pass for exploration.
  95. Also the class's constrcutor requires a dataclass config object called `~ray.rllib.core.rl_module.rl_module.RLModuleConfig` which contains the following fields:
  96. - :py:attr:`~ray.rllib.core.rl_module.rl_module.RLModuleConfig.observation_space`: The observation space of the environment (either processed or raw).
  97. - :py:attr:`~ray.rllib.core.rl_module.rl_module.RLModuleConfig.action_space`: The action space of the environment.
  98. - :py:attr:`~ray.rllib.core.rl_module.rl_module.RLModuleConfig.model_config_dict`: The model config dictionary of the algorithm. Model hyper-parameters such as number of layers, type of activation, etc. are defined here.
  99. - :py:attr:`~ray.rllib.core.rl_module.rl_module.RLModuleConfig.catalog_class`: The :py:class:`~ray.rllib.core.models.catalog.Catalog` object of the algorithm.
  100. When writing RL Modules, you need to use these fields to construct your model.
  101. .. tab-set::
  102. .. tab-item:: Single Agent (torch)
  103. .. literalinclude:: doc_code/rlmodule_guide.py
  104. :language: python
  105. :start-after: __write-custom-sa-rlmodule-torch-begin__
  106. :end-before: __write-custom-sa-rlmodule-torch-end__
  107. .. tab-item:: Single Agent (tensorflow)
  108. .. literalinclude:: doc_code/rlmodule_guide.py
  109. :language: python
  110. :start-after: __write-custom-sa-rlmodule-tf-begin__
  111. :end-before: __write-custom-sa-rlmodule-tf-end__
  112. In :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` you can enforce the checking for the existence of certain input or output keys in the data that is communicated into and out of RL Modules. This serves multiple purposes:
  113. - For the I/O requirement of each method to be self-documenting.
  114. - For failures to happen quickly. If users extend the modules and implement something that does not match the assumptions of the I/O specs, the check reports missing keys and their expected format. For example, RLModule should always have an ``obs`` key in the input batch and an ``action_dist`` key in the output.
  115. .. tab-set::
  116. .. tab-item:: Single Level Keys
  117. .. literalinclude:: doc_code/rlmodule_guide.py
  118. :language: python
  119. :start-after: __extend-spec-checking-single-level-begin__
  120. :end-before: __extend-spec-checking-single-level-end__
  121. .. tab-item:: Nested Keys
  122. .. literalinclude:: doc_code/rlmodule_guide.py
  123. :language: python
  124. :start-after: __extend-spec-checking-nested-begin__
  125. :end-before: __extend-spec-checking-nested-end__
  126. .. tab-item:: TensorShape Spec
  127. .. literalinclude:: doc_code/rlmodule_guide.py
  128. :language: python
  129. :start-after: __extend-spec-checking-torch-specs-begin__
  130. :end-before: __extend-spec-checking-torch-specs-end__
  131. .. tab-item:: Type Spec
  132. .. literalinclude:: doc_code/rlmodule_guide.py
  133. :language: python
  134. :start-after: __extend-spec-checking-type-specs-begin__
  135. :end-before: __extend-spec-checking-type-specs-end__
  136. :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` has two methods for each forward method, totaling 6 methods that can be override to describe the specs of the input and output of each method:
  137. - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_inference`
  138. - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_inference`
  139. - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_exploration`
  140. - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_exploration`
  141. - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_train`
  142. - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_train`
  143. To learn more, see the `SpecType` documentation.
  144. Writing Custom Multi-Agent RL Modules (Advanced)
  145. ------------------------------------------------
  146. For multi-agent modules, RLlib implements :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule`, which is a dictionary of :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` objects, one for each policy, and possibly some shared modules. The base-class implementation works for most of use cases that need to define independent neural networks for sub-groups of agents. For more complex, multi-agent use cases, where the agents share some part of their neural network, you should inherit from this class and override the default implementation.
  147. The :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule` offers an API for constructing custom models tailored to specific needs. The key method for this customization is :py:meth:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule`.build.
  148. The following example creates a custom multi-agent RL module with underlying modules. The modules share an encoder, which gets applied to the global part of the observations space. The local part passes through a separate encoder, specific to each policy.
  149. .. tab-set::
  150. .. tab-item:: Multi agent with shared encoder (Torch)
  151. .. literalinclude:: doc_code/rlmodule_guide.py
  152. :language: python
  153. :start-after: __write-custom-marlmodule-shared-enc-begin__
  154. :end-before: __write-custom-marlmodule-shared-enc-end__
  155. To construct this custom multi-agent RL module, pass the class to the :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModuleSpec` constructor. Also, pass the :py:class:`~ray.rllib.core.rl_module.rl_module.SingleAgentRLModuleSpec` for each agent because RLlib requires the observation, action spaces, and model hyper-parameters for each agent.
  156. .. literalinclude:: doc_code/rlmodule_guide.py
  157. :language: python
  158. :start-after: __pass-custom-marlmodule-shared-enc-begin__
  159. :end-before: __pass-custom-marlmodule-shared-enc-end__
  160. Extending Existing RLlib RL Modules
  161. -----------------------------------
  162. RLlib provides a number of RL Modules for different frameworks (e.g., PyTorch, TensorFlow, etc.). Extend these modules by inheriting from them and overriding the methods you need to customize. For example, extend :py:class:`~ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` and augment it with your own customization. Then pass the new customized class into the algorithm configuration.
  163. There are two possible ways to extend existing RL Modules:
  164. .. tab-set::
  165. .. tab-item:: Inheriting existing RL Modules
  166. One way to extend existing RL Modules is to inherit from them and override the methods you need to customize. For example, extend :py:class:`~ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` and augment it with your own customization. Then pass the new customized class into the algorithm configuration to use the PPO algorithm to optimize your custom RL Module.
  167. .. code-block:: python
  168. class MyPPORLModule(PPORLModule):
  169. def __init__(self, config: RLModuleConfig):
  170. super().__init__(config)
  171. ...
  172. # Pass in the custom RL Module class to the spec
  173. algo_config = algo_config.rl_module(
  174. rl_module_spec=SingleAgentRLModuleSpec(module_class=MyPPORLModule)
  175. )
  176. .. tab-item:: Extending RL Module Catalog
  177. Another way to customize your module is by extending its :py:class:`~ray.rllib.core.models.catalog.Catalog`. The :py:class:`~ray.rllib.core.models.catalog.Catalog` is a component that defines the default architecture and behavior of a model based on factors such as ``observation_space``, ``action_space``, etc. To modify sub-components of an existing RL Module, extend the corresponding Catalog class.
  178. For instance, to adapt the existing ``PPORLModule`` for a custom graph observation space not supported by RLlib out-of-the-box, extend the :py:class:`~ray.rllib.core.models.catalog.Catalog` class used to create the ``PPORLModule`` and override the method responsible for returning the encoder component to ensure that your custom encoder replaces the default one initially provided by RLlib. For more information on the :py:class:`~ray.rllib.core.models.catalog.Catalog` class, refer to the `Catalog user guide <rllib-catalogs.html>`__.
  179. .. code-block:: python
  180. class MyAwesomeCatalog(PPOCatalog):
  181. def get_actor_critic_encoder_config():
  182. # create your awesome graph encoder here and return it
  183. pass
  184. # Pass in the custom catalog class to the spec
  185. algo_config = algo_config.rl_module(
  186. rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyAwesomeCatalog)
  187. )
  188. Migrating from Custom Policies and Models to RL Modules
  189. -------------------------------------------------------
  190. This document is for those who have implemented custom policies and models in RLlib and want to migrate to the new `~ray.rllib.core.rl_module.rl_module.RLModule` API. If you have implemented custom policies that extended the `~ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2` or `~ray.rllib.policy.torch_policy_v2.TorchPolicyV2` classes, you likely did so that you could either modify the behavior of constructing models and distributions (via overriding `~ray.rllib.policy.torch_policy_v2.TorchPolicyV2.make_model`, `~ray.rllib.policy.torch_policy_v2.TorchPolicyV2.make_model_and_action_dist`), control the action sampling logic (via overriding `~ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.action_distribution_fn` or `~ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.action_sampler_fn`), or control the logic for infernce (via overriding `~ray.rllib.policy.policy.Policy.compute_actions_from_input_dict`, `~ray.rllib.policy.policy.Policy.compute_actions`, or `~ray.rllib.policy.policy.Policy.compute_log_likelihoods`). These APIs were built with `ray.rllib.models.modelv2.ModelV2` models in mind to enable you to customize the behavior of those functions. However `~ray.rllib.core.rl_module.rl_module.RLModule` is a more general abstraction that will reduce the amount of functions that you need to override.
  191. In the new `~ray.rllib.core.rl_module.rl_module.RLModule` API the construction of the models and the action distribution class that should be used are best defined in the constructor. That RL Module is constructed automatically if users follow the instructions outlined in the sections `Enabling RL Modules in the Configuration`_ and `Constructing RL Modules`_. `~ray.rllib.policy.policy.Policy.compute_actions` and `~ray.rllib.policy.policy.Policy.compute_actions_from_input_dict` can still be used for sampling actions for inference or exploration by using the ``explore=True|False`` parameter. If called with ``explore=True`` these functions will invoke `~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and if ``explore=False`` then they will call `~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference`.
  192. What your customization could have looked like before:
  193. .. tab-set::
  194. .. tab-item:: ModelV2
  195. .. code-block:: python
  196. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  197. from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
  198. class MyCustomModel(TorchModelV2):
  199. """Code for your previous custom model"""
  200. ...
  201. class CustomPolicy(TorchPolicyV2):
  202. @DeveloperAPI
  203. @OverrideToImplementCustomLogic
  204. def make_model(self) -> ModelV2:
  205. """Create model.
  206. Note: only one of make_model or make_model_and_action_dist
  207. can be overridden.
  208. Returns:
  209. ModelV2 model.
  210. """
  211. return MyCustomModel(...)
  212. .. tab-item:: ModelV2 + Distribution
  213. .. code-block:: python
  214. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  215. from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
  216. class MyCustomModel(TorchModelV2):
  217. """Code for your previous custom model"""
  218. ...
  219. class CustomPolicy(TorchPolicyV2):
  220. @DeveloperAPI
  221. @OverrideToImplementCustomLogic
  222. def make_model_and_action_dist(self):
  223. """Create model and action distribution function.
  224. Returns:
  225. ModelV2 model.
  226. ActionDistribution class.
  227. """
  228. my_model = MyCustomModel(...) # construct some ModelV2 instance here
  229. dist_class = ... # Action distribution cls
  230. return my_model, dist_class
  231. .. tab-item:: Sampler functions
  232. .. code-block:: python
  233. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  234. from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
  235. class CustomPolicy(TorchPolicyV2):
  236. @DeveloperAPI
  237. @OverrideToImplementCustomLogic
  238. def action_sampler_fn(
  239. self,
  240. model: ModelV2,
  241. *,
  242. obs_batch: TensorType,
  243. state_batches: TensorType,
  244. **kwargs,
  245. ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
  246. """Custom function for sampling new actions given policy.
  247. Args:
  248. model: Underlying model.
  249. obs_batch: Observation tensor batch.
  250. state_batches: Action sampling state batch.
  251. Returns:
  252. Sampled action
  253. Log-likelihood
  254. Action distribution inputs
  255. Updated state
  256. """
  257. return None, None, None, None
  258. @DeveloperAPI
  259. @OverrideToImplementCustomLogic
  260. def action_distribution_fn(
  261. self,
  262. model: ModelV2,
  263. *,
  264. obs_batch: TensorType,
  265. state_batches: TensorType,
  266. **kwargs,
  267. ) -> Tuple[TensorType, type, List[TensorType]]:
  268. """Action distribution function for this Policy.
  269. Args:
  270. model: Underlying model.
  271. obs_batch: Observation tensor batch.
  272. state_batches: Action sampling state batch.
  273. Returns:
  274. Distribution input.
  275. ActionDistribution class.
  276. State outs.
  277. """
  278. return None, None, None
  279. All of the ``Policy.compute_***`` functions expect that `~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and `~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference` return a dictionary that contains the key "action_dist_inputs", whose value are the parameters (inputs) of a ``ray.rllib.models.distributions.Distribution`` class. Commonly used distribution implementations can be found under ``ray.rllib.models.tf.tf_distributions`` for tensorflow and ``ray.rllib.models.torch.torch_distributions`` for torch. You can choose to return determinstic actions, by creating a determinstic distribution instance. See `Writing Custom Single Agent RL Modules`_ for more details on how to implement your own custom RL Module.
  280. .. tab-set::
  281. .. tab-item:: The Equivalent RL Module
  282. .. code-block:: python
  283. """
  284. No need to override any policy functions. Simply instead implement any custom logic in your custom RL Module
  285. """
  286. from ray.rllib.models.torch.torch_distributions import YOUR_DIST_CLASS
  287. class MyRLModule(TorchRLModule):
  288. def __init__(self, config: RLConfig):
  289. # construct any custom networks here using config
  290. # specify an action distribution class here
  291. ...
  292. def _forward_inference(self, batch):
  293. ...
  294. def _forward_exploration(self, batch):
  295. ...
  296. Notable TODOs
  297. -------------
  298. - [] Add support for RNNs.
  299. - [] Checkpointing.
  300. - [] End to end example for custom RL Modules extending PPORLModule (e.g. LLM)