123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- import collections
- import logging
- import numpy as np
- from typing import Any, Dict, List, Optional, TYPE_CHECKING
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
- from ray.rllib.utils.annotations import DeveloperAPI
- from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
- from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
- if TYPE_CHECKING:
- from ray.rllib.evaluation.worker_set import WorkerSet
- logger = logging.getLogger(__name__)
- RolloutMetrics = DeveloperAPI(
- collections.namedtuple(
- "RolloutMetrics",
- [
- "episode_length",
- "episode_reward",
- "agent_rewards",
- "custom_metrics",
- "perf_stats",
- "hist_data",
- "media",
- "episode_faulty",
- "connector_metrics",
- ],
- )
- )
- RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {}, False, {})
- def _extract_stats(stats: Dict, key: str) -> Dict[str, Any]:
- if key in stats:
- return stats[key]
- multiagent_stats = {}
- for k, v in stats.items():
- if isinstance(v, dict):
- if key in v:
- multiagent_stats[k] = v[key]
- return multiagent_stats
- @DeveloperAPI
- def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict:
- """Return optimization stats reported from the policy.
- Example:
- >>> grad_info = worker.learn_on_batch(samples)
- {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}
- >>> print(get_stats(grad_info))
- {"vf_loss": ..., "policy_loss": ...}
- """
- if LEARNER_STATS_KEY in grad_info:
- return grad_info[LEARNER_STATS_KEY]
- multiagent_stats = {}
- for k, v in grad_info.items():
- if type(v) is dict:
- if LEARNER_STATS_KEY in v:
- multiagent_stats[k] = v[LEARNER_STATS_KEY]
- return multiagent_stats
- @DeveloperAPI
- def collect_metrics(
- workers: "WorkerSet",
- remote_worker_ids: Optional[List[int]] = None,
- timeout_seconds: int = 180,
- keep_custom_metrics: bool = False,
- ) -> ResultDict:
- """Gathers episode metrics from rollout worker set.
- Args:
- workers: WorkerSet.
- remote_worker_ids: Optional list of IDs of remote workers to collect
- metrics from.
- timeout_seconds: Timeout in seconds for collecting metrics from remote workers.
- keep_custom_metrics: Whether to keep custom metrics in the result dict as
- they are (True) or to aggregate them (False).
- Returns:
- A result dict of metrics.
- """
- episodes = collect_episodes(
- workers, remote_worker_ids, timeout_seconds=timeout_seconds
- )
- metrics = summarize_episodes(
- episodes, episodes, keep_custom_metrics=keep_custom_metrics
- )
- return metrics
- @DeveloperAPI
- def collect_episodes(
- workers: "WorkerSet",
- remote_worker_ids: Optional[List[int]] = None,
- timeout_seconds: int = 180,
- ) -> List[RolloutMetrics]:
- """Gathers new episodes metrics tuples from the given RolloutWorkers.
- Args:
- workers: WorkerSet.
- remote_worker_ids: Optional list of IDs of remote workers to collect
- metrics from.
- timeout_seconds: Timeout in seconds for collecting metrics from remote workers.
- Returns:
- List of RolloutMetrics.
- """
- # This will drop get_metrics() calls that are too slow.
- # We can potentially make this an asynchronous call if this turns
- # out to be a problem.
- metric_lists = workers.foreach_worker(
- lambda w: w.get_metrics(),
- local_worker=True,
- remote_worker_ids=remote_worker_ids,
- timeout_seconds=timeout_seconds,
- )
- if len(metric_lists) == 0:
- logger.warning("WARNING: collected no metrics.")
- episodes = []
- for metrics in metric_lists:
- episodes.extend(metrics)
- return episodes
- @DeveloperAPI
- def summarize_episodes(
- episodes: List[RolloutMetrics],
- new_episodes: List[RolloutMetrics] = None,
- keep_custom_metrics: bool = False,
- ) -> ResultDict:
- """Summarizes a set of episode metrics tuples.
- Args:
- episodes: List of most recent n episodes. This may include historical ones
- (not newly collected in this iteration) in order to achieve the size of
- the smoothing window.
- new_episodes: All the episodes that were completed in this iteration.
- keep_custom_metrics: Whether to keep custom metrics in the result dict as
- they are (True) or to aggregate them (False).
- Returns:
- A result dict of metrics.
- """
- if new_episodes is None:
- new_episodes = episodes
- episode_rewards = []
- episode_lengths = []
- policy_rewards = collections.defaultdict(list)
- custom_metrics = collections.defaultdict(list)
- perf_stats = collections.defaultdict(list)
- hist_stats = collections.defaultdict(list)
- episode_media = collections.defaultdict(list)
- connector_metrics = collections.defaultdict(list)
- num_faulty_episodes = 0
- for episode in episodes:
- # Faulty episodes may still carry perf_stats data.
- for k, v in episode.perf_stats.items():
- perf_stats[k].append(v)
- # Continue if this is a faulty episode.
- # There should be other meaningful stats to be collected.
- if episode.episode_faulty:
- num_faulty_episodes += 1
- continue
- episode_lengths.append(episode.episode_length)
- episode_rewards.append(episode.episode_reward)
- for k, v in episode.custom_metrics.items():
- custom_metrics[k].append(v)
- for (_, policy_id), reward in episode.agent_rewards.items():
- if policy_id != DEFAULT_POLICY_ID:
- policy_rewards[policy_id].append(reward)
- for k, v in episode.hist_data.items():
- hist_stats[k] += v
- for k, v in episode.media.items():
- episode_media[k].append(v)
- if hasattr(episode, "connector_metrics"):
- # Group connector metrics by connector_metric name for all policies
- for per_pipeline_metrics in episode.connector_metrics.values():
- for per_connector_metrics in per_pipeline_metrics.values():
- for connector_metric_name, val in per_connector_metrics.items():
- connector_metrics[connector_metric_name].append(val)
- if episode_rewards:
- min_reward = min(episode_rewards)
- max_reward = max(episode_rewards)
- avg_reward = np.mean(episode_rewards)
- else:
- min_reward = float("nan")
- max_reward = float("nan")
- avg_reward = float("nan")
- if episode_lengths:
- avg_length = np.mean(episode_lengths)
- else:
- avg_length = float("nan")
- # Show as histogram distributions.
- hist_stats["episode_reward"] = episode_rewards
- hist_stats["episode_lengths"] = episode_lengths
- policy_reward_min = {}
- policy_reward_mean = {}
- policy_reward_max = {}
- for policy_id, rewards in policy_rewards.copy().items():
- policy_reward_min[policy_id] = np.min(rewards)
- policy_reward_mean[policy_id] = np.mean(rewards)
- policy_reward_max[policy_id] = np.max(rewards)
- # Show as histogram distributions.
- hist_stats["policy_{}_reward".format(policy_id)] = rewards
- for k, v_list in custom_metrics.copy().items():
- filt = [v for v in v_list if not np.any(np.isnan(v))]
- if keep_custom_metrics:
- custom_metrics[k] = filt
- else:
- custom_metrics[k + "_mean"] = np.mean(filt)
- if filt:
- custom_metrics[k + "_min"] = np.min(filt)
- custom_metrics[k + "_max"] = np.max(filt)
- else:
- custom_metrics[k + "_min"] = float("nan")
- custom_metrics[k + "_max"] = float("nan")
- del custom_metrics[k]
- for k, v_list in perf_stats.copy().items():
- perf_stats[k] = np.mean(v_list)
- mean_connector_metrics = dict()
- for k, v_list in connector_metrics.items():
- mean_connector_metrics[k] = np.mean(v_list)
- return dict(
- episode_reward_max=max_reward,
- episode_reward_min=min_reward,
- episode_reward_mean=avg_reward,
- episode_len_mean=avg_length,
- episode_media=dict(episode_media),
- episodes_this_iter=len(new_episodes),
- policy_reward_min=policy_reward_min,
- policy_reward_max=policy_reward_max,
- policy_reward_mean=policy_reward_mean,
- custom_metrics=dict(custom_metrics),
- hist_stats=dict(hist_stats),
- sampler_perf=dict(perf_stats),
- num_faulty_episodes=num_faulty_episodes,
- connector_metrics=mean_connector_metrics,
- )
|