impala_torch_policy.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. import gymnasium as gym
  2. import logging
  3. import numpy as np
  4. from typing import Dict, List, Optional, Type, Union
  5. import ray
  6. from ray.rllib.evaluation.episode import Episode
  7. from ray.rllib.evaluation.postprocessing import compute_bootstrap_value
  8. from ray.rllib.models.modelv2 import ModelV2
  9. from ray.rllib.models.action_dist import ActionDistribution
  10. from ray.rllib.models.torch.torch_action_dist import TorchCategorical
  11. from ray.rllib.policy.sample_batch import SampleBatch
  12. from ray.rllib.policy.torch_mixins import (
  13. EntropyCoeffSchedule,
  14. LearningRateSchedule,
  15. ValueNetworkMixin,
  16. )
  17. from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
  18. from ray.rllib.utils.annotations import override
  19. from ray.rllib.utils.framework import try_import_torch
  20. from ray.rllib.utils.numpy import convert_to_numpy
  21. from ray.rllib.utils.torch_utils import (
  22. apply_grad_clipping,
  23. explained_variance,
  24. global_norm,
  25. sequence_mask,
  26. )
  27. from ray.rllib.utils.typing import TensorType
  28. torch, nn = try_import_torch()
  29. logger = logging.getLogger(__name__)
  30. class VTraceLoss:
  31. def __init__(
  32. self,
  33. actions,
  34. actions_logp,
  35. actions_entropy,
  36. dones,
  37. behaviour_action_logp,
  38. behaviour_logits,
  39. target_logits,
  40. discount,
  41. rewards,
  42. values,
  43. bootstrap_value,
  44. dist_class,
  45. model,
  46. valid_mask,
  47. config,
  48. vf_loss_coeff=0.5,
  49. entropy_coeff=0.01,
  50. clip_rho_threshold=1.0,
  51. clip_pg_rho_threshold=1.0,
  52. ):
  53. """Policy gradient loss with vtrace importance weighting.
  54. VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
  55. batch_size. The reason we need to know `B` is for V-trace to properly
  56. handle episode cut boundaries.
  57. Args:
  58. actions: An int|float32 tensor of shape [T, B, ACTION_SPACE].
  59. actions_logp: A float32 tensor of shape [T, B].
  60. actions_entropy: A float32 tensor of shape [T, B].
  61. dones: A bool tensor of shape [T, B].
  62. behaviour_action_logp: Tensor of shape [T, B].
  63. behaviour_logits: A list with length of ACTION_SPACE of float32
  64. tensors of shapes
  65. [T, B, ACTION_SPACE[0]],
  66. ...,
  67. [T, B, ACTION_SPACE[-1]]
  68. target_logits: A list with length of ACTION_SPACE of float32
  69. tensors of shapes
  70. [T, B, ACTION_SPACE[0]],
  71. ...,
  72. [T, B, ACTION_SPACE[-1]]
  73. discount: A float32 scalar.
  74. rewards: A float32 tensor of shape [T, B].
  75. values: A float32 tensor of shape [T, B].
  76. bootstrap_value: A float32 tensor of shape [B].
  77. dist_class: action distribution class for logits.
  78. valid_mask: A bool tensor of valid RNN input elements (#2992).
  79. config: Algorithm config dict.
  80. """
  81. import ray.rllib.algorithms.impala.vtrace_torch as vtrace
  82. if valid_mask is None:
  83. valid_mask = torch.ones_like(actions_logp)
  84. # Compute vtrace on the CPU for better perf
  85. # (devices handled inside `vtrace.multi_from_logits`).
  86. device = behaviour_action_logp[0].device
  87. self.vtrace_returns = vtrace.multi_from_logits(
  88. behaviour_action_log_probs=behaviour_action_logp,
  89. behaviour_policy_logits=behaviour_logits,
  90. target_policy_logits=target_logits,
  91. actions=torch.unbind(actions, dim=2),
  92. discounts=(1.0 - dones.float()) * discount,
  93. rewards=rewards,
  94. values=values,
  95. bootstrap_value=bootstrap_value,
  96. dist_class=dist_class,
  97. model=model,
  98. clip_rho_threshold=clip_rho_threshold,
  99. clip_pg_rho_threshold=clip_pg_rho_threshold,
  100. )
  101. # Move v-trace results back to GPU for actual loss computing.
  102. self.value_targets = self.vtrace_returns.vs.to(device)
  103. # The policy gradients loss.
  104. self.pi_loss = -torch.sum(
  105. actions_logp * self.vtrace_returns.pg_advantages.to(device) * valid_mask
  106. )
  107. # The baseline loss.
  108. delta = (values - self.value_targets) * valid_mask
  109. self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0))
  110. # The entropy loss.
  111. self.entropy = torch.sum(actions_entropy * valid_mask)
  112. self.mean_entropy = self.entropy / torch.sum(valid_mask)
  113. # The summed weighted loss.
  114. self.total_loss = (
  115. self.pi_loss + self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff
  116. )
  117. def make_time_major(policy, seq_lens, tensor):
  118. """Swaps batch and trajectory axis.
  119. Args:
  120. policy: Policy reference
  121. seq_lens: Sequence lengths if recurrent or None
  122. tensor: A tensor or list of tensors to reshape.
  123. Returns:
  124. res: A tensor with swapped axes or a list of tensors with
  125. swapped axes.
  126. """
  127. if isinstance(tensor, (list, tuple)):
  128. return [make_time_major(policy, seq_lens, t) for t in tensor]
  129. if policy.is_recurrent():
  130. B = seq_lens.shape[0]
  131. T = tensor.shape[0] // B
  132. else:
  133. # Important: chop the tensor into batches at known episode cut
  134. # boundaries.
  135. # TODO: (sven) this is kind of a hack and won't work for
  136. # batch_mode=complete_episodes.
  137. T = policy.config["rollout_fragment_length"]
  138. B = tensor.shape[0] // T
  139. rs = torch.reshape(tensor, [B, T] + list(tensor.shape[1:]))
  140. # Swap B and T axes.
  141. res = torch.transpose(rs, 1, 0)
  142. return res
  143. class VTraceOptimizer:
  144. """Optimizer function for VTrace torch policies."""
  145. def __init__(self):
  146. pass
  147. @override(TorchPolicyV2)
  148. def optimizer(
  149. self,
  150. ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
  151. if self.config["opt_type"] == "adam":
  152. return torch.optim.Adam(params=self.model.parameters(), lr=self.cur_lr)
  153. else:
  154. return torch.optim.RMSprop(
  155. params=self.model.parameters(),
  156. lr=self.cur_lr,
  157. weight_decay=self.config["decay"],
  158. momentum=self.config["momentum"],
  159. eps=self.config["epsilon"],
  160. )
  161. # VTrace mixins are placed in front of more general mixins to make sure
  162. # their functions like optimizer() overrides all the other implementations
  163. # (e.g., LearningRateSchedule.optimizer())
  164. class ImpalaTorchPolicy(
  165. VTraceOptimizer,
  166. LearningRateSchedule,
  167. EntropyCoeffSchedule,
  168. ValueNetworkMixin,
  169. TorchPolicyV2,
  170. ):
  171. """PyTorch policy class used with Impala."""
  172. def __init__(self, observation_space, action_space, config):
  173. config = dict(
  174. ray.rllib.algorithms.impala.impala.ImpalaConfig().to_dict(), **config
  175. )
  176. # If Learner API is used, we don't need any loss-specific mixins.
  177. # However, we also would like to avoid creating special Policy-subclasses
  178. # for this as the entire Policy concept will soon not be used anymore with
  179. # the new Learner- and RLModule APIs.
  180. if not config.get("_enable_learner_api"):
  181. VTraceOptimizer.__init__(self)
  182. # Need to initialize learning rate variable before calling
  183. # TorchPolicyV2.__init__.
  184. LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])
  185. EntropyCoeffSchedule.__init__(
  186. self, config["entropy_coeff"], config["entropy_coeff_schedule"]
  187. )
  188. TorchPolicyV2.__init__(
  189. self,
  190. observation_space,
  191. action_space,
  192. config,
  193. max_seq_len=config["model"]["max_seq_len"],
  194. )
  195. ValueNetworkMixin.__init__(self, config)
  196. self._initialize_loss_from_dummy_batch()
  197. @override(TorchPolicyV2)
  198. def loss(
  199. self,
  200. model: ModelV2,
  201. dist_class: Type[ActionDistribution],
  202. train_batch: SampleBatch,
  203. ) -> Union[TensorType, List[TensorType]]:
  204. model_out, _ = model(train_batch)
  205. action_dist = dist_class(model_out, model)
  206. if isinstance(self.action_space, gym.spaces.Discrete):
  207. is_multidiscrete = False
  208. output_hidden_shape = [self.action_space.n]
  209. elif isinstance(self.action_space, gym.spaces.MultiDiscrete):
  210. is_multidiscrete = True
  211. output_hidden_shape = self.action_space.nvec.astype(np.int32)
  212. else:
  213. is_multidiscrete = False
  214. output_hidden_shape = 1
  215. def _make_time_major(*args, **kw):
  216. return make_time_major(
  217. self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw
  218. )
  219. actions = train_batch[SampleBatch.ACTIONS]
  220. dones = train_batch[SampleBatch.TERMINATEDS]
  221. rewards = train_batch[SampleBatch.REWARDS]
  222. behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
  223. behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
  224. if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
  225. unpacked_behaviour_logits = torch.split(
  226. behaviour_logits, list(output_hidden_shape), dim=1
  227. )
  228. unpacked_outputs = torch.split(model_out, list(output_hidden_shape), dim=1)
  229. else:
  230. unpacked_behaviour_logits = torch.chunk(
  231. behaviour_logits, output_hidden_shape, dim=1
  232. )
  233. unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1)
  234. values = model.value_function()
  235. values_time_major = _make_time_major(values)
  236. bootstrap_values_time_major = _make_time_major(
  237. train_batch[SampleBatch.VALUES_BOOTSTRAPPED]
  238. )
  239. bootstrap_value = bootstrap_values_time_major[-1]
  240. if self.is_recurrent():
  241. max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
  242. mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
  243. mask = torch.reshape(mask_orig, [-1])
  244. else:
  245. mask = torch.ones_like(rewards)
  246. # Prepare actions for loss.
  247. loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1)
  248. # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc.
  249. loss = VTraceLoss(
  250. actions=_make_time_major(loss_actions),
  251. actions_logp=_make_time_major(action_dist.logp(actions)),
  252. actions_entropy=_make_time_major(action_dist.entropy()),
  253. dones=_make_time_major(dones),
  254. behaviour_action_logp=_make_time_major(behaviour_action_logp),
  255. behaviour_logits=_make_time_major(unpacked_behaviour_logits),
  256. target_logits=_make_time_major(unpacked_outputs),
  257. discount=self.config["gamma"],
  258. rewards=_make_time_major(rewards),
  259. values=values_time_major,
  260. bootstrap_value=bootstrap_value,
  261. dist_class=TorchCategorical if is_multidiscrete else dist_class,
  262. model=model,
  263. valid_mask=_make_time_major(mask),
  264. config=self.config,
  265. vf_loss_coeff=self.config["vf_loss_coeff"],
  266. entropy_coeff=self.entropy_coeff,
  267. clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
  268. clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"],
  269. )
  270. # Store values for stats function in model (tower), such that for
  271. # multi-GPU, we do not override them during the parallel loss phase.
  272. model.tower_stats["pi_loss"] = loss.pi_loss
  273. model.tower_stats["vf_loss"] = loss.vf_loss
  274. model.tower_stats["entropy"] = loss.entropy
  275. model.tower_stats["mean_entropy"] = loss.mean_entropy
  276. model.tower_stats["total_loss"] = loss.total_loss
  277. values_batched = make_time_major(
  278. self,
  279. train_batch.get(SampleBatch.SEQ_LENS),
  280. values,
  281. )
  282. model.tower_stats["vf_explained_var"] = explained_variance(
  283. torch.reshape(loss.value_targets, [-1]), torch.reshape(values_batched, [-1])
  284. )
  285. return loss.total_loss
  286. @override(TorchPolicyV2)
  287. def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
  288. return convert_to_numpy(
  289. {
  290. "cur_lr": self.cur_lr,
  291. "total_loss": torch.mean(
  292. torch.stack(self.get_tower_stats("total_loss"))
  293. ),
  294. "policy_loss": torch.mean(torch.stack(self.get_tower_stats("pi_loss"))),
  295. "entropy": torch.mean(
  296. torch.stack(self.get_tower_stats("mean_entropy"))
  297. ),
  298. "entropy_coeff": self.entropy_coeff,
  299. "var_gnorm": global_norm(self.model.trainable_variables()),
  300. "vf_loss": torch.mean(torch.stack(self.get_tower_stats("vf_loss"))),
  301. "vf_explained_var": torch.mean(
  302. torch.stack(self.get_tower_stats("vf_explained_var"))
  303. ),
  304. }
  305. )
  306. @override(TorchPolicyV2)
  307. def postprocess_trajectory(
  308. self,
  309. sample_batch: SampleBatch,
  310. other_agent_batches: Optional[SampleBatch] = None,
  311. episode: Optional["Episode"] = None,
  312. ):
  313. # Call super's postprocess_trajectory first.
  314. # sample_batch = super().postprocess_trajectory(
  315. # sample_batch, other_agent_batches, episode
  316. # )
  317. if self.config["vtrace"]:
  318. # Add the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need
  319. # inside the loss for vtrace calculations.
  320. sample_batch = compute_bootstrap_value(sample_batch, self)
  321. return sample_batch
  322. @override(TorchPolicyV2)
  323. def extra_grad_process(
  324. self, optimizer: "torch.optim.Optimizer", loss: TensorType
  325. ) -> Dict[str, TensorType]:
  326. return apply_grad_clipping(self, optimizer, loss)
  327. @override(TorchPolicyV2)
  328. def get_batch_divisibility_req(self) -> int:
  329. return self.config["rollout_fragment_length"]