common.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from ray.util.iter import LocalIterator
  2. from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
  3. from ray.rllib.utils.deprecation import deprecation_warning
  4. from ray.rllib.utils.typing import Dict, SampleBatchType
  5. from ray.util.iter_metrics import MetricsContext
  6. from ray.util import log_once
  7. # Backward compatibility.
  8. from ray.rllib.utils.metrics import ( # noqa: F401
  9. LAST_TARGET_UPDATE_TS,
  10. NUM_TARGET_UPDATES,
  11. APPLY_GRADS_TIMER,
  12. COMPUTE_GRADS_TIMER,
  13. SYNCH_WORKER_WEIGHTS_TIMER as WORKER_UPDATE_TIMER,
  14. GRAD_WAIT_TIMER,
  15. SAMPLE_TIMER,
  16. LEARN_ON_BATCH_TIMER,
  17. LOAD_BATCH_TIMER,
  18. )
  19. STEPS_SAMPLED_COUNTER = "num_steps_sampled"
  20. AGENT_STEPS_SAMPLED_COUNTER = "num_agent_steps_sampled"
  21. STEPS_TRAINED_COUNTER = "num_steps_trained"
  22. STEPS_TRAINED_THIS_ITER_COUNTER = "num_steps_trained_this_iter"
  23. AGENT_STEPS_TRAINED_COUNTER = "num_agent_steps_trained"
  24. # End: Backward compatibility.
  25. # Asserts that an object is a type of SampleBatch.
  26. def _check_sample_batch_type(batch: SampleBatchType) -> None:
  27. if not isinstance(batch, (SampleBatch, MultiAgentBatch)):
  28. raise ValueError(
  29. "Expected either SampleBatch or MultiAgentBatch, "
  30. "got {}: {}".format(type(batch), batch)
  31. )
  32. # Returns pipeline global vars that should be periodically sent to each worker.
  33. def _get_global_vars() -> Dict:
  34. metrics = LocalIterator.get_metrics()
  35. return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]}
  36. def _get_shared_metrics() -> MetricsContext:
  37. """Return shared metrics for the training workflow.
  38. This only applies if this algorithm has an execution plan."""
  39. if log_once("shared-metrics-deprecation-warning"):
  40. deprecation_warning(old="ray.rllib.execution.common._get_shared_metrics")
  41. return LocalIterator.get_metrics()