simple_q_torch_policy.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """PyTorch policy class used for Simple Q-Learning"""
  2. import logging
  3. from typing import Dict, Tuple
  4. import gym
  5. import ray
  6. from ray.rllib.agents.dqn.simple_q_tf_policy import (
  7. build_q_models, compute_q_values, get_distribution_inputs_and_class)
  8. from ray.rllib.models.modelv2 import ModelV2
  9. from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
  10. TorchDistributionWrapper
  11. from ray.rllib.policy import Policy
  12. from ray.rllib.policy.policy_template import build_policy_class
  13. from ray.rllib.policy.sample_batch import SampleBatch
  14. from ray.rllib.policy.torch_policy import TorchPolicy
  15. from ray.rllib.utils.annotations import override
  16. from ray.rllib.utils.framework import try_import_torch
  17. from ray.rllib.utils.torch_utils import concat_multi_gpu_td_errors, huber_loss
  18. from ray.rllib.utils.typing import TensorType, TrainerConfigDict
  19. torch, nn = try_import_torch()
  20. F = None
  21. if nn:
  22. F = nn.functional
  23. logger = logging.getLogger(__name__)
  24. class TargetNetworkMixin:
  25. """Assign the `update_target` method to the SimpleQTorchPolicy
  26. The function is called every `target_network_update_freq` steps by the
  27. master learner.
  28. """
  29. def __init__(self):
  30. # Hard initial update from Q-net(s) to target Q-net(s).
  31. self.update_target()
  32. def update_target(self):
  33. # Update_target_fn will be called periodically to copy Q network to
  34. # target Q networks.
  35. state_dict = self.model.state_dict()
  36. for target in self.target_models.values():
  37. target.load_state_dict(state_dict)
  38. @override(TorchPolicy)
  39. def set_weights(self, weights):
  40. # Makes sure that whenever we restore weights for this policy's
  41. # model, we sync the target network (from the main model)
  42. # at the same time.
  43. TorchPolicy.set_weights(self, weights)
  44. self.update_target()
  45. def build_q_model_and_distribution(
  46. policy: Policy, obs_space: gym.spaces.Space,
  47. action_space: gym.spaces.Space,
  48. config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
  49. return build_q_models(policy, obs_space, action_space, config), \
  50. TorchCategorical
  51. def build_q_losses(policy: Policy, model, dist_class,
  52. train_batch: SampleBatch) -> TensorType:
  53. """Constructs the loss for SimpleQTorchPolicy.
  54. Args:
  55. policy (Policy): The Policy to calculate the loss for.
  56. model (ModelV2): The Model to calculate the loss for.
  57. dist_class (Type[ActionDistribution]): The action distribution class.
  58. train_batch (SampleBatch): The training data.
  59. Returns:
  60. TensorType: A single loss tensor.
  61. """
  62. target_model = policy.target_models[model]
  63. # q network evaluation
  64. q_t = compute_q_values(
  65. policy,
  66. model,
  67. train_batch[SampleBatch.CUR_OBS],
  68. explore=False,
  69. is_training=True)
  70. # target q network evalution
  71. q_tp1 = compute_q_values(
  72. policy,
  73. target_model,
  74. train_batch[SampleBatch.NEXT_OBS],
  75. explore=False,
  76. is_training=True)
  77. # q scores for actions which we know were selected in the given state.
  78. one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
  79. policy.action_space.n)
  80. q_t_selected = torch.sum(q_t * one_hot_selection, 1)
  81. # compute estimate of best possible value starting from state at t + 1
  82. dones = train_batch[SampleBatch.DONES].float()
  83. q_tp1_best_one_hot_selection = F.one_hot(
  84. torch.argmax(q_tp1, 1), policy.action_space.n)
  85. q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
  86. q_tp1_best_masked = (1.0 - dones) * q_tp1_best
  87. # compute RHS of bellman equation
  88. q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
  89. policy.config["gamma"] * q_tp1_best_masked)
  90. # Compute the error (Square/Huber).
  91. td_error = q_t_selected - q_t_selected_target.detach()
  92. loss = torch.mean(huber_loss(td_error))
  93. # Store values for stats function in model (tower), such that for
  94. # multi-GPU, we do not override them during the parallel loss phase.
  95. model.tower_stats["loss"] = loss
  96. # TD-error tensor in final stats
  97. # will be concatenated and retrieved for each individual batch item.
  98. model.tower_stats["td_error"] = td_error
  99. return loss
  100. def stats_fn(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]:
  101. return {"loss": torch.mean(torch.stack(policy.get_tower_stats("loss")))}
  102. def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
  103. action_dist) -> Dict[str, TensorType]:
  104. """Adds q-values to the action out dict."""
  105. return {"q_values": policy.q_values}
  106. def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
  107. action_space: gym.spaces.Space,
  108. config: TrainerConfigDict) -> None:
  109. """Call all mixin classes' constructors before SimpleQTorchPolicy
  110. initialization.
  111. Args:
  112. policy (Policy): The Policy object.
  113. obs_space (gym.spaces.Space): The Policy's observation space.
  114. action_space (gym.spaces.Space): The Policy's action space.
  115. config (TrainerConfigDict): The Policy's config.
  116. """
  117. TargetNetworkMixin.__init__(policy)
  118. SimpleQTorchPolicy = build_policy_class(
  119. name="SimpleQPolicy",
  120. framework="torch",
  121. loss_fn=build_q_losses,
  122. get_default_config=lambda: ray.rllib.agents.dqn.simple_q.DEFAULT_CONFIG,
  123. stats_fn=stats_fn,
  124. extra_action_out_fn=extra_action_out_fn,
  125. after_init=setup_late_mixins,
  126. make_model_and_action_dist=build_q_model_and_distribution,
  127. mixins=[TargetNetworkMixin],
  128. action_distribution_fn=get_distribution_inputs_and_class,
  129. extra_learn_fetches_fn=concat_multi_gpu_td_errors,
  130. )