pg_torch_policy.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. """
  2. PyTorch policy class used for PG.
  3. """
  4. import logging
  5. from typing import Dict, List, Type, Union, Optional, Tuple
  6. from ray.rllib.evaluation.episode import Episode
  7. from ray.rllib.utils.typing import AgentID
  8. from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
  9. from ray.rllib.utils.annotations import override
  10. from ray.rllib.utils.numpy import convert_to_numpy
  11. from ray.rllib.algorithms.pg.pg import PGConfig
  12. from ray.rllib.algorithms.pg.utils import post_process_advantages
  13. from ray.rllib.evaluation.postprocessing import Postprocessing
  14. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  15. from ray.rllib.models.modelv2 import ModelV2
  16. from ray.rllib.policy import Policy
  17. from ray.rllib.policy.sample_batch import SampleBatch
  18. from ray.rllib.policy.torch_mixins import LearningRateSchedule
  19. from ray.rllib.utils.framework import try_import_torch
  20. from ray.rllib.utils.typing import TensorType
  21. torch, nn = try_import_torch()
  22. logger = logging.getLogger(__name__)
  23. class PGTorchPolicy(LearningRateSchedule, TorchPolicyV2):
  24. """PyTorch policy class used with PG."""
  25. def __init__(self, observation_space, action_space, config: PGConfig):
  26. # Enforce AlgorithmConfig for PG Policies.
  27. if isinstance(config, dict):
  28. config = PGConfig.from_dict(config)
  29. TorchPolicyV2.__init__(
  30. self,
  31. observation_space,
  32. action_space,
  33. config,
  34. max_seq_len=config.model["max_seq_len"],
  35. )
  36. LearningRateSchedule.__init__(self, config.lr, config.lr_schedule)
  37. # TODO: Don't require users to call this manually.
  38. self._initialize_loss_from_dummy_batch()
  39. @override(TorchPolicyV2)
  40. def loss(
  41. self,
  42. model: ModelV2,
  43. dist_class: Type[TorchDistributionWrapper],
  44. train_batch: SampleBatch,
  45. ) -> Union[TensorType, List[TensorType]]:
  46. """The basic policy gradients loss function.
  47. Calculates the vanilla policy gradient loss based on:
  48. L = -E[ log(pi(a|s)) * A]
  49. Args:
  50. model: The Model to calculate the loss for.
  51. dist_class: The action distr. class.
  52. train_batch: The training data.
  53. Returns:
  54. Union[TensorType, List[TensorType]]: A single loss tensor or a list
  55. of loss tensors.
  56. """
  57. # Pass the training data through our model to get distribution parameters.
  58. dist_inputs, _ = model(train_batch)
  59. # Create an action distribution object.
  60. action_dist = dist_class(dist_inputs, model)
  61. # Calculate the vanilla PG loss based on:
  62. # L = -E[ log(pi(a|s)) * A]
  63. log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
  64. # Final policy loss.
  65. policy_loss = -torch.mean(log_probs * train_batch[Postprocessing.ADVANTAGES])
  66. # Store values for stats function in model (tower), such that for
  67. # multi-GPU, we do not override them during the parallel loss phase.
  68. model.tower_stats["policy_loss"] = policy_loss
  69. return policy_loss
  70. @override(TorchPolicyV2)
  71. def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
  72. """Returns the calculated loss in a stats dict.
  73. Args:
  74. policy: The Policy object.
  75. train_batch: The data used for training.
  76. Returns:
  77. Dict[str, TensorType]: The stats dict.
  78. """
  79. return convert_to_numpy(
  80. {
  81. "policy_loss": torch.mean(
  82. torch.stack(self.get_tower_stats("policy_loss"))
  83. ),
  84. "cur_lr": self.cur_lr,
  85. }
  86. )
  87. @override(TorchPolicyV2)
  88. def postprocess_trajectory(
  89. self,
  90. sample_batch: SampleBatch,
  91. other_agent_batches: Optional[
  92. Dict[AgentID, Tuple["Policy", SampleBatch]]
  93. ] = None,
  94. episode: Optional["Episode"] = None,
  95. ) -> SampleBatch:
  96. sample_batch = super().postprocess_trajectory(
  97. sample_batch, other_agent_batches, episode
  98. )
  99. return post_process_advantages(self, sample_batch, other_agent_batches, episode)