cql_torch_policy.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. """
  2. PyTorch policy class used for CQL.
  3. """
  4. import numpy as np
  5. import gym
  6. import logging
  7. from typing import Dict, List, Tuple, Type, Union
  8. import ray
  9. import ray.experimental.tf_utils
  10. from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \
  11. validate_spaces
  12. from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \
  13. build_sac_model_and_action_dist, optimizer_fn, ComputeTDErrorMixin, \
  14. TargetNetworkMixin, setup_late_mixins, action_distribution_fn
  15. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  16. from ray.rllib.policy.policy_template import build_policy_class
  17. from ray.rllib.models.modelv2 import ModelV2
  18. from ray.rllib.policy.policy import Policy
  19. from ray.rllib.policy.sample_batch import SampleBatch
  20. from ray.rllib.utils.framework import try_import_torch
  21. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  22. from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
  23. TrainerConfigDict
  24. from ray.rllib.utils.torch_utils import apply_grad_clipping, \
  25. convert_to_torch_tensor, concat_multi_gpu_td_errors
  26. torch, nn = try_import_torch()
  27. F = nn.functional
  28. logger = logging.getLogger(__name__)
  29. MEAN_MIN = -9.0
  30. MEAN_MAX = 9.0
  31. # Returns policy tiled actions and log probabilities for CQL Loss
  32. def policy_actions_repeat(model, action_dist, obs, num_repeat=1):
  33. obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(
  34. obs.shape[0] * num_repeat, obs.shape[1])
  35. logits = model.get_policy_output(obs_temp)
  36. policy_dist = action_dist(logits, model)
  37. actions, logp_ = policy_dist.sample_logp()
  38. logp = logp_.unsqueeze(-1)
  39. return actions, logp.view(obs.shape[0], num_repeat, 1)
  40. def q_values_repeat(model, obs, actions, twin=False):
  41. action_shape = actions.shape[0]
  42. obs_shape = obs.shape[0]
  43. num_repeat = int(action_shape / obs_shape)
  44. obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(
  45. obs.shape[0] * num_repeat, obs.shape[1])
  46. if not twin:
  47. preds_ = model.get_q_values(obs_temp, actions)
  48. else:
  49. preds_ = model.get_twin_q_values(obs_temp, actions)
  50. preds = preds_.view(obs.shape[0], num_repeat, 1)
  51. return preds
  52. def cql_loss(policy: Policy, model: ModelV2,
  53. dist_class: Type[TorchDistributionWrapper],
  54. train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
  55. logger.info(f"Current iteration = {policy.cur_iter}")
  56. policy.cur_iter += 1
  57. # Look up the target model (tower) using the model tower.
  58. target_model = policy.target_models[model]
  59. # For best performance, turn deterministic off
  60. deterministic = policy.config["_deterministic_loss"]
  61. assert not deterministic
  62. twin_q = policy.config["twin_q"]
  63. discount = policy.config["gamma"]
  64. action_low = model.action_space.low[0]
  65. action_high = model.action_space.high[0]
  66. # CQL Parameters
  67. bc_iters = policy.config["bc_iters"]
  68. cql_temp = policy.config["temperature"]
  69. num_actions = policy.config["num_actions"]
  70. min_q_weight = policy.config["min_q_weight"]
  71. use_lagrange = policy.config["lagrangian"]
  72. target_action_gap = policy.config["lagrangian_thresh"]
  73. obs = train_batch[SampleBatch.CUR_OBS]
  74. actions = train_batch[SampleBatch.ACTIONS]
  75. rewards = train_batch[SampleBatch.REWARDS].float()
  76. next_obs = train_batch[SampleBatch.NEXT_OBS]
  77. terminals = train_batch[SampleBatch.DONES]
  78. model_out_t, _ = model({
  79. "obs": obs,
  80. "is_training": True,
  81. }, [], None)
  82. model_out_tp1, _ = model({
  83. "obs": next_obs,
  84. "is_training": True,
  85. }, [], None)
  86. target_model_out_tp1, _ = target_model({
  87. "obs": next_obs,
  88. "is_training": True,
  89. }, [], None)
  90. action_dist_class = _get_dist_class(policy, policy.config,
  91. policy.action_space)
  92. action_dist_t = action_dist_class(
  93. model.get_policy_output(model_out_t), policy.model)
  94. policy_t, log_pis_t = action_dist_t.sample_logp()
  95. log_pis_t = torch.unsqueeze(log_pis_t, -1)
  96. # Unlike original SAC, Alpha and Actor Loss are computed first.
  97. # Alpha Loss
  98. alpha_loss = -(model.log_alpha *
  99. (log_pis_t + model.target_entropy).detach()).mean()
  100. if obs.shape[0] == policy.config["train_batch_size"]:
  101. policy.alpha_optim.zero_grad()
  102. alpha_loss.backward()
  103. policy.alpha_optim.step()
  104. # Policy Loss (Either Behavior Clone Loss or SAC Loss)
  105. alpha = torch.exp(model.log_alpha)
  106. if policy.cur_iter >= bc_iters:
  107. min_q = model.get_q_values(model_out_t, policy_t)
  108. if twin_q:
  109. twin_q_ = model.get_twin_q_values(model_out_t, policy_t)
  110. min_q = torch.min(min_q, twin_q_)
  111. actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
  112. else:
  113. bc_logp = action_dist_t.logp(actions)
  114. actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean()
  115. # actor_loss = -bc_logp.mean()
  116. if obs.shape[0] == policy.config["train_batch_size"]:
  117. policy.actor_optim.zero_grad()
  118. actor_loss.backward(retain_graph=True)
  119. policy.actor_optim.step()
  120. # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
  121. # SAC Loss:
  122. # Q-values for the batched actions.
  123. action_dist_tp1 = action_dist_class(
  124. model.get_policy_output(model_out_tp1), policy.model)
  125. policy_tp1, _ = action_dist_tp1.sample_logp()
  126. q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
  127. q_t_selected = torch.squeeze(q_t, dim=-1)
  128. if twin_q:
  129. twin_q_t = model.get_twin_q_values(model_out_t,
  130. train_batch[SampleBatch.ACTIONS])
  131. twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
  132. # Target q network evaluation.
  133. q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1)
  134. if twin_q:
  135. twin_q_tp1 = target_model.get_twin_q_values(target_model_out_tp1,
  136. policy_tp1)
  137. # Take min over both twin-NNs.
  138. q_tp1 = torch.min(q_tp1, twin_q_tp1)
  139. q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
  140. q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best
  141. # compute RHS of bellman equation
  142. q_t_target = (
  143. rewards +
  144. (discount**policy.config["n_step"]) * q_tp1_best_masked).detach()
  145. # Compute the TD-error (potentially clipped), for priority replay buffer
  146. base_td_error = torch.abs(q_t_selected - q_t_target)
  147. if twin_q:
  148. twin_td_error = torch.abs(twin_q_t_selected - q_t_target)
  149. td_error = 0.5 * (base_td_error + twin_td_error)
  150. else:
  151. td_error = base_td_error
  152. critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target)
  153. if twin_q:
  154. critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target)
  155. # CQL Loss (We are using Entropy version of CQL (the best version))
  156. rand_actions = convert_to_torch_tensor(
  157. torch.FloatTensor(actions.shape[0] * num_actions,
  158. actions.shape[-1]).uniform_(action_low, action_high),
  159. policy.device)
  160. curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class,
  161. model_out_t, num_actions)
  162. next_actions, next_logp = policy_actions_repeat(model, action_dist_class,
  163. model_out_tp1, num_actions)
  164. q1_rand = q_values_repeat(model, model_out_t, rand_actions)
  165. q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
  166. q1_next_actions = q_values_repeat(model, model_out_t, next_actions)
  167. if twin_q:
  168. q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
  169. q2_curr_actions = q_values_repeat(
  170. model, model_out_t, curr_actions, twin=True)
  171. q2_next_actions = q_values_repeat(
  172. model, model_out_t, next_actions, twin=True)
  173. random_density = np.log(0.5**curr_actions.shape[-1])
  174. cat_q1 = torch.cat([
  175. q1_rand - random_density, q1_next_actions - next_logp.detach(),
  176. q1_curr_actions - curr_logp.detach()
  177. ], 1)
  178. if twin_q:
  179. cat_q2 = torch.cat([
  180. q2_rand - random_density, q2_next_actions - next_logp.detach(),
  181. q2_curr_actions - curr_logp.detach()
  182. ], 1)
  183. min_qf1_loss_ = torch.logsumexp(
  184. cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
  185. min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight)
  186. if twin_q:
  187. min_qf2_loss_ = torch.logsumexp(
  188. cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
  189. min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight)
  190. if use_lagrange:
  191. alpha_prime = torch.clamp(
  192. model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0]
  193. min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
  194. if twin_q:
  195. min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
  196. alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
  197. else:
  198. alpha_prime_loss = -min_qf1_loss
  199. cql_loss = [min_qf1_loss]
  200. if twin_q:
  201. cql_loss.append(min_qf2_loss)
  202. critic_loss = [critic_loss_1 + min_qf1_loss]
  203. if twin_q:
  204. critic_loss.append(critic_loss_2 + min_qf2_loss)
  205. if obs.shape[0] == policy.config["train_batch_size"]:
  206. policy.critic_optims[0].zero_grad()
  207. critic_loss[0].backward(retain_graph=True)
  208. policy.critic_optims[0].step()
  209. if twin_q:
  210. policy.critic_optims[1].zero_grad()
  211. critic_loss[1].backward(retain_graph=False)
  212. policy.critic_optims[1].step()
  213. # Store values for stats function in model (tower), such that for
  214. # multi-GPU, we do not override them during the parallel loss phase.
  215. # SAC stats.
  216. model.tower_stats["q_t"] = q_t_selected
  217. model.tower_stats["policy_t"] = policy_t
  218. model.tower_stats["log_pis_t"] = log_pis_t
  219. model.tower_stats["actor_loss"] = actor_loss
  220. model.tower_stats["critic_loss"] = critic_loss
  221. model.tower_stats["alpha_loss"] = alpha_loss
  222. model.tower_stats["log_alpha_value"] = model.log_alpha
  223. model.tower_stats["alpha_value"] = alpha
  224. model.tower_stats["target_entropy"] = model.target_entropy
  225. # CQL stats.
  226. model.tower_stats["cql_loss"] = cql_loss
  227. # TD-error tensor in final stats
  228. # will be concatenated and retrieved for each individual batch item.
  229. model.tower_stats["td_error"] = td_error
  230. if use_lagrange:
  231. model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0]
  232. model.tower_stats["alpha_prime_value"] = alpha_prime
  233. model.tower_stats["alpha_prime_loss"] = alpha_prime_loss
  234. if obs.shape[0] == policy.config["train_batch_size"]:
  235. policy.alpha_prime_optim.zero_grad()
  236. alpha_prime_loss.backward()
  237. policy.alpha_prime_optim.step()
  238. # Return all loss terms corresponding to our optimizers.
  239. return tuple([actor_loss] + critic_loss + [alpha_loss] +
  240. ([alpha_prime_loss] if use_lagrange else []))
  241. def cql_stats(policy: Policy,
  242. train_batch: SampleBatch) -> Dict[str, TensorType]:
  243. # Get SAC loss stats.
  244. stats_dict = stats(policy, train_batch)
  245. # Add CQL loss stats to the dict.
  246. stats_dict["cql_loss"] = torch.mean(
  247. torch.stack(*policy.get_tower_stats("cql_loss")))
  248. if policy.config["lagrangian"]:
  249. stats_dict["log_alpha_prime_value"] = torch.mean(
  250. torch.stack(policy.get_tower_stats("log_alpha_prime_value")))
  251. stats_dict["alpha_prime_value"] = torch.mean(
  252. torch.stack(policy.get_tower_stats("alpha_prime_value")))
  253. stats_dict["alpha_prime_loss"] = torch.mean(
  254. torch.stack(policy.get_tower_stats("alpha_prime_loss")))
  255. return stats_dict
  256. def cql_optimizer_fn(policy: Policy, config: TrainerConfigDict) -> \
  257. Tuple[LocalOptimizer]:
  258. policy.cur_iter = 0
  259. opt_list = optimizer_fn(policy, config)
  260. if config["lagrangian"]:
  261. log_alpha_prime = nn.Parameter(
  262. torch.zeros(1, requires_grad=True).float())
  263. policy.model.register_parameter("log_alpha_prime", log_alpha_prime)
  264. policy.alpha_prime_optim = torch.optim.Adam(
  265. params=[policy.model.log_alpha_prime],
  266. lr=config["optimization"]["critic_learning_rate"],
  267. eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default
  268. )
  269. return tuple([policy.actor_optim] + policy.critic_optims +
  270. [policy.alpha_optim] + [policy.alpha_prime_optim])
  271. return opt_list
  272. def cql_setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
  273. action_space: gym.spaces.Space,
  274. config: TrainerConfigDict) -> None:
  275. setup_late_mixins(policy, obs_space, action_space, config)
  276. if config["lagrangian"]:
  277. policy.model.log_alpha_prime = policy.model.log_alpha_prime.to(
  278. policy.device)
  279. def compute_gradients_fn(policy, postprocessed_batch):
  280. batches = [policy._lazy_tensor_dict(postprocessed_batch)]
  281. model = policy.model
  282. policy._loss(policy, model, policy.dist_class, batches[0])
  283. stats = {
  284. LEARNER_STATS_KEY: policy._convert_to_non_torch_type(
  285. cql_stats(policy, batches[0]))
  286. }
  287. return [None, stats]
  288. def apply_gradients_fn(policy, gradients):
  289. return
  290. # Build a child class of `TorchPolicy`, given the custom functions defined
  291. # above.
  292. CQLTorchPolicy = build_policy_class(
  293. name="CQLTorchPolicy",
  294. framework="torch",
  295. loss_fn=cql_loss,
  296. get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
  297. stats_fn=cql_stats,
  298. postprocess_fn=postprocess_trajectory,
  299. extra_grad_process_fn=apply_grad_clipping,
  300. optimizer_fn=cql_optimizer_fn,
  301. validate_spaces=validate_spaces,
  302. before_loss_init=cql_setup_late_mixins,
  303. make_model_and_action_dist=build_sac_model_and_action_dist,
  304. extra_learn_fetches_fn=concat_multi_gpu_td_errors,
  305. mixins=[TargetNetworkMixin, ComputeTDErrorMixin],
  306. action_distribution_fn=action_distribution_fn,
  307. compute_gradients_fn=compute_gradients_fn,
  308. apply_gradients_fn=apply_gradients_fn,
  309. )