key-concepts.rst 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. .. include:: /_includes/rllib/announcement.rst
  2. .. include:: /_includes/rllib/we_are_hiring.rst
  3. .. TODO: We need algorithms, environments, policies, models here. Likely in that order.
  4. Execution plans are not a "core" concept for users. Sample batches should probably also be left out.
  5. .. _rllib-core-concepts:
  6. Key Concepts
  7. ============
  8. On this page, we'll cover the key concepts to help you understand how RLlib works and
  9. how to use it. In RLlib, you use ``Algorithm``'s to learn how to solve problem ``environments``.
  10. The algorithms use ``policies`` to select actions. Given a policy,
  11. ``rollouts`` throughout an ``environment`` produce
  12. ``sample batches`` (or ``trajectories``) of experiences.
  13. You can also customize the ``training_step``\s of your RL experiments.
  14. .. _environments:
  15. Environments
  16. ------------
  17. Solving a problem in RL begins with an **environment**. In the simplest definition of RL:
  18. An **agent** interacts with an **environment** and receives a reward.
  19. An environment in RL is the agent's world, it is a simulation of the problem to be solved.
  20. .. image:: images/env_key_concept1.png
  21. An RLlib environment consists of:
  22. 1. all possible actions (**action space**)
  23. 2. a complete description of the environment, nothing hidden (**state space**)
  24. 3. an observation by the agent of certain parts of the state (**observation space**)
  25. 4. **reward**, which is the only feedback the agent receives per action.
  26. The model that tries to maximize the expected sum over all future rewards is called a **policy**. The policy is a function mapping the environment's observations to an action to take, usually written **π** (s(t)) -> a(t). Below is a diagram of the RL iterative learning process.
  27. .. image:: images/env_key_concept2.png
  28. The RL simulation feedback loop repeatedly collects data, for one (single-agent case) or multiple (multi-agent case) policies, trains the policies on these collected data, and makes sure the policies' weights are kept in sync. Thereby, the collected environment data contains observations, taken actions, received rewards and so-called **done** flags, indicating the boundaries of different episodes the agents play through in the simulation.
  29. The simulation iterations of action -> reward -> next state -> train -> repeat, until the end state, is called an **episode**, or in RLlib, a **rollout**.
  30. .. _algorithms:
  31. Algorithms
  32. ----------
  33. Algorithms bring all RLlib components together, making learning of different tasks
  34. accessible via RLlib's Python API and its command line interface (CLI).
  35. Each ``Algorithm`` class is managed by its respective ``AlgorithmConfig``, for example to
  36. configure a ``PPO`` instance, you should use the ``PPOConfig`` class.
  37. An ``Algorithm`` sets up its rollout workers and optimizers, and collects training metrics.
  38. ``Algorithms`` also implement the :ref:`Tune Trainable API <tune-60-seconds>` for
  39. easy experiment management.
  40. You have three ways to interact with an algorithm. You can use the basic Python API or the command line to train it, or you
  41. can use Ray Tune to tune hyperparameters of your reinforcement learning algorithm.
  42. The following example shows three equivalent ways of interacting with ``PPO``,
  43. which implements the proximal policy optimization algorithm in RLlib.
  44. .. tab-set::
  45. .. tab-item:: Basic RLlib Algorithm
  46. .. code-block:: python
  47. # Configure.
  48. from ray.rllib.algorithms.ppo import PPOConfig
  49. config = PPOConfig().environment(env="CartPole-v1").training(train_batch_size=4000)
  50. # Build.
  51. algo = config.build()
  52. # Train.
  53. while True:
  54. print(algo.train())
  55. .. tab-item:: RLlib Algorithms and Tune
  56. .. code-block:: python
  57. from ray import tune
  58. # Configure.
  59. from ray.rllib.algorithms.ppo import PPOConfig
  60. config = PPOConfig().environment(env="CartPole-v1").training(train_batch_size=4000)
  61. # Train via Ray Tune.
  62. tune.run("PPO", config=config)
  63. .. tab-item:: RLlib Command Line
  64. .. code-block:: bash
  65. rllib train --run=PPO --env=CartPole-v1 --config='{"train_batch_size": 4000}'
  66. RLlib `Algorithm classes <rllib-concepts.html#algorithms>`__ coordinate the distributed workflow of running rollouts and optimizing policies.
  67. Algorithm classes leverage parallel iterators to implement the desired computation pattern.
  68. The following figure shows *synchronous sampling*, the simplest of `these patterns <rllib-algorithms.html>`__:
  69. .. figure:: images/a2c-arch.svg
  70. Synchronous Sampling (e.g., A2C, PG, PPO)
  71. RLlib uses `Ray actors <actors.html>`__ to scale training from a single core to many thousands of cores in a cluster.
  72. You can `configure the parallelism <rllib-training.html#specifying-resources>`__ used for training by changing the ``num_workers`` parameter.
  73. Check out our `scaling guide <rllib-training.html#scaling-guide>`__ for more details here.
  74. Policies
  75. --------
  76. `Policies <rllib-concepts.html#policies>`__ are a core concept in RLlib. In a nutshell, policies are
  77. Python classes that define how an agent acts in an environment.
  78. `Rollout workers <rllib-concepts.html#policy-evaluation>`__ query the policy to determine agent actions.
  79. In a `Farama-Foundation Gymnasium <rllib-env.html#gymnasium>`__ environment, there is a single agent and policy.
  80. In `vector envs <rllib-env.html#vectorized>`__, policy inference is for multiple agents at once,
  81. and in `multi-agent <rllib-env.html#multi-agent-and-hierarchical>`__, there may be multiple policies,
  82. each controlling one or more agents:
  83. .. image:: images/multi-flat.svg
  84. Policies can be implemented using `any framework <https://github.com/ray-project/ray/blob/master/rllib/policy/policy.py>`__.
  85. However, for TensorFlow and PyTorch, RLlib has
  86. `build_tf_policy <rllib-concepts.html#building-policies-in-tensorflow>`__ and
  87. `build_torch_policy <rllib-concepts.html#building-policies-in-pytorch>`__ helper functions that let you
  88. define a trainable policy with a functional-style API, for example:
  89. .. TODO: test this code snippet
  90. .. code-block:: python
  91. def policy_gradient_loss(policy, model, dist_class, train_batch):
  92. logits, _ = model.from_batch(train_batch)
  93. action_dist = dist_class(logits, model)
  94. return -tf.reduce_mean(
  95. action_dist.logp(train_batch["actions"]) * train_batch["rewards"])
  96. # <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
  97. MyTFPolicy = build_tf_policy(
  98. name="MyTFPolicy",
  99. loss_fn=policy_gradient_loss)
  100. Policy Evaluation
  101. -----------------
  102. Given an environment and policy, policy evaluation produces `batches <https://github.com/ray-project/ray/blob/master/rllib/policy/sample_batch.py>`__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `RolloutWorker <https://github.com/ray-project/ray/blob/master/rllib/evaluation/rollout_worker.py>`__ class that manages all of this, and this class is used in most RLlib algorithms.
  103. You can use rollout workers standalone to produce batches of experiences. This can be done by calling ``worker.sample()`` on a worker instance, or ``worker.sample.remote()`` in parallel on worker instances created as Ray actors (see `WorkerSet <https://github.com/ray-project/ray/blob/master/rllib/evaluation/worker_set.py>`__).
  104. Here is an example of creating a set of rollout workers and using them gather experiences in parallel. The trajectories are concatenated, the policy learns on the trajectory batch, and then we broadcast the policy weights to the workers for the next round of rollouts:
  105. .. code-block:: python
  106. # Setup policy and rollout workers.
  107. env = gym.make("CartPole-v1")
  108. policy = CustomPolicy(env.observation_space, env.action_space, {})
  109. workers = WorkerSet(
  110. policy_class=CustomPolicy,
  111. env_creator=lambda c: gym.make("CartPole-v1"),
  112. num_workers=10)
  113. while True:
  114. # Gather a batch of samples.
  115. T1 = SampleBatch.concat_samples(
  116. ray.get([w.sample.remote() for w in workers.remote_workers()]))
  117. # Improve the policy using the T1 batch.
  118. policy.learn_on_batch(T1)
  119. # The local worker acts as a "parameter server" here.
  120. # We put the weights of its `policy` into the Ray object store once (`ray.put`)...
  121. weights = ray.put({"default_policy": policy.get_weights()})
  122. for w in workers.remote_workers():
  123. # ... so that we can broacast these weights to all rollout-workers once.
  124. w.set_weights.remote(weights)
  125. Sample Batches
  126. --------------
  127. Whether running in a single process or a `large cluster <rllib-training.html#specifying-resources>`__,
  128. all data in RLlib is interchanged in the form of `sample batches <https://github.com/ray-project/ray/blob/master/rllib/policy/sample_batch.py>`__.
  129. Sample batches encode one or more fragments of a trajectory.
  130. Typically, RLlib collects batches of size ``rollout_fragment_length`` from rollout workers, and concatenates one or
  131. more of these batches into a batch of size ``train_batch_size`` that is the input to SGD.
  132. A typical sample batch looks something like the following when summarized.
  133. Since all values are kept in arrays, this allows for efficient encoding and transmission across the network:
  134. .. code-block:: python
  135. sample_batch = { 'action_logp': np.ndarray((200,), dtype=float32, min=-0.701, max=-0.685, mean=-0.694),
  136. 'actions': np.ndarray((200,), dtype=int64, min=0.0, max=1.0, mean=0.495),
  137. 'dones': np.ndarray((200,), dtype=bool, min=0.0, max=1.0, mean=0.055),
  138. 'infos': np.ndarray((200,), dtype=object, head={}),
  139. 'new_obs': np.ndarray((200, 4), dtype=float32, min=-2.46, max=2.259, mean=0.018),
  140. 'obs': np.ndarray((200, 4), dtype=float32, min=-2.46, max=2.259, mean=0.016),
  141. 'rewards': np.ndarray((200,), dtype=float32, min=1.0, max=1.0, mean=1.0),
  142. 't': np.ndarray((200,), dtype=int64, min=0.0, max=34.0, mean=9.14)
  143. }
  144. In `multi-agent mode <rllib-concepts.html#policies-in-multi-agent>`__,
  145. sample batches are collected separately for each individual policy.
  146. These batches are wrapped up together in a ``MultiAgentBatch``,
  147. serving as a container for the individual agents' sample batches.
  148. Training Step Method (``Algorithm.training_step()``)
  149. ----------------------------------------------------
  150. .. TODO all training_step snippets below must be tested
  151. .. note::
  152. It's important to have a good understanding of the basic :ref:`ray core methods <core-walkthrough>` before reading this section.
  153. Furthermore, we utilize concepts such as the ``SampleBatch`` (and its more advanced sibling: the ``MultiAgentBatch``),
  154. ``RolloutWorker``, and ``Algorithm``, which can be read about on this page
  155. and the :ref:`rollout worker reference docs <rolloutworker-reference-docs>`.
  156. Finally, developers who are looking to implement custom algorithms should familiarize themselves with the :ref:`Policy <rllib-policy-walkthrough>` and
  157. :ref:`Model <rllib-models-walkthrough>` classes.
  158. What is it?
  159. ~~~~~~~~~~~
  160. The ``training_step()`` method of the ``Algorithm`` class defines the repeatable
  161. execution logic that sits at the core of any algorithm. Think of it as the python implementation
  162. of an algorithm's pseudocode you can find in research papers.
  163. You can use ``training_step()`` to express how you want to
  164. coordinate the collection of samples from the environment(s), the movement of this data to other
  165. parts of the algorithm, and the updates and management of your policy's weights
  166. across the different distributed components.
  167. **In short, a developer will need to override/modify the ``training_step`` method if they want to
  168. make custom changes to an existing algorithm, write their own algo from scratch, or implement some algorithm from a paper.**
  169. When is ``training_step()`` invoked?
  170. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  171. The ``Algorithm``'s ``training_step()`` method is called:
  172. 1. when the ``train()`` method of ``Algorithm`` is called (e.g. "manually" by a user that has constructed an ``Algorithm`` instance).
  173. 2. when an RLlib Algorithm is being run by Ray Tune. ``training_step()`` will be continuously called till the
  174. :ref:`ray tune stop criteria <tune-run-ref>` is met.
  175. Key Subconcepts
  176. ~~~~~~~~~~~~~~~
  177. In the following, using the example of VPG ("vanilla policy gradient"), we will try to illustrate
  178. how to use the ``training_step()`` method to implement this algorithm in RLlib.
  179. The "vanilla policy gradient" algo can be thought of as a sequence of repeating steps, or *dataflow*, of:
  180. 1. Sampling (to collect data from an env)
  181. 2. Updating the Policy (to learn a behavior)
  182. 3. Broadcasting the updated Policy's weights (to make sure all distributed units have the same weights again)
  183. 4. Metrics reporting (returning relevant stats from all the above operations with regards to performance and runtime)
  184. An example implementation of VPG could look like the following:
  185. .. code-block:: python
  186. def training_step(self) -> ResultDict:
  187. # 1. Sampling.
  188. train_batch = synchronous_parallel_sample(
  189. worker_set=self.workers,
  190. max_env_steps=self.config["train_batch_size"]
  191. )
  192. # 2. Updating the Policy.
  193. train_results = train_one_step(self, train_batch)
  194. # 3. Synchronize worker weights.
  195. self.workers.sync_weights()
  196. # 4. Return results.
  197. return train_results
  198. .. note::
  199. Note that the ``training_step`` method is deep learning framework agnostic.
  200. This means that you should not write PyTorch- or TensorFlow specific code inside this module,
  201. allowing for a strict separation of concerns and enabling us to use the same ``training_step()``
  202. method for both TF- and PyTorch versions of your algorithms.
  203. DL framework specific code should only be added to the
  204. :ref:`Policy <rllib-policy-walkthrough>` (e.g. in its loss function(s)) and
  205. :ref:`Model <rllib-models-walkthrough>` (e.g. tf.keras or torch.nn neural network code) classes.
  206. Let's further break down our above ``training_step()`` code.
  207. In the first step, we collect trajectory data from the environment(s):
  208. .. code-block:: python
  209. train_batch = synchronous_parallel_sample(
  210. worker_set=self.workers,
  211. max_env_steps=self.config["train_batch_size"]
  212. )
  213. Here, ``self.workers`` is a set of ``RolloutWorkers`` that are created in the ``Algorithm``'s ``setup()`` method
  214. (prior to calling ``training_step()``).
  215. This ``WorkerSet`` is covered in greater depth on the :ref:`WorkerSet documentation page <workerset-reference-docs>`.
  216. The utility function ``synchronous_parallel_sample`` can be used for parallel sampling in a blocking
  217. fashion across multiple rollout workers (returns once all rollout workers are done sampling).
  218. It returns one final MultiAgentBatch resulting from concatenating n smaller MultiAgentBatches
  219. (exactly one from each remote rollout worker).
  220. The ``train_batch`` is then passed to another utility function: ``train_one_step``.
  221. .. code-block:: python
  222. train_results = train_one_step(self, train_batch)
  223. Methods like ``train_one_step`` and ``multi_gpu_train_one_step`` are used for training our Policy.
  224. Further documentation with examples can be found on the :ref:`train ops documentation page <train-ops-docs>`.
  225. The training updates on the policy are only applied to its version inside ``self.workers.local_worker``.
  226. Note that each WorkerSet has n remote workers and exactly one "local worker" and that each worker (remote and local ones)
  227. holds a copy of the policy.
  228. Now that we updated the local policy (the copy in ``self.workers.local_worker``), we need to make sure
  229. that the copies in all remote workers (``self.workers.remote_workers``) have their weights synchronized
  230. (from the local one):
  231. .. code-block:: python
  232. self.workers.sync_weights()
  233. By calling ``self.workers.sync_weights()``,
  234. weights are broadcasted from the local worker to the remote workers. See :ref:`rollout worker
  235. reference docs <rolloutworker-reference-docs>` for further details.
  236. .. code-block:: python
  237. return train_results
  238. A dictionary is expected to be returned that contains the results of the training update.
  239. It maps keys of type ``str`` to values that are of type ``float`` or to dictionaries of
  240. the same form, allowing for a nested structure.
  241. For example, a results dictionary could map policy_ids to learning and sampling statistics for that policy:
  242. .. code-block:: python
  243. {
  244. 'policy_1': {
  245. 'learner_stats': {'policy_loss': 6.7291455},
  246. 'num_agent_steps_trained': 32
  247. },
  248. 'policy_2': {
  249. 'learner_stats': {'policy_loss': 3.554927},
  250. 'num_agent_steps_trained': 32
  251. },
  252. }
  253. Training Step Method Utilities
  254. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  255. RLlib provides a collection of utilities that abstract away common tasks in RL training.
  256. In particular, if you would like to work with the various ``training_step`` methods or implement your
  257. own, it's recommended to familiarize yourself first with these following concepts here:
  258. `Sample Batch <core-concepts.html#sample-batches>`__:
  259. ``SampleBatch`` and ``MultiAgentBatch`` are the two types that we use for storing trajectory data in RLlib. All of our
  260. RLlib abstractions (policies, replay buffers, etc.) operate on these two types.
  261. :ref:`Rollout Workers <rolloutworker-reference-docs>`:
  262. Rollout workers are an abstraction that wraps a policy (or policies in the case of multi-agent) and an environment.
  263. From a high level, we can use rollout workers to collect experiences from the environment by calling
  264. their ``sample()`` method and we can train their policies by calling their ``learn_on_batch()`` method.
  265. By default, in RLlib, we create a set of workers that can be used for sampling and training.
  266. We create a ``WorkerSet`` object inside of ``setup`` which is called when an RLlib algorithm is created. The ``WorkerSet`` has a ``local_worker``
  267. and ``remote_workers`` if ``num_workers > 0`` in the experiment config. In RLlib we typically use ``local_worker``
  268. for training and ``remote_workers`` for sampling.
  269. :ref:`Train Ops <train-ops-docs>`:
  270. These are methods that improve the policy and update workers. The most basic operator, ``train_one_step``, takes in as
  271. input a batch of experiences and emits a ``ResultDict`` with metrics as output. For training with GPUs, use
  272. ``multi_gpu_train_one_step``. These methods use the ``learn_on_batch`` method of rollout workers to complete the
  273. training update.
  274. :ref:`Replay Buffers <replay-buffer-reference-docs>`:
  275. RLlib provides `a collection <https://github.com/ray-project/ray/tree/master/rllib/utils/replay_buffers>`__ of replay
  276. buffers that can be used for storing and sampling experiences.