postprocessing.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import numpy as np
  2. import scipy.signal
  3. from typing import Dict, Optional
  4. from ray.rllib.evaluation.episode import Episode
  5. from ray.rllib.policy.policy import Policy
  6. from ray.rllib.policy.sample_batch import SampleBatch
  7. from ray.rllib.utils.annotations import DeveloperAPI
  8. from ray.rllib.utils.typing import AgentID
  9. class Postprocessing:
  10. """Constant definitions for postprocessing."""
  11. ADVANTAGES = "advantages"
  12. VALUE_TARGETS = "value_targets"
  13. def adjust_nstep(n_step: int, gamma: float, batch: SampleBatch) -> None:
  14. """Rewrites `batch` to encode n-step rewards, dones, and next-obs.
  15. Observations and actions remain unaffected. At the end of the trajectory,
  16. n is truncated to fit in the traj length.
  17. Args:
  18. n_step: The number of steps to look ahead and adjust.
  19. gamma: The discount factor.
  20. batch: The SampleBatch to adjust (in place).
  21. Examples:
  22. n-step=3
  23. Trajectory=o0 r0 d0, o1 r1 d1, o2 r2 d2, o3 r3 d3, o4 r4 d4=True o5
  24. gamma=0.9
  25. Returned trajectory:
  26. 0: o0 [r0 + 0.9*r1 + 0.9^2*r2 + 0.9^3*r3] d3 o0'=o3
  27. 1: o1 [r1 + 0.9*r2 + 0.9^2*r3 + 0.9^3*r4] d4 o1'=o4
  28. 2: o2 [r2 + 0.9*r3 + 0.9^2*r4] d4 o1'=o5
  29. 3: o3 [r3 + 0.9*r4] d4 o3'=o5
  30. 4: o4 r4 d4 o4'=o5
  31. """
  32. assert not any(batch[SampleBatch.DONES][:-1]), \
  33. "Unexpected done in middle of trajectory!"
  34. len_ = len(batch)
  35. # Shift NEXT_OBS and DONES.
  36. batch[SampleBatch.NEXT_OBS] = np.concatenate(
  37. [
  38. batch[SampleBatch.OBS][n_step:],
  39. np.stack([batch[SampleBatch.NEXT_OBS][-1]] * min(n_step, len_))
  40. ],
  41. axis=0)
  42. batch[SampleBatch.DONES] = np.concatenate(
  43. [
  44. batch[SampleBatch.DONES][n_step - 1:],
  45. np.tile(batch[SampleBatch.DONES][-1], min(n_step - 1, len_))
  46. ],
  47. axis=0)
  48. # Change rewards in place.
  49. for i in range(len_):
  50. for j in range(1, n_step):
  51. if i + j < len_:
  52. batch[SampleBatch.REWARDS][i] += \
  53. gamma**j * batch[SampleBatch.REWARDS][i + j]
  54. @DeveloperAPI
  55. def compute_advantages(rollout: SampleBatch,
  56. last_r: float,
  57. gamma: float = 0.9,
  58. lambda_: float = 1.0,
  59. use_gae: bool = True,
  60. use_critic: bool = True):
  61. """Given a rollout, compute its value targets and the advantages.
  62. Args:
  63. rollout: SampleBatch of a single trajectory.
  64. last_r: Value estimation for last observation.
  65. gamma: Discount factor.
  66. lambda_: Parameter for GAE.
  67. use_gae: Using Generalized Advantage Estimation.
  68. use_critic: Whether to use critic (value estimates). Setting
  69. this to False will use 0 as baseline.
  70. Returns:
  71. SampleBatch with experience from rollout and processed rewards.
  72. """
  73. assert SampleBatch.VF_PREDS in rollout or not use_critic, \
  74. "use_critic=True but values not found"
  75. assert use_critic or not use_gae, \
  76. "Can't use gae without using a value function"
  77. if use_gae:
  78. vpred_t = np.concatenate(
  79. [rollout[SampleBatch.VF_PREDS],
  80. np.array([last_r])])
  81. delta_t = (
  82. rollout[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1])
  83. # This formula for the advantage comes from:
  84. # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
  85. rollout[Postprocessing.ADVANTAGES] = discount_cumsum(
  86. delta_t, gamma * lambda_)
  87. rollout[Postprocessing.VALUE_TARGETS] = (
  88. rollout[Postprocessing.ADVANTAGES] +
  89. rollout[SampleBatch.VF_PREDS]).astype(np.float32)
  90. else:
  91. rewards_plus_v = np.concatenate(
  92. [rollout[SampleBatch.REWARDS],
  93. np.array([last_r])])
  94. discounted_returns = discount_cumsum(rewards_plus_v,
  95. gamma)[:-1].astype(np.float32)
  96. if use_critic:
  97. rollout[Postprocessing.
  98. ADVANTAGES] = discounted_returns - rollout[SampleBatch.
  99. VF_PREDS]
  100. rollout[Postprocessing.VALUE_TARGETS] = discounted_returns
  101. else:
  102. rollout[Postprocessing.ADVANTAGES] = discounted_returns
  103. rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like(
  104. rollout[Postprocessing.ADVANTAGES])
  105. rollout[Postprocessing.ADVANTAGES] = rollout[
  106. Postprocessing.ADVANTAGES].astype(np.float32)
  107. return rollout
  108. def compute_gae_for_sample_batch(
  109. policy: Policy,
  110. sample_batch: SampleBatch,
  111. other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
  112. episode: Optional[Episode] = None) -> SampleBatch:
  113. """Adds GAE (generalized advantage estimations) to a trajectory.
  114. The trajectory contains only data from one episode and from one agent.
  115. - If `config.batch_mode=truncate_episodes` (default), sample_batch may
  116. contain a truncated (at-the-end) episode, in case the
  117. `config.rollout_fragment_length` was reached by the sampler.
  118. - If `config.batch_mode=complete_episodes`, sample_batch will contain
  119. exactly one episode (no matter how long).
  120. New columns can be added to sample_batch and existing ones may be altered.
  121. Args:
  122. policy: The Policy used to generate the trajectory (`sample_batch`)
  123. sample_batch: The SampleBatch to postprocess.
  124. other_agent_batches: Optional dict of AgentIDs mapping to other
  125. agents' trajectory data (from the same episode).
  126. NOTE: The other agents use the same policy.
  127. episode: Optional multi-agent episode object in which the agents
  128. operated.
  129. Returns:
  130. The postprocessed, modified SampleBatch (or a new one).
  131. """
  132. # Trajectory is actually complete -> last r=0.0.
  133. if sample_batch[SampleBatch.DONES][-1]:
  134. last_r = 0.0
  135. # Trajectory has been truncated -> last r=VF estimate of last obs.
  136. else:
  137. # Input dict is provided to us automatically via the Model's
  138. # requirements. It's a single-timestep (last one in trajectory)
  139. # input_dict.
  140. # Create an input dict according to the Model's requirements.
  141. input_dict = sample_batch.get_single_step_input_dict(
  142. policy.model.view_requirements, index="last")
  143. last_r = policy._value(**input_dict)
  144. # Adds the policy logits, VF preds, and advantages to the batch,
  145. # using GAE ("generalized advantage estimation") or not.
  146. batch = compute_advantages(
  147. sample_batch,
  148. last_r,
  149. policy.config["gamma"],
  150. policy.config["lambda"],
  151. use_gae=policy.config["use_gae"],
  152. use_critic=policy.config.get("use_critic", True))
  153. return batch
  154. def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
  155. """Calculates the discounted cumulative sum over a reward sequence `x`.
  156. y[t] - discount*y[t+1] = x[t]
  157. reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
  158. Args:
  159. gamma: The discount factor gamma.
  160. Returns:
  161. The sequence containing the discounted cumulative sums
  162. for each individual reward in `x` till the end of the trajectory.
  163. Examples:
  164. >>> x = np.array([0.0, 1.0, 2.0, 3.0])
  165. >>> gamma = 0.9
  166. >>> discount_cumsum(x, gamma)
  167. ... array([0.0 + 0.9*1.0 + 0.9^2*2.0 + 0.9^3*3.0,
  168. ... 1.0 + 0.9*2.0 + 0.9^2*3.0,
  169. ... 2.0 + 0.9*3.0,
  170. ... 3.0])
  171. """
  172. return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]