vtrace_torch.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright 2018 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # https://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch version of the functions to compute V-trace off-policy actor critic
  15. targets.
  16. For details and theory see:
  17. "IMPALA: Scalable Distributed Deep-RL with
  18. Importance Weighted Actor-Learner Architectures"
  19. by Espeholt, Soyer, Munos et al.
  20. See https://arxiv.org/abs/1802.01561 for the full paper.
  21. In addition to the original paper's code, changes have been made
  22. to support MultiDiscrete action spaces. behaviour_policy_logits,
  23. target_policy_logits and actions parameters in the entry point
  24. multi_from_logits method accepts lists of tensors instead of just
  25. tensors.
  26. """
  27. from ray.rllib.agents.impala.vtrace_tf import VTraceFromLogitsReturns, \
  28. VTraceReturns
  29. from ray.rllib.models.torch.torch_action_dist import TorchCategorical
  30. from ray.rllib.utils import force_list
  31. from ray.rllib.utils.framework import try_import_torch
  32. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  33. torch, nn = try_import_torch()
  34. def log_probs_from_logits_and_actions(policy_logits,
  35. actions,
  36. dist_class=TorchCategorical,
  37. model=None):
  38. return multi_log_probs_from_logits_and_actions([policy_logits], [actions],
  39. dist_class, model)[0]
  40. def multi_log_probs_from_logits_and_actions(policy_logits, actions, dist_class,
  41. model):
  42. """Computes action log-probs from policy logits and actions.
  43. In the notation used throughout documentation and comments, T refers to the
  44. time dimension ranging from 0 to T-1. B refers to the batch size and
  45. ACTION_SPACE refers to the list of numbers each representing a number of
  46. actions.
  47. Args:
  48. policy_logits: A list with length of ACTION_SPACE of float32
  49. tensors of shapes [T, B, ACTION_SPACE[0]], ...,
  50. [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
  51. parameterizing a softmax policy.
  52. actions: A list with length of ACTION_SPACE of tensors of shapes
  53. [T, B, ...], ..., [T, B, ...]
  54. with actions.
  55. dist_class: Python class of the action distribution.
  56. Returns:
  57. A list with length of ACTION_SPACE of float32 tensors of shapes
  58. [T, B], ..., [T, B] corresponding to the sampling log probability
  59. of the chosen action w.r.t. the policy.
  60. """
  61. log_probs = []
  62. for i in range(len(policy_logits)):
  63. p_shape = policy_logits[i].shape
  64. a_shape = actions[i].shape
  65. policy_logits_flat = torch.reshape(policy_logits[i],
  66. (-1, ) + tuple(p_shape[2:]))
  67. actions_flat = torch.reshape(actions[i], (-1, ) + tuple(a_shape[2:]))
  68. log_probs.append(
  69. torch.reshape(
  70. dist_class(policy_logits_flat, model).logp(actions_flat),
  71. a_shape[:2]))
  72. return log_probs
  73. def from_logits(behaviour_policy_logits,
  74. target_policy_logits,
  75. actions,
  76. discounts,
  77. rewards,
  78. values,
  79. bootstrap_value,
  80. dist_class=TorchCategorical,
  81. model=None,
  82. clip_rho_threshold=1.0,
  83. clip_pg_rho_threshold=1.0):
  84. """multi_from_logits wrapper used only for tests"""
  85. res = multi_from_logits(
  86. [behaviour_policy_logits], [target_policy_logits], [actions],
  87. discounts,
  88. rewards,
  89. values,
  90. bootstrap_value,
  91. dist_class,
  92. model,
  93. clip_rho_threshold=clip_rho_threshold,
  94. clip_pg_rho_threshold=clip_pg_rho_threshold)
  95. assert len(res.behaviour_action_log_probs) == 1
  96. assert len(res.target_action_log_probs) == 1
  97. return VTraceFromLogitsReturns(
  98. vs=res.vs,
  99. pg_advantages=res.pg_advantages,
  100. log_rhos=res.log_rhos,
  101. behaviour_action_log_probs=res.behaviour_action_log_probs[0],
  102. target_action_log_probs=res.target_action_log_probs[0],
  103. )
  104. def multi_from_logits(behaviour_policy_logits,
  105. target_policy_logits,
  106. actions,
  107. discounts,
  108. rewards,
  109. values,
  110. bootstrap_value,
  111. dist_class,
  112. model,
  113. behaviour_action_log_probs=None,
  114. clip_rho_threshold=1.0,
  115. clip_pg_rho_threshold=1.0):
  116. """V-trace for softmax policies.
  117. Calculates V-trace actor critic targets for softmax polices as described in
  118. "IMPALA: Scalable Distributed Deep-RL with
  119. Importance Weighted Actor-Learner Architectures"
  120. by Espeholt, Soyer, Munos et al.
  121. Target policy refers to the policy we are interested in improving and
  122. behaviour policy refers to the policy that generated the given
  123. rewards and actions.
  124. In the notation used throughout documentation and comments, T refers to the
  125. time dimension ranging from 0 to T-1. B refers to the batch size and
  126. ACTION_SPACE refers to the list of numbers each representing a number of
  127. actions.
  128. Args:
  129. behaviour_policy_logits: A list with length of ACTION_SPACE of float32
  130. tensors of shapes [T, B, ACTION_SPACE[0]], ...,
  131. [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
  132. parameterizing the softmax behavior policy.
  133. target_policy_logits: A list with length of ACTION_SPACE of float32
  134. tensors of shapes [T, B, ACTION_SPACE[0]], ...,
  135. [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
  136. parameterizing the softmax target policy.
  137. actions: A list with length of ACTION_SPACE of tensors of shapes
  138. [T, B, ...], ..., [T, B, ...]
  139. with actions sampled from the behavior policy.
  140. discounts: A float32 tensor of shape [T, B] with the discount
  141. encountered when following the behavior policy.
  142. rewards: A float32 tensor of shape [T, B] with the rewards generated by
  143. following the behavior policy.
  144. values: A float32 tensor of shape [T, B] with the value function
  145. estimates wrt. the target policy.
  146. bootstrap_value: A float32 of shape [B] with the value function
  147. estimate at time T.
  148. dist_class: action distribution class for the logits.
  149. model: backing ModelV2 instance
  150. behaviour_action_log_probs: Precalculated values of the behavior
  151. actions.
  152. clip_rho_threshold: A scalar float32 tensor with the clipping threshold
  153. for importance weights (rho) when calculating the baseline targets
  154. (vs). rho^bar in the paper.
  155. clip_pg_rho_threshold: A scalar float32 tensor with the clipping
  156. threshold on rho_s in:
  157. \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
  158. Returns:
  159. A `VTraceFromLogitsReturns` namedtuple with the following fields:
  160. vs: A float32 tensor of shape [T, B]. Can be used as target to train a
  161. baseline (V(x_t) - vs_t)^2.
  162. pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an
  163. estimate of the advantage in the calculation of policy gradients.
  164. log_rhos: A float32 tensor of shape [T, B] containing the log
  165. importance sampling weights (log rhos).
  166. behaviour_action_log_probs: A float32 tensor of shape [T, B] containing
  167. behaviour policy action log probabilities (log \mu(a_t)).
  168. target_action_log_probs: A float32 tensor of shape [T, B] containing
  169. target policy action probabilities (log \pi(a_t)).
  170. """
  171. behaviour_policy_logits = convert_to_torch_tensor(
  172. behaviour_policy_logits, device="cpu")
  173. target_policy_logits = convert_to_torch_tensor(
  174. target_policy_logits, device="cpu")
  175. actions = convert_to_torch_tensor(actions, device="cpu")
  176. # Make sure tensor ranks are as expected.
  177. # The rest will be checked by from_action_log_probs.
  178. for i in range(len(behaviour_policy_logits)):
  179. assert len(behaviour_policy_logits[i].size()) == 3
  180. assert len(target_policy_logits[i].size()) == 3
  181. target_action_log_probs = multi_log_probs_from_logits_and_actions(
  182. target_policy_logits, actions, dist_class, model)
  183. if (len(behaviour_policy_logits) > 1
  184. or behaviour_action_log_probs is None):
  185. # can't use precalculated values, recompute them. Note that
  186. # recomputing won't work well for autoregressive action dists
  187. # which may have variables not captured by 'logits'
  188. behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
  189. behaviour_policy_logits, actions, dist_class, model)
  190. behaviour_action_log_probs = convert_to_torch_tensor(
  191. behaviour_action_log_probs, device="cpu")
  192. behaviour_action_log_probs = force_list(behaviour_action_log_probs)
  193. log_rhos = get_log_rhos(target_action_log_probs,
  194. behaviour_action_log_probs)
  195. vtrace_returns = from_importance_weights(
  196. log_rhos=log_rhos,
  197. discounts=discounts,
  198. rewards=rewards,
  199. values=values,
  200. bootstrap_value=bootstrap_value,
  201. clip_rho_threshold=clip_rho_threshold,
  202. clip_pg_rho_threshold=clip_pg_rho_threshold)
  203. return VTraceFromLogitsReturns(
  204. log_rhos=log_rhos,
  205. behaviour_action_log_probs=behaviour_action_log_probs,
  206. target_action_log_probs=target_action_log_probs,
  207. **vtrace_returns._asdict())
  208. def from_importance_weights(log_rhos,
  209. discounts,
  210. rewards,
  211. values,
  212. bootstrap_value,
  213. clip_rho_threshold=1.0,
  214. clip_pg_rho_threshold=1.0):
  215. """V-trace from log importance weights.
  216. Calculates V-trace actor critic targets as described in
  217. "IMPALA: Scalable Distributed Deep-RL with
  218. Importance Weighted Actor-Learner Architectures"
  219. by Espeholt, Soyer, Munos et al.
  220. In the notation used throughout documentation and comments, T refers to the
  221. time dimension ranging from 0 to T-1. B refers to the batch size. This code
  222. also supports the case where all tensors have the same number of additional
  223. dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C],
  224. `bootstrap_value` is [B, C].
  225. Args:
  226. log_rhos: A float32 tensor of shape [T, B] representing the log
  227. importance sampling weights, i.e.
  228. log(target_policy(a) / behaviour_policy(a)). V-trace performs
  229. operations on rhos in log-space for numerical stability.
  230. discounts: A float32 tensor of shape [T, B] with discounts encountered
  231. when following the behaviour policy.
  232. rewards: A float32 tensor of shape [T, B] containing rewards generated
  233. by following the behaviour policy.
  234. values: A float32 tensor of shape [T, B] with the value function
  235. estimates wrt. the target policy.
  236. bootstrap_value: A float32 of shape [B] with the value function
  237. estimate at time T.
  238. clip_rho_threshold: A scalar float32 tensor with the clipping threshold
  239. for importance weights (rho) when calculating the baseline targets
  240. (vs). rho^bar in the paper. If None, no clipping is applied.
  241. clip_pg_rho_threshold: A scalar float32 tensor with the clipping
  242. threshold on rho_s in
  243. \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
  244. If None, no clipping is applied.
  245. Returns:
  246. A VTraceReturns namedtuple (vs, pg_advantages) where:
  247. vs: A float32 tensor of shape [T, B]. Can be used as target to
  248. train a baseline (V(x_t) - vs_t)^2.
  249. pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
  250. advantage in the calculation of policy gradients.
  251. """
  252. log_rhos = convert_to_torch_tensor(log_rhos, device="cpu")
  253. discounts = convert_to_torch_tensor(discounts, device="cpu")
  254. rewards = convert_to_torch_tensor(rewards, device="cpu")
  255. values = convert_to_torch_tensor(values, device="cpu")
  256. bootstrap_value = convert_to_torch_tensor(bootstrap_value, device="cpu")
  257. # Make sure tensor ranks are consistent.
  258. rho_rank = len(log_rhos.size()) # Usually 2.
  259. assert rho_rank == len(values.size())
  260. assert rho_rank - 1 == len(bootstrap_value.size()),\
  261. "must have rank {}".format(rho_rank - 1)
  262. assert rho_rank == len(discounts.size())
  263. assert rho_rank == len(rewards.size())
  264. rhos = torch.exp(log_rhos)
  265. if clip_rho_threshold is not None:
  266. clipped_rhos = torch.clamp_max(rhos, clip_rho_threshold)
  267. else:
  268. clipped_rhos = rhos
  269. cs = torch.clamp_max(rhos, 1.0)
  270. # Append bootstrapped value to get [v1, ..., v_t+1]
  271. values_t_plus_1 = torch.cat(
  272. [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
  273. deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
  274. vs_minus_v_xs = [torch.zeros_like(bootstrap_value)]
  275. for i in reversed(range(len(discounts))):
  276. discount_t, c_t, delta_t = discounts[i], cs[i], deltas[i]
  277. vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1])
  278. vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:])
  279. # Reverse the results back to original order.
  280. vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[0])
  281. # Add V(x_s) to get v_s.
  282. vs = vs_minus_v_xs + values
  283. # Advantage for policy gradient.
  284. vs_t_plus_1 = torch.cat(
  285. [vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
  286. if clip_pg_rho_threshold is not None:
  287. clipped_pg_rhos = torch.clamp_max(rhos, clip_pg_rho_threshold)
  288. else:
  289. clipped_pg_rhos = rhos
  290. pg_advantages = (
  291. clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values))
  292. # Make sure no gradients backpropagated through the returned values.
  293. return VTraceReturns(vs=vs.detach(), pg_advantages=pg_advantages.detach())
  294. def get_log_rhos(target_action_log_probs, behaviour_action_log_probs):
  295. """With the selected log_probs for multi-discrete actions of behavior
  296. and target policies we compute the log_rhos for calculating the vtrace."""
  297. t = torch.stack(target_action_log_probs)
  298. b = torch.stack(behaviour_action_log_probs)
  299. log_rhos = torch.sum(t - b, dim=0)
  300. return log_rhos