utils.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import List, Optional
  2. from ray.rllib.evaluation.episode import Episode
  3. from ray.rllib.evaluation.postprocessing import compute_advantages
  4. from ray.rllib.policy import Policy
  5. from ray.rllib.policy.sample_batch import SampleBatch
  6. def post_process_advantages(
  7. policy: Policy,
  8. sample_batch: SampleBatch,
  9. other_agent_batches: Optional[List[SampleBatch]] = None,
  10. episode: Optional[Episode] = None,
  11. ) -> SampleBatch:
  12. """Adds the "advantages" column to `sample_batch`.
  13. Args:
  14. policy: The Policy object to do post-processing for.
  15. sample_batch: The actual sample batch to post-process.
  16. other_agent_batches (Optional[List[SampleBatch]]): Optional list of
  17. other agents' SampleBatch objects.
  18. episode: The multi-agent episode object, from which
  19. `sample_batch` was generated.
  20. Returns:
  21. SampleBatch: The SampleBatch enhanced by the added ADVANTAGES field.
  22. """
  23. # Calculates advantage values based on the rewards in the sample batch.
  24. # The value of the last observation is assumed to be 0.0 (no value function
  25. # estimation at the end of the sampled chunk).
  26. return compute_advantages(
  27. rollout=sample_batch,
  28. last_r=0.0,
  29. gamma=policy.config["gamma"],
  30. use_gae=False,
  31. use_critic=False,
  32. )