rnnsac_torch_policy.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. import gym
  2. import numpy as np
  3. from typing import List, Optional, Tuple, Type, Union
  4. import ray
  5. from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS
  6. from ray.rllib.agents.sac import SACTorchPolicy
  7. from ray.rllib.agents.sac.rnnsac_torch_model import RNNSACTorchModel
  8. from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class
  9. from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
  10. from ray.rllib.models.modelv2 import ModelV2
  11. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  12. from ray.rllib.policy.policy import Policy
  13. from ray.rllib.policy.sample_batch import SampleBatch
  14. from ray.rllib.utils.framework import try_import_torch
  15. from ray.rllib.utils.torch_utils import huber_loss, sequence_mask
  16. from ray.rllib.utils.typing import \
  17. ModelInputDict, TensorType, TrainerConfigDict
  18. torch, nn = try_import_torch()
  19. F = None
  20. if nn:
  21. F = nn.functional
  22. def build_rnnsac_model(policy: Policy, obs_space: gym.spaces.Space,
  23. action_space: gym.spaces.Space,
  24. config: TrainerConfigDict) -> ModelV2:
  25. """Constructs the necessary ModelV2 for the Policy and returns it.
  26. Args:
  27. policy (Policy): The TFPolicy that will use the models.
  28. obs_space (gym.spaces.Space): The observation space.
  29. action_space (gym.spaces.Space): The action space.
  30. config (TrainerConfigDict): The SAC trainer's config dict.
  31. Returns:
  32. ModelV2: The ModelV2 to be used by the Policy. Note: An additional
  33. target model will be created in this function and assigned to
  34. `policy.target_model`.
  35. """
  36. # With separate state-preprocessor (before obs+action concat).
  37. num_outputs = int(np.product(obs_space.shape))
  38. # Force-ignore any additionally provided hidden layer sizes.
  39. # Everything should be configured using SAC's "Q_model" and "policy_model"
  40. # settings.
  41. policy_model_config = MODEL_DEFAULTS.copy()
  42. policy_model_config.update(config["policy_model"])
  43. q_model_config = MODEL_DEFAULTS.copy()
  44. q_model_config.update(config["Q_model"])
  45. default_model_cls = RNNSACTorchModel
  46. model = ModelCatalog.get_model_v2(
  47. obs_space=obs_space,
  48. action_space=action_space,
  49. num_outputs=num_outputs,
  50. model_config=config["model"],
  51. framework=config["framework"],
  52. default_model=default_model_cls,
  53. name="sac_model",
  54. policy_model_config=policy_model_config,
  55. q_model_config=q_model_config,
  56. twin_q=config["twin_q"],
  57. initial_alpha=config["initial_alpha"],
  58. target_entropy=config["target_entropy"])
  59. assert isinstance(model, default_model_cls)
  60. # Create an exact copy of the model and store it in `policy.target_model`.
  61. # This will be used for tau-synched Q-target models that run behind the
  62. # actual Q-networks and are used for target q-value calculations in the
  63. # loss terms.
  64. policy.target_model = ModelCatalog.get_model_v2(
  65. obs_space=obs_space,
  66. action_space=action_space,
  67. num_outputs=num_outputs,
  68. model_config=config["model"],
  69. framework=config["framework"],
  70. default_model=default_model_cls,
  71. name="target_sac_model",
  72. policy_model_config=policy_model_config,
  73. q_model_config=q_model_config,
  74. twin_q=config["twin_q"],
  75. initial_alpha=config["initial_alpha"],
  76. target_entropy=config["target_entropy"])
  77. assert isinstance(policy.target_model, default_model_cls)
  78. return model
  79. def build_sac_model_and_action_dist(
  80. policy: Policy,
  81. obs_space: gym.spaces.Space,
  82. action_space: gym.spaces.Space,
  83. config: TrainerConfigDict) -> \
  84. Tuple[ModelV2, Type[TorchDistributionWrapper]]:
  85. """Constructs the necessary ModelV2 and action dist class for the Policy.
  86. Args:
  87. policy (Policy): The TFPolicy that will use the models.
  88. obs_space (gym.spaces.Space): The observation space.
  89. action_space (gym.spaces.Space): The action space.
  90. config (TrainerConfigDict): The SAC trainer's config dict.
  91. Returns:
  92. ModelV2: The ModelV2 to be used by the Policy. Note: An additional
  93. target model will be created in this function and assigned to
  94. `policy.target_model`.
  95. """
  96. model = build_rnnsac_model(policy, obs_space, action_space, config)
  97. assert model.get_initial_state() != [], \
  98. "RNNSAC requires its model to be a recurrent one!"
  99. action_dist_class = _get_dist_class(policy, config, action_space)
  100. return model, action_dist_class
  101. def action_distribution_fn(
  102. policy: Policy,
  103. model: ModelV2,
  104. input_dict: ModelInputDict,
  105. *,
  106. state_batches: Optional[List[TensorType]] = None,
  107. seq_lens: Optional[TensorType] = None,
  108. prev_action_batch: Optional[TensorType] = None,
  109. prev_reward_batch=None,
  110. explore: Optional[bool] = None,
  111. timestep: Optional[int] = None,
  112. is_training: Optional[bool] = None) -> \
  113. Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]:
  114. """The action distribution function to be used the algorithm.
  115. An action distribution function is used to customize the choice of action
  116. distribution class and the resulting action distribution inputs (to
  117. parameterize the distribution object).
  118. After parameterizing the distribution, a `sample()` call
  119. will be made on it to generate actions.
  120. Args:
  121. policy (Policy): The Policy being queried for actions and calling this
  122. function.
  123. model (TorchModelV2): The SAC specific Model to use to generate the
  124. distribution inputs (see sac_tf|torch_model.py). Must support the
  125. `get_policy_output` method.
  126. input_dict (ModelInputDict): The input-dict to be used for the model
  127. call.
  128. state_batches (Optional[List[TensorType]]): The list of internal state
  129. tensor batches.
  130. seq_lens (Optional[TensorType]): The tensor of sequence lengths used
  131. in RNNs.
  132. prev_action_batch (Optional[TensorType]): Optional batch of prev
  133. actions used by the model.
  134. prev_reward_batch (Optional[TensorType]): Optional batch of prev
  135. rewards used by the model.
  136. explore (Optional[bool]): Whether to activate exploration or not. If
  137. None, use value of `config.explore`.
  138. timestep (Optional[int]): An optional timestep.
  139. is_training (Optional[bool]): An optional is-training flag.
  140. Returns:
  141. Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]:
  142. The dist inputs, dist class, and a list of internal state outputs
  143. (in the RNN case).
  144. """
  145. # Get base-model output (w/o the SAC specific parts of the network).
  146. model_out, state_in = model(input_dict, state_batches, seq_lens)
  147. # Use the base output to get the policy outputs from the SAC model's
  148. # policy components.
  149. states_in = model.select_state(state_in, ["policy", "q", "twin_q"])
  150. distribution_inputs, policy_state_out = \
  151. model.get_policy_output(model_out, states_in["policy"], seq_lens)
  152. _, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens)
  153. if model.twin_q_net:
  154. _, twin_q_state_out = \
  155. model.get_twin_q_values(model_out, states_in["twin_q"], seq_lens)
  156. else:
  157. twin_q_state_out = []
  158. # Get a distribution class to be used with the just calculated dist-inputs.
  159. action_dist_class = _get_dist_class(policy, policy.config,
  160. policy.action_space)
  161. states_out = policy_state_out + q_state_out + twin_q_state_out
  162. return distribution_inputs, action_dist_class, states_out
  163. def actor_critic_loss(
  164. policy: Policy, model: ModelV2,
  165. dist_class: Type[TorchDistributionWrapper],
  166. train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
  167. """Constructs the loss for the Soft Actor Critic.
  168. Args:
  169. policy (Policy): The Policy to calculate the loss for.
  170. model (ModelV2): The Model to calculate the loss for.
  171. dist_class (Type[TorchDistributionWrapper]: The action distr. class.
  172. train_batch (SampleBatch): The training data.
  173. Returns:
  174. Union[TensorType, List[TensorType]]: A single loss tensor or a list
  175. of loss tensors.
  176. """
  177. target_model = policy.target_models[model]
  178. # Should be True only for debugging purposes (e.g. test cases)!
  179. deterministic = policy.config["_deterministic_loss"]
  180. i = 0
  181. state_batches = []
  182. while "state_in_{}".format(i) in train_batch:
  183. state_batches.append(train_batch["state_in_{}".format(i)])
  184. i += 1
  185. assert state_batches
  186. seq_lens = train_batch.get(SampleBatch.SEQ_LENS)
  187. model_out_t, state_in_t = model(
  188. SampleBatch(
  189. obs=train_batch[SampleBatch.CUR_OBS],
  190. prev_actions=train_batch[SampleBatch.PREV_ACTIONS],
  191. prev_rewards=train_batch[SampleBatch.PREV_REWARDS],
  192. _is_training=True), state_batches, seq_lens)
  193. states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"])
  194. model_out_tp1, state_in_tp1 = model(
  195. SampleBatch(
  196. obs=train_batch[SampleBatch.NEXT_OBS],
  197. prev_actions=train_batch[SampleBatch.ACTIONS],
  198. prev_rewards=train_batch[SampleBatch.REWARDS],
  199. _is_training=True), state_batches, seq_lens)
  200. states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"])
  201. target_model_out_tp1, target_state_in_tp1 = target_model(
  202. SampleBatch(
  203. obs=train_batch[SampleBatch.NEXT_OBS],
  204. prev_actions=train_batch[SampleBatch.ACTIONS],
  205. prev_rewards=train_batch[SampleBatch.REWARDS],
  206. _is_training=True), state_batches, seq_lens)
  207. target_states_in_tp1 = target_model.select_state(state_in_tp1,
  208. ["policy", "q", "twin_q"])
  209. alpha = torch.exp(model.log_alpha)
  210. # Discrete case.
  211. if model.discrete:
  212. # Get all action probs directly from pi and form their logp.
  213. log_pis_t = F.log_softmax(
  214. model.get_policy_output(model_out_t, states_in_t["policy"],
  215. seq_lens)[0],
  216. dim=-1)
  217. policy_t = torch.exp(log_pis_t)
  218. log_pis_tp1 = F.log_softmax(
  219. model.get_policy_output(model_out_tp1, states_in_tp1["policy"],
  220. seq_lens)[0], -1)
  221. policy_tp1 = torch.exp(log_pis_tp1)
  222. # Q-values.
  223. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)[0]
  224. # Target Q-values.
  225. q_tp1 = target_model.get_q_values(
  226. target_model_out_tp1, target_states_in_tp1["q"], seq_lens)[0]
  227. if policy.config["twin_q"]:
  228. twin_q_t = model.get_twin_q_values(
  229. model_out_t, states_in_t["twin_q"], seq_lens)[0]
  230. twin_q_tp1 = target_model.get_twin_q_values(
  231. target_model_out_tp1, target_states_in_tp1["twin_q"],
  232. seq_lens)[0]
  233. q_tp1 = torch.min(q_tp1, twin_q_tp1)
  234. q_tp1 -= alpha * log_pis_tp1
  235. # Actually selected Q-values (from the actions batch).
  236. one_hot = F.one_hot(
  237. train_batch[SampleBatch.ACTIONS].long(),
  238. num_classes=q_t.size()[-1])
  239. q_t_selected = torch.sum(q_t * one_hot, dim=-1)
  240. if policy.config["twin_q"]:
  241. twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
  242. # Discrete case: "Best" means weighted by the policy (prob) outputs.
  243. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1)
  244. q_tp1_best_masked = \
  245. (1.0 - train_batch[SampleBatch.DONES].float()) * \
  246. q_tp1_best
  247. # Continuous actions case.
  248. else:
  249. # Sample single actions from distribution.
  250. action_dist_class = _get_dist_class(policy, policy.config,
  251. policy.action_space)
  252. action_dist_t = action_dist_class(
  253. model.get_policy_output(model_out_t, states_in_t["policy"],
  254. seq_lens)[0], model)
  255. policy_t = action_dist_t.sample() if not deterministic else \
  256. action_dist_t.deterministic_sample()
  257. log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
  258. action_dist_tp1 = action_dist_class(
  259. model.get_policy_output(model_out_tp1, states_in_tp1["policy"],
  260. seq_lens)[0], model)
  261. policy_tp1 = action_dist_tp1.sample() if not deterministic else \
  262. action_dist_tp1.deterministic_sample()
  263. log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)
  264. # Q-values for the actually selected actions.
  265. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens,
  266. train_batch[SampleBatch.ACTIONS])[0]
  267. if policy.config["twin_q"]:
  268. twin_q_t = model.get_twin_q_values(
  269. model_out_t, states_in_t["twin_q"], seq_lens,
  270. train_batch[SampleBatch.ACTIONS])[0]
  271. # Q-values for current policy in given current state.
  272. q_t_det_policy = model.get_q_values(model_out_t, states_in_t["q"],
  273. seq_lens, policy_t)[0]
  274. if policy.config["twin_q"]:
  275. twin_q_t_det_policy = model.get_twin_q_values(
  276. model_out_t, states_in_t["twin_q"], seq_lens, policy_t)[0]
  277. q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy)
  278. # Target q network evaluation.
  279. q_tp1 = target_model.get_q_values(target_model_out_tp1,
  280. target_states_in_tp1["q"], seq_lens,
  281. policy_tp1)[0]
  282. if policy.config["twin_q"]:
  283. twin_q_tp1 = target_model.get_twin_q_values(
  284. target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens,
  285. policy_tp1)[0]
  286. # Take min over both twin-NNs.
  287. q_tp1 = torch.min(q_tp1, twin_q_tp1)
  288. q_t_selected = torch.squeeze(q_t, dim=-1)
  289. if policy.config["twin_q"]:
  290. twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
  291. q_tp1 -= alpha * log_pis_tp1
  292. q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
  293. q_tp1_best_masked = \
  294. (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best
  295. # compute RHS of bellman equation
  296. q_t_selected_target = (
  297. train_batch[SampleBatch.REWARDS] +
  298. (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked
  299. ).detach()
  300. # BURNIN #
  301. B = state_batches[0].shape[0]
  302. T = q_t_selected.shape[0] // B
  303. seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)
  304. # Mask away also the burn-in sequence at the beginning.
  305. burn_in = policy.config["burn_in"]
  306. if burn_in > 0 and burn_in < T:
  307. seq_mask[:, :burn_in] = False
  308. seq_mask = seq_mask.reshape(-1)
  309. num_valid = torch.sum(seq_mask)
  310. def reduce_mean_valid(t):
  311. return torch.sum(t[seq_mask]) / num_valid
  312. # Compute the TD-error (potentially clipped).
  313. base_td_error = torch.abs(q_t_selected - q_t_selected_target)
  314. if policy.config["twin_q"]:
  315. twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target)
  316. td_error = 0.5 * (base_td_error + twin_td_error)
  317. else:
  318. td_error = base_td_error
  319. critic_loss = [
  320. reduce_mean_valid(
  321. train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error))
  322. ]
  323. if policy.config["twin_q"]:
  324. critic_loss.append(
  325. reduce_mean_valid(
  326. train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error)))
  327. td_error = td_error * seq_mask
  328. # Alpha- and actor losses.
  329. # Note: In the papers, alpha is used directly, here we take the log.
  330. # Discrete case: Multiply the action probs as weights with the original
  331. # loss terms (no expectations needed).
  332. if model.discrete:
  333. weighted_log_alpha_loss = policy_t.detach() * (
  334. -model.log_alpha * (log_pis_t + model.target_entropy).detach())
  335. # Sum up weighted terms and mean over all batch items.
  336. alpha_loss = reduce_mean_valid(
  337. torch.sum(weighted_log_alpha_loss, dim=-1))
  338. # Actor loss.
  339. actor_loss = reduce_mean_valid(
  340. torch.sum(
  341. torch.mul(
  342. # NOTE: No stop_grad around policy output here
  343. # (compare with q_t_det_policy for continuous case).
  344. policy_t,
  345. alpha.detach() * log_pis_t - q_t.detach()),
  346. dim=-1))
  347. else:
  348. alpha_loss = -reduce_mean_valid(
  349. model.log_alpha * (log_pis_t + model.target_entropy).detach())
  350. # Note: Do not detach q_t_det_policy here b/c is depends partly
  351. # on the policy vars (policy sample pushed through Q-net).
  352. # However, we must make sure `actor_loss` is not used to update
  353. # the Q-net(s)' variables.
  354. actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t -
  355. q_t_det_policy)
  356. # Store values for stats function in model (tower), such that for
  357. # multi-GPU, we do not override them during the parallel loss phase.
  358. model.tower_stats["q_t"] = q_t * seq_mask[..., None]
  359. model.tower_stats["policy_t"] = policy_t * seq_mask[..., None]
  360. model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None]
  361. model.tower_stats["actor_loss"] = actor_loss
  362. model.tower_stats["critic_loss"] = critic_loss
  363. model.tower_stats["alpha_loss"] = alpha_loss
  364. # Store per time chunk (b/c we need only one mean
  365. # prioritized replay weight per stored sequence).
  366. model.tower_stats["td_error"] = torch.mean(
  367. td_error.reshape([-1, T]), dim=-1)
  368. # Return all loss terms corresponding to our optimizers.
  369. return tuple([actor_loss] + critic_loss + [alpha_loss])
  370. RNNSACTorchPolicy = SACTorchPolicy.with_updates(
  371. name="RNNSACPolicy",
  372. get_default_config=lambda: ray.rllib.agents.sac.rnnsac.DEFAULT_CONFIG,
  373. action_distribution_fn=action_distribution_fn,
  374. make_model_and_action_dist=build_sac_model_and_action_dist,
  375. loss_fn=actor_critic_loss,
  376. )