r2d2_tf_policy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. """TensorFlow policy class used for R2D2."""
  2. from typing import Dict, List, Optional, Tuple
  3. import gym
  4. import ray
  5. from ray.rllib.agents.dqn.dqn_tf_policy import clip_gradients, \
  6. compute_q_values, PRIO_WEIGHTS, postprocess_nstep_and_prio
  7. from ray.rllib.agents.dqn.dqn_tf_policy import build_q_model
  8. from ray.rllib.agents.dqn.simple_q_tf_policy import TargetNetworkMixin
  9. from ray.rllib.models.action_dist import ActionDistribution
  10. from ray.rllib.models.modelv2 import ModelV2
  11. from ray.rllib.models.tf.tf_action_dist import Categorical
  12. from ray.rllib.models.torch.torch_action_dist import TorchCategorical
  13. from ray.rllib.policy.policy import Policy
  14. from ray.rllib.policy.tf_policy_template import build_tf_policy
  15. from ray.rllib.policy.sample_batch import SampleBatch
  16. from ray.rllib.policy.tf_policy import LearningRateSchedule
  17. from ray.rllib.utils.framework import try_import_tf
  18. from ray.rllib.utils.tf_utils import huber_loss
  19. from ray.rllib.utils.typing import ModelInputDict, TensorType, \
  20. TrainerConfigDict
  21. tf1, tf, tfv = try_import_tf()
  22. def build_r2d2_model(policy: Policy, obs_space: gym.spaces.Space,
  23. action_space: gym.spaces.Space, config: TrainerConfigDict
  24. ) -> Tuple[ModelV2, ActionDistribution]:
  25. """Build q_model and target_model for DQN
  26. Args:
  27. policy (Policy): The policy, which will use the model for optimization.
  28. obs_space (gym.spaces.Space): The policy's observation space.
  29. action_space (gym.spaces.Space): The policy's action space.
  30. config (TrainerConfigDict):
  31. Returns:
  32. q_model
  33. Note: The target q model will not be returned, just assigned to
  34. `policy.target_model`.
  35. """
  36. # Create the policy's models.
  37. model = build_q_model(policy, obs_space, action_space, config)
  38. # Assert correct model type by checking the init state to be present.
  39. # For attention nets: These don't necessarily publish their init state via
  40. # Model.get_initial_state, but may only use the trajectory view API
  41. # (view_requirements).
  42. assert (model.get_initial_state() != [] or
  43. model.view_requirements.get("state_in_0") is not None), \
  44. "R2D2 requires its model to be a recurrent one! Try using " \
  45. "`model.use_lstm` or `model.use_attention` in your config " \
  46. "to auto-wrap your model with an LSTM- or attention net."
  47. return model
  48. def r2d2_loss(policy: Policy, model, _,
  49. train_batch: SampleBatch) -> TensorType:
  50. """Constructs the loss for R2D2TFPolicy.
  51. Args:
  52. policy (Policy): The Policy to calculate the loss for.
  53. model (ModelV2): The Model to calculate the loss for.
  54. train_batch (SampleBatch): The training data.
  55. Returns:
  56. TensorType: A single loss tensor.
  57. """
  58. config = policy.config
  59. # Construct internal state inputs.
  60. i = 0
  61. state_batches = []
  62. while "state_in_{}".format(i) in train_batch:
  63. state_batches.append(train_batch["state_in_{}".format(i)])
  64. i += 1
  65. assert state_batches
  66. # Q-network evaluation (at t).
  67. q, _, _, _ = compute_q_values(
  68. policy,
  69. model,
  70. train_batch,
  71. state_batches=state_batches,
  72. seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
  73. explore=False,
  74. is_training=True)
  75. # Target Q-network evaluation (at t+1).
  76. q_target, _, _, _ = compute_q_values(
  77. policy,
  78. policy.target_model,
  79. train_batch,
  80. state_batches=state_batches,
  81. seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
  82. explore=False,
  83. is_training=True)
  84. if not hasattr(policy, "target_q_func_vars"):
  85. policy.target_q_func_vars = policy.target_model.variables()
  86. actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.int64)
  87. dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32)
  88. rewards = train_batch[SampleBatch.REWARDS]
  89. weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)
  90. B = tf.shape(state_batches[0])[0]
  91. T = tf.shape(q)[0] // B
  92. # Q scores for actions which we know were selected in the given state.
  93. one_hot_selection = tf.one_hot(actions, policy.action_space.n)
  94. q_selected = tf.reduce_sum(
  95. tf.where(q > tf.float32.min, q, tf.zeros_like(q)) * one_hot_selection,
  96. axis=1)
  97. if config["double_q"]:
  98. best_actions = tf.argmax(q, axis=1)
  99. else:
  100. best_actions = tf.argmax(q_target, axis=1)
  101. best_actions_one_hot = tf.one_hot(best_actions, policy.action_space.n)
  102. q_target_best = tf.reduce_sum(
  103. tf.where(q_target > tf.float32.min, q_target, tf.zeros_like(q_target))
  104. * best_actions_one_hot,
  105. axis=1)
  106. if config["num_atoms"] > 1:
  107. raise ValueError("Distributional R2D2 not supported yet!")
  108. else:
  109. q_target_best_masked_tp1 = (1.0 - dones) * tf.concat(
  110. [q_target_best[1:], tf.constant([0.0])], axis=0)
  111. if config["use_h_function"]:
  112. h_inv = h_inverse(q_target_best_masked_tp1,
  113. config["h_function_epsilon"])
  114. target = h_function(
  115. rewards + config["gamma"]**config["n_step"] * h_inv,
  116. config["h_function_epsilon"])
  117. else:
  118. target = rewards + \
  119. config["gamma"] ** config["n_step"] * q_target_best_masked_tp1
  120. # Seq-mask all loss-related terms.
  121. seq_mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS],
  122. T)[:, :-1]
  123. # Mask away also the burn-in sequence at the beginning.
  124. burn_in = policy.config["burn_in"]
  125. # Making sure, this works for both static graph and eager.
  126. if burn_in > 0:
  127. seq_mask = tf.cond(
  128. pred=tf.convert_to_tensor(burn_in, tf.int32) < T,
  129. true_fn=lambda: tf.concat([tf.fill([B, burn_in], False),
  130. seq_mask[:, burn_in:]], 1),
  131. false_fn=lambda: seq_mask,
  132. )
  133. def reduce_mean_valid(t):
  134. return tf.reduce_mean(tf.boolean_mask(t, seq_mask))
  135. # Make sure to use the correct time indices:
  136. # Q(t) - [gamma * r + Q^(t+1)]
  137. q_selected = tf.reshape(q_selected, [B, T])[:, :-1]
  138. td_error = q_selected - tf.stop_gradient(
  139. tf.reshape(target, [B, T])[:, :-1])
  140. td_error = td_error * tf.cast(seq_mask, tf.float32)
  141. weights = tf.reshape(weights, [B, T])[:, :-1]
  142. policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
  143. # Store the TD-error per time chunk (b/c we need only one mean
  144. # prioritized replay weight per stored sequence).
  145. policy._td_error = tf.reduce_mean(td_error, axis=-1)
  146. policy._loss_stats = {
  147. "mean_q": reduce_mean_valid(q_selected),
  148. "min_q": tf.reduce_min(q_selected),
  149. "max_q": tf.reduce_max(q_selected),
  150. "mean_td_error": reduce_mean_valid(td_error),
  151. }
  152. return policy._total_loss
  153. def h_function(x, epsilon=1.0):
  154. """h-function to normalize target Qs, described in the paper [1].
  155. h(x) = sign(x) * [sqrt(abs(x) + 1) - 1] + epsilon * x
  156. Used in [1] in combination with h_inverse:
  157. targets = h(r + gamma * h_inverse(Q^))
  158. """
  159. return tf.sign(x) * (tf.sqrt(tf.abs(x) + 1.0) - 1.0) + epsilon * x
  160. def h_inverse(x, epsilon=1.0):
  161. """Inverse if the above h-function, described in the paper [1].
  162. If x > 0.0:
  163. h-1(x) = [2eps * x + (2eps + 1) - sqrt(4eps x + (2eps + 1)^2)] /
  164. (2 * eps^2)
  165. If x < 0.0:
  166. h-1(x) = [2eps * x + (2eps + 1) + sqrt(-4eps x + (2eps + 1)^2)] /
  167. (2 * eps^2)
  168. """
  169. two_epsilon = epsilon * 2
  170. if_x_pos = (two_epsilon * x + (two_epsilon + 1.0) -
  171. tf.sqrt(4.0 * epsilon * x +
  172. (two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
  173. if_x_neg = (two_epsilon * x - (two_epsilon + 1.0) +
  174. tf.sqrt(-4.0 * epsilon * x +
  175. (two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
  176. return tf.where(x < 0.0, if_x_neg, if_x_pos)
  177. class ComputeTDErrorMixin:
  178. """Assign the `compute_td_error` method to the R2D2TFPolicy
  179. This allows us to prioritize on the worker side.
  180. """
  181. def __init__(self):
  182. def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
  183. importance_weights):
  184. input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
  185. input_dict[SampleBatch.ACTIONS] = act_t
  186. input_dict[SampleBatch.REWARDS] = rew_t
  187. input_dict[SampleBatch.NEXT_OBS] = obs_tp1
  188. input_dict[SampleBatch.DONES] = done_mask
  189. input_dict[PRIO_WEIGHTS] = importance_weights
  190. # Do forward pass on loss to update td error attribute
  191. r2d2_loss(self, self.model, None, input_dict)
  192. return self._td_error
  193. self.compute_td_error = compute_td_error
  194. def get_distribution_inputs_and_class(
  195. policy: Policy,
  196. model: ModelV2,
  197. *,
  198. input_dict: ModelInputDict,
  199. state_batches: Optional[List[TensorType]] = None,
  200. seq_lens: Optional[TensorType] = None,
  201. explore: bool = True,
  202. is_training: bool = False,
  203. **kwargs) -> Tuple[TensorType, type, List[TensorType]]:
  204. if policy.config["framework"] == "torch":
  205. from ray.rllib.agents.dqn.r2d2_torch_policy import \
  206. compute_q_values as torch_compute_q_values
  207. func = torch_compute_q_values
  208. else:
  209. func = compute_q_values
  210. q_vals, logits, probs_or_logits, state_out = func(
  211. policy, model, input_dict, state_batches, seq_lens, explore,
  212. is_training)
  213. policy.q_values = q_vals
  214. if not hasattr(policy, "q_func_vars"):
  215. policy.q_func_vars = model.variables()
  216. action_dist_class = TorchCategorical if \
  217. policy.config["framework"] == "torch" else Categorical
  218. return policy.q_values, action_dist_class, state_out
  219. def adam_optimizer(policy: Policy, config: TrainerConfigDict
  220. ) -> "tf.keras.optimizers.Optimizer":
  221. return tf1.train.AdamOptimizer(
  222. learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
  223. def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
  224. return dict({
  225. "cur_lr": policy.cur_lr,
  226. }, **policy._loss_stats)
  227. def setup_early_mixins(policy: Policy, obs_space, action_space,
  228. config: TrainerConfigDict) -> None:
  229. LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
  230. def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,
  231. action_space: gym.spaces.Space,
  232. config: TrainerConfigDict) -> None:
  233. ComputeTDErrorMixin.__init__(policy)
  234. TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
  235. R2D2TFPolicy = build_tf_policy(
  236. name="R2D2TFPolicy",
  237. loss_fn=r2d2_loss,
  238. get_default_config=lambda: ray.rllib.agents.dqn.r2d2.R2D2_DEFAULT_CONFIG,
  239. postprocess_fn=postprocess_nstep_and_prio,
  240. stats_fn=build_q_stats,
  241. make_model=build_r2d2_model,
  242. action_distribution_fn=get_distribution_inputs_and_class,
  243. optimizer_fn=adam_optimizer,
  244. extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
  245. compute_gradients_fn=clip_gradients,
  246. extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error},
  247. before_init=setup_early_mixins,
  248. before_loss_init=before_loss_init,
  249. mixins=[
  250. TargetNetworkMixin,
  251. ComputeTDErrorMixin,
  252. LearningRateSchedule,
  253. ])