dqn_tf_policy.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. """TensorFlow policy class used for DQN"""
  2. from typing import Dict
  3. import gym
  4. import numpy as np
  5. import ray
  6. from ray.rllib.agents.dqn.distributional_q_tf_model import \
  7. DistributionalQTFModel
  8. from ray.rllib.agents.dqn.simple_q_tf_policy import TargetNetworkMixin
  9. from ray.rllib.evaluation.postprocessing import adjust_nstep
  10. from ray.rllib.models import ModelCatalog
  11. from ray.rllib.models.modelv2 import ModelV2
  12. from ray.rllib.models.tf.tf_action_dist import Categorical
  13. from ray.rllib.policy.policy import Policy
  14. from ray.rllib.policy.sample_batch import SampleBatch
  15. from ray.rllib.policy.tf_policy import LearningRateSchedule
  16. from ray.rllib.policy.tf_policy_template import build_tf_policy
  17. from ray.rllib.utils.error import UnsupportedSpaceException
  18. from ray.rllib.utils.exploration import ParameterNoise
  19. from ray.rllib.utils.framework import try_import_tf
  20. from ray.rllib.utils.numpy import convert_to_numpy
  21. from ray.rllib.utils.tf_utils import (
  22. huber_loss, make_tf_callable, minimize_and_clip, reduce_mean_ignore_inf)
  23. from ray.rllib.utils.typing import (ModelGradients, TensorType,
  24. TrainerConfigDict)
  25. tf1, tf, tfv = try_import_tf()
  26. Q_SCOPE = "q_func"
  27. Q_TARGET_SCOPE = "target_q_func"
  28. # Importance sampling weights for prioritized replay
  29. PRIO_WEIGHTS = "weights"
  30. class QLoss:
  31. def __init__(self,
  32. q_t_selected: TensorType,
  33. q_logits_t_selected: TensorType,
  34. q_tp1_best: TensorType,
  35. q_dist_tp1_best: TensorType,
  36. importance_weights: TensorType,
  37. rewards: TensorType,
  38. done_mask: TensorType,
  39. gamma: float = 0.99,
  40. n_step: int = 1,
  41. num_atoms: int = 1,
  42. v_min: float = -10.0,
  43. v_max: float = 10.0):
  44. if num_atoms > 1:
  45. # Distributional Q-learning which corresponds to an entropy loss
  46. z = tf.range(num_atoms, dtype=tf.float32)
  47. z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
  48. # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
  49. r_tau = tf.expand_dims(
  50. rewards, -1) + gamma**n_step * tf.expand_dims(
  51. 1.0 - done_mask, -1) * tf.expand_dims(z, 0)
  52. r_tau = tf.clip_by_value(r_tau, v_min, v_max)
  53. b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
  54. lb = tf.floor(b)
  55. ub = tf.math.ceil(b)
  56. # indispensable judgement which is missed in most implementations
  57. # when b happens to be an integer, lb == ub, so pr_j(s', a*) will
  58. # be discarded because (ub-b) == (b-lb) == 0
  59. floor_equal_ceil = tf.cast(tf.less(ub - lb, 0.5), tf.float32)
  60. l_project = tf.one_hot(
  61. tf.cast(lb, dtype=tf.int32),
  62. num_atoms) # (batch_size, num_atoms, num_atoms)
  63. u_project = tf.one_hot(
  64. tf.cast(ub, dtype=tf.int32),
  65. num_atoms) # (batch_size, num_atoms, num_atoms)
  66. ml_delta = q_dist_tp1_best * (ub - b + floor_equal_ceil)
  67. mu_delta = q_dist_tp1_best * (b - lb)
  68. ml_delta = tf.reduce_sum(
  69. l_project * tf.expand_dims(ml_delta, -1), axis=1)
  70. mu_delta = tf.reduce_sum(
  71. u_project * tf.expand_dims(mu_delta, -1), axis=1)
  72. m = ml_delta + mu_delta
  73. # Rainbow paper claims that using this cross entropy loss for
  74. # priority is robust and insensitive to `prioritized_replay_alpha`
  75. self.td_error = tf.nn.softmax_cross_entropy_with_logits(
  76. labels=m, logits=q_logits_t_selected)
  77. self.loss = tf.reduce_mean(
  78. self.td_error * tf.cast(importance_weights, tf.float32))
  79. self.stats = {
  80. # TODO: better Q stats for dist dqn
  81. "mean_td_error": tf.reduce_mean(self.td_error),
  82. }
  83. else:
  84. q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
  85. # compute RHS of bellman equation
  86. q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
  87. # compute the error (potentially clipped)
  88. self.td_error = (
  89. q_t_selected - tf.stop_gradient(q_t_selected_target))
  90. self.loss = tf.reduce_mean(
  91. tf.cast(importance_weights, tf.float32) * huber_loss(
  92. self.td_error))
  93. self.stats = {
  94. "mean_q": tf.reduce_mean(q_t_selected),
  95. "min_q": tf.reduce_min(q_t_selected),
  96. "max_q": tf.reduce_max(q_t_selected),
  97. "mean_td_error": tf.reduce_mean(self.td_error),
  98. }
  99. class ComputeTDErrorMixin:
  100. """Assign the `compute_td_error` method to the DQNTFPolicy
  101. This allows us to prioritize on the worker side.
  102. """
  103. def __init__(self):
  104. @make_tf_callable(self.get_session(), dynamic_shape=True)
  105. def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
  106. importance_weights):
  107. # Do forward pass on loss to update td error attribute
  108. build_q_losses(
  109. self, self.model, None, {
  110. SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
  111. SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
  112. SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
  113. SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
  114. SampleBatch.DONES: tf.convert_to_tensor(done_mask),
  115. PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
  116. })
  117. return self.q_loss.td_error
  118. self.compute_td_error = compute_td_error
  119. def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
  120. action_space: gym.spaces.Space,
  121. config: TrainerConfigDict) -> ModelV2:
  122. """Build q_model and target_model for DQN
  123. Args:
  124. policy (Policy): The Policy, which will use the model for optimization.
  125. obs_space (gym.spaces.Space): The policy's observation space.
  126. action_space (gym.spaces.Space): The policy's action space.
  127. config (TrainerConfigDict):
  128. Returns:
  129. ModelV2: The Model for the Policy to use.
  130. Note: The target q model will not be returned, just assigned to
  131. `policy.target_model`.
  132. """
  133. if not isinstance(action_space, gym.spaces.Discrete):
  134. raise UnsupportedSpaceException(
  135. "Action space {} is not supported for DQN.".format(action_space))
  136. if config["hiddens"]:
  137. # try to infer the last layer size, otherwise fall back to 256
  138. num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
  139. config["model"]["no_final_linear"] = True
  140. else:
  141. num_outputs = action_space.n
  142. q_model = ModelCatalog.get_model_v2(
  143. obs_space=obs_space,
  144. action_space=action_space,
  145. num_outputs=num_outputs,
  146. model_config=config["model"],
  147. framework="tf",
  148. model_interface=DistributionalQTFModel,
  149. name=Q_SCOPE,
  150. num_atoms=config["num_atoms"],
  151. dueling=config["dueling"],
  152. q_hiddens=config["hiddens"],
  153. use_noisy=config["noisy"],
  154. v_min=config["v_min"],
  155. v_max=config["v_max"],
  156. sigma0=config["sigma0"],
  157. # TODO(sven): Move option to add LayerNorm after each Dense
  158. # generically into ModelCatalog.
  159. add_layer_norm=isinstance(
  160. getattr(policy, "exploration", None), ParameterNoise)
  161. or config["exploration_config"]["type"] == "ParameterNoise")
  162. policy.target_model = ModelCatalog.get_model_v2(
  163. obs_space=obs_space,
  164. action_space=action_space,
  165. num_outputs=num_outputs,
  166. model_config=config["model"],
  167. framework="tf",
  168. model_interface=DistributionalQTFModel,
  169. name=Q_TARGET_SCOPE,
  170. num_atoms=config["num_atoms"],
  171. dueling=config["dueling"],
  172. q_hiddens=config["hiddens"],
  173. use_noisy=config["noisy"],
  174. v_min=config["v_min"],
  175. v_max=config["v_max"],
  176. sigma0=config["sigma0"],
  177. # TODO(sven): Move option to add LayerNorm after each Dense
  178. # generically into ModelCatalog.
  179. add_layer_norm=isinstance(
  180. getattr(policy, "exploration", None), ParameterNoise)
  181. or config["exploration_config"]["type"] == "ParameterNoise")
  182. return q_model
  183. def get_distribution_inputs_and_class(policy: Policy,
  184. model: ModelV2,
  185. input_dict: SampleBatch,
  186. *,
  187. explore=True,
  188. **kwargs):
  189. q_vals = compute_q_values(
  190. policy, model, input_dict, state_batches=None, explore=explore)
  191. q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
  192. policy.q_values = q_vals
  193. return policy.q_values, Categorical, [] # state-out
  194. def build_q_losses(policy: Policy, model, _,
  195. train_batch: SampleBatch) -> TensorType:
  196. """Constructs the loss for DQNTFPolicy.
  197. Args:
  198. policy (Policy): The Policy to calculate the loss for.
  199. model (ModelV2): The Model to calculate the loss for.
  200. train_batch (SampleBatch): The training data.
  201. Returns:
  202. TensorType: A single loss tensor.
  203. """
  204. config = policy.config
  205. # q network evaluation
  206. q_t, q_logits_t, q_dist_t, _ = compute_q_values(
  207. policy,
  208. model,
  209. SampleBatch({
  210. "obs": train_batch[SampleBatch.CUR_OBS]
  211. }),
  212. state_batches=None,
  213. explore=False)
  214. # target q network evalution
  215. q_tp1, q_logits_tp1, q_dist_tp1, _ = compute_q_values(
  216. policy,
  217. policy.target_model,
  218. SampleBatch({
  219. "obs": train_batch[SampleBatch.NEXT_OBS]
  220. }),
  221. state_batches=None,
  222. explore=False)
  223. if not hasattr(policy, "target_q_func_vars"):
  224. policy.target_q_func_vars = policy.target_model.variables()
  225. # q scores for actions which we know were selected in the given state.
  226. one_hot_selection = tf.one_hot(
  227. tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32),
  228. policy.action_space.n)
  229. q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
  230. q_logits_t_selected = tf.reduce_sum(
  231. q_logits_t * tf.expand_dims(one_hot_selection, -1), 1)
  232. # compute estimate of best possible value starting from state at t + 1
  233. if config["double_q"]:
  234. q_tp1_using_online_net, q_logits_tp1_using_online_net, \
  235. q_dist_tp1_using_online_net, _ = compute_q_values(
  236. policy, model,
  237. SampleBatch({"obs": train_batch[SampleBatch.NEXT_OBS]}),
  238. state_batches=None,
  239. explore=False)
  240. q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
  241. q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net,
  242. policy.action_space.n)
  243. q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
  244. q_dist_tp1_best = tf.reduce_sum(
  245. q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
  246. else:
  247. q_tp1_best_one_hot_selection = tf.one_hot(
  248. tf.argmax(q_tp1, 1), policy.action_space.n)
  249. q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
  250. q_dist_tp1_best = tf.reduce_sum(
  251. q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
  252. policy.q_loss = QLoss(
  253. q_t_selected, q_logits_t_selected, q_tp1_best, q_dist_tp1_best,
  254. train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS],
  255. tf.cast(train_batch[SampleBatch.DONES],
  256. tf.float32), config["gamma"], config["n_step"],
  257. config["num_atoms"], config["v_min"], config["v_max"])
  258. return policy.q_loss.loss
  259. def adam_optimizer(policy: Policy, config: TrainerConfigDict
  260. ) -> "tf.keras.optimizers.Optimizer":
  261. if policy.config["framework"] in ["tf2", "tfe"]:
  262. return tf.keras.optimizers.Adam(
  263. learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
  264. else:
  265. return tf1.train.AdamOptimizer(
  266. learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
  267. def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer",
  268. loss: TensorType) -> ModelGradients:
  269. if not hasattr(policy, "q_func_vars"):
  270. policy.q_func_vars = policy.model.variables()
  271. return minimize_and_clip(
  272. optimizer,
  273. loss,
  274. var_list=policy.q_func_vars,
  275. clip_val=policy.config["grad_clip"])
  276. def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
  277. return dict({
  278. "cur_lr": tf.cast(policy.cur_lr, tf.float64),
  279. }, **policy.q_loss.stats)
  280. def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None:
  281. LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
  282. ComputeTDErrorMixin.__init__(policy)
  283. def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
  284. action_space: gym.spaces.Space,
  285. config: TrainerConfigDict) -> None:
  286. TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
  287. def compute_q_values(policy: Policy,
  288. model: ModelV2,
  289. input_batch: SampleBatch,
  290. state_batches=None,
  291. seq_lens=None,
  292. explore=None,
  293. is_training: bool = False):
  294. config = policy.config
  295. model_out, state = model(input_batch, state_batches or [], seq_lens)
  296. if config["num_atoms"] > 1:
  297. (action_scores, z, support_logits_per_action, logits,
  298. dist) = model.get_q_value_distributions(model_out)
  299. else:
  300. (action_scores, logits,
  301. dist) = model.get_q_value_distributions(model_out)
  302. if config["dueling"]:
  303. state_score = model.get_state_value(model_out)
  304. if config["num_atoms"] > 1:
  305. support_logits_per_action_mean = tf.reduce_mean(
  306. support_logits_per_action, 1)
  307. support_logits_per_action_centered = (
  308. support_logits_per_action - tf.expand_dims(
  309. support_logits_per_action_mean, 1))
  310. support_logits_per_action = tf.expand_dims(
  311. state_score, 1) + support_logits_per_action_centered
  312. support_prob_per_action = tf.nn.softmax(
  313. logits=support_logits_per_action)
  314. value = tf.reduce_sum(
  315. input_tensor=z * support_prob_per_action, axis=-1)
  316. logits = support_logits_per_action
  317. dist = support_prob_per_action
  318. else:
  319. action_scores_mean = reduce_mean_ignore_inf(action_scores, 1)
  320. action_scores_centered = action_scores - tf.expand_dims(
  321. action_scores_mean, 1)
  322. value = state_score + action_scores_centered
  323. else:
  324. value = action_scores
  325. return value, logits, dist, state
  326. def postprocess_nstep_and_prio(policy: Policy,
  327. batch: SampleBatch,
  328. other_agent=None,
  329. episode=None) -> SampleBatch:
  330. # N-step Q adjustments.
  331. if policy.config["n_step"] > 1:
  332. adjust_nstep(policy.config["n_step"], policy.config["gamma"], batch)
  333. # Create dummy prio-weights (1.0) in case we don't have any in
  334. # the batch.
  335. if PRIO_WEIGHTS not in batch:
  336. batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
  337. # Prioritize on the worker side.
  338. if batch.count > 0 and policy.config["worker_side_prioritization"]:
  339. td_errors = policy.compute_td_error(
  340. batch[SampleBatch.OBS], batch[SampleBatch.ACTIONS],
  341. batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
  342. batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
  343. new_priorities = (np.abs(convert_to_numpy(td_errors)) +
  344. policy.config["prioritized_replay_eps"])
  345. batch[PRIO_WEIGHTS] = new_priorities
  346. return batch
  347. DQNTFPolicy = build_tf_policy(
  348. name="DQNTFPolicy",
  349. get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
  350. make_model=build_q_model,
  351. action_distribution_fn=get_distribution_inputs_and_class,
  352. loss_fn=build_q_losses,
  353. stats_fn=build_q_stats,
  354. postprocess_fn=postprocess_nstep_and_prio,
  355. optimizer_fn=adam_optimizer,
  356. compute_gradients_fn=clip_gradients,
  357. extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
  358. extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
  359. before_loss_init=setup_mid_mixins,
  360. after_init=setup_late_mixins,
  361. mixins=[
  362. TargetNetworkMixin,
  363. ComputeTDErrorMixin,
  364. LearningRateSchedule,
  365. ])