compute_adapted_gae_on_postprocess_trajectory.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. """
  2. Adapted (time-dependent) GAE for PPO algorithm can be activated by setting
  3. use_adapted_gae=True in the policy config. Additionally, it is required that
  4. "callbacks" include the custom callback class in the Trainer's config.
  5. Furthermore, the env must return in its info dictionary a key-value pair of
  6. the form "d_ts": ... where the value is the length (time) of recent agent step.
  7. This adapted, time-dependent computation of advantages may be useful in cases
  8. where agent's actions take various times and thus time steps are not
  9. equidistant (https://docdro.id/400TvlR)
  10. """
  11. from ray.rllib.agents.callbacks import DefaultCallbacks
  12. from ray.rllib.policy.sample_batch import SampleBatch
  13. from ray.rllib.evaluation.postprocessing import Postprocessing
  14. from ray.rllib.utils.annotations import override
  15. import numpy as np
  16. class MyCallbacks(DefaultCallbacks):
  17. @override(DefaultCallbacks)
  18. def on_postprocess_trajectory(self, *, worker, episode, agent_id,
  19. policy_id, policies, postprocessed_batch,
  20. original_batches, **kwargs):
  21. super().on_postprocess_trajectory(
  22. worker=worker,
  23. episode=episode,
  24. agent_id=agent_id,
  25. policy_id=policy_id,
  26. policies=policies,
  27. postprocessed_batch=postprocessed_batch,
  28. original_batches=original_batches,
  29. **kwargs)
  30. if policies[policy_id].config.get("use_adapted_gae", False):
  31. policy = policies[policy_id]
  32. assert policy.config["use_gae"], \
  33. "Can't use adapted gae without use_gae=True!"
  34. info_dicts = postprocessed_batch[SampleBatch.INFOS]
  35. assert np.all(["d_ts" in info_dict for info_dict in info_dicts]), \
  36. "Info dicts in sample batch must contain data 'd_ts' \
  37. (=ts[i+1]-ts[i] length of time steps)!"
  38. d_ts = np.array(
  39. [np.float(info_dict.get("d_ts")) for info_dict in info_dicts])
  40. assert np.all([e.is_integer() for e in d_ts]), \
  41. "Elements of 'd_ts' (length of time steps) must be integer!"
  42. # Trajectory is actually complete -> last r=0.0.
  43. if postprocessed_batch[SampleBatch.DONES][-1]:
  44. last_r = 0.0
  45. # Trajectory has been truncated -> last r=VF estimate of last obs.
  46. else:
  47. # Input dict is provided to us automatically via the Model's
  48. # requirements. It's a single-timestep (last one in trajectory)
  49. # input_dict.
  50. # Create an input dict according to the Model's requirements.
  51. input_dict = postprocessed_batch.get_single_step_input_dict(
  52. policy.model.view_requirements, index="last")
  53. last_r = policy._value(**input_dict)
  54. gamma = policy.config["gamma"]
  55. lambda_ = policy.config["lambda"]
  56. vpred_t = np.concatenate([
  57. postprocessed_batch[SampleBatch.VF_PREDS],
  58. np.array([last_r])
  59. ])
  60. delta_t = (postprocessed_batch[SampleBatch.REWARDS] +
  61. gamma**d_ts * vpred_t[1:] - vpred_t[:-1])
  62. # This formula for the advantage is an adaption of
  63. # "Generalized Advantage Estimation"
  64. # (https://arxiv.org/abs/1506.02438) which accounts for time steps
  65. # of irregular length (see proposal here ).
  66. # NOTE: last time step delta is not required
  67. postprocessed_batch[Postprocessing.ADVANTAGES] = \
  68. generalized_discount_cumsum(
  69. delta_t, d_ts[:-1], gamma * lambda_)
  70. postprocessed_batch[Postprocessing.VALUE_TARGETS] = (
  71. postprocessed_batch[Postprocessing.ADVANTAGES] +
  72. postprocessed_batch[SampleBatch.VF_PREDS]).astype(np.float32)
  73. postprocessed_batch[Postprocessing.ADVANTAGES] = \
  74. postprocessed_batch[Postprocessing.ADVANTAGES].astype(
  75. np.float32)
  76. def generalized_discount_cumsum(x: np.ndarray, deltas: np.ndarray,
  77. gamma: float) -> np.ndarray:
  78. """Calculates the 'time-dependent' discounted cumulative sum over a
  79. (reward) sequence `x`.
  80. Recursive equations:
  81. y[t] - gamma**deltas[t+1]*y[t+1] = x[t]
  82. reversed(y)[t] - gamma**reversed(deltas)[t-1]*reversed(y)[t-1] =
  83. reversed(x)[t]
  84. Args:
  85. x (np.ndarray): A sequence of rewards or one-step TD residuals.
  86. deltas (np.ndarray): A sequence of time step deltas (length of time
  87. steps).
  88. gamma (float): The discount factor gamma.
  89. Returns:
  90. np.ndarray: The sequence containing the 'time-dependent' discounted
  91. cumulative sums for each individual element in `x` till the end of
  92. the trajectory.
  93. Examples:
  94. >>> x = np.array([0.0, 1.0, 2.0, 3.0])
  95. >>> deltas = np.array([1.0, 4.0, 15.0])
  96. >>> gamma = 0.9
  97. >>> generalized_discount_cumsum(x, deltas, gamma)
  98. ... array([0.0 + 0.9^1.0*1.0 + 0.9^4.0*2.0 + 0.9^15.0*3.0,
  99. ... 1.0 + 0.9^4.0*2.0 + 0.9^15.0*3.0,
  100. ... 2.0 + 0.9^15.0*3.0,
  101. ... 3.0])
  102. """
  103. reversed_x = x[::-1]
  104. reversed_deltas = deltas[::-1]
  105. reversed_y = np.empty_like(x)
  106. reversed_y[0] = reversed_x[0]
  107. for i in range(1, x.size):
  108. reversed_y[i] = \
  109. reversed_x[i] + gamma**reversed_deltas[i-1] * reversed_y[i-1]
  110. return reversed_y[::-1]