vtrace_tf.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # Copyright 2018 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # https://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Functions to compute V-trace off-policy actor critic targets.
  15. For details and theory see:
  16. "IMPALA: Scalable Distributed Deep-RL with
  17. Importance Weighted Actor-Learner Architectures"
  18. by Espeholt, Soyer, Munos et al.
  19. See https://arxiv.org/abs/1802.01561 for the full paper.
  20. In addition to the original paper's code, changes have been made
  21. to support MultiDiscrete action spaces. behaviour_policy_logits,
  22. target_policy_logits and actions parameters in the entry point
  23. multi_from_logits method accepts lists of tensors instead of just
  24. tensors.
  25. """
  26. import collections
  27. from ray.rllib.models.tf.tf_action_dist import Categorical
  28. from ray.rllib.utils.framework import try_import_tf
  29. tf1, tf, tfv = try_import_tf()
  30. VTraceFromLogitsReturns = collections.namedtuple(
  31. "VTraceFromLogitsReturns",
  32. [
  33. "vs",
  34. "pg_advantages",
  35. "log_rhos",
  36. "behaviour_action_log_probs",
  37. "target_action_log_probs",
  38. ],
  39. )
  40. VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages")
  41. def log_probs_from_logits_and_actions(
  42. policy_logits, actions, dist_class=Categorical, model=None
  43. ):
  44. return multi_log_probs_from_logits_and_actions(
  45. [policy_logits], [actions], dist_class, model
  46. )[0]
  47. def multi_log_probs_from_logits_and_actions(policy_logits, actions, dist_class, model):
  48. """Computes action log-probs from policy logits and actions.
  49. In the notation used throughout documentation and comments, T refers to the
  50. time dimension ranging from 0 to T-1. B refers to the batch size and
  51. ACTION_SPACE refers to the list of numbers each representing a number of
  52. actions.
  53. Args:
  54. policy_logits: A list with length of ACTION_SPACE of float32
  55. tensors of shapes [T, B, ACTION_SPACE[0]], ...,
  56. [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
  57. parameterizing a softmax policy.
  58. actions: A list with length of ACTION_SPACE of tensors of shapes
  59. [T, B, ...], ..., [T, B, ...]
  60. with actions.
  61. dist_class: Python class of the action distribution.
  62. Returns:
  63. A list with length of ACTION_SPACE of float32 tensors of shapes
  64. [T, B], ..., [T, B] corresponding to the sampling log probability
  65. of the chosen action w.r.t. the policy.
  66. """
  67. log_probs = []
  68. for i in range(len(policy_logits)):
  69. p_shape = tf.shape(policy_logits[i])
  70. a_shape = tf.shape(actions[i])
  71. policy_logits_flat = tf.reshape(
  72. policy_logits[i], tf.concat([[-1], p_shape[2:]], axis=0)
  73. )
  74. actions_flat = tf.reshape(actions[i], tf.concat([[-1], a_shape[2:]], axis=0))
  75. log_probs.append(
  76. tf.reshape(
  77. dist_class(policy_logits_flat, model).logp(actions_flat), a_shape[:2]
  78. )
  79. )
  80. return log_probs
  81. def from_logits(
  82. behaviour_policy_logits,
  83. target_policy_logits,
  84. actions,
  85. discounts,
  86. rewards,
  87. values,
  88. bootstrap_value,
  89. dist_class=Categorical,
  90. model=None,
  91. clip_rho_threshold=1.0,
  92. clip_pg_rho_threshold=1.0,
  93. name="vtrace_from_logits",
  94. ):
  95. """multi_from_logits wrapper used only for tests"""
  96. res = multi_from_logits(
  97. [behaviour_policy_logits],
  98. [target_policy_logits],
  99. [actions],
  100. discounts,
  101. rewards,
  102. values,
  103. bootstrap_value,
  104. dist_class,
  105. model,
  106. clip_rho_threshold=clip_rho_threshold,
  107. clip_pg_rho_threshold=clip_pg_rho_threshold,
  108. name=name,
  109. )
  110. return VTraceFromLogitsReturns(
  111. vs=res.vs,
  112. pg_advantages=res.pg_advantages,
  113. log_rhos=res.log_rhos,
  114. behaviour_action_log_probs=tf.squeeze(res.behaviour_action_log_probs, axis=0),
  115. target_action_log_probs=tf.squeeze(res.target_action_log_probs, axis=0),
  116. )
  117. def multi_from_logits(
  118. behaviour_policy_logits,
  119. target_policy_logits,
  120. actions,
  121. discounts,
  122. rewards,
  123. values,
  124. bootstrap_value,
  125. dist_class,
  126. model,
  127. behaviour_action_log_probs=None,
  128. clip_rho_threshold=1.0,
  129. clip_pg_rho_threshold=1.0,
  130. name="vtrace_from_logits",
  131. ):
  132. r"""V-trace for softmax policies.
  133. Calculates V-trace actor critic targets for softmax polices as described in
  134. "IMPALA: Scalable Distributed Deep-RL with
  135. Importance Weighted Actor-Learner Architectures"
  136. by Espeholt, Soyer, Munos et al.
  137. Target policy refers to the policy we are interested in improving and
  138. behaviour policy refers to the policy that generated the given
  139. rewards and actions.
  140. In the notation used throughout documentation and comments, T refers to the
  141. time dimension ranging from 0 to T-1. B refers to the batch size and
  142. ACTION_SPACE refers to the list of numbers each representing a number of
  143. actions.
  144. Args:
  145. behaviour_policy_logits: A list with length of ACTION_SPACE of float32
  146. tensors of shapes
  147. [T, B, ACTION_SPACE[0]],
  148. ...,
  149. [T, B, ACTION_SPACE[-1]]
  150. with un-normalized log-probabilities parameterizing the softmax behaviour
  151. policy.
  152. target_policy_logits: A list with length of ACTION_SPACE of float32
  153. tensors of shapes
  154. [T, B, ACTION_SPACE[0]],
  155. ...,
  156. [T, B, ACTION_SPACE[-1]]
  157. with un-normalized log-probabilities parameterizing the softmax target
  158. policy.
  159. actions: A list with length of ACTION_SPACE of
  160. tensors of shapes
  161. [T, B, ...],
  162. ...,
  163. [T, B, ...]
  164. with actions sampled from the behaviour policy.
  165. discounts: A float32 tensor of shape [T, B] with the discount encountered
  166. when following the behaviour policy.
  167. rewards: A float32 tensor of shape [T, B] with the rewards generated by
  168. following the behaviour policy.
  169. values: A float32 tensor of shape [T, B] with the value function estimates
  170. wrt. the target policy.
  171. bootstrap_value: A float32 of shape [B] with the value function estimate at
  172. time T.
  173. dist_class: action distribution class for the logits.
  174. model: backing ModelV2 instance
  175. behaviour_action_log_probs: precalculated values of the behaviour actions
  176. clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
  177. importance weights (rho) when calculating the baseline targets (vs).
  178. rho^bar in the paper.
  179. clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
  180. on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
  181. name: The name scope that all V-trace operations will be created in.
  182. Returns:
  183. A `VTraceFromLogitsReturns` namedtuple with the following fields:
  184. vs: A float32 tensor of shape [T, B]. Can be used as target to train a
  185. baseline (V(x_t) - vs_t)^2.
  186. pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an
  187. estimate of the advantage in the calculation of policy gradients.
  188. log_rhos: A float32 tensor of shape [T, B] containing the log importance
  189. sampling weights (log rhos).
  190. behaviour_action_log_probs: A float32 tensor of shape [T, B] containing
  191. behaviour policy action log probabilities (log \mu(a_t)).
  192. target_action_log_probs: A float32 tensor of shape [T, B] containing
  193. target policy action probabilities (log \pi(a_t)).
  194. """
  195. for i in range(len(behaviour_policy_logits)):
  196. behaviour_policy_logits[i] = tf.convert_to_tensor(
  197. behaviour_policy_logits[i], dtype=tf.float32
  198. )
  199. target_policy_logits[i] = tf.convert_to_tensor(
  200. target_policy_logits[i], dtype=tf.float32
  201. )
  202. # Make sure tensor ranks are as expected.
  203. # The rest will be checked by from_action_log_probs.
  204. behaviour_policy_logits[i].shape.assert_has_rank(3)
  205. target_policy_logits[i].shape.assert_has_rank(3)
  206. with tf1.name_scope(
  207. name,
  208. values=[
  209. behaviour_policy_logits,
  210. target_policy_logits,
  211. actions,
  212. discounts,
  213. rewards,
  214. values,
  215. bootstrap_value,
  216. ],
  217. ):
  218. target_action_log_probs = multi_log_probs_from_logits_and_actions(
  219. target_policy_logits, actions, dist_class, model
  220. )
  221. if len(behaviour_policy_logits) > 1 or behaviour_action_log_probs is None:
  222. # can't use precalculated values, recompute them. Note that
  223. # recomputing won't work well for autoregressive action dists
  224. # which may have variables not captured by 'logits'
  225. behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
  226. behaviour_policy_logits, actions, dist_class, model
  227. )
  228. log_rhos = get_log_rhos(target_action_log_probs, behaviour_action_log_probs)
  229. vtrace_returns = from_importance_weights(
  230. log_rhos=log_rhos,
  231. discounts=discounts,
  232. rewards=rewards,
  233. values=values,
  234. bootstrap_value=bootstrap_value,
  235. clip_rho_threshold=clip_rho_threshold,
  236. clip_pg_rho_threshold=clip_pg_rho_threshold,
  237. )
  238. return VTraceFromLogitsReturns(
  239. log_rhos=log_rhos,
  240. behaviour_action_log_probs=behaviour_action_log_probs,
  241. target_action_log_probs=target_action_log_probs,
  242. **vtrace_returns._asdict()
  243. )
  244. def from_importance_weights(
  245. log_rhos,
  246. discounts,
  247. rewards,
  248. values,
  249. bootstrap_value,
  250. clip_rho_threshold=1.0,
  251. clip_pg_rho_threshold=1.0,
  252. name="vtrace_from_importance_weights",
  253. ):
  254. r"""V-trace from log importance weights.
  255. Calculates V-trace actor critic targets as described in
  256. "IMPALA: Scalable Distributed Deep-RL with
  257. Importance Weighted Actor-Learner Architectures"
  258. by Espeholt, Soyer, Munos et al.
  259. In the notation used throughout documentation and comments, T refers to the
  260. time dimension ranging from 0 to T-1. B refers to the batch size. This code
  261. also supports the case where all tensors have the same number of additional
  262. dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C],
  263. `bootstrap_value` is [B, C].
  264. Args:
  265. log_rhos: A float32 tensor of shape [T, B] representing the
  266. log importance sampling weights, i.e.
  267. log(target_policy(a) / behaviour_policy(a)). V-trace performs operations
  268. on rhos in log-space for numerical stability.
  269. discounts: A float32 tensor of shape [T, B] with discounts encountered when
  270. following the behaviour policy.
  271. rewards: A float32 tensor of shape [T, B] containing rewards generated by
  272. following the behaviour policy.
  273. values: A float32 tensor of shape [T, B] with the value function estimates
  274. wrt. the target policy.
  275. bootstrap_value: A float32 of shape [B] with the value function estimate at
  276. time T.
  277. clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
  278. importance weights (rho) when calculating the baseline targets (vs).
  279. rho^bar in the paper. If None, no clipping is applied.
  280. clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
  281. on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). If
  282. None, no clipping is applied.
  283. name: The name scope that all V-trace operations will be created in.
  284. Returns:
  285. A VTraceReturns namedtuple (vs, pg_advantages) where:
  286. vs: A float32 tensor of shape [T, B]. Can be used as target to
  287. train a baseline (V(x_t) - vs_t)^2.
  288. pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
  289. advantage in the calculation of policy gradients.
  290. """
  291. log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32)
  292. discounts = tf.convert_to_tensor(discounts, dtype=tf.float32)
  293. rewards = tf.convert_to_tensor(rewards, dtype=tf.float32)
  294. values = tf.convert_to_tensor(values, dtype=tf.float32)
  295. bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32)
  296. if clip_rho_threshold is not None:
  297. clip_rho_threshold = tf.convert_to_tensor(clip_rho_threshold, dtype=tf.float32)
  298. if clip_pg_rho_threshold is not None:
  299. clip_pg_rho_threshold = tf.convert_to_tensor(
  300. clip_pg_rho_threshold, dtype=tf.float32
  301. )
  302. # Make sure tensor ranks are consistent.
  303. rho_rank = log_rhos.shape.ndims # Usually 2.
  304. values.shape.assert_has_rank(rho_rank)
  305. bootstrap_value.shape.assert_has_rank(rho_rank - 1)
  306. discounts.shape.assert_has_rank(rho_rank)
  307. rewards.shape.assert_has_rank(rho_rank)
  308. if clip_rho_threshold is not None:
  309. clip_rho_threshold.shape.assert_has_rank(0)
  310. if clip_pg_rho_threshold is not None:
  311. clip_pg_rho_threshold.shape.assert_has_rank(0)
  312. with tf1.name_scope(
  313. name, values=[log_rhos, discounts, rewards, values, bootstrap_value]
  314. ):
  315. rhos = tf.math.exp(log_rhos)
  316. if clip_rho_threshold is not None:
  317. clipped_rhos = tf.minimum(clip_rho_threshold, rhos, name="clipped_rhos")
  318. else:
  319. clipped_rhos = rhos
  320. cs = tf.minimum(1.0, rhos, name="cs")
  321. # Append bootstrapped value to get [v1, ..., v_t+1]
  322. values_t_plus_1 = tf.concat(
  323. [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0
  324. )
  325. deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
  326. # All sequences are reversed, computation starts from the back.
  327. sequences = (
  328. tf.reverse(discounts, axis=[0]),
  329. tf.reverse(cs, axis=[0]),
  330. tf.reverse(deltas, axis=[0]),
  331. )
  332. # V-trace vs are calculated through a scan from the back to the
  333. # beginning of the given trajectory.
  334. def scanfunc(acc, sequence_item):
  335. discount_t, c_t, delta_t = sequence_item
  336. return delta_t + discount_t * c_t * acc
  337. initial_values = tf.zeros_like(bootstrap_value)
  338. vs_minus_v_xs = tf.nest.map_structure(
  339. tf.stop_gradient,
  340. tf.scan(
  341. fn=scanfunc,
  342. elems=sequences,
  343. initializer=initial_values,
  344. parallel_iterations=1,
  345. name="scan",
  346. ),
  347. )
  348. # Reverse the results back to original order.
  349. vs_minus_v_xs = tf.reverse(vs_minus_v_xs, [0], name="vs_minus_v_xs")
  350. # Add V(x_s) to get v_s.
  351. vs = tf.add(vs_minus_v_xs, values, name="vs")
  352. # Advantage for policy gradient.
  353. vs_t_plus_1 = tf.concat([vs[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
  354. if clip_pg_rho_threshold is not None:
  355. clipped_pg_rhos = tf.minimum(
  356. clip_pg_rho_threshold, rhos, name="clipped_pg_rhos"
  357. )
  358. else:
  359. clipped_pg_rhos = rhos
  360. pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)
  361. # Make sure no gradients backpropagated through the returned values.
  362. return VTraceReturns(
  363. vs=tf.stop_gradient(vs), pg_advantages=tf.stop_gradient(pg_advantages)
  364. )
  365. def get_log_rhos(target_action_log_probs, behaviour_action_log_probs):
  366. """With the selected log_probs for multi-discrete actions of behaviour
  367. and target policies we compute the log_rhos for calculating the vtrace."""
  368. t = tf.stack(target_action_log_probs)
  369. b = tf.stack(behaviour_action_log_probs)
  370. log_rhos = tf.reduce_sum(t - b, axis=0)
  371. return log_rhos