metric_ops.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from typing import Any, List, Dict
  2. import time
  3. from ray.actor import ActorHandle
  4. from ray.util.iter import LocalIterator
  5. from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
  6. from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
  7. STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, \
  8. STEPS_TRAINED_THIS_ITER_COUNTER, _get_shared_metrics
  9. from ray.rllib.evaluation.worker_set import WorkerSet
  10. def StandardMetricsReporting(
  11. train_op: LocalIterator[Any],
  12. workers: WorkerSet,
  13. config: dict,
  14. selected_workers: List[ActorHandle] = None,
  15. by_steps_trained: bool = False,
  16. ) -> LocalIterator[dict]:
  17. """Operator to periodically collect and report metrics.
  18. Args:
  19. train_op (LocalIterator): Operator for executing training steps.
  20. We ignore the output values.
  21. workers (WorkerSet): Rollout workers to collect metrics from.
  22. config (dict): Trainer configuration, used to determine the frequency
  23. of stats reporting.
  24. selected_workers (list): Override the list of remote workers
  25. to collect metrics from.
  26. by_steps_trained (bool): If True, uses the `STEPS_TRAINED_COUNTER`
  27. instead of the `STEPS_SAMPLED_COUNTER` in metrics.
  28. Returns:
  29. LocalIterator[dict]: A local iterator over training results.
  30. Examples:
  31. >>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
  32. >>> metrics_op = StandardMetricsReporting(train_op, workers, config)
  33. >>> next(metrics_op)
  34. {"episode_reward_max": ..., "episode_reward_mean": ..., ...}
  35. """
  36. output_op = train_op \
  37. .filter(OncePerTimestepsElapsed(config["timesteps_per_iteration"],
  38. by_steps_trained=by_steps_trained)) \
  39. .filter(OncePerTimeInterval(config["min_iter_time_s"])) \
  40. .for_each(CollectMetrics(
  41. workers,
  42. min_history=config["metrics_smoothing_episodes"],
  43. timeout_seconds=config["collect_metrics_timeout"],
  44. selected_workers=selected_workers))
  45. return output_op
  46. class CollectMetrics:
  47. """Callable that collects metrics from workers.
  48. The metrics are smoothed over a given history window.
  49. This should be used with the .for_each() operator. For a higher level
  50. API, consider using StandardMetricsReporting instead.
  51. Examples:
  52. >>> output_op = train_op.for_each(CollectMetrics(workers))
  53. >>> print(next(output_op))
  54. {"episode_reward_max": ..., "episode_reward_mean": ..., ...}
  55. """
  56. def __init__(self,
  57. workers: WorkerSet,
  58. min_history: int = 100,
  59. timeout_seconds: int = 180,
  60. selected_workers: List[ActorHandle] = None):
  61. self.workers = workers
  62. self.episode_history = []
  63. self.to_be_collected = []
  64. self.min_history = min_history
  65. self.timeout_seconds = timeout_seconds
  66. self.selected_workers = selected_workers
  67. def __call__(self, _: Any) -> Dict:
  68. # Collect worker metrics.
  69. episodes, self.to_be_collected = collect_episodes(
  70. self.workers.local_worker(),
  71. self.selected_workers or self.workers.remote_workers(),
  72. self.to_be_collected,
  73. timeout_seconds=self.timeout_seconds)
  74. orig_episodes = list(episodes)
  75. missing = self.min_history - len(episodes)
  76. if missing > 0:
  77. episodes = self.episode_history[-missing:] + episodes
  78. assert len(episodes) <= self.min_history
  79. self.episode_history.extend(orig_episodes)
  80. self.episode_history = self.episode_history[-self.min_history:]
  81. res = summarize_episodes(episodes, orig_episodes)
  82. # Add in iterator metrics.
  83. metrics = _get_shared_metrics()
  84. custom_metrics_from_info = metrics.info.pop("custom_metrics", {})
  85. timers = {}
  86. counters = {}
  87. info = {}
  88. info.update(metrics.info)
  89. for k, counter in metrics.counters.items():
  90. counters[k] = counter
  91. for k, timer in metrics.timers.items():
  92. timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
  93. if timer.has_units_processed():
  94. timers["{}_throughput".format(k)] = round(
  95. timer.mean_throughput, 3)
  96. res.update({
  97. "num_healthy_workers": len(self.workers.remote_workers()),
  98. "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
  99. # tune.Trainable uses timesteps_this_iter for tracking
  100. # total timesteps.
  101. "timesteps_this_iter": metrics.counters[
  102. STEPS_TRAINED_THIS_ITER_COUNTER],
  103. "agent_timesteps_total": metrics.counters.get(
  104. AGENT_STEPS_SAMPLED_COUNTER, 0),
  105. })
  106. res["timers"] = timers
  107. res["info"] = info
  108. res["info"].update(counters)
  109. res["custom_metrics"] = res.get("custom_metrics", {})
  110. res["episode_media"] = res.get("episode_media", {})
  111. res["custom_metrics"].update(custom_metrics_from_info)
  112. return res
  113. class OncePerTimeInterval:
  114. """Callable that returns True once per given interval.
  115. This should be used with the .filter() operator to throttle / rate-limit
  116. metrics reporting. For a higher-level API, consider using
  117. StandardMetricsReporting instead.
  118. Examples:
  119. >>> throttled_op = train_op.filter(OncePerTimeInterval(5))
  120. >>> start = time.time()
  121. >>> next(throttled_op)
  122. >>> print(time.time() - start)
  123. 5.00001 # will be greater than 5 seconds
  124. """
  125. def __init__(self, delay: int):
  126. self.delay = delay
  127. self.last_called = 0
  128. def __call__(self, item: Any) -> bool:
  129. if self.delay <= 0.0:
  130. return True
  131. now = time.time()
  132. if now - self.last_called > self.delay:
  133. self.last_called = now
  134. return True
  135. return False
  136. class OncePerTimestepsElapsed:
  137. """Callable that returns True once per given number of timesteps.
  138. This should be used with the .filter() operator to throttle / rate-limit
  139. metrics reporting. For a higher-level API, consider using
  140. StandardMetricsReporting instead.
  141. Examples:
  142. >>> throttled_op = train_op.filter(OncePerTimestepsElapsed(1000))
  143. >>> next(throttled_op)
  144. # will only return after 1000 steps have elapsed
  145. """
  146. def __init__(self, delay_steps: int, by_steps_trained: bool = False):
  147. """
  148. Args:
  149. delay_steps (int): The number of steps (sampled or trained) every
  150. which this op returns True.
  151. by_steps_trained (bool): If True, uses the `STEPS_TRAINED_COUNTER`
  152. instead of the `STEPS_SAMPLED_COUNTER` in metrics.
  153. """
  154. self.delay_steps = delay_steps
  155. self.by_steps_trained = by_steps_trained
  156. self.last_called = 0
  157. def __call__(self, item: Any) -> bool:
  158. if self.delay_steps <= 0:
  159. return True
  160. metrics = _get_shared_metrics()
  161. if self.by_steps_trained:
  162. now = metrics.counters[STEPS_TRAINED_COUNTER]
  163. else:
  164. now = metrics.counters[STEPS_SAMPLED_COUNTER]
  165. if now - self.last_called >= self.delay_steps:
  166. self.last_called = now
  167. return True
  168. return False