postprocessing.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import numpy as np
  2. import scipy.signal
  3. from typing import Dict, Optional
  4. from ray.rllib.evaluation.episode import Episode
  5. from ray.rllib.policy.policy import Policy
  6. from ray.rllib.policy.sample_batch import SampleBatch
  7. from ray.rllib.utils.annotations import DeveloperAPI
  8. from ray.rllib.utils.nested_dict import NestedDict
  9. from ray.rllib.utils.numpy import convert_to_numpy
  10. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  11. from ray.rllib.utils.typing import AgentID
  12. from ray.rllib.utils.typing import TensorType
  13. @DeveloperAPI
  14. class Postprocessing:
  15. """Constant definitions for postprocessing."""
  16. ADVANTAGES = "advantages"
  17. VALUE_TARGETS = "value_targets"
  18. @DeveloperAPI
  19. def adjust_nstep(n_step: int, gamma: float, batch: SampleBatch) -> None:
  20. """Rewrites `batch` to encode n-step rewards, terminateds, truncateds, and next-obs.
  21. Observations and actions remain unaffected. At the end of the trajectory,
  22. n is truncated to fit in the traj length.
  23. Args:
  24. n_step: The number of steps to look ahead and adjust.
  25. gamma: The discount factor.
  26. batch: The SampleBatch to adjust (in place).
  27. Examples:
  28. n-step=3
  29. Trajectory=o0 r0 d0, o1 r1 d1, o2 r2 d2, o3 r3 d3, o4 r4 d4=True o5
  30. gamma=0.9
  31. Returned trajectory:
  32. 0: o0 [r0 + 0.9*r1 + 0.9^2*r2 + 0.9^3*r3] d3 o0'=o3
  33. 1: o1 [r1 + 0.9*r2 + 0.9^2*r3 + 0.9^3*r4] d4 o1'=o4
  34. 2: o2 [r2 + 0.9*r3 + 0.9^2*r4] d4 o1'=o5
  35. 3: o3 [r3 + 0.9*r4] d4 o3'=o5
  36. 4: o4 r4 d4 o4'=o5
  37. """
  38. assert (
  39. batch.is_single_trajectory()
  40. ), "Unexpected terminated|truncated in middle of trajectory!"
  41. len_ = len(batch)
  42. # Shift NEXT_OBS, TERMINATEDS, and TRUNCATEDS.
  43. batch[SampleBatch.NEXT_OBS] = np.concatenate(
  44. [
  45. batch[SampleBatch.OBS][n_step:],
  46. np.stack([batch[SampleBatch.NEXT_OBS][-1]] * min(n_step, len_)),
  47. ],
  48. axis=0,
  49. )
  50. batch[SampleBatch.TERMINATEDS] = np.concatenate(
  51. [
  52. batch[SampleBatch.TERMINATEDS][n_step - 1 :],
  53. np.tile(batch[SampleBatch.TERMINATEDS][-1], min(n_step - 1, len_)),
  54. ],
  55. axis=0,
  56. )
  57. # Only fix `truncateds`, if present in the batch.
  58. if SampleBatch.TRUNCATEDS in batch:
  59. batch[SampleBatch.TRUNCATEDS] = np.concatenate(
  60. [
  61. batch[SampleBatch.TRUNCATEDS][n_step - 1 :],
  62. np.tile(batch[SampleBatch.TRUNCATEDS][-1], min(n_step - 1, len_)),
  63. ],
  64. axis=0,
  65. )
  66. # Change rewards in place.
  67. for i in range(len_):
  68. for j in range(1, n_step):
  69. if i + j < len_:
  70. batch[SampleBatch.REWARDS][i] += (
  71. gamma**j * batch[SampleBatch.REWARDS][i + j]
  72. )
  73. @DeveloperAPI
  74. def compute_advantages(
  75. rollout: SampleBatch,
  76. last_r: float,
  77. gamma: float = 0.9,
  78. lambda_: float = 1.0,
  79. use_gae: bool = True,
  80. use_critic: bool = True,
  81. rewards: TensorType = None,
  82. vf_preds: TensorType = None,
  83. ):
  84. """Given a rollout, compute its value targets and the advantages.
  85. Args:
  86. rollout: SampleBatch of a single trajectory.
  87. last_r: Value estimation for last observation.
  88. gamma: Discount factor.
  89. lambda_: Parameter for GAE.
  90. use_gae: Using Generalized Advantage Estimation.
  91. use_critic: Whether to use critic (value estimates). Setting
  92. this to False will use 0 as baseline.
  93. rewards: Override the reward values in rollout.
  94. vf_preds: Override the value function predictions in rollout.
  95. Returns:
  96. SampleBatch with experience from rollout and processed rewards.
  97. """
  98. assert (
  99. SampleBatch.VF_PREDS in rollout or not use_critic
  100. ), "use_critic=True but values not found"
  101. assert use_critic or not use_gae, "Can't use gae without using a value function"
  102. last_r = convert_to_numpy(last_r)
  103. if rewards is None:
  104. rewards = rollout[SampleBatch.REWARDS]
  105. if vf_preds is None and use_critic:
  106. vf_preds = rollout[SampleBatch.VF_PREDS]
  107. if use_gae:
  108. vpred_t = np.concatenate([vf_preds, np.array([last_r])])
  109. delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
  110. # This formula for the advantage comes from:
  111. # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
  112. rollout[Postprocessing.ADVANTAGES] = discount_cumsum(delta_t, gamma * lambda_)
  113. rollout[Postprocessing.VALUE_TARGETS] = (
  114. rollout[Postprocessing.ADVANTAGES] + vf_preds
  115. ).astype(np.float32)
  116. else:
  117. rewards_plus_v = np.concatenate([rewards, np.array([last_r])])
  118. discounted_returns = discount_cumsum(rewards_plus_v, gamma)[:-1].astype(
  119. np.float32
  120. )
  121. if use_critic:
  122. rollout[Postprocessing.ADVANTAGES] = discounted_returns - vf_preds
  123. rollout[Postprocessing.VALUE_TARGETS] = discounted_returns
  124. else:
  125. rollout[Postprocessing.ADVANTAGES] = discounted_returns
  126. rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like(
  127. rollout[Postprocessing.ADVANTAGES]
  128. )
  129. rollout[Postprocessing.ADVANTAGES] = rollout[Postprocessing.ADVANTAGES].astype(
  130. np.float32
  131. )
  132. return rollout
  133. @DeveloperAPI
  134. def compute_gae_for_sample_batch(
  135. policy: Policy,
  136. sample_batch: SampleBatch,
  137. other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
  138. episode: Optional[Episode] = None,
  139. ) -> SampleBatch:
  140. """Adds GAE (generalized advantage estimations) to a trajectory.
  141. The trajectory contains only data from one episode and from one agent.
  142. - If `config.batch_mode=truncate_episodes` (default), sample_batch may
  143. contain a truncated (at-the-end) episode, in case the
  144. `config.rollout_fragment_length` was reached by the sampler.
  145. - If `config.batch_mode=complete_episodes`, sample_batch will contain
  146. exactly one episode (no matter how long).
  147. New columns can be added to sample_batch and existing ones may be altered.
  148. Args:
  149. policy: The Policy used to generate the trajectory (`sample_batch`)
  150. sample_batch: The SampleBatch to postprocess.
  151. other_agent_batches: Optional dict of AgentIDs mapping to other
  152. agents' trajectory data (from the same episode).
  153. NOTE: The other agents use the same policy.
  154. episode: Optional multi-agent episode object in which the agents
  155. operated.
  156. Returns:
  157. The postprocessed, modified SampleBatch (or a new one).
  158. """
  159. # Compute the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need for the
  160. # following `last_r` arg in `compute_advantages()`.
  161. sample_batch = compute_bootstrap_value(sample_batch, policy)
  162. vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS])
  163. rewards = np.array(sample_batch[SampleBatch.REWARDS])
  164. # We need to squeeze out the time dimension if there is one
  165. # Sanity check that both have the same shape
  166. if len(vf_preds.shape) == 2:
  167. assert vf_preds.shape == rewards.shape
  168. vf_preds = np.squeeze(vf_preds, axis=1)
  169. rewards = np.squeeze(rewards, axis=1)
  170. squeezed = True
  171. else:
  172. squeezed = False
  173. # Adds the policy logits, VF preds, and advantages to the batch,
  174. # using GAE ("generalized advantage estimation") or not.
  175. batch = compute_advantages(
  176. rollout=sample_batch,
  177. last_r=sample_batch[SampleBatch.VALUES_BOOTSTRAPPED][-1],
  178. gamma=policy.config["gamma"],
  179. lambda_=policy.config["lambda"],
  180. use_gae=policy.config["use_gae"],
  181. use_critic=policy.config.get("use_critic", True),
  182. vf_preds=vf_preds,
  183. rewards=rewards,
  184. )
  185. if squeezed:
  186. # If we needed to squeeze rewards and vf_preds, we need to unsqueeze
  187. # advantages again for it to have the same shape
  188. batch[Postprocessing.ADVANTAGES] = np.expand_dims(
  189. batch[Postprocessing.ADVANTAGES], axis=1
  190. )
  191. return batch
  192. @DeveloperAPI
  193. def compute_bootstrap_value(sample_batch: SampleBatch, policy: Policy) -> SampleBatch:
  194. """Performs a value function computation at the end of a trajectory.
  195. If the trajectory is terminated (not truncated), will not use the value function,
  196. but assume that the value of the last timestep is 0.0.
  197. In all other cases, will use the given policy's value function to compute the
  198. "bootstrapped" value estimate at the end of the given trajectory. To do so, the
  199. very last observation (sample_batch[NEXT_OBS][-1]) and - if applicable -
  200. the very last state output (sample_batch[STATE_OUT][-1]) wil be used as inputs to
  201. the value function.
  202. The thus computed value estimate will be stored in a new column of the
  203. `sample_batch`: SampleBatch.VALUES_BOOTSTRAPPED. Thereby, values at all timesteps
  204. in this column are set to 0.0, except or the last timestep, which receives the
  205. computed bootstrapped value.
  206. This is done, such that in any loss function (which processes raw, intact
  207. trajectories, such as those of IMPALA and APPO) can use this new column as follows:
  208. Example: numbers=ts in episode, '|'=episode boundary (terminal),
  209. X=bootstrapped value (!= 0.0 b/c ts=12 is not a terminal).
  210. ts=5 is NOT a terminal.
  211. T: 8 9 10 11 12 <- no terminal
  212. VF_PREDS: . . . . .
  213. VALUES_BOOTSTRAPPED: 0 0 0 0 X
  214. Args:
  215. sample_batch: The SampleBatch (single trajectory) for which to compute the
  216. bootstrap value at the end. This SampleBatch will be altered in place
  217. (by adding a new column: SampleBatch.VALUES_BOOTSTRAPPED).
  218. policy: The Policy object, whose value function to use.
  219. Returns:
  220. The altered SampleBatch (with the extra SampleBatch.VALUES_BOOTSTRAPPED
  221. column).
  222. """
  223. # Trajectory is actually complete -> last r=0.0.
  224. if sample_batch[SampleBatch.TERMINATEDS][-1]:
  225. last_r = 0.0
  226. # Trajectory has been truncated -> last r=VF estimate of last obs.
  227. else:
  228. # Input dict is provided to us automatically via the Model's
  229. # requirements. It's a single-timestep (last one in trajectory)
  230. # input_dict.
  231. # Create an input dict according to the Policy's requirements.
  232. input_dict = sample_batch.get_single_step_input_dict(
  233. policy.view_requirements, index="last"
  234. )
  235. if policy.config.get("_enable_rl_module_api"):
  236. # Note: During sampling you are using the parameters at the beginning of
  237. # the sampling process. If I'll be using this advantages during training
  238. # should it not be the latest parameters during training for this to be
  239. # correct? Does this mean that I need to preserve the trajectory
  240. # information during training and compute the advantages inside the loss
  241. # function?
  242. # TODO (Kourosh): Another thing we need to figure out is which end point
  243. # to call here (why forward_exploration)? What if this method is getting
  244. # called inside the learner loop or via another abstraction like
  245. # RLSampler.postprocess_trajectory() which is non-batched cpu/gpu task
  246. # running across different processes for different trajectories?
  247. # This implementation right now will compute even the action_dist which
  248. # will not be needed but takes time to compute.
  249. if policy.framework == "torch":
  250. input_dict = convert_to_torch_tensor(input_dict, device=policy.device)
  251. # For recurrent models, we need to add a time dimension.
  252. input_dict = policy.maybe_add_time_dimension(
  253. input_dict, seq_lens=input_dict[SampleBatch.SEQ_LENS]
  254. )
  255. input_dict = NestedDict(input_dict)
  256. fwd_out = policy.model.forward_exploration(input_dict)
  257. # For recurrent models, we need to remove the time dimension.
  258. fwd_out = policy.maybe_remove_time_dimension(fwd_out)
  259. last_r = fwd_out[SampleBatch.VF_PREDS][-1]
  260. else:
  261. last_r = policy._value(**input_dict)
  262. vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS])
  263. # We need to squeeze out the time dimension if there is one
  264. if len(vf_preds.shape) == 2:
  265. vf_preds = np.squeeze(vf_preds, axis=1)
  266. squeezed = True
  267. else:
  268. squeezed = False
  269. # Set the SampleBatch.VALUES_BOOTSTRAPPED field to VF_PREDS[1:] + the
  270. # very last timestep (where this bootstrapping value is actually needed), which
  271. # we set to the computed `last_r`.
  272. sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.concatenate(
  273. [
  274. convert_to_numpy(vf_preds[1:]),
  275. np.array([convert_to_numpy(last_r)], dtype=np.float32),
  276. ],
  277. axis=0,
  278. )
  279. if squeezed:
  280. sample_batch[SampleBatch.VF_PREDS] = np.expand_dims(vf_preds, axis=1)
  281. sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.expand_dims(
  282. sample_batch[SampleBatch.VALUES_BOOTSTRAPPED], axis=1
  283. )
  284. return sample_batch
  285. @DeveloperAPI
  286. def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
  287. """Calculates the discounted cumulative sum over a reward sequence `x`.
  288. y[t] - discount*y[t+1] = x[t]
  289. reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
  290. Args:
  291. gamma: The discount factor gamma.
  292. Returns:
  293. The sequence containing the discounted cumulative sums
  294. for each individual reward in `x` till the end of the trajectory.
  295. Examples:
  296. >>> x = np.array([0.0, 1.0, 2.0, 3.0])
  297. >>> gamma = 0.9
  298. >>> discount_cumsum(x, gamma)
  299. ... array([0.0 + 0.9*1.0 + 0.9^2*2.0 + 0.9^3*3.0,
  300. ... 1.0 + 0.9*2.0 + 0.9^2*3.0,
  301. ... 2.0 + 0.9*3.0,
  302. ... 3.0])
  303. """
  304. return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]