pg_tf_policy.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """
  2. TensorFlow policy class used for PG.
  3. """
  4. import logging
  5. from typing import Dict, List, Type, Union, Optional, Tuple
  6. from ray.rllib.evaluation.episode import Episode
  7. from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
  8. from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
  9. from ray.rllib.algorithms.pg.pg import PGConfig
  10. from ray.rllib.algorithms.pg.utils import post_process_advantages
  11. from ray.rllib.utils.typing import AgentID
  12. from ray.rllib.utils.annotations import override
  13. from ray.rllib.utils.typing import (
  14. TFPolicyV2Type,
  15. )
  16. from ray.rllib.evaluation.postprocessing import Postprocessing
  17. from ray.rllib.models.action_dist import ActionDistribution
  18. from ray.rllib.models.modelv2 import ModelV2
  19. from ray.rllib.policy import Policy
  20. from ray.rllib.policy.sample_batch import SampleBatch
  21. from ray.rllib.policy.tf_mixins import LearningRateSchedule
  22. from ray.rllib.utils.framework import try_import_tf
  23. from ray.rllib.utils.typing import TensorType
  24. tf1, tf, tfv = try_import_tf()
  25. logger = logging.getLogger(__name__)
  26. # We need this builder function because we want to share the same
  27. # custom logics between TF1 dynamic and TF2 eager policies.
  28. def get_pg_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
  29. """Construct a PGTFPolicy inheriting either dynamic or eager base policies.
  30. Args:
  31. base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
  32. Returns:
  33. A TF Policy to be used with PG.
  34. """
  35. class PGTFPolicy(
  36. LearningRateSchedule,
  37. base,
  38. ):
  39. def __init__(
  40. self,
  41. observation_space,
  42. action_space,
  43. config: PGConfig,
  44. existing_model=None,
  45. existing_inputs=None,
  46. ):
  47. # First thing first, enable eager execution if necessary.
  48. base.enable_eager_execution_if_necessary()
  49. # Enforce AlgorithmConfig for PG Policies.
  50. if isinstance(config, dict):
  51. config = PGConfig.from_dict(config)
  52. # Initialize base class.
  53. base.__init__(
  54. self,
  55. observation_space,
  56. action_space,
  57. config,
  58. existing_inputs=existing_inputs,
  59. existing_model=existing_model,
  60. )
  61. LearningRateSchedule.__init__(self, config.lr, config.lr_schedule)
  62. # Note: this is a bit ugly, but loss and optimizer initialization must
  63. # happen after all the MixIns are initialized.
  64. self.maybe_initialize_optimizer_and_loss()
  65. @override(base)
  66. def loss(
  67. self,
  68. model: ModelV2,
  69. dist_class: Type[ActionDistribution],
  70. train_batch: SampleBatch,
  71. ) -> Union[TensorType, List[TensorType]]:
  72. """The basic policy gradients loss function.
  73. Calculates the vanilla policy gradient loss based on:
  74. L = -E[ log(pi(a|s)) * A]
  75. Args:
  76. model: The Model to calculate the loss for.
  77. dist_class: The action distr. class.
  78. train_batch: The training data.
  79. Returns:
  80. Union[TensorType, List[TensorType]]: A single loss tensor or a list
  81. of loss tensors.
  82. """
  83. # Pass the training data through our model to get distribution parameters.
  84. dist_inputs, _ = model(train_batch)
  85. # Create an action distribution object.
  86. action_dist = dist_class(dist_inputs, model)
  87. # Calculate the vanilla PG loss based on:
  88. # L = -E[ log(pi(a|s)) * A]
  89. loss = -tf.reduce_mean(
  90. action_dist.logp(train_batch[SampleBatch.ACTIONS])
  91. * tf.cast(train_batch[Postprocessing.ADVANTAGES], dtype=tf.float32)
  92. )
  93. self.policy_loss = loss
  94. return loss
  95. @override(base)
  96. def postprocess_trajectory(
  97. self,
  98. sample_batch: SampleBatch,
  99. other_agent_batches: Optional[
  100. Dict[AgentID, Tuple["Policy", SampleBatch]]
  101. ] = None,
  102. episode: Optional["Episode"] = None,
  103. ) -> SampleBatch:
  104. sample_batch = super().postprocess_trajectory(
  105. sample_batch, other_agent_batches, episode
  106. )
  107. return post_process_advantages(
  108. self, sample_batch, other_agent_batches, episode
  109. )
  110. @override(base)
  111. def extra_learn_fetches_fn(self) -> Dict[str, TensorType]:
  112. return {
  113. "learner_stats": {"cur_lr": self.cur_lr},
  114. }
  115. @override(base)
  116. def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
  117. """Returns the calculated loss and learning rate in a stats dict.
  118. Args:
  119. policy: The Policy object.
  120. train_batch: The data used for training.
  121. Returns:
  122. Dict[str, TensorType]: The stats dict.
  123. """
  124. return {
  125. "policy_loss": self.policy_loss,
  126. "cur_lr": self.cur_lr,
  127. }
  128. PGTFPolicy.__name__ = name
  129. PGTFPolicy.__qualname__ = name
  130. return PGTFPolicy
  131. PGTF1Policy = get_pg_tf_policy("PGTF1Policy", DynamicTFPolicyV2)
  132. PGTF2Policy = get_pg_tf_policy("PGTF2Policy", EagerTFPolicyV2)