appo_torch_policy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. """
  2. PyTorch policy class used for APPO.
  3. Adapted from VTraceTFPolicy to use the PPO surrogate loss.
  4. Keep in sync with changes to VTraceTFPolicy.
  5. """
  6. import gym
  7. import numpy as np
  8. import logging
  9. from typing import Type
  10. from ray.rllib.agents.dqn.simple_q_torch_policy import TargetNetworkMixin
  11. import ray.rllib.agents.impala.vtrace_torch as vtrace
  12. from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \
  13. choose_optimizer
  14. from ray.rllib.agents.ppo.appo_tf_policy import make_appo_model, \
  15. postprocess_trajectory
  16. from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin, \
  17. KLCoeffMixin
  18. from ray.rllib.evaluation.postprocessing import Postprocessing
  19. from ray.rllib.models.modelv2 import ModelV2
  20. from ray.rllib.models.torch.torch_action_dist import \
  21. TorchDistributionWrapper, TorchCategorical
  22. from ray.rllib.policy.policy import Policy
  23. from ray.rllib.policy.policy_template import build_policy_class
  24. from ray.rllib.policy.sample_batch import SampleBatch
  25. from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
  26. LearningRateSchedule
  27. from ray.rllib.utils.framework import try_import_torch
  28. from ray.rllib.utils.torch_utils import apply_grad_clipping, \
  29. explained_variance, global_norm, sequence_mask
  30. from ray.rllib.utils.typing import TensorType, TrainerConfigDict
  31. torch, nn = try_import_torch()
  32. logger = logging.getLogger(__name__)
  33. def appo_surrogate_loss(policy: Policy, model: ModelV2,
  34. dist_class: Type[TorchDistributionWrapper],
  35. train_batch: SampleBatch) -> TensorType:
  36. """Constructs the loss for APPO.
  37. With IS modifications and V-trace for Advantage Estimation.
  38. Args:
  39. policy (Policy): The Policy to calculate the loss for.
  40. model (ModelV2): The Model to calculate the loss for.
  41. dist_class (Type[ActionDistribution]): The action distr. class.
  42. train_batch (SampleBatch): The training data.
  43. Returns:
  44. Union[TensorType, List[TensorType]]: A single loss tensor or a list
  45. of loss tensors.
  46. """
  47. target_model = policy.target_models[model]
  48. model_out, _ = model(train_batch)
  49. action_dist = dist_class(model_out, model)
  50. if isinstance(policy.action_space, gym.spaces.Discrete):
  51. is_multidiscrete = False
  52. output_hidden_shape = [policy.action_space.n]
  53. elif isinstance(policy.action_space,
  54. gym.spaces.multi_discrete.MultiDiscrete):
  55. is_multidiscrete = True
  56. output_hidden_shape = policy.action_space.nvec.astype(np.int32)
  57. else:
  58. is_multidiscrete = False
  59. output_hidden_shape = 1
  60. def _make_time_major(*args, **kwargs):
  61. return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
  62. *args, **kwargs)
  63. actions = train_batch[SampleBatch.ACTIONS]
  64. dones = train_batch[SampleBatch.DONES]
  65. rewards = train_batch[SampleBatch.REWARDS]
  66. behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
  67. target_model_out, _ = target_model(train_batch)
  68. prev_action_dist = dist_class(behaviour_logits, model)
  69. values = model.value_function()
  70. values_time_major = _make_time_major(values)
  71. drop_last = policy.config["vtrace"] and \
  72. policy.config["vtrace_drop_last_ts"]
  73. if policy.is_recurrent():
  74. max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
  75. mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
  76. mask = torch.reshape(mask, [-1])
  77. mask = _make_time_major(mask, drop_last=drop_last)
  78. num_valid = torch.sum(mask)
  79. def reduce_mean_valid(t):
  80. return torch.sum(t[mask]) / num_valid
  81. else:
  82. reduce_mean_valid = torch.mean
  83. if policy.config["vtrace"]:
  84. logger.debug("Using V-Trace surrogate loss (vtrace=True; "
  85. f"drop_last={drop_last})")
  86. old_policy_behaviour_logits = target_model_out.detach()
  87. old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
  88. if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
  89. unpacked_behaviour_logits = torch.split(
  90. behaviour_logits, list(output_hidden_shape), dim=1)
  91. unpacked_old_policy_behaviour_logits = torch.split(
  92. old_policy_behaviour_logits, list(output_hidden_shape), dim=1)
  93. else:
  94. unpacked_behaviour_logits = torch.chunk(
  95. behaviour_logits, output_hidden_shape, dim=1)
  96. unpacked_old_policy_behaviour_logits = torch.chunk(
  97. old_policy_behaviour_logits, output_hidden_shape, dim=1)
  98. # Prepare actions for loss.
  99. loss_actions = actions if is_multidiscrete else torch.unsqueeze(
  100. actions, dim=1)
  101. # Prepare KL for loss.
  102. action_kl = _make_time_major(
  103. old_policy_action_dist.kl(action_dist), drop_last=drop_last)
  104. # Compute vtrace on the CPU for better perf.
  105. vtrace_returns = vtrace.multi_from_logits(
  106. behaviour_policy_logits=_make_time_major(
  107. unpacked_behaviour_logits, drop_last=drop_last),
  108. target_policy_logits=_make_time_major(
  109. unpacked_old_policy_behaviour_logits, drop_last=drop_last),
  110. actions=torch.unbind(
  111. _make_time_major(loss_actions, drop_last=drop_last), dim=2),
  112. discounts=(1.0 - _make_time_major(
  113. dones, drop_last=drop_last).float()) * policy.config["gamma"],
  114. rewards=_make_time_major(rewards, drop_last=drop_last),
  115. values=values_time_major[:-1] if drop_last else values_time_major,
  116. bootstrap_value=values_time_major[-1],
  117. dist_class=TorchCategorical if is_multidiscrete else dist_class,
  118. model=model,
  119. clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
  120. clip_pg_rho_threshold=policy.config[
  121. "vtrace_clip_pg_rho_threshold"])
  122. actions_logp = _make_time_major(
  123. action_dist.logp(actions), drop_last=drop_last)
  124. prev_actions_logp = _make_time_major(
  125. prev_action_dist.logp(actions), drop_last=drop_last)
  126. old_policy_actions_logp = _make_time_major(
  127. old_policy_action_dist.logp(actions), drop_last=drop_last)
  128. is_ratio = torch.clamp(
  129. torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
  130. logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
  131. policy._is_ratio = is_ratio
  132. advantages = vtrace_returns.pg_advantages.to(logp_ratio.device)
  133. surrogate_loss = torch.min(
  134. advantages * logp_ratio,
  135. advantages *
  136. torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
  137. 1 + policy.config["clip_param"]))
  138. mean_kl_loss = reduce_mean_valid(action_kl)
  139. mean_policy_loss = -reduce_mean_valid(surrogate_loss)
  140. # The value function loss.
  141. value_targets = vtrace_returns.vs.to(values_time_major.device)
  142. if drop_last:
  143. delta = values_time_major[:-1] - value_targets
  144. else:
  145. delta = values_time_major - value_targets
  146. mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
  147. # The entropy loss.
  148. mean_entropy = reduce_mean_valid(
  149. _make_time_major(action_dist.entropy(), drop_last=drop_last))
  150. else:
  151. logger.debug("Using PPO surrogate loss (vtrace=False)")
  152. # Prepare KL for Loss
  153. action_kl = _make_time_major(prev_action_dist.kl(action_dist))
  154. actions_logp = _make_time_major(action_dist.logp(actions))
  155. prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
  156. logp_ratio = torch.exp(actions_logp - prev_actions_logp)
  157. advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
  158. surrogate_loss = torch.min(
  159. advantages * logp_ratio,
  160. advantages *
  161. torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
  162. 1 + policy.config["clip_param"]))
  163. mean_kl_loss = reduce_mean_valid(action_kl)
  164. mean_policy_loss = -reduce_mean_valid(surrogate_loss)
  165. # The value function loss.
  166. value_targets = _make_time_major(
  167. train_batch[Postprocessing.VALUE_TARGETS])
  168. delta = values_time_major - value_targets
  169. mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
  170. # The entropy loss.
  171. mean_entropy = reduce_mean_valid(
  172. _make_time_major(action_dist.entropy()))
  173. # The summed weighted loss
  174. total_loss = mean_policy_loss + \
  175. mean_vf_loss * policy.config["vf_loss_coeff"] - \
  176. mean_entropy * policy.entropy_coeff
  177. # Optional additional KL Loss
  178. if policy.config["use_kl_loss"]:
  179. total_loss += policy.kl_coeff * mean_kl_loss
  180. # Store values for stats function in model (tower), such that for
  181. # multi-GPU, we do not override them during the parallel loss phase.
  182. model.tower_stats["total_loss"] = total_loss
  183. model.tower_stats["mean_policy_loss"] = mean_policy_loss
  184. model.tower_stats["mean_kl_loss"] = mean_kl_loss
  185. model.tower_stats["mean_vf_loss"] = mean_vf_loss
  186. model.tower_stats["mean_entropy"] = mean_entropy
  187. model.tower_stats["value_targets"] = value_targets
  188. model.tower_stats["vf_explained_var"] = explained_variance(
  189. torch.reshape(value_targets, [-1]),
  190. torch.reshape(
  191. values_time_major[:-1] if drop_last else values_time_major, [-1]),
  192. )
  193. return total_loss
  194. def stats(policy: Policy, train_batch: SampleBatch):
  195. """Stats function for APPO. Returns a dict with important loss stats.
  196. Args:
  197. policy (Policy): The Policy to generate stats for.
  198. train_batch (SampleBatch): The SampleBatch (already) used for training.
  199. Returns:
  200. Dict[str, TensorType]: The stats dict.
  201. """
  202. stats_dict = {
  203. "cur_lr": policy.cur_lr,
  204. "total_loss": torch.mean(
  205. torch.stack(policy.get_tower_stats("total_loss"))),
  206. "policy_loss": torch.mean(
  207. torch.stack(policy.get_tower_stats("mean_policy_loss"))),
  208. "entropy": torch.mean(
  209. torch.stack(policy.get_tower_stats("mean_entropy"))),
  210. "entropy_coeff": policy.entropy_coeff,
  211. "var_gnorm": global_norm(policy.model.trainable_variables()),
  212. "vf_loss": torch.mean(
  213. torch.stack(policy.get_tower_stats("mean_vf_loss"))),
  214. "vf_explained_var": torch.mean(
  215. torch.stack(policy.get_tower_stats("vf_explained_var"))),
  216. }
  217. if policy.config["vtrace"]:
  218. is_stat_mean = torch.mean(policy._is_ratio, [0, 1])
  219. is_stat_var = torch.var(policy._is_ratio, [0, 1])
  220. stats_dict["mean_IS"] = is_stat_mean
  221. stats_dict["var_IS"] = is_stat_var
  222. if policy.config["use_kl_loss"]:
  223. stats_dict["kl"] = policy.get_tower_stats("mean_kl_loss")
  224. stats_dict["KL_Coeff"] = policy.kl_coeff
  225. return stats_dict
  226. def add_values(policy, input_dict, state_batches, model, action_dist):
  227. out = {}
  228. if not policy.config["vtrace"]:
  229. out[SampleBatch.VF_PREDS] = model.value_function()
  230. return out
  231. def setup_early_mixins(policy: Policy, obs_space: gym.spaces.Space,
  232. action_space: gym.spaces.Space,
  233. config: TrainerConfigDict):
  234. """Call all mixin classes' constructors before APPOPolicy initialization.
  235. Args:
  236. policy (Policy): The Policy object.
  237. obs_space (gym.spaces.Space): The Policy's observation space.
  238. action_space (gym.spaces.Space): The Policy's action space.
  239. config (TrainerConfigDict): The Policy's config.
  240. """
  241. LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
  242. EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
  243. config["entropy_coeff_schedule"])
  244. def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
  245. action_space: gym.spaces.Space,
  246. config: TrainerConfigDict):
  247. """Call all mixin classes' constructors after APPOPolicy initialization.
  248. Args:
  249. policy (Policy): The Policy object.
  250. obs_space (gym.spaces.Space): The Policy's observation space.
  251. action_space (gym.spaces.Space): The Policy's action space.
  252. config (TrainerConfigDict): The Policy's config.
  253. """
  254. KLCoeffMixin.__init__(policy, config)
  255. ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
  256. TargetNetworkMixin.__init__(policy)
  257. # Build a child class of `TorchPolicy`, given the custom functions defined
  258. # above.
  259. AsyncPPOTorchPolicy = build_policy_class(
  260. name="AsyncPPOTorchPolicy",
  261. framework="torch",
  262. loss_fn=appo_surrogate_loss,
  263. stats_fn=stats,
  264. postprocess_fn=postprocess_trajectory,
  265. extra_action_out_fn=add_values,
  266. extra_grad_process_fn=apply_grad_clipping,
  267. optimizer_fn=choose_optimizer,
  268. before_init=setup_early_mixins,
  269. before_loss_init=setup_late_mixins,
  270. make_model=make_appo_model,
  271. mixins=[
  272. LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin,
  273. ValueNetworkMixin, EntropyCoeffSchedule
  274. ],
  275. get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])