123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- """
- TensorFlow policy class used for PG.
- """
- from typing import Dict, List, Type, Union
- import ray
- from ray.rllib.agents.pg.utils import post_process_advantages
- from ray.rllib.evaluation.postprocessing import Postprocessing
- from ray.rllib.models.action_dist import ActionDistribution
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.policy import Policy
- from ray.rllib.policy.tf_policy_template import build_tf_policy
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.typing import TensorType
- tf1, tf, tfv = try_import_tf()
- def pg_tf_loss(
- policy: Policy, model: ModelV2, dist_class: Type[ActionDistribution],
- train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
- """The basic policy gradients loss function.
- Args:
- policy (Policy): The Policy to calculate the loss for.
- model (ModelV2): The Model to calculate the loss for.
- dist_class (Type[ActionDistribution]: The action distr. class.
- train_batch (SampleBatch): The training data.
- Returns:
- Union[TensorType, List[TensorType]]: A single loss tensor or a list
- of loss tensors.
- """
- # Pass the training data through our model to get distribution parameters.
- dist_inputs, _ = model(train_batch)
- # Create an action distribution object.
- action_dist = dist_class(dist_inputs, model)
- # Calculate the vanilla PG loss based on:
- # L = -E[ log(pi(a|s)) * A]
- loss = -tf.reduce_mean(
- action_dist.logp(train_batch[SampleBatch.ACTIONS]) * tf.cast(
- train_batch[Postprocessing.ADVANTAGES], dtype=tf.float32))
- policy.policy_loss = loss
- return loss
- def pg_loss_stats(policy: Policy,
- train_batch: SampleBatch) -> Dict[str, TensorType]:
- """Returns the calculated loss in a stats dict.
- Args:
- policy (Policy): The Policy object.
- train_batch (SampleBatch): The data used for training.
- Returns:
- Dict[str, TensorType]: The stats dict.
- """
- return {
- "policy_loss": policy.policy_loss,
- }
- # Build a child class of `DynamicTFPolicy`, given the extra options:
- # - trajectory post-processing function (to calculate advantages)
- # - PG loss function
- PGTFPolicy = build_tf_policy(
- name="PGTFPolicy",
- get_default_config=lambda: ray.rllib.agents.pg.DEFAULT_CONFIG,
- postprocess_fn=post_process_advantages,
- stats_fn=pg_loss_stats,
- loss_fn=pg_tf_loss)
|