rllib-sample-collection.rst 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. .. include:: /_includes/rllib/announcement.rst
  2. .. include:: /_includes/rllib/we_are_hiring.rst
  3. Sample Collections and Trajectory Views
  4. =======================================
  5. The SampleCollector Class is Used to Store and Retrieve Temporary Data
  6. ----------------------------------------------------------------------
  7. RLlib's `RolloutWorkers <https://github.com/ray-project/ray/blob/master/rllib/evaluation/rollout_worker.py>`__,
  8. when running against a live environment, use the ``SamplerInput`` class to interact
  9. with that environment and produce batches of experiences.
  10. The two implemented sub-classes of ``SamplerInput`` are ``SyncSampler`` and ``AsyncSampler``
  11. (residing under the ``RolloutWorker.sampler`` property).
  12. In case the "_use_trajectory_view_api" top-level config key is set to True
  13. (by default since version >=1.1.0), every such sampler object will use the
  14. ``SampleCollector`` API to store and retrieve temporary environment-, model-, and other data
  15. during rollouts (see figure below).
  16. .. Edit figure below at: https://docs.google.com/drawings/d/1ZdNUU3ChwiUeT-DBRxvLAsbEPEqEFWSPZcOyVy3KxVg/edit
  17. .. image:: images/rllib-sample-collection.svg
  18. **Sample collection process implemented by RLlib:**
  19. The Policy's model tells the Sampler and its SampleCollector object, which data to store and
  20. how to present it back to the dependent methods (e.g. `Model.compute_actions()`).
  21. This is done using a dict that maps strings (column names) to `ViewRequirement` objects (details see below).
  22. The exact behavior for a single such rollout and the number of environment transitions therein
  23. are determined by the following ``AlgorithmConfig.rollout(..)`` args:
  24. **batch_mode [truncate_episodes|complete_episodes]**:
  25. *truncated_episodes (default value)*:
  26. Rollouts are performed
  27. over exactly ``rollout_fragment_length`` (see below) number of steps. Thereby, steps are
  28. counted as either environment steps or as individual agent steps (see ``count_steps_as`` below).
  29. It does not matter, whether one or more episodes end within this rollout or whether
  30. the rollout starts in the middle of an already ongoing episode.
  31. *complete_episodes*:
  32. Each rollout always only contains **full** episodes (from beginning to terminal), never any episode fragments. The number of episodes in the rollout is 1 or larger.
  33. The ``rollout_fragment_length`` setting defines the minimum number of
  34. timesteps that will be covered in the rollout.
  35. For example, if ``rollout_fragment_length=100`` and your episodes are always 98 timesteps long, then rollouts will happen over two complete episodes and always be 196 timesteps long: 98 < 100 -> too short, keep rollout going; 98+98 >= 100 -> good, stop rollout after 2 episodes (196 timesteps).
  36. Note that you have to be careful when choosing ``complete_episodes`` as batch_mode: If your environment does not
  37. terminate easily, this setting could lead to enormous batch sizes.
  38. **rollout_fragment_length [int]**:
  39. The exact number of environment- or agent steps to
  40. be performed per rollout, if the ``batch_mode`` setting (see above) is "truncate_episodes".
  41. If ``batch_mode`` is "complete_episodes", ``rollout_fragment_length`` is ignored,
  42. The unit to count fragments in is set via ``multiagent.count_steps_by=[env_steps|agent_steps]``
  43. (within the ``multiagent`` config dict).
  44. .. Edit figure below at: https://docs.google.com/drawings/d/1uRNGImBNq8gv3bBoFX_HernGyeovtCB3wKpZ71c0VE4/edit
  45. .. image:: images/rllib-batch-modes.svg
  46. **Above:** The two supported batch modes in RLlib. For "truncated_episodes",
  47. batches can a) span over more than one episode, b) end in the middle of an episode, and
  48. c) start in the middle of an episode. Also, `Policy.postprocess_trajectory()` is always
  49. called at the end of a rollout-fragment (red lines on right side) as well as at the end
  50. of each episode (arrow heads). This way, RLlib makes sure that the
  51. `Policy.postprocess_trajectory()` method never sees data from more than one episode.
  52. ... as well as ``AlgorithmConfig.multi_agent(count_steps_by=..)``:
  53. **count_steps_by [env_steps|agent_steps]**:
  54. Within the Algorithm's ``multiagent`` config dict, you can set the unit, by which RLlib will count a) rollout fragment lengths as well as b) the size of the final train_batch (see below). The two supported values are:
  55. *env_steps (default)*:
  56. Each call to ``[Env].step()`` is counted as one. It does not
  57. matter, how many individual agents are stepping simultaneously in this very call
  58. (not all existing agents in the environment may step at the same time).
  59. *agent_steps*:
  60. In a multi-agent environment, count each individual agent's step
  61. as one. For example, if N agents are in an environment and all these N agents
  62. always step at the same time, a single env step corresponds to N agent steps.
  63. Note that in the single-agent case, ``env_steps`` and ``agent_steps`` are the same thing.
  64. To trigger a single rollout, RLlib calls ``RolloutWorker.sample()``, which returns
  65. a SampleBatch or MultiAgentBatch object representing all the data collected during that
  66. rollout. These batches are then usually further concatenated (from the ``num_workers``
  67. parallelized RolloutWorkers) to form a final train batch. The size of that train batch is determined
  68. by the ``train_batch_size`` config parameter. Train batches are usually sent to the Policy's
  69. ``learn_on_batch`` method, which handles loss- and gradient calculations, and optimizer stepping.
  70. RLlib's default ``SampleCollector`` class is the ``SimpleListCollector``, which appends single timestep data (e.g. actions)
  71. to lists, then builds SampleBatches from these and sends them to the downstream processing functions.
  72. It thereby tries to avoid collecting duplicate data separately (OBS and NEXT_OBS use the same underlying list).
  73. If you want to implement your own collection logic and data structures, you can sub-class ``SampleCollector``
  74. and specify that new class under the Algorithm's "sample_collector" config key.
  75. Let's now look at how the Policy's Model lets the RolloutWorker and its SampleCollector
  76. know, what data in the ongoing episode/trajectory to use for the different required method calls
  77. during rollouts. These method calls in particular are:
  78. ``Policy.compute_actions_from_input_dict()`` to compute actions to be taken in an episode.
  79. ``Policy.postprocess_trajectory()``, which is called after an episode ends or a rollout hit its
  80. ``rollout_fragment_length`` limit (in ``batch_mode=truncated_episodes``), and ``Policy.learn_on_batch()``,
  81. which is called with a "train_batch" to improve the policy.
  82. Trajectory View API
  83. -------------------
  84. The trajectory view API allows custom models to define what parts of the trajectory they
  85. require in order to execute the forward pass. For example, in the simplest case, a model might
  86. only look at the latest observation. However, an RNN- or attention based model could look
  87. at previous states emitted by the model, concatenate previously seen rewards with the current observation,
  88. or require the entire range of the n most recent observations.
  89. The trajectory view API lets models define these requirements and lets RLlib gather the required
  90. data for the forward pass in an efficient way.
  91. Since the following methods all call into the model class, they are all indirectly using the trajectory view API.
  92. It is important to note that the API is only accessible to the user via the model classes
  93. (see below on how to setup trajectory view requirements for a custom model).
  94. In particular, the methods receiving inputs that depend on a Model's trajectory view rules are:
  95. a) ``Policy.compute_actions_from_input_dict()``
  96. b) ``Policy.postprocess_trajectory()`` and
  97. c) ``Policy.learn_on_batch()`` (and consecutively: the Policy's loss function).
  98. The input data to these methods can stem from either the environment (observations, rewards, and env infos),
  99. the model itself (previously computed actions, internal state outputs, action-probs, etc..)
  100. or the Sampler (e.g. agent index, env ID, episode ID, timestep, etc..).
  101. All data has an associated time axis, which is 0-based, meaning that the first action taken, the
  102. first reward received in an episode, and the first observation (directly after a reset)
  103. all have t=0.
  104. The idea is to allow more flexibility and standardization in how a model defines required
  105. "views" on the ongoing trajectory (during action computations/inference), past episodes (training
  106. on a batch), or even trajectories of other agents in the same episode, some of which
  107. may even use a different policy.
  108. Such a "view requirements" formalism is helpful when having to support more complex model
  109. setups like RNNs, attention nets, observation image framestacking (e.g. for Atari),
  110. and building multi-agent communication channels.
  111. The way to define a set of rules used for making the Model see certain
  112. data is through a "view requirements dict", residing in the ``Policy.model.view_requirements``
  113. property.
  114. View requirements dicts map strings (column names), such as "obs" or "actions" to
  115. a ``ViewRequirement`` object, which defines the exact conditions by which this column
  116. should be populated with data.
  117. View Requirement Dictionaries
  118. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  119. View requirements are stored within the ``view_requirements`` property of the ``ModelV2``
  120. class.
  121. You can acccess it like this:
  122. .. code-block:: python
  123. my_simple_model = ModelV2(...)
  124. print(my_simple_model.view_requirements)
  125. >>>{"obs": ViewRequirement(shift=0, space=[observation space])}
  126. my_lstm_model = LSTMModel(...)
  127. print(my_lstm_model.view_requirements)
  128. >>>{
  129. >>> "obs": ViewRequirement(shift=0, space=[observation space]),
  130. >>> "prev_actions": ViewRequirement(shift=-1, data_col="actions", space=[action space]),
  131. >>> "prev_rewards": ViewRequirement(shift=-1, data_col="rewards"),
  132. >>>}
  133. The ``view_requirements`` property holds a dictionary mapping
  134. string keys (e.g. "actions", "rewards", "next_obs", etc..)
  135. to a ``ViewRequirement`` object. This ``ViewRequirement`` object determines what exact data to
  136. provide under the given key in case a SampleBatch or a single-timestep (action computing) "input dict"
  137. needs to be build and fed into one of the above ModelV2- or Policy methods.
  138. .. Edit figure below at: https://docs.google.com/drawings/d/1YEPUtMrRXmWfvM0E6mD3VsOaRlLV7DtctF-yL96VHGg/edit
  139. .. image:: images/rllib-trajectory-view-example.svg
  140. **Above:** An example `ViewRequirements` dict that causes the current observation
  141. and the previous action to be available in each compute_action call, as
  142. well as for the Policy's `postprocess_trajectory()` function (and train batch).
  143. A similar setup is often used by LSTM/RNN-based models.
  144. The ViewRequirement class
  145. ~~~~~~~~~~~~~~~~~~~~~~~~~
  146. Here is a description of the constructor-settable properties of a ViewRequirement
  147. object and what each of these properties controls.
  148. **data_col**:
  149. An optional string key referencing the underlying data to use to
  150. create the view. If not provided, assumes that there is data under the
  151. dict-key under which this ViewRequirement resides.
  152. Examples:
  153. .. code-block:: python
  154. ModelV2.view_requirements = {"rewards": ViewRequirements(shift=0)}
  155. # -> implies that the underlying data to use are the collected rewards
  156. # from the environment.
  157. ModelV2.view_requirements = {"prev_rewards": ViewRequirements(data_col="rewards", shift=-1)}
  158. # -> means that the actual data used to create the "prev_rewards" column
  159. # is the "rewards" data from the environment (shifted by 1 timestep).
  160. **space**:
  161. An optional gym.Space used as a hint for the SampleCollector to know,
  162. how to fill timesteps before the episode actually started (e.g. if
  163. shift=-2, we need dummy data at timesteps -2 and -1).
  164. **shift [int]**:
  165. An int, a list of ints, or a range string (e.g. "-50:-1") to indicate
  166. which time offsets or ranges of the underlying data to use for the view.
  167. Examples:
  168. .. code-block:: python
  169. shift=0 # -> Use the data under ``data_col`` as is.
  170. shift=1 # -> Use the data under ``data_col``, but shifted by +1 timestep
  171. # (used by e.g. next_obs views).
  172. shift=-1 # -> Use the data under ``data_col``, but shifted by -1 timestep
  173. # (used by e.g. prev_actions views).
  174. shift=[-2, -1] # -> Use the data under ``data_col``, but always provide 2 values
  175. # at each timestep (the two previous ones).
  176. # Could be used e.g. to feed the last two actions or rewards into an LSTM.
  177. shift="-50:-1" # -> Use the data under ``data_col``, but always provide a range of
  178. # the last 50 timesteps (used by our attention nets).
  179. **used_for_training [bool]**:
  180. True by default. If False, the column will not be available inside the train batch (arriving in the
  181. Policy's loss function).
  182. RLlib will automatically switch this to False for a given column, if it detects during
  183. Policy initialization that that column is not accessed inside the loss function (see below).
  184. How does RLlib determine, which Views are required?
  185. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  186. When initializing a Policy, it automatically determines how to later build batches
  187. for postprocessing, loss function calls, and action computations, based on
  188. the Model's ``view_requirements`` dict. It does so by sending generic dummy batches
  189. through its ``compute_actions_from_input_dict``, ``postprocess_trajectory``, and loss functions
  190. and then checks, which fields in these dummy batches get accessed, overwritten, deleted or added.
  191. Based on these test passes, the Policy then throws out those ViewRequirements from an initial
  192. very broad list, that it deems unnecessary. This procedure saves a lot of data copying
  193. during later rollouts, batch transfers (via ray) and loss calculations and makes things like
  194. manually deleting columns from a SampleBatch (e.g. PPO used to delete the "next_obs" column
  195. inside the postprocessing function) unnecessary.
  196. Note that the "rewards" and "dones" columns are never discarded and thus should always
  197. arrive in your loss function's SampleBatch (``train_batch`` arg).
  198. Setting ViewRequirements manually in your Model
  199. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  200. If you need to specify special view requirements for your model, you can add
  201. columns to the Model's ``view_requirements`` dict in the
  202. Model's constructor.
  203. For example, our auto-LSTM wrapper classes (tf and torch) have these additional
  204. lines in their constructors (torch version shown here):
  205. .. literalinclude:: ../../../rllib/models/torch/recurrent_net.py
  206. :language: python
  207. :start-after: __sphinx_doc_begin__
  208. :end-before: __sphinx_doc_end__
  209. This makes sure that, if the users requires this via the model config, previous rewards
  210. and/or previous actions are added properly to the ``compute_actions`` input-dicts and SampleBatches
  211. used for postprocessing and training.
  212. Another example are our attention nets, which make sure the last n (memory) model outputs
  213. are always fed back into the model on the next time step (tf version shown here).
  214. .. literalinclude:: ../../../rllib/models/tf/attention_net.py
  215. :language: python
  216. :start-after: __sphinx_doc_begin__
  217. :end-before: __sphinx_doc_end__
  218. Setting ViewRequirements manually after Policy construction
  219. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  220. Here is a simple example, of how you can modify and add to the ViewRequirements dict
  221. even after policy (or RolloutWorker) creation. However, note that it's better to
  222. make these modifications to your batches in your postprocessing function:
  223. .. code-block:: python
  224. # Modify view_requirements in the Policy object.
  225. action_space = Discrete(2)
  226. rollout_worker = RolloutWorker(
  227. env_creator=lambda _: gym.make("CartPole-v1"),
  228. policy_config=ppo.DEFAULT_CONFIG,
  229. policy_spec=ppo.PPOTorchPolicy,
  230. )
  231. policy = rollout_worker.policy_map["default_policy"]
  232. # Add the next action to the view reqs of the policy.
  233. # This should be visible then in postprocessing and train batches.
  234. policy.view_requirements["next_actions"] = ViewRequirement(
  235. SampleBatch.ACTIONS, shift=1, space=action_space)
  236. # Check, whether a sampled batch has the requested `next_actions` view.
  237. batch = rollout_worker.sample()
  238. self.assertTrue("next_actions" in batch.data)
  239. # Doing the same in a custom postprocessing callback function:
  240. class MyCallback(DefaultCallbacks):
  241. # ...
  242. @override(DefaultCallbacks)
  243. def on_postprocess_trajectory(self, worker, episode, agent_id, policy_id,
  244. policies, postprocessed_batch, original_batches,
  245. **kwargs):
  246. postprocessed_batch["next_actions"] = np.concatenate(
  247. [postprocessed_batch["actions"][1:],
  248. np.zeros_like([policies[policy_id].action_space.sample()])])
  249. The above two examples add a "next_action" view to the postprocessed SampleBatche needed
  250. used by the Policy for training. It will not feed the "next_action"
  251. to the Model's ``compute_action`` calls (it can't b/c the next action is of course not known
  252. at that point).
  253. .. include:: /_includes/rllib/announcement_bottom.rst