dqn_torch_policy.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. """PyTorch policy class used for DQN"""
  2. from typing import Dict, List, Tuple
  3. import gymnasium as gym
  4. import ray
  5. from ray.rllib.algorithms.dqn.dqn_tf_policy import (
  6. PRIO_WEIGHTS,
  7. Q_SCOPE,
  8. Q_TARGET_SCOPE,
  9. postprocess_nstep_and_prio,
  10. )
  11. from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
  12. from ray.rllib.models.catalog import ModelCatalog
  13. from ray.rllib.models.modelv2 import ModelV2
  14. from ray.rllib.models.torch.torch_action_dist import (
  15. get_torch_categorical_class_with_temperature,
  16. TorchDistributionWrapper,
  17. )
  18. from ray.rllib.policy.policy import Policy
  19. from ray.rllib.policy.policy_template import build_policy_class
  20. from ray.rllib.policy.sample_batch import SampleBatch
  21. from ray.rllib.policy.torch_mixins import (
  22. LearningRateSchedule,
  23. TargetNetworkMixin,
  24. )
  25. from ray.rllib.utils.error import UnsupportedSpaceException
  26. from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
  27. from ray.rllib.utils.framework import try_import_torch
  28. from ray.rllib.utils.torch_utils import (
  29. apply_grad_clipping,
  30. concat_multi_gpu_td_errors,
  31. FLOAT_MIN,
  32. huber_loss,
  33. l2_loss,
  34. reduce_mean_ignore_inf,
  35. softmax_cross_entropy_with_logits,
  36. )
  37. from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
  38. torch, nn = try_import_torch()
  39. F = None
  40. if nn:
  41. F = nn.functional
  42. class QLoss:
  43. def __init__(
  44. self,
  45. q_t_selected: TensorType,
  46. q_logits_t_selected: TensorType,
  47. q_tp1_best: TensorType,
  48. q_probs_tp1_best: TensorType,
  49. importance_weights: TensorType,
  50. rewards: TensorType,
  51. done_mask: TensorType,
  52. gamma=0.99,
  53. n_step=1,
  54. num_atoms=1,
  55. v_min=-10.0,
  56. v_max=10.0,
  57. loss_fn=huber_loss,
  58. ):
  59. if num_atoms > 1:
  60. # Distributional Q-learning which corresponds to an entropy loss
  61. z = torch.arange(0.0, num_atoms, dtype=torch.float32).to(rewards.device)
  62. z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
  63. # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
  64. r_tau = torch.unsqueeze(rewards, -1) + gamma**n_step * torch.unsqueeze(
  65. 1.0 - done_mask, -1
  66. ) * torch.unsqueeze(z, 0)
  67. r_tau = torch.clamp(r_tau, v_min, v_max)
  68. b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
  69. lb = torch.floor(b)
  70. ub = torch.ceil(b)
  71. # Indispensable judgement which is missed in most implementations
  72. # when b happens to be an integer, lb == ub, so pr_j(s', a*) will
  73. # be discarded because (ub-b) == (b-lb) == 0.
  74. floor_equal_ceil = ((ub - lb) < 0.5).float()
  75. # (batch_size, num_atoms, num_atoms)
  76. l_project = F.one_hot(lb.long(), num_atoms)
  77. # (batch_size, num_atoms, num_atoms)
  78. u_project = F.one_hot(ub.long(), num_atoms)
  79. ml_delta = q_probs_tp1_best * (ub - b + floor_equal_ceil)
  80. mu_delta = q_probs_tp1_best * (b - lb)
  81. ml_delta = torch.sum(l_project * torch.unsqueeze(ml_delta, -1), dim=1)
  82. mu_delta = torch.sum(u_project * torch.unsqueeze(mu_delta, -1), dim=1)
  83. m = ml_delta + mu_delta
  84. # Rainbow paper claims that using this cross entropy loss for
  85. # priority is robust and insensitive to `prioritized_replay_alpha`
  86. self.td_error = softmax_cross_entropy_with_logits(
  87. logits=q_logits_t_selected, labels=m.detach()
  88. )
  89. self.loss = torch.mean(self.td_error * importance_weights)
  90. self.stats = {
  91. # TODO: better Q stats for dist dqn
  92. }
  93. else:
  94. q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
  95. # compute RHS of bellman equation
  96. q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
  97. # compute the error (potentially clipped)
  98. self.td_error = q_t_selected - q_t_selected_target.detach()
  99. self.loss = torch.mean(importance_weights.float() * loss_fn(self.td_error))
  100. self.stats = {
  101. "mean_q": torch.mean(q_t_selected),
  102. "min_q": torch.min(q_t_selected),
  103. "max_q": torch.max(q_t_selected),
  104. }
  105. class ComputeTDErrorMixin:
  106. """Assign the `compute_td_error` method to the DQNTorchPolicy
  107. This allows us to prioritize on the worker side.
  108. """
  109. def __init__(self):
  110. def compute_td_error(
  111. obs_t, act_t, rew_t, obs_tp1, terminateds_mask, importance_weights
  112. ):
  113. input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
  114. input_dict[SampleBatch.ACTIONS] = act_t
  115. input_dict[SampleBatch.REWARDS] = rew_t
  116. input_dict[SampleBatch.NEXT_OBS] = obs_tp1
  117. input_dict[SampleBatch.TERMINATEDS] = terminateds_mask
  118. input_dict[PRIO_WEIGHTS] = importance_weights
  119. # Do forward pass on loss to update td error attribute
  120. build_q_losses(self, self.model, None, input_dict)
  121. return self.model.tower_stats["q_loss"].td_error
  122. self.compute_td_error = compute_td_error
  123. def build_q_model_and_distribution(
  124. policy: Policy,
  125. obs_space: gym.spaces.Space,
  126. action_space: gym.spaces.Space,
  127. config: AlgorithmConfigDict,
  128. ) -> Tuple[ModelV2, TorchDistributionWrapper]:
  129. """Build q_model and target_model for DQN
  130. Args:
  131. policy: The policy, which will use the model for optimization.
  132. obs_space (gym.spaces.Space): The policy's observation space.
  133. action_space (gym.spaces.Space): The policy's action space.
  134. config (AlgorithmConfigDict):
  135. Returns:
  136. (q_model, TorchCategorical)
  137. Note: The target q model will not be returned, just assigned to
  138. `policy.target_model`.
  139. """
  140. if not isinstance(action_space, gym.spaces.Discrete):
  141. raise UnsupportedSpaceException(
  142. "Action space {} is not supported for DQN.".format(action_space)
  143. )
  144. if config["hiddens"]:
  145. # try to infer the last layer size, otherwise fall back to 256
  146. num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
  147. config["model"]["no_final_linear"] = True
  148. else:
  149. num_outputs = action_space.n
  150. # TODO(sven): Move option to add LayerNorm after each Dense
  151. # generically into ModelCatalog.
  152. add_layer_norm = (
  153. isinstance(getattr(policy, "exploration", None), ParameterNoise)
  154. or config["exploration_config"]["type"] == "ParameterNoise"
  155. )
  156. model = ModelCatalog.get_model_v2(
  157. obs_space=obs_space,
  158. action_space=action_space,
  159. num_outputs=num_outputs,
  160. model_config=config["model"],
  161. framework="torch",
  162. model_interface=DQNTorchModel,
  163. name=Q_SCOPE,
  164. q_hiddens=config["hiddens"],
  165. dueling=config["dueling"],
  166. num_atoms=config["num_atoms"],
  167. use_noisy=config["noisy"],
  168. v_min=config["v_min"],
  169. v_max=config["v_max"],
  170. sigma0=config["sigma0"],
  171. # TODO(sven): Move option to add LayerNorm after each Dense
  172. # generically into ModelCatalog.
  173. add_layer_norm=add_layer_norm,
  174. )
  175. policy.target_model = ModelCatalog.get_model_v2(
  176. obs_space=obs_space,
  177. action_space=action_space,
  178. num_outputs=num_outputs,
  179. model_config=config["model"],
  180. framework="torch",
  181. model_interface=DQNTorchModel,
  182. name=Q_TARGET_SCOPE,
  183. q_hiddens=config["hiddens"],
  184. dueling=config["dueling"],
  185. num_atoms=config["num_atoms"],
  186. use_noisy=config["noisy"],
  187. v_min=config["v_min"],
  188. v_max=config["v_max"],
  189. sigma0=config["sigma0"],
  190. # TODO(sven): Move option to add LayerNorm after each Dense
  191. # generically into ModelCatalog.
  192. add_layer_norm=add_layer_norm,
  193. )
  194. # Return a Torch TorchCategorical distribution where the temperature
  195. # parameter is partially binded to the configured value.
  196. temperature = config["categorical_distribution_temperature"]
  197. return model, get_torch_categorical_class_with_temperature(temperature)
  198. def get_distribution_inputs_and_class(
  199. policy: Policy,
  200. model: ModelV2,
  201. input_dict: SampleBatch,
  202. *,
  203. explore: bool = True,
  204. is_training: bool = False,
  205. **kwargs
  206. ) -> Tuple[TensorType, type, List[TensorType]]:
  207. q_vals = compute_q_values(
  208. policy, model, input_dict, explore=explore, is_training=is_training
  209. )
  210. q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
  211. model.tower_stats["q_values"] = q_vals
  212. # Return a Torch TorchCategorical distribution where the temperature
  213. # parameter is partially binded to the configured value.
  214. temperature = policy.config["categorical_distribution_temperature"]
  215. return (
  216. q_vals,
  217. get_torch_categorical_class_with_temperature(temperature),
  218. [], # state-out
  219. )
  220. def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType:
  221. """Constructs the loss for DQNTorchPolicy.
  222. Args:
  223. policy: The Policy to calculate the loss for.
  224. model (ModelV2): The Model to calculate the loss for.
  225. train_batch: The training data.
  226. Returns:
  227. TensorType: A single loss tensor.
  228. """
  229. config = policy.config
  230. # Q-network evaluation.
  231. q_t, q_logits_t, q_probs_t, _ = compute_q_values(
  232. policy,
  233. model,
  234. {"obs": train_batch[SampleBatch.CUR_OBS]},
  235. explore=False,
  236. is_training=True,
  237. )
  238. # Target Q-network evaluation.
  239. q_tp1, q_logits_tp1, q_probs_tp1, _ = compute_q_values(
  240. policy,
  241. policy.target_models[model],
  242. {"obs": train_batch[SampleBatch.NEXT_OBS]},
  243. explore=False,
  244. is_training=True,
  245. )
  246. # Q scores for actions which we know were selected in the given state.
  247. one_hot_selection = F.one_hot(
  248. train_batch[SampleBatch.ACTIONS].long(), policy.action_space.n
  249. )
  250. q_t_selected = torch.sum(
  251. torch.where(q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=q_t.device))
  252. * one_hot_selection,
  253. 1,
  254. )
  255. q_logits_t_selected = torch.sum(
  256. q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1
  257. )
  258. # compute estimate of best possible value starting from state at t + 1
  259. if config["double_q"]:
  260. (
  261. q_tp1_using_online_net,
  262. q_logits_tp1_using_online_net,
  263. q_dist_tp1_using_online_net,
  264. _,
  265. ) = compute_q_values(
  266. policy,
  267. model,
  268. {"obs": train_batch[SampleBatch.NEXT_OBS]},
  269. explore=False,
  270. is_training=True,
  271. )
  272. q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
  273. q_tp1_best_one_hot_selection = F.one_hot(
  274. q_tp1_best_using_online_net, policy.action_space.n
  275. )
  276. q_tp1_best = torch.sum(
  277. torch.where(
  278. q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device)
  279. )
  280. * q_tp1_best_one_hot_selection,
  281. 1,
  282. )
  283. q_probs_tp1_best = torch.sum(
  284. q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
  285. )
  286. else:
  287. q_tp1_best_one_hot_selection = F.one_hot(
  288. torch.argmax(q_tp1, 1), policy.action_space.n
  289. )
  290. q_tp1_best = torch.sum(
  291. torch.where(
  292. q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device)
  293. )
  294. * q_tp1_best_one_hot_selection,
  295. 1,
  296. )
  297. q_probs_tp1_best = torch.sum(
  298. q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
  299. )
  300. loss_fn = huber_loss if policy.config["td_error_loss_fn"] == "huber" else l2_loss
  301. q_loss = QLoss(
  302. q_t_selected,
  303. q_logits_t_selected,
  304. q_tp1_best,
  305. q_probs_tp1_best,
  306. train_batch[PRIO_WEIGHTS],
  307. train_batch[SampleBatch.REWARDS],
  308. train_batch[SampleBatch.TERMINATEDS].float(),
  309. config["gamma"],
  310. config["n_step"],
  311. config["num_atoms"],
  312. config["v_min"],
  313. config["v_max"],
  314. loss_fn,
  315. )
  316. # Store values for stats function in model (tower), such that for
  317. # multi-GPU, we do not override them during the parallel loss phase.
  318. model.tower_stats["td_error"] = q_loss.td_error
  319. # TD-error tensor in final stats
  320. # will be concatenated and retrieved for each individual batch item.
  321. model.tower_stats["q_loss"] = q_loss
  322. return q_loss.loss
  323. def adam_optimizer(
  324. policy: Policy, config: AlgorithmConfigDict
  325. ) -> "torch.optim.Optimizer":
  326. # By this time, the models have been moved to the GPU - if any - and we
  327. # can define our optimizers using the correct CUDA variables.
  328. if not hasattr(policy, "q_func_vars"):
  329. policy.q_func_vars = policy.model.variables()
  330. return torch.optim.Adam(
  331. policy.q_func_vars, lr=policy.cur_lr, eps=config["adam_epsilon"]
  332. )
  333. def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
  334. stats = {}
  335. for stats_key in policy.model_gpu_towers[0].tower_stats["q_loss"].stats.keys():
  336. stats[stats_key] = torch.mean(
  337. torch.stack(
  338. [
  339. t.tower_stats["q_loss"].stats[stats_key].to(policy.device)
  340. for t in policy.model_gpu_towers
  341. if "q_loss" in t.tower_stats
  342. ]
  343. )
  344. )
  345. stats["cur_lr"] = policy.cur_lr
  346. return stats
  347. def setup_early_mixins(
  348. policy: Policy, obs_space, action_space, config: AlgorithmConfigDict
  349. ) -> None:
  350. LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
  351. def before_loss_init(
  352. policy: Policy,
  353. obs_space: gym.spaces.Space,
  354. action_space: gym.spaces.Space,
  355. config: AlgorithmConfigDict,
  356. ) -> None:
  357. ComputeTDErrorMixin.__init__(policy)
  358. TargetNetworkMixin.__init__(policy)
  359. def compute_q_values(
  360. policy: Policy,
  361. model: ModelV2,
  362. input_dict,
  363. state_batches=None,
  364. seq_lens=None,
  365. explore=None,
  366. is_training: bool = False,
  367. ):
  368. config = policy.config
  369. model_out, state = model(input_dict, state_batches or [], seq_lens)
  370. if config["num_atoms"] > 1:
  371. (
  372. action_scores,
  373. z,
  374. support_logits_per_action,
  375. logits,
  376. probs_or_logits,
  377. ) = model.get_q_value_distributions(model_out)
  378. else:
  379. (action_scores, logits, probs_or_logits) = model.get_q_value_distributions(
  380. model_out
  381. )
  382. if config["dueling"]:
  383. state_score = model.get_state_value(model_out)
  384. if policy.config["num_atoms"] > 1:
  385. support_logits_per_action_mean = torch.mean(
  386. support_logits_per_action, dim=1
  387. )
  388. support_logits_per_action_centered = (
  389. support_logits_per_action
  390. - torch.unsqueeze(support_logits_per_action_mean, dim=1)
  391. )
  392. support_logits_per_action = (
  393. torch.unsqueeze(state_score, dim=1) + support_logits_per_action_centered
  394. )
  395. support_prob_per_action = nn.functional.softmax(
  396. support_logits_per_action, dim=-1
  397. )
  398. value = torch.sum(z * support_prob_per_action, dim=-1)
  399. logits = support_logits_per_action
  400. probs_or_logits = support_prob_per_action
  401. else:
  402. advantages_mean = reduce_mean_ignore_inf(action_scores, 1)
  403. advantages_centered = action_scores - torch.unsqueeze(advantages_mean, 1)
  404. value = state_score + advantages_centered
  405. else:
  406. value = action_scores
  407. return value, logits, probs_or_logits, state
  408. def grad_process_and_td_error_fn(
  409. policy: Policy, optimizer: "torch.optim.Optimizer", loss: TensorType
  410. ) -> Dict[str, TensorType]:
  411. # Clip grads if configured.
  412. return apply_grad_clipping(policy, optimizer, loss)
  413. def extra_action_out_fn(
  414. policy: Policy, input_dict, state_batches, model, action_dist
  415. ) -> Dict[str, TensorType]:
  416. return {"q_values": model.tower_stats["q_values"]}
  417. DQNTorchPolicy = build_policy_class(
  418. name="DQNTorchPolicy",
  419. framework="torch",
  420. loss_fn=build_q_losses,
  421. get_default_config=lambda: ray.rllib.algorithms.dqn.dqn.DQNConfig(),
  422. make_model_and_action_dist=build_q_model_and_distribution,
  423. action_distribution_fn=get_distribution_inputs_and_class,
  424. stats_fn=build_q_stats,
  425. postprocess_fn=postprocess_nstep_and_prio,
  426. optimizer_fn=adam_optimizer,
  427. extra_grad_process_fn=grad_process_and_td_error_fn,
  428. extra_learn_fetches_fn=concat_multi_gpu_td_errors,
  429. extra_action_out_fn=extra_action_out_fn,
  430. before_init=setup_early_mixins,
  431. before_loss_init=before_loss_init,
  432. mixins=[
  433. TargetNetworkMixin,
  434. ComputeTDErrorMixin,
  435. LearningRateSchedule,
  436. ],
  437. )