pg_tf_policy.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """
  2. TensorFlow policy class used for PG.
  3. """
  4. from typing import Dict, List, Type, Union
  5. import ray
  6. from ray.rllib.agents.pg.utils import post_process_advantages
  7. from ray.rllib.evaluation.postprocessing import Postprocessing
  8. from ray.rllib.models.action_dist import ActionDistribution
  9. from ray.rllib.models.modelv2 import ModelV2
  10. from ray.rllib.policy import Policy
  11. from ray.rllib.policy.tf_policy_template import build_tf_policy
  12. from ray.rllib.policy.sample_batch import SampleBatch
  13. from ray.rllib.utils.framework import try_import_tf
  14. from ray.rllib.utils.typing import TensorType
  15. tf1, tf, tfv = try_import_tf()
  16. def pg_tf_loss(
  17. policy: Policy, model: ModelV2, dist_class: Type[ActionDistribution],
  18. train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
  19. """The basic policy gradients loss function.
  20. Args:
  21. policy (Policy): The Policy to calculate the loss for.
  22. model (ModelV2): The Model to calculate the loss for.
  23. dist_class (Type[ActionDistribution]: The action distr. class.
  24. train_batch (SampleBatch): The training data.
  25. Returns:
  26. Union[TensorType, List[TensorType]]: A single loss tensor or a list
  27. of loss tensors.
  28. """
  29. # Pass the training data through our model to get distribution parameters.
  30. dist_inputs, _ = model(train_batch)
  31. # Create an action distribution object.
  32. action_dist = dist_class(dist_inputs, model)
  33. # Calculate the vanilla PG loss based on:
  34. # L = -E[ log(pi(a|s)) * A]
  35. loss = -tf.reduce_mean(
  36. action_dist.logp(train_batch[SampleBatch.ACTIONS]) * tf.cast(
  37. train_batch[Postprocessing.ADVANTAGES], dtype=tf.float32))
  38. policy.policy_loss = loss
  39. return loss
  40. def pg_loss_stats(policy: Policy,
  41. train_batch: SampleBatch) -> Dict[str, TensorType]:
  42. """Returns the calculated loss in a stats dict.
  43. Args:
  44. policy (Policy): The Policy object.
  45. train_batch (SampleBatch): The data used for training.
  46. Returns:
  47. Dict[str, TensorType]: The stats dict.
  48. """
  49. return {
  50. "policy_loss": policy.policy_loss,
  51. }
  52. # Build a child class of `DynamicTFPolicy`, given the extra options:
  53. # - trajectory post-processing function (to calculate advantages)
  54. # - PG loss function
  55. PGTFPolicy = build_tf_policy(
  56. name="PGTFPolicy",
  57. get_default_config=lambda: ray.rllib.agents.pg.DEFAULT_CONFIG,
  58. postprocess_fn=post_process_advantages,
  59. stats_fn=pg_loss_stats,
  60. loss_fn=pg_tf_loss)