metric_ops.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from typing import Any, Dict, List
  2. from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
  3. from ray.rllib.execution.common import (
  4. AGENT_STEPS_SAMPLED_COUNTER,
  5. STEPS_SAMPLED_COUNTER,
  6. STEPS_TRAINED_COUNTER,
  7. STEPS_TRAINED_THIS_ITER_COUNTER,
  8. _get_shared_metrics,
  9. )
  10. from ray.rllib.evaluation.worker_set import WorkerSet
  11. from ray.rllib.utils.deprecation import deprecation_warning
  12. from ray.util import log_once
  13. class CollectMetrics:
  14. """Callable that collects metrics from workers.
  15. The metrics are smoothed over a given history window.
  16. This should be used with the .for_each() operator. For a higher level
  17. API, consider using StandardMetricsReporting instead.
  18. Examples:
  19. >>> from ray.rllib.execution.metric_ops import CollectMetrics
  20. >>> train_op, workers = ... # doctest: +SKIP
  21. >>> output_op = train_op.for_each(CollectMetrics(workers)) # doctest: +SKIP
  22. >>> print(next(output_op)) # doctest: +SKIP
  23. {"episode_reward_max": ..., "episode_reward_mean": ..., ...}
  24. """
  25. def __init__(
  26. self,
  27. workers: WorkerSet,
  28. min_history: int = 100,
  29. timeout_seconds: int = 180,
  30. keep_per_episode_custom_metrics: bool = False,
  31. selected_workers: List[int] = None,
  32. by_steps_trained: bool = False,
  33. ):
  34. self.workers = workers
  35. self.episode_history = []
  36. self.min_history = min_history
  37. self.timeout_seconds = timeout_seconds
  38. self.keep_custom_metrics = keep_per_episode_custom_metrics
  39. self.selected_workers = selected_workers
  40. self.by_steps_trained = by_steps_trained
  41. if log_once("learner-thread-deprecation-warning"):
  42. deprecation_warning(old="ray.rllib.execution.metric_ops.CollectMetrics")
  43. def __call__(self, _: Any) -> Dict:
  44. # Collect worker metrics.
  45. episodes = collect_episodes(
  46. self.workers,
  47. self.selected_workers or self.workers.healthy_worker_ids(),
  48. timeout_seconds=self.timeout_seconds,
  49. )
  50. orig_episodes = list(episodes)
  51. missing = self.min_history - len(episodes)
  52. if missing > 0:
  53. episodes = self.episode_history[-missing:] + episodes
  54. assert len(episodes) <= self.min_history
  55. self.episode_history.extend(orig_episodes)
  56. self.episode_history = self.episode_history[-self.min_history :]
  57. res = summarize_episodes(episodes, orig_episodes, self.keep_custom_metrics)
  58. # Add in iterator metrics.
  59. metrics = _get_shared_metrics()
  60. custom_metrics_from_info = metrics.info.pop("custom_metrics", {})
  61. timers = {}
  62. counters = {}
  63. info = {}
  64. info.update(metrics.info)
  65. for k, counter in metrics.counters.items():
  66. counters[k] = counter
  67. for k, timer in metrics.timers.items():
  68. timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
  69. if timer.has_units_processed():
  70. timers["{}_throughput".format(k)] = round(timer.mean_throughput, 3)
  71. res.update(
  72. {
  73. "num_healthy_workers": self.workers.num_healthy_workers(),
  74. "timesteps_total": (
  75. metrics.counters[STEPS_TRAINED_COUNTER]
  76. if self.by_steps_trained
  77. else metrics.counters[STEPS_SAMPLED_COUNTER]
  78. ),
  79. # tune.Trainable uses timesteps_this_iter for tracking
  80. # total timesteps.
  81. "timesteps_this_iter": metrics.counters[
  82. STEPS_TRAINED_THIS_ITER_COUNTER
  83. ],
  84. "agent_timesteps_total": metrics.counters.get(
  85. AGENT_STEPS_SAMPLED_COUNTER, 0
  86. ),
  87. }
  88. )
  89. res["timers"] = timers
  90. res["info"] = info
  91. res["info"].update(counters)
  92. res["custom_metrics"] = res.get("custom_metrics", {})
  93. res["episode_media"] = res.get("episode_media", {})
  94. res["custom_metrics"].update(custom_metrics_from_info)
  95. return res