utils.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from typing import List, Optional
  2. from ray.rllib.evaluation.episode import Episode
  3. from ray.rllib.evaluation.postprocessing import compute_advantages
  4. from ray.rllib.policy import Policy
  5. from ray.rllib.policy.sample_batch import SampleBatch
  6. def post_process_advantages(
  7. policy: Policy,
  8. sample_batch: SampleBatch,
  9. other_agent_batches: Optional[List[SampleBatch]] = None,
  10. episode: Optional[Episode] = None) -> SampleBatch:
  11. """Adds the "advantages" column to `sample_batch`.
  12. Args:
  13. policy (Policy): The Policy object to do post-processing for.
  14. sample_batch (SampleBatch): The actual sample batch to post-process.
  15. other_agent_batches (Optional[List[SampleBatch]]): Optional list of
  16. other agents' SampleBatch objects.
  17. episode (Episode): The multi-agent episode object, from which
  18. `sample_batch` was generated.
  19. Returns:
  20. SampleBatch: The SampleBatch enhanced by the added ADVANTAGES field.
  21. """
  22. # Calculates advantage values based on the rewards in the sample batch.
  23. # The value of the last observation is assumed to be 0.0 (no value function
  24. # estimation at the end of the sampled chunk).
  25. return compute_advantages(
  26. rollout=sample_batch,
  27. last_r=0.0,
  28. gamma=policy.config["gamma"],
  29. use_gae=False,
  30. use_critic=False)