metrics.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import collections
  2. import logging
  3. import numpy as np
  4. from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
  5. import ray
  6. from ray import ObjectRef
  7. from ray.actor import ActorHandle
  8. from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
  9. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  10. from ray.rllib.utils.annotations import DeveloperAPI
  11. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  12. from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
  13. if TYPE_CHECKING:
  14. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  15. logger = logging.getLogger(__name__)
  16. RolloutMetrics = collections.namedtuple("RolloutMetrics", [
  17. "episode_length",
  18. "episode_reward",
  19. "agent_rewards",
  20. "custom_metrics",
  21. "perf_stats",
  22. "hist_data",
  23. "media",
  24. ])
  25. RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {})
  26. def extract_stats(stats: Dict, key: str) -> Dict[str, Any]:
  27. if key in stats:
  28. return stats[key]
  29. multiagent_stats = {}
  30. for k, v in stats.items():
  31. if isinstance(v, dict):
  32. if key in v:
  33. multiagent_stats[k] = v[key]
  34. return multiagent_stats
  35. @DeveloperAPI
  36. def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict:
  37. """Return optimization stats reported from the policy.
  38. Example:
  39. >>> grad_info = worker.learn_on_batch(samples)
  40. {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}
  41. >>> print(get_stats(grad_info))
  42. {"vf_loss": ..., "policy_loss": ...}
  43. """
  44. if LEARNER_STATS_KEY in grad_info:
  45. return grad_info[LEARNER_STATS_KEY]
  46. multiagent_stats = {}
  47. for k, v in grad_info.items():
  48. if type(v) is dict:
  49. if LEARNER_STATS_KEY in v:
  50. multiagent_stats[k] = v[LEARNER_STATS_KEY]
  51. return multiagent_stats
  52. @DeveloperAPI
  53. def collect_metrics(local_worker: Optional["RolloutWorker"] = None,
  54. remote_workers: Optional[List[ActorHandle]] = None,
  55. to_be_collected: Optional[List[ObjectRef]] = None,
  56. timeout_seconds: int = 180) -> ResultDict:
  57. """Gathers episode metrics from RolloutWorker instances."""
  58. if remote_workers is None:
  59. remote_workers = []
  60. if to_be_collected is None:
  61. to_be_collected = []
  62. episodes, to_be_collected = collect_episodes(
  63. local_worker,
  64. remote_workers,
  65. to_be_collected,
  66. timeout_seconds=timeout_seconds)
  67. metrics = summarize_episodes(episodes, episodes)
  68. return metrics
  69. @DeveloperAPI
  70. def collect_episodes(
  71. local_worker: Optional["RolloutWorker"] = None,
  72. remote_workers: Optional[List[ActorHandle]] = None,
  73. to_be_collected: Optional[List[ObjectRef]] = None,
  74. timeout_seconds: int = 180
  75. ) -> Tuple[List[Union[RolloutMetrics, OffPolicyEstimate]], List[ObjectRef]]:
  76. """Gathers new episodes metrics tuples from the given evaluators."""
  77. if remote_workers is None:
  78. remote_workers = []
  79. if to_be_collected is None:
  80. to_be_collected = []
  81. if remote_workers:
  82. pending = [
  83. a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers
  84. ] + to_be_collected
  85. collected, to_be_collected = ray.wait(
  86. pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
  87. if pending and len(collected) == 0:
  88. logger.warning(
  89. "WARNING: collected no metrics in {} seconds".format(
  90. timeout_seconds))
  91. metric_lists = ray.get(collected)
  92. else:
  93. metric_lists = []
  94. if local_worker:
  95. metric_lists.append(local_worker.get_metrics())
  96. episodes = []
  97. for metrics in metric_lists:
  98. episodes.extend(metrics)
  99. return episodes, to_be_collected
  100. @DeveloperAPI
  101. def summarize_episodes(
  102. episodes: List[Union[RolloutMetrics, OffPolicyEstimate]],
  103. new_episodes: List[Union[RolloutMetrics, OffPolicyEstimate]] = None
  104. ) -> ResultDict:
  105. """Summarizes a set of episode metrics tuples.
  106. Args:
  107. episodes: smoothed set of episodes including historical ones
  108. new_episodes: just the new episodes in this iteration. This must be
  109. a subset of `episodes`. If None, assumes all episodes are new.
  110. """
  111. if new_episodes is None:
  112. new_episodes = episodes
  113. episodes, estimates = _partition(episodes)
  114. new_episodes, _ = _partition(new_episodes)
  115. episode_rewards = []
  116. episode_lengths = []
  117. policy_rewards = collections.defaultdict(list)
  118. custom_metrics = collections.defaultdict(list)
  119. perf_stats = collections.defaultdict(list)
  120. hist_stats = collections.defaultdict(list)
  121. episode_media = collections.defaultdict(list)
  122. for episode in episodes:
  123. episode_lengths.append(episode.episode_length)
  124. episode_rewards.append(episode.episode_reward)
  125. for k, v in episode.custom_metrics.items():
  126. custom_metrics[k].append(v)
  127. for k, v in episode.perf_stats.items():
  128. perf_stats[k].append(v)
  129. for (_, policy_id), reward in episode.agent_rewards.items():
  130. if policy_id != DEFAULT_POLICY_ID:
  131. policy_rewards[policy_id].append(reward)
  132. for k, v in episode.hist_data.items():
  133. hist_stats[k] += v
  134. for k, v in episode.media.items():
  135. episode_media[k].append(v)
  136. if episode_rewards:
  137. min_reward = min(episode_rewards)
  138. max_reward = max(episode_rewards)
  139. avg_reward = np.mean(episode_rewards)
  140. else:
  141. min_reward = float("nan")
  142. max_reward = float("nan")
  143. avg_reward = float("nan")
  144. if episode_lengths:
  145. avg_length = np.mean(episode_lengths)
  146. else:
  147. avg_length = float("nan")
  148. # Show as histogram distributions.
  149. hist_stats["episode_reward"] = episode_rewards
  150. hist_stats["episode_lengths"] = episode_lengths
  151. policy_reward_min = {}
  152. policy_reward_mean = {}
  153. policy_reward_max = {}
  154. for policy_id, rewards in policy_rewards.copy().items():
  155. policy_reward_min[policy_id] = np.min(rewards)
  156. policy_reward_mean[policy_id] = np.mean(rewards)
  157. policy_reward_max[policy_id] = np.max(rewards)
  158. # Show as histogram distributions.
  159. hist_stats["policy_{}_reward".format(policy_id)] = rewards
  160. for k, v_list in custom_metrics.copy().items():
  161. filt = [v for v in v_list if not np.any(np.isnan(v))]
  162. custom_metrics[k + "_mean"] = np.mean(filt)
  163. if filt:
  164. custom_metrics[k + "_min"] = np.min(filt)
  165. custom_metrics[k + "_max"] = np.max(filt)
  166. else:
  167. custom_metrics[k + "_min"] = float("nan")
  168. custom_metrics[k + "_max"] = float("nan")
  169. del custom_metrics[k]
  170. for k, v_list in perf_stats.copy().items():
  171. perf_stats[k] = np.mean(v_list)
  172. estimators = collections.defaultdict(lambda: collections.defaultdict(list))
  173. for e in estimates:
  174. acc = estimators[e.estimator_name]
  175. for k, v in e.metrics.items():
  176. acc[k].append(v)
  177. for name, metrics in estimators.items():
  178. for k, v_list in metrics.items():
  179. metrics[k] = np.mean(v_list)
  180. estimators[name] = dict(metrics)
  181. return dict(
  182. episode_reward_max=max_reward,
  183. episode_reward_min=min_reward,
  184. episode_reward_mean=avg_reward,
  185. episode_len_mean=avg_length,
  186. episode_media=dict(episode_media),
  187. episodes_this_iter=len(new_episodes),
  188. policy_reward_min=policy_reward_min,
  189. policy_reward_max=policy_reward_max,
  190. policy_reward_mean=policy_reward_mean,
  191. custom_metrics=dict(custom_metrics),
  192. hist_stats=dict(hist_stats),
  193. sampler_perf=dict(perf_stats),
  194. off_policy_estimator=dict(estimators))
  195. def _partition(episodes: List[RolloutMetrics]
  196. ) -> Tuple[List[RolloutMetrics], List[OffPolicyEstimate]]:
  197. """Divides metrics data into true rollouts vs off-policy estimates."""
  198. rollouts, estimates = [], []
  199. for e in episodes:
  200. if isinstance(e, RolloutMetrics):
  201. rollouts.append(e)
  202. elif isinstance(e, OffPolicyEstimate):
  203. estimates.append(e)
  204. else:
  205. raise ValueError("Unknown metric type: {}".format(e))
  206. return rollouts, estimates