rllib-concepts.rst 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. .. include:: /_includes/rllib/announcement.rst
  2. .. include:: /_includes/rllib/we_are_hiring.rst
  3. .. _rllib-policy-walkthrough:
  4. How To Customize Policies
  5. =========================
  6. This page describes the internal concepts used to implement algorithms in RLlib.
  7. You might find this useful if modifying or adding new algorithms to RLlib.
  8. Policy classes encapsulate the core numerical components of RL algorithms.
  9. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given post-processed experiences.
  10. For a simple example, see the policy gradients `policy definition <https://github.com/ray-project/ray/blob/master/rllib/algorithms/pg/pg_tf_policy.py>`__.
  11. Most interaction with deep learning frameworks is isolated to the `Policy interface <https://github.com/ray-project/ray/blob/master/rllib/policy/policy.py>`__, allowing RLlib to support multiple frameworks.
  12. To simplify the definition of policies, RLlib includes `Tensorflow <#building-policies-in-tensorflow>`__ and `PyTorch-specific <#building-policies-in-pytorch>`__ templates.
  13. You can also write your own from scratch. Here is an example:
  14. .. code-block:: python
  15. class CustomPolicy(Policy):
  16. """Example of a custom policy written from scratch.
  17. You might find it more convenient to use the `build_tf_policy` and
  18. `build_torch_policy` helpers instead for a real policy, which are
  19. described in the next sections.
  20. """
  21. def __init__(self, observation_space, action_space, config):
  22. Policy.__init__(self, observation_space, action_space, config)
  23. # example parameter
  24. self.w = 1.0
  25. def compute_actions(self,
  26. obs_batch,
  27. state_batches,
  28. prev_action_batch=None,
  29. prev_reward_batch=None,
  30. info_batch=None,
  31. episodes=None,
  32. **kwargs):
  33. # return action batch, RNN states, extra values to include in batch
  34. return [self.action_space.sample() for _ in obs_batch], [], {}
  35. def learn_on_batch(self, samples):
  36. # implement your learning code here
  37. return {} # return stats
  38. def get_weights(self):
  39. return {"w": self.w}
  40. def set_weights(self, weights):
  41. self.w = weights["w"]
  42. The above basic policy, when run, will produce batches of observations with the basic ``obs``, ``new_obs``, ``actions``, ``rewards``, ``dones``, and ``infos`` columns.
  43. There are two more mechanisms to pass along and emit extra information:
  44. **Policy recurrent state**: Suppose you want to compute actions based on the current timestep of the episode.
  45. While it is possible to have the environment provide this as part of the observation, we can instead compute and store it as part of the Policy recurrent state:
  46. .. code-block:: python
  47. def get_initial_state(self):
  48. """Returns initial RNN state for the current policy."""
  49. return [0] # list of single state element (t=0)
  50. # you could also return multiple values, e.g., [0, "foo"]
  51. def compute_actions(self,
  52. obs_batch,
  53. state_batches,
  54. prev_action_batch=None,
  55. prev_reward_batch=None,
  56. info_batch=None,
  57. episodes=None,
  58. **kwargs):
  59. assert len(state_batches) == len(self.get_initial_state())
  60. new_state_batches = [[
  61. t + 1 for t in state_batches[0]
  62. ]]
  63. return ..., new_state_batches, {}
  64. def learn_on_batch(self, samples):
  65. # can access array of the state elements at each timestep
  66. # or state_in_1, 2, etc. if there are multiple state elements
  67. assert "state_in_0" in samples.keys()
  68. assert "state_out_0" in samples.keys()
  69. **Extra action info output**: You can also emit extra outputs at each step which will be available for learning on. For example, you might want to output the behaviour policy logits as extra action info, which can be used for importance weighting, but in general arbitrary values can be stored here (as long as they are convertible to numpy arrays):
  70. .. code-block:: python
  71. def compute_actions(self,
  72. obs_batch,
  73. state_batches,
  74. prev_action_batch=None,
  75. prev_reward_batch=None,
  76. info_batch=None,
  77. episodes=None,
  78. **kwargs):
  79. action_info_batch = {
  80. "some_value": ["foo" for _ in obs_batch],
  81. "other_value": [12345 for _ in obs_batch],
  82. }
  83. return ..., [], action_info_batch
  84. def learn_on_batch(self, samples):
  85. # can access array of the extra values at each timestep
  86. assert "some_value" in samples.keys()
  87. assert "other_value" in samples.keys()
  88. Policies in Multi-Agent
  89. ~~~~~~~~~~~~~~~~~~~~~~~
  90. Beyond being agnostic of framework implementation, one of the main reasons to have a Policy abstraction is for use in multi-agent environments. For example, the `rock-paper-scissors example <rllib-env.html#rock-paper-scissors-example>`__ shows how you can leverage the Policy abstraction to evaluate heuristic policies against learned policies.
  91. Building Policies in TensorFlow
  92. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  93. This section covers how to build a TensorFlow RLlib policy using ``tf_policy_template.build_tf_policy()``.
  94. To start, you first have to define a loss function. In RLlib, loss functions are defined over batches of trajectory data produced by policy evaluation. A basic policy gradient loss that only tries to maximize the 1-step reward can be defined as follows:
  95. .. code-block:: python
  96. import tensorflow as tf
  97. from ray.rllib.policy.sample_batch import SampleBatch
  98. def policy_gradient_loss(policy, model, dist_class, train_batch):
  99. actions = train_batch[SampleBatch.ACTIONS]
  100. rewards = train_batch[SampleBatch.REWARDS]
  101. logits, _ = model.from_batch(train_batch)
  102. action_dist = dist_class(logits, model)
  103. return -tf.reduce_mean(action_dist.logp(actions) * rewards)
  104. In the above snippet, ``actions`` is a Tensor placeholder of shape ``[batch_size, action_dim...]``, and ``rewards`` is a placeholder of shape ``[batch_size]``. The ``action_dist`` object is an :ref:`ActionDistribution <rllib-models-walkthrough>` that is parameterized by the output of the neural network policy model. Passing this loss function to ``build_tf_policy`` is enough to produce a very basic TF policy:
  105. .. code-block:: python
  106. from ray.rllib.policy.tf_policy_template import build_tf_policy
  107. # <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
  108. MyTFPolicy = build_tf_policy(
  109. name="MyTFPolicy",
  110. loss_fn=policy_gradient_loss)
  111. We can create an `Algorithm <#algorithms>`__ and try running this policy on a toy env with two parallel rollout workers:
  112. .. code-block:: python
  113. import ray
  114. from ray import tune
  115. from ray.rllib.algorithms.algorithm import Algorithm
  116. class MyAlgo(Algorithm):
  117. def get_default_policy_class(self, config):
  118. return MyTFPolicy
  119. ray.init()
  120. tune.Tuner(MyAlgo, param_space={"env": "CartPole-v1", "num_workers": 2}).fit()
  121. If you run the above snippet `(runnable file here) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_tf_policy.py>`__, you'll probably notice that CartPole doesn't learn so well:
  122. .. code-block:: bash
  123. == Status ==
  124. Using FIFO scheduling algorithm.
  125. Resources requested: 3/4 CPUs, 0/0 GPUs
  126. Memory usage on this node: 4.6/12.3 GB
  127. Result logdir: /home/ubuntu/ray_results/MyAlgTrainer
  128. Number of trials: 1 ({'RUNNING': 1})
  129. RUNNING trials:
  130. - MyAlgTrainer_CartPole-v0_0: RUNNING, [3 CPUs, 0 GPUs], [pid=26784],
  131. 32 s, 156 iter, 62400 ts, 23.1 rew
  132. Let's modify our policy loss to include rewards summed over time. To enable this advantage calculation, we need to define a *trajectory postprocessor* for the policy. This can be done by defining ``postprocess_fn``:
  133. .. code-block:: python
  134. from ray.rllib.evaluation.postprocessing import compute_advantages, \
  135. Postprocessing
  136. def postprocess_advantages(policy,
  137. sample_batch,
  138. other_agent_batches=None,
  139. episode=None):
  140. return compute_advantages(
  141. sample_batch, 0.0, policy.config["gamma"], use_gae=False)
  142. def policy_gradient_loss(policy, model, dist_class, train_batch):
  143. logits, _ = model.from_batch(train_batch)
  144. action_dist = dist_class(logits, model)
  145. return -tf.reduce_mean(
  146. action_dist.logp(train_batch[SampleBatch.ACTIONS]) *
  147. train_batch[Postprocessing.ADVANTAGES])
  148. MyTFPolicy = build_tf_policy(
  149. name="MyTFPolicy",
  150. loss_fn=policy_gradient_loss,
  151. postprocess_fn=postprocess_advantages)
  152. The ``postprocess_advantages()`` function above uses calls RLlib's ``compute_advantages`` function to compute advantages for each timestep. If you re-run the algorithm with this improved policy, you'll find that it quickly achieves the max reward of 200.
  153. You might be wondering how RLlib makes the advantages placeholder automatically available as ``train_batch[Postprocessing.ADVANTAGES]``. When building your policy, RLlib will create a "dummy" trajectory batch where all observations, actions, rewards, etc. are zeros. It then calls your ``postprocess_fn``, and generates TF placeholders based on the numpy shapes of the postprocessed batch. RLlib tracks which placeholders that ``loss_fn`` and ``stats_fn`` access, and then feeds the corresponding sample data into those placeholders during loss optimization. You can also access these placeholders via ``policy.get_placeholder(<name>)`` after loss initialization.
  154. **Example 1: Proximal Policy Optimization**
  155. In the above section you saw how to compose a simple policy gradient algorithm with RLlib.
  156. In this example, we'll dive into how PPO is defined within RLlib and how you can modify it.
  157. First, check out the `PPO definition <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py>`__:
  158. .. code-block:: python
  159. class PPO(Algorithm):
  160. @classmethod
  161. @override(Algorithm)
  162. def get_default_config(cls) -> AlgorithmConfigDict:
  163. return DEFAULT_CONFIG
  164. @override(Algorithm)
  165. def validate_config(self, config: AlgorithmConfigDict) -> None:
  166. ...
  167. @override(Algorithm)
  168. def get_default_policy_class(self, config):
  169. return PPOTFPolicy
  170. @override(Algorithm)
  171. def training_step(self):
  172. ...
  173. Besides some boilerplate for defining the PPO configuration and some warnings, the most important method to take note of is the ``training_step``.
  174. The algorithm's `training step method <core-concepts.html#training-step-method>`__ defines the distributed training workflow.
  175. Depending on the ``simple_optimizer`` config setting,
  176. PPO can switch between a simple, synchronous optimizer, or a multi-GPU one that implements
  177. pre-loading of the batch to the GPU for higher performance on repeated minibatch updates utilizing
  178. the same pre-loaded batch:
  179. .. code-block:: python
  180. def training_step(self) -> ResultDict:
  181. # Collect SampleBatches from sample workers until we have a full batch.
  182. if self._by_agent_steps:
  183. train_batch = synchronous_parallel_sample(
  184. worker_set=self.workers, max_agent_steps=self.config["train_batch_size"]
  185. )
  186. else:
  187. train_batch = synchronous_parallel_sample(
  188. worker_set=self.workers, max_env_steps=self.config["train_batch_size"]
  189. )
  190. train_batch = train_batch.as_multi_agent()
  191. self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
  192. self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
  193. # Standardize advantages
  194. train_batch = standardize_fields(train_batch, ["advantages"])
  195. # Train
  196. if self.config["simple_optimizer"]:
  197. train_results = train_one_step(self, train_batch)
  198. else:
  199. train_results = multi_gpu_train_one_step(self, train_batch)
  200. global_vars = {
  201. "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
  202. }
  203. # Update weights - after learning on the local worker - on all remote
  204. # workers.
  205. if self.workers.remote_workers():
  206. with self._timers[WORKER_UPDATE_TIMER]:
  207. self.workers.sync_weights(global_vars=global_vars)
  208. # For each policy: update KL scale and warn about possible issues
  209. for policy_id, policy_info in train_results.items():
  210. # Update KL loss with dynamic scaling
  211. # for each (possibly multiagent) policy we are training
  212. kl_divergence = policy_info[LEARNER_STATS_KEY].get("kl")
  213. self.get_policy(policy_id).update_kl(kl_divergence)
  214. # Update global vars on local worker as well.
  215. self.workers.local_worker().set_global_vars(global_vars)
  216. return train_results
  217. Now let's look at each PPO policy definition:
  218. .. code-block:: python
  219. PPOTFPolicy = build_tf_policy(
  220. name="PPOTFPolicy",
  221. get_default_config=lambda: ray.rllib.algorithms.ppo.ppo.PPOConfig().to_dict(),
  222. loss_fn=ppo_surrogate_loss,
  223. stats_fn=kl_and_loss_stats,
  224. extra_action_out_fn=vf_preds_and_logits_fetches,
  225. postprocess_fn=postprocess_ppo_gae,
  226. gradients_fn=clip_gradients,
  227. before_loss_init=setup_mixins,
  228. mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin])
  229. ``stats_fn``: The stats function returns a dictionary of Tensors that will be reported with the training results. This also includes the ``kl`` metric which is used by the algorithm to adjust the KL penalty. Note that many of the values below reference ``policy.loss_obj``, which is assigned by ``loss_fn`` (not shown here since the PPO loss is quite complex). RLlib will always call ``stats_fn`` after ``loss_fn``, so you can rely on using values saved by ``loss_fn`` as part of your statistics:
  230. .. code-block:: python
  231. def kl_and_loss_stats(policy, train_batch):
  232. policy.explained_variance = explained_variance(
  233. train_batch[Postprocessing.VALUE_TARGETS], policy.model.value_function())
  234. stats_fetches = {
  235. "cur_kl_coeff": policy.kl_coeff,
  236. "cur_lr": tf.cast(policy.cur_lr, tf.float64),
  237. "total_loss": policy.loss_obj.loss,
  238. "policy_loss": policy.loss_obj.mean_policy_loss,
  239. "vf_loss": policy.loss_obj.mean_vf_loss,
  240. "vf_explained_var": policy.explained_variance,
  241. "kl": policy.loss_obj.mean_kl,
  242. "entropy": policy.loss_obj.mean_entropy,
  243. }
  244. return stats_fetches
  245. ``extra_actions_fetches_fn``: This function defines extra outputs that will be recorded when generating actions with the policy. For example, this enables saving the raw policy logits in the experience batch, which e.g. means it can be referenced in the PPO loss function via ``batch[BEHAVIOUR_LOGITS]``. Other values such as the current value prediction can also be emitted for debugging or optimization purposes:
  246. .. code-block:: python
  247. def vf_preds_and_logits_fetches(policy):
  248. return {
  249. SampleBatch.VF_PREDS: policy.model.value_function(),
  250. BEHAVIOUR_LOGITS: policy.model.last_output(),
  251. }
  252. ``gradients_fn``: If defined, this function returns TF gradients for the loss function. You'd typically only want to override this to apply transformations such as gradient clipping:
  253. .. code-block:: python
  254. def clip_gradients(policy, optimizer, loss):
  255. if policy.config["grad_clip"] is not None:
  256. grads = tf.gradients(loss, policy.model.trainable_variables())
  257. policy.grads, _ = tf.clip_by_global_norm(grads,
  258. policy.config["grad_clip"])
  259. clipped_grads = list(zip(policy.grads, policy.model.trainable_variables()))
  260. return clipped_grads
  261. else:
  262. return optimizer.compute_gradients(
  263. loss, colocate_gradients_with_ops=True)
  264. ``mixins``: To add arbitrary stateful components, you can add mixin classes to the policy. Methods defined by these mixins will have higher priority than the base policy class, so you can use these to override methods (as in the case of ``LearningRateSchedule``), or define extra methods and attributes (e.g., ``KLCoeffMixin``, ``ValueNetworkMixin``). Like any other Python superclass, these should be initialized at some point, which is what the ``setup_mixins`` function does:
  265. .. code-block:: python
  266. def setup_mixins(policy, obs_space, action_space, config):
  267. ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
  268. KLCoeffMixin.__init__(policy, config)
  269. LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
  270. In PPO we run ``setup_mixins`` before the loss function is called (i.e., ``before_loss_init``), but other callbacks you can use include ``before_init`` and ``after_init``.
  271. **Example 2: Deep Q Networks**
  272. Let's look at how to implement a different family of policies, by looking at the `SimpleQ policy definition <https://github.com/ray-project/ray/blob/master/rllib/algorithms/simple_q/simple_q_tf_policy.py>`__:
  273. .. code-block:: python
  274. SimpleQPolicy = build_tf_policy(
  275. name="SimpleQPolicy",
  276. get_default_config=lambda: ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG,
  277. make_model=build_q_models,
  278. action_sampler_fn=build_action_sampler,
  279. loss_fn=build_q_losses,
  280. extra_action_feed_fn=exploration_setting_inputs,
  281. extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
  282. extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
  283. before_init=setup_early_mixins,
  284. after_init=setup_late_mixins,
  285. obs_include_prev_action_reward=False,
  286. mixins=[
  287. ExplorationStateMixin,
  288. TargetNetworkMixin,
  289. ])
  290. The biggest difference from the policy gradient policies you saw previously is that SimpleQPolicy defines its own ``make_model`` and ``action_sampler_fn``. This means that the policy builder will not internally create a model and action distribution, rather it will call ``build_q_models`` and ``build_action_sampler`` to get the output action tensors.
  291. The model creation function actually creates two different models for DQN: the base Q network, and also a target network. It requires each model to be of type ``SimpleQModel``, which implements a ``get_q_values()`` method. The model catalog will raise an error if you try to use a custom ModelV2 model that isn't a subclass of SimpleQModel. Similarly, the full DQN policy requires models to subclass ``DistributionalQModel``, which implements ``get_q_value_distributions()`` and ``get_state_value()``:
  292. .. code-block:: python
  293. def build_q_models(policy, obs_space, action_space, config):
  294. ...
  295. policy.q_model = ModelCatalog.get_model_v2(
  296. obs_space,
  297. action_space,
  298. num_outputs,
  299. config["model"],
  300. framework="tf",
  301. name=Q_SCOPE,
  302. model_interface=SimpleQModel,
  303. q_hiddens=config["hiddens"])
  304. policy.target_q_model = ModelCatalog.get_model_v2(
  305. obs_space,
  306. action_space,
  307. num_outputs,
  308. config["model"],
  309. framework="tf",
  310. name=Q_TARGET_SCOPE,
  311. model_interface=SimpleQModel,
  312. q_hiddens=config["hiddens"])
  313. return policy.q_model
  314. The action sampler is straightforward, it just takes the q_model, runs a forward pass, and returns the argmax over the actions:
  315. .. code-block:: python
  316. def build_action_sampler(policy, q_model, input_dict, obs_space, action_space,
  317. config):
  318. # do max over Q values...
  319. ...
  320. return action, action_logp
  321. The remainder of DQN is similar to other algorithms. Target updates are handled by a ``after_optimizer_step`` callback that periodically copies the weights of the Q network to the target.
  322. Finally, note that you do not have to use ``build_tf_policy`` to define a TensorFlow policy. You can alternatively subclass ``Policy``, ``TFPolicy``, or ``DynamicTFPolicy`` as convenient.
  323. Building Policies in TensorFlow Eager
  324. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  325. Policies built with ``build_tf_policy`` (most of the reference algorithms are)
  326. can be run in eager mode by setting
  327. the ``"framework": "tf2"`` / ``"eager_tracing": true`` config options or
  328. using ``rllib train '{"framework": "tf2", "eager_tracing": true}'``.
  329. This will tell RLlib to execute the model forward pass, action distribution,
  330. loss, and stats functions in eager mode.
  331. Eager mode makes debugging much easier, since you can now use line-by-line
  332. debugging with breakpoints or Python ``print()`` to inspect
  333. intermediate tensor values.
  334. However, eager can be slower than graph mode unless tracing is enabled.
  335. You can also selectively leverage eager operations within graph mode
  336. execution with `tf.py_function <https://www.tensorflow.org/api_docs/python/tf/py_function>`__.
  337. Here's an example of using eager ops embedded
  338. `within a loss function <https://github.com/ray-project/ray/blob/master/rllib/examples/eager_execution.py>`__.
  339. Building Policies in PyTorch
  340. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  341. Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a algorithm given a Torch policy is exactly the same).
  342. Here's a simple example of a trivial torch policy `(runnable file here) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_policy.py>`__:
  343. .. code-block:: python
  344. from ray.rllib.policy.sample_batch import SampleBatch
  345. from ray.rllib.policy.torch_policy_template import build_torch_policy
  346. def policy_gradient_loss(policy, model, dist_class, train_batch):
  347. logits, _ = model.from_batch(train_batch)
  348. action_dist = dist_class(logits)
  349. log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
  350. return -train_batch[SampleBatch.REWARDS].dot(log_probs)
  351. # <class 'ray.rllib.policy.torch_policy_template.MyTorchPolicy'>
  352. MyTorchPolicy = build_torch_policy(
  353. name="MyTorchPolicy",
  354. loss_fn=policy_gradient_loss)
  355. Now, building on the TF examples above, let's look at how the `A3C torch policy <https://github.com/ray-project/ray/blob/master/rllib/algorithms/a3c/a3c_torch_policy.py>`__ is defined:
  356. .. code-block:: python
  357. A3CTorchPolicy = build_torch_policy(
  358. name="A3CTorchPolicy",
  359. get_default_config=lambda: ray.rllib.algorithms.a3c.a3c.DEFAULT_CONFIG,
  360. loss_fn=actor_critic_loss,
  361. stats_fn=loss_and_entropy_stats,
  362. postprocess_fn=add_advantages,
  363. extra_action_out_fn=model_value_predictions,
  364. extra_grad_process_fn=apply_grad_clipping,
  365. optimizer_fn=torch_optimizer,
  366. mixins=[ValueNetworkMixin])
  367. ``loss_fn``: Similar to the TF example, the actor critic loss is defined over ``batch``. We imperatively execute the forward pass by calling ``model()`` on the observations followed by ``dist_class()`` on the output logits. The output Tensors are saved as attributes of the policy object (e.g., ``policy.entropy = dist.entropy.mean()``), and we return the scalar loss:
  368. .. code-block:: python
  369. def actor_critic_loss(policy, model, dist_class, train_batch):
  370. logits, _ = model.from_batch(train_batch)
  371. values = model.value_function()
  372. action_dist = dist_class(logits)
  373. log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
  374. policy.entropy = action_dist.entropy().mean()
  375. ...
  376. return overall_err
  377. ``stats_fn``: The stats function references ``entropy``, ``pi_err``, and ``value_err`` saved from the call to the loss function, similar in the PPO TF example:
  378. .. code-block:: python
  379. def loss_and_entropy_stats(policy, train_batch):
  380. return {
  381. "policy_entropy": policy.entropy.item(),
  382. "policy_loss": policy.pi_err.item(),
  383. "vf_loss": policy.value_err.item(),
  384. }
  385. ``extra_action_out_fn``: We save value function predictions given model outputs. This makes the value function predictions of the model available in the trajectory as ``batch[SampleBatch.VF_PREDS]``:
  386. .. code-block:: python
  387. def model_value_predictions(policy, input_dict, state_batches, model):
  388. return {SampleBatch.VF_PREDS: model.value_function().cpu().numpy()}
  389. ``postprocess_fn`` and ``mixins``: Similar to the PPO example, we need access to the value function during postprocessing (i.e., ``add_advantages`` below calls ``policy._value()``. The value function is exposed through a mixin class that defines the method:
  390. .. code-block:: python
  391. def add_advantages(policy,
  392. sample_batch,
  393. other_agent_batches=None,
  394. episode=None):
  395. completed = sample_batch[SampleBatch.DONES][-1]
  396. if completed:
  397. last_r = 0.0
  398. else:
  399. last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
  400. return compute_advantages(sample_batch, last_r, policy.config["gamma"],
  401. policy.config["lambda"])
  402. class ValueNetworkMixin(object):
  403. def _value(self, obs):
  404. with self.lock:
  405. obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
  406. _, _, vf, _ = self.model({"obs": obs}, [])
  407. return vf.detach().cpu().numpy().squeeze()
  408. You can find the full policy definition in `a3c_torch_policy.py <https://github.com/ray-project/ray/blob/master/rllib/algorithms/a3c/a3c_torch_policy.py>`__.
  409. In summary, the main differences between the PyTorch and TensorFlow policy builder functions is that the TF loss and stats functions are built symbolically when the policy is initialized, whereas for PyTorch (or TensorFlow Eager) these functions are called imperatively each time they are used.
  410. Extending Existing Policies
  411. ~~~~~~~~~~~~~~~~~~~~~~~~~~~
  412. You can use the ``with_updates`` method on Trainers and Policy objects built with ``make_*`` to create a copy of the object with some changes, for example:
  413. .. code-block:: python
  414. from ray.rllib.algorithms.ppo import PPO
  415. from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTFPolicy
  416. CustomPolicy = PPOTFPolicy.with_updates(
  417. name="MyCustomPPOTFPolicy",
  418. loss_fn=some_custom_loss_fn)
  419. CustomTrainer = PPOTrainer.with_updates(
  420. default_policy=CustomPolicy)
  421. .. include:: /_includes/rllib/announcement_bottom.rst