a3c_torch_policy.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import gym
  2. from typing import Dict, List, Optional
  3. import ray
  4. from ray.rllib.evaluation.episode import Episode
  5. from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
  6. Postprocessing
  7. from ray.rllib.models.action_dist import ActionDistribution
  8. from ray.rllib.models.modelv2 import ModelV2
  9. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  10. from ray.rllib.policy.policy import Policy
  11. from ray.rllib.policy.policy_template import build_policy_class
  12. from ray.rllib.policy.sample_batch import SampleBatch
  13. from ray.rllib.policy.torch_policy import LearningRateSchedule, \
  14. EntropyCoeffSchedule
  15. from ray.rllib.utils.deprecation import Deprecated
  16. from ray.rllib.utils.framework import try_import_torch
  17. from ray.rllib.utils.torch_utils import apply_grad_clipping, sequence_mask
  18. from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
  19. PolicyID, LocalOptimizer
  20. torch, nn = try_import_torch()
  21. @Deprecated(
  22. old="rllib.agents.a3c.a3c_torch_policy.add_advantages",
  23. new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
  24. error=False)
  25. def add_advantages(
  26. policy: Policy,
  27. sample_batch: SampleBatch,
  28. other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
  29. episode: Optional[Episode] = None) -> SampleBatch:
  30. return compute_gae_for_sample_batch(policy, sample_batch,
  31. other_agent_batches, episode)
  32. def actor_critic_loss(policy: Policy, model: ModelV2,
  33. dist_class: ActionDistribution,
  34. train_batch: SampleBatch) -> TensorType:
  35. logits, _ = model(train_batch)
  36. values = model.value_function()
  37. if policy.is_recurrent():
  38. B = len(train_batch[SampleBatch.SEQ_LENS])
  39. max_seq_len = logits.shape[0] // B
  40. mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
  41. max_seq_len)
  42. valid_mask = torch.reshape(mask_orig, [-1])
  43. else:
  44. valid_mask = torch.ones_like(values, dtype=torch.bool)
  45. dist = dist_class(logits, model)
  46. log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
  47. pi_err = -torch.sum(
  48. torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES],
  49. valid_mask))
  50. # Compute a value function loss.
  51. if policy.config["use_critic"]:
  52. value_err = 0.5 * torch.sum(
  53. torch.pow(
  54. torch.masked_select(
  55. values.reshape(-1) -
  56. train_batch[Postprocessing.VALUE_TARGETS], valid_mask),
  57. 2.0))
  58. # Ignore the value function.
  59. else:
  60. value_err = 0.0
  61. entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask))
  62. total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] -
  63. entropy * policy.entropy_coeff)
  64. # Store values for stats function in model (tower), such that for
  65. # multi-GPU, we do not override them during the parallel loss phase.
  66. model.tower_stats["entropy"] = entropy
  67. model.tower_stats["pi_err"] = pi_err
  68. model.tower_stats["value_err"] = value_err
  69. return total_loss
  70. def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
  71. return {
  72. "cur_lr": policy.cur_lr,
  73. "entropy_coeff": policy.entropy_coeff,
  74. "policy_entropy": torch.mean(
  75. torch.stack(policy.get_tower_stats("entropy"))),
  76. "policy_loss": torch.mean(
  77. torch.stack(policy.get_tower_stats("pi_err"))),
  78. "vf_loss": torch.mean(
  79. torch.stack(policy.get_tower_stats("value_err"))),
  80. }
  81. def vf_preds_fetches(
  82. policy: Policy, input_dict: Dict[str, TensorType],
  83. state_batches: List[TensorType], model: ModelV2,
  84. action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
  85. """Defines extra fetches per action computation.
  86. Args:
  87. policy (Policy): The Policy to perform the extra action fetch on.
  88. input_dict (Dict[str, TensorType]): The input dict used for the action
  89. computing forward pass.
  90. state_batches (List[TensorType]): List of state tensors (empty for
  91. non-RNNs).
  92. model (ModelV2): The Model object of the Policy.
  93. action_dist (TorchDistributionWrapper): The instantiated distribution
  94. object, resulting from the model's outputs and the given
  95. distribution class.
  96. Returns:
  97. Dict[str, TensorType]: Dict with extra tf fetches to perform per
  98. action computation.
  99. """
  100. # Return value function outputs. VF estimates will hence be added to the
  101. # SampleBatches produced by the sampler(s) to generate the train batches
  102. # going into the loss function.
  103. return {
  104. SampleBatch.VF_PREDS: model.value_function(),
  105. }
  106. def torch_optimizer(policy: Policy,
  107. config: TrainerConfigDict) -> LocalOptimizer:
  108. return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
  109. class ValueNetworkMixin:
  110. """Assigns the `_value()` method to the PPOPolicy.
  111. This way, Policy can call `_value()` to get the current VF estimate on a
  112. single(!) observation (as done in `postprocess_trajectory_fn`).
  113. Note: When doing this, an actual forward pass is being performed.
  114. This is different from only calling `model.value_function()`, where
  115. the result of the most recent forward pass is being used to return an
  116. already calculated tensor.
  117. """
  118. def __init__(self, obs_space, action_space, config):
  119. # When doing GAE, we need the value function estimate on the
  120. # observation.
  121. if config["use_gae"]:
  122. # Input dict is provided to us automatically via the Model's
  123. # requirements. It's a single-timestep (last one in trajectory)
  124. # input_dict.
  125. def value(**input_dict):
  126. input_dict = SampleBatch(input_dict)
  127. input_dict = self._lazy_tensor_dict(input_dict)
  128. model_out, _ = self.model(input_dict)
  129. # [0] = remove the batch dim.
  130. return self.model.value_function()[0].item()
  131. # When not doing GAE, we do not require the value function's output.
  132. else:
  133. def value(*args, **kwargs):
  134. return 0.0
  135. self._value = value
  136. def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
  137. action_space: gym.spaces.Space,
  138. config: TrainerConfigDict) -> None:
  139. """Call all mixin classes' constructors before PPOPolicy initialization.
  140. Args:
  141. policy (Policy): The Policy object.
  142. obs_space (gym.spaces.Space): The Policy's observation space.
  143. action_space (gym.spaces.Space): The Policy's action space.
  144. config (TrainerConfigDict): The Policy's config.
  145. """
  146. EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
  147. config["entropy_coeff_schedule"])
  148. LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
  149. ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
  150. A3CTorchPolicy = build_policy_class(
  151. name="A3CTorchPolicy",
  152. framework="torch",
  153. get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
  154. loss_fn=actor_critic_loss,
  155. stats_fn=stats,
  156. postprocess_fn=compute_gae_for_sample_batch,
  157. extra_action_out_fn=vf_preds_fetches,
  158. extra_grad_process_fn=apply_grad_clipping,
  159. optimizer_fn=torch_optimizer,
  160. before_loss_init=setup_mixins,
  161. mixins=[ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule],
  162. )