common.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from ray.util.iter import LocalIterator
  2. from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
  3. from ray.rllib.utils.typing import Dict, SampleBatchType
  4. from ray.util.iter_metrics import MetricsContext
  5. # Counters for training progress (keys for metrics.counters).
  6. STEPS_SAMPLED_COUNTER = "num_steps_sampled"
  7. AGENT_STEPS_SAMPLED_COUNTER = "num_agent_steps_sampled"
  8. STEPS_TRAINED_COUNTER = "num_steps_trained"
  9. STEPS_TRAINED_THIS_ITER_COUNTER = "num_steps_trained_this_iter"
  10. AGENT_STEPS_TRAINED_COUNTER = "num_agent_steps_trained"
  11. # Counters to track target network updates.
  12. LAST_TARGET_UPDATE_TS = "last_target_update_ts"
  13. NUM_TARGET_UPDATES = "num_target_updates"
  14. # Performance timers (keys for metrics.timers).
  15. APPLY_GRADS_TIMER = "apply_grad"
  16. COMPUTE_GRADS_TIMER = "compute_grads"
  17. WORKER_UPDATE_TIMER = "update"
  18. GRAD_WAIT_TIMER = "grad_wait"
  19. SAMPLE_TIMER = "sample"
  20. LEARN_ON_BATCH_TIMER = "learn"
  21. LOAD_BATCH_TIMER = "load"
  22. # Asserts that an object is a type of SampleBatch.
  23. def _check_sample_batch_type(batch: SampleBatchType) -> None:
  24. if not isinstance(batch, (SampleBatch, MultiAgentBatch)):
  25. raise ValueError("Expected either SampleBatch or MultiAgentBatch, "
  26. "got {}: {}".format(type(batch), batch))
  27. # Returns pipeline global vars that should be periodically sent to each worker.
  28. def _get_global_vars() -> Dict:
  29. metrics = LocalIterator.get_metrics()
  30. return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]}
  31. def _get_shared_metrics() -> MetricsContext:
  32. """Return shared metrics for the training workflow.
  33. This only applies if this trainer has an execution plan."""
  34. return LocalIterator.get_metrics()