r2d2_torch_policy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. """PyTorch policy class used for R2D2."""
  2. from typing import Dict, Tuple
  3. import gym
  4. import ray
  5. from ray.rllib.agents.dqn.dqn_tf_policy import (PRIO_WEIGHTS,
  6. postprocess_nstep_and_prio)
  7. from ray.rllib.agents.dqn.dqn_torch_policy import adam_optimizer, \
  8. build_q_model_and_distribution, compute_q_values
  9. from ray.rllib.agents.dqn.r2d2_tf_policy import \
  10. get_distribution_inputs_and_class
  11. from ray.rllib.agents.dqn.simple_q_torch_policy import TargetNetworkMixin
  12. from ray.rllib.models.modelv2 import ModelV2
  13. from ray.rllib.models.torch.torch_action_dist import \
  14. TorchDistributionWrapper
  15. from ray.rllib.policy.policy import Policy
  16. from ray.rllib.policy.policy_template import build_policy_class
  17. from ray.rllib.policy.sample_batch import SampleBatch
  18. from ray.rllib.policy.torch_policy import LearningRateSchedule
  19. from ray.rllib.utils.framework import try_import_torch
  20. from ray.rllib.utils.torch_utils import apply_grad_clipping, \
  21. concat_multi_gpu_td_errors, FLOAT_MIN, huber_loss, sequence_mask
  22. from ray.rllib.utils.typing import TensorType, TrainerConfigDict
  23. torch, nn = try_import_torch()
  24. F = None
  25. if nn:
  26. F = nn.functional
  27. def build_r2d2_model_and_distribution(
  28. policy: Policy, obs_space: gym.spaces.Space,
  29. action_space: gym.spaces.Space,
  30. config: TrainerConfigDict) -> \
  31. Tuple[ModelV2, TorchDistributionWrapper]:
  32. """Build q_model and target_model for DQN
  33. Args:
  34. policy (Policy): The policy, which will use the model for optimization.
  35. obs_space (gym.spaces.Space): The policy's observation space.
  36. action_space (gym.spaces.Space): The policy's action space.
  37. config (TrainerConfigDict):
  38. Returns:
  39. (q_model, TorchCategorical)
  40. Note: The target q model will not be returned, just assigned to
  41. `policy.target_model`.
  42. """
  43. # Create the policy's models and action dist class.
  44. model, distribution_cls = build_q_model_and_distribution(
  45. policy, obs_space, action_space, config)
  46. # Assert correct model type by checking the init state to be present.
  47. # For attention nets: These don't necessarily publish their init state via
  48. # Model.get_initial_state, but may only use the trajectory view API
  49. # (view_requirements).
  50. assert (model.get_initial_state() != [] or
  51. model.view_requirements.get("state_in_0") is not None), \
  52. "R2D2 requires its model to be a recurrent one! Try using " \
  53. "`model.use_lstm` or `model.use_attention` in your config " \
  54. "to auto-wrap your model with an LSTM- or attention net."
  55. return model, distribution_cls
  56. def r2d2_loss(policy: Policy, model, _,
  57. train_batch: SampleBatch) -> TensorType:
  58. """Constructs the loss for R2D2TorchPolicy.
  59. Args:
  60. policy (Policy): The Policy to calculate the loss for.
  61. model (ModelV2): The Model to calculate the loss for.
  62. train_batch (SampleBatch): The training data.
  63. Returns:
  64. TensorType: A single loss tensor.
  65. """
  66. target_model = policy.target_models[model]
  67. config = policy.config
  68. # Construct internal state inputs.
  69. i = 0
  70. state_batches = []
  71. while "state_in_{}".format(i) in train_batch:
  72. state_batches.append(train_batch["state_in_{}".format(i)])
  73. i += 1
  74. assert state_batches
  75. # Q-network evaluation (at t).
  76. q, _, _, _ = compute_q_values(
  77. policy,
  78. model,
  79. train_batch,
  80. state_batches=state_batches,
  81. seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
  82. explore=False,
  83. is_training=True)
  84. # Target Q-network evaluation (at t+1).
  85. q_target, _, _, _ = compute_q_values(
  86. policy,
  87. target_model,
  88. train_batch,
  89. state_batches=state_batches,
  90. seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
  91. explore=False,
  92. is_training=True)
  93. actions = train_batch[SampleBatch.ACTIONS].long()
  94. dones = train_batch[SampleBatch.DONES].float()
  95. rewards = train_batch[SampleBatch.REWARDS]
  96. weights = train_batch[PRIO_WEIGHTS]
  97. B = state_batches[0].shape[0]
  98. T = q.shape[0] // B
  99. # Q scores for actions which we know were selected in the given state.
  100. one_hot_selection = F.one_hot(actions, policy.action_space.n)
  101. q_selected = torch.sum(
  102. torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=q.device)) *
  103. one_hot_selection, 1)
  104. if config["double_q"]:
  105. best_actions = torch.argmax(q, dim=1)
  106. else:
  107. best_actions = torch.argmax(q_target, dim=1)
  108. best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n)
  109. q_target_best = torch.sum(
  110. torch.where(q_target > FLOAT_MIN, q_target,
  111. torch.tensor(0.0, device=q_target.device)) *
  112. best_actions_one_hot,
  113. dim=1)
  114. if config["num_atoms"] > 1:
  115. raise ValueError("Distributional R2D2 not supported yet!")
  116. else:
  117. q_target_best_masked_tp1 = (1.0 - dones) * torch.cat([
  118. q_target_best[1:],
  119. torch.tensor([0.0], device=q_target_best.device)
  120. ])
  121. if config["use_h_function"]:
  122. h_inv = h_inverse(q_target_best_masked_tp1,
  123. config["h_function_epsilon"])
  124. target = h_function(
  125. rewards + config["gamma"]**config["n_step"] * h_inv,
  126. config["h_function_epsilon"])
  127. else:
  128. target = rewards + \
  129. config["gamma"] ** config["n_step"] * q_target_best_masked_tp1
  130. # Seq-mask all loss-related terms.
  131. seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1]
  132. # Mask away also the burn-in sequence at the beginning.
  133. burn_in = policy.config["burn_in"]
  134. if burn_in > 0 and burn_in < T:
  135. seq_mask[:, :burn_in] = False
  136. num_valid = torch.sum(seq_mask)
  137. def reduce_mean_valid(t):
  138. return torch.sum(t[seq_mask]) / num_valid
  139. # Make sure use the correct time indices:
  140. # Q(t) - [gamma * r + Q^(t+1)]
  141. q_selected = q_selected.reshape([B, T])[:, :-1]
  142. td_error = q_selected - target.reshape([B, T])[:, :-1].detach()
  143. td_error = td_error * seq_mask
  144. weights = weights.reshape([B, T])[:, :-1]
  145. total_loss = reduce_mean_valid(weights * huber_loss(td_error))
  146. # Store values for stats function in model (tower), such that for
  147. # multi-GPU, we do not override them during the parallel loss phase.
  148. model.tower_stats["total_loss"] = total_loss
  149. model.tower_stats["mean_q"] = reduce_mean_valid(q_selected)
  150. model.tower_stats["min_q"] = torch.min(q_selected)
  151. model.tower_stats["max_q"] = torch.max(q_selected)
  152. model.tower_stats["mean_td_error"] = reduce_mean_valid(td_error)
  153. # Store per time chunk (b/c we need only one mean
  154. # prioritized replay weight per stored sequence).
  155. model.tower_stats["td_error"] = torch.mean(td_error, dim=-1)
  156. return total_loss
  157. def h_function(x, epsilon=1.0):
  158. """h-function to normalize target Qs, described in the paper [1].
  159. h(x) = sign(x) * [sqrt(abs(x) + 1) - 1] + epsilon * x
  160. Used in [1] in combination with h_inverse:
  161. targets = h(r + gamma * h_inverse(Q^))
  162. """
  163. return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1.0) - 1.0) + epsilon * x
  164. def h_inverse(x, epsilon=1.0):
  165. """Inverse if the above h-function, described in the paper [1].
  166. If x > 0.0:
  167. h-1(x) = [2eps * x + (2eps + 1) - sqrt(4eps x + (2eps + 1)^2)] /
  168. (2 * eps^2)
  169. If x < 0.0:
  170. h-1(x) = [2eps * x + (2eps + 1) + sqrt(-4eps x + (2eps + 1)^2)] /
  171. (2 * eps^2)
  172. """
  173. two_epsilon = epsilon * 2
  174. if_x_pos = (two_epsilon * x + (two_epsilon + 1.0) -
  175. torch.sqrt(4.0 * epsilon * x +
  176. (two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
  177. if_x_neg = (two_epsilon * x - (two_epsilon + 1.0) +
  178. torch.sqrt(-4.0 * epsilon * x +
  179. (two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
  180. return torch.where(x < 0.0, if_x_neg, if_x_pos)
  181. class ComputeTDErrorMixin:
  182. """Assign the `compute_td_error` method to the R2D2TorchPolicy
  183. This allows us to prioritize on the worker side.
  184. """
  185. def __init__(self):
  186. def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
  187. importance_weights):
  188. input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
  189. input_dict[SampleBatch.ACTIONS] = act_t
  190. input_dict[SampleBatch.REWARDS] = rew_t
  191. input_dict[SampleBatch.NEXT_OBS] = obs_tp1
  192. input_dict[SampleBatch.DONES] = done_mask
  193. input_dict[PRIO_WEIGHTS] = importance_weights
  194. # Do forward pass on loss to update td error attribute
  195. r2d2_loss(self, self.model, None, input_dict)
  196. return self.model.tower_stats["td_error"]
  197. self.compute_td_error = compute_td_error
  198. def build_q_stats(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]:
  199. return {
  200. "cur_lr": policy.cur_lr,
  201. "total_loss": torch.mean(
  202. torch.stack(policy.get_tower_stats("total_loss"))),
  203. "mean_q": torch.mean(torch.stack(policy.get_tower_stats("mean_q"))),
  204. "min_q": torch.mean(torch.stack(policy.get_tower_stats("min_q"))),
  205. "max_q": torch.mean(torch.stack(policy.get_tower_stats("max_q"))),
  206. "mean_td_error": torch.mean(
  207. torch.stack(policy.get_tower_stats("mean_td_error"))),
  208. }
  209. def setup_early_mixins(policy: Policy, obs_space, action_space,
  210. config: TrainerConfigDict) -> None:
  211. LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
  212. def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,
  213. action_space: gym.spaces.Space,
  214. config: TrainerConfigDict) -> None:
  215. ComputeTDErrorMixin.__init__(policy)
  216. TargetNetworkMixin.__init__(policy)
  217. def grad_process_and_td_error_fn(policy: Policy,
  218. optimizer: "torch.optim.Optimizer",
  219. loss: TensorType) -> Dict[str, TensorType]:
  220. # Clip grads if configured.
  221. return apply_grad_clipping(policy, optimizer, loss)
  222. def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
  223. action_dist) -> Dict[str, TensorType]:
  224. return {"q_values": policy.q_values}
  225. R2D2TorchPolicy = build_policy_class(
  226. name="R2D2TorchPolicy",
  227. framework="torch",
  228. loss_fn=r2d2_loss,
  229. get_default_config=lambda: ray.rllib.agents.dqn.r2d2.R2D2_DEFAULT_CONFIG,
  230. make_model_and_action_dist=build_r2d2_model_and_distribution,
  231. action_distribution_fn=get_distribution_inputs_and_class,
  232. stats_fn=build_q_stats,
  233. postprocess_fn=postprocess_nstep_and_prio,
  234. optimizer_fn=adam_optimizer,
  235. extra_grad_process_fn=grad_process_and_td_error_fn,
  236. extra_learn_fetches_fn=concat_multi_gpu_td_errors,
  237. extra_action_out_fn=extra_action_out_fn,
  238. before_init=setup_early_mixins,
  239. before_loss_init=before_loss_init,
  240. mixins=[
  241. TargetNetworkMixin,
  242. ComputeTDErrorMixin,
  243. LearningRateSchedule,
  244. ])