1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222 |
- from abc import abstractmethod, ABCMeta
- from collections import defaultdict, namedtuple
- import logging
- import numpy as np
- import queue
- import threading
- import time
- import tree # pip install dm_tree
- from typing import Any, Callable, Dict, List, Iterator, Optional, Set, Tuple,\
- Type, TYPE_CHECKING, Union
- from ray.util.debug import log_once
- from ray.rllib.evaluation.collectors.sample_collector import \
- SampleCollector
- from ray.rllib.evaluation.collectors.simple_list_collector import \
- SimpleListCollector
- from ray.rllib.evaluation.episode import Episode
- from ray.rllib.evaluation.metrics import RolloutMetrics
- from ray.rllib.evaluation.sample_batch_builder import \
- MultiAgentSampleBatchBuilder
- from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
- from ray.rllib.env.wrappers.atari_wrappers import get_wrapper_by_cls, \
- MonitorEnv
- from ray.rllib.models.preprocessors import Preprocessor
- from ray.rllib.offline import InputReader
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.policy_map import PolicyMap
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.annotations import override, DeveloperAPI
- from ray.rllib.utils.debug import summarize
- from ray.rllib.utils.deprecation import deprecation_warning
- from ray.rllib.utils.filter import Filter
- from ray.rllib.utils.numpy import convert_to_numpy
- from ray.rllib.utils.spaces.space_utils import clip_action, \
- unsquash_action, unbatch
- from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
- EnvObsType, EnvInfoDict, EnvID, MultiEnvDict, EnvActionType, \
- TensorStructType
- if TYPE_CHECKING:
- from ray.rllib.agents.callbacks import DefaultCallbacks
- from ray.rllib.evaluation.observation_function import ObservationFunction
- from ray.rllib.evaluation.rollout_worker import RolloutWorker
- from ray.rllib.utils import try_import_tf
- _, tf, _ = try_import_tf()
- from gym.envs.classic_control.rendering import SimpleImageViewer
- logger = logging.getLogger(__name__)
- PolicyEvalData = namedtuple("PolicyEvalData", [
- "env_id", "agent_id", "obs", "info", "rnn_state", "prev_action",
- "prev_reward"
- ])
- # A batch of RNN states with dimensions [state_index, batch, state_object].
- StateBatch = List[List[Any]]
- class NewEpisodeDefaultDict(defaultdict):
- def __missing__(self, env_id):
- if self.default_factory is None:
- raise KeyError(env_id)
- else:
- ret = self[env_id] = self.default_factory(env_id)
- return ret
- class _PerfStats:
- """Sampler perf stats that will be included in rollout metrics."""
- def __init__(self):
- self.iters = 0
- self.raw_obs_processing_time = 0.0
- self.inference_time = 0.0
- self.action_processing_time = 0.0
- self.env_wait_time = 0.0
- self.env_render_time = 0.0
- def get(self):
- # Mean multiplicator (1000 = ms -> sec).
- factor = 1000 / self.iters
- return {
- # Raw observation preprocessing.
- "mean_raw_obs_processing_ms": self.raw_obs_processing_time *
- factor,
- # Computing actions through policy.
- "mean_inference_ms": self.inference_time * factor,
- # Processing actions (to be sent to env, e.g. clipping).
- "mean_action_processing_ms": self.action_processing_time * factor,
- # Waiting for environment (during poll).
- "mean_env_wait_ms": self.env_wait_time * factor,
- # Environment rendering (False by default).
- "mean_env_render_ms": self.env_render_time * factor,
- }
- @DeveloperAPI
- class SamplerInput(InputReader, metaclass=ABCMeta):
- """Reads input experiences from an existing sampler."""
- @override(InputReader)
- def next(self) -> SampleBatchType:
- batches = [self.get_data()]
- batches.extend(self.get_extra_batches())
- if len(batches) > 1:
- return batches[0].concat_samples(batches)
- else:
- return batches[0]
- @abstractmethod
- @DeveloperAPI
- def get_data(self) -> SampleBatchType:
- """Called by `self.next()` to return the next batch of data.
- Override this in child classes.
- Returns:
- The next batch of data.
- """
- raise NotImplementedError
- @abstractmethod
- @DeveloperAPI
- def get_metrics(self) -> List[RolloutMetrics]:
- """Returns list of episode metrics since the last call to this method.
- The list will contain one RolloutMetrics object per completed episode.
- Returns:
- List of RolloutMetrics objects, one per completed episode since
- the last call to this method.
- """
- raise NotImplementedError
- @abstractmethod
- @DeveloperAPI
- def get_extra_batches(self) -> List[SampleBatchType]:
- """Returns list of extra batches since the last call to this method.
- The list will contain all SampleBatches or
- MultiAgentBatches that the user has provided thus-far. Users can
- add these "extra batches" to an episode by calling the episode's
- `add_extra_batch([SampleBatchType])` method. This can be done from
- inside an overridden `Policy.compute_actions_from_input_dict(...,
- episodes)` or from a custom callback's `on_episode_[start|step|end]()`
- methods.
- Returns:
- List of SamplesBatches or MultiAgentBatches provided thus-far by
- the user since the last call to this method.
- """
- raise NotImplementedError
- @DeveloperAPI
- class SyncSampler(SamplerInput):
- """Sync SamplerInput that collects experiences when `get_data()` is called.
- """
- def __init__(
- self,
- *,
- worker: "RolloutWorker",
- env: BaseEnv,
- clip_rewards: Union[bool, float],
- rollout_fragment_length: int,
- count_steps_by: str = "env_steps",
- callbacks: "DefaultCallbacks",
- horizon: int = None,
- multiple_episodes_in_batch: bool = False,
- normalize_actions: bool = True,
- clip_actions: bool = False,
- soft_horizon: bool = False,
- no_done_at_end: bool = False,
- observation_fn: Optional["ObservationFunction"] = None,
- sample_collector_class: Optional[Type[SampleCollector]] = None,
- render: bool = False,
- # Obsolete.
- policies=None,
- policy_mapping_fn=None,
- preprocessors=None,
- obs_filters=None,
- tf_sess=None,
- ):
- """Initializes a SyncSampler instance.
- Args:
- worker: The RolloutWorker that will use this Sampler for sampling.
- env: Any Env object. Will be converted into an RLlib BaseEnv.
- clip_rewards: True for +/-1.0 clipping,
- actual float value for +/- value clipping. False for no
- clipping.
- rollout_fragment_length: The length of a fragment to collect
- before building a SampleBatch from the data and resetting
- the SampleBatchBuilder object.
- count_steps_by: One of "env_steps" (default) or "agent_steps".
- Use "agent_steps", if you want rollout lengths to be counted
- by individual agent steps. In a multi-agent env,
- a single env_step contains one or more agent_steps, depending
- on how many agents are present at any given time in the
- ongoing episode.
- callbacks: The Callbacks object to use when episode
- events happen during rollout.
- horizon: Hard-reset the Env after this many timesteps.
- multiple_episodes_in_batch: Whether to pack multiple
- episodes into each batch. This guarantees batches will be
- exactly `rollout_fragment_length` in size.
- normalize_actions: Whether to normalize actions to the
- action space's bounds.
- clip_actions: Whether to clip actions according to the
- given action_space's bounds.
- soft_horizon: If True, calculate bootstrapped values as if
- episode had ended, but don't physically reset the environment
- when the horizon is hit.
- no_done_at_end: Ignore the done=True at the end of the
- episode and instead record done=False.
- observation_fn: Optional multi-agent observation func to use for
- preprocessing observations.
- sample_collector_class: An optional Samplecollector sub-class to
- use to collect, store, and retrieve environment-, model-,
- and sampler data.
- render: Whether to try to render the environment after each step.
- """
- # All of the following arguments are deprecated. They will instead be
- # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
- if log_once("deprecated_sync_sampler_args"):
- if policies is not None:
- deprecation_warning(old="policies")
- if policy_mapping_fn is not None:
- deprecation_warning(old="policy_mapping_fn")
- if preprocessors is not None:
- deprecation_warning(old="preprocessors")
- if obs_filters is not None:
- deprecation_warning(old="obs_filters")
- if tf_sess is not None:
- deprecation_warning(old="tf_sess")
- self.base_env = BaseEnv.to_base_env(env)
- self.rollout_fragment_length = rollout_fragment_length
- self.horizon = horizon
- self.extra_batches = queue.Queue()
- self.perf_stats = _PerfStats()
- if not sample_collector_class:
- sample_collector_class = SimpleListCollector
- self.sample_collector = sample_collector_class(
- worker.policy_map,
- clip_rewards,
- callbacks,
- multiple_episodes_in_batch,
- rollout_fragment_length,
- count_steps_by=count_steps_by)
- self.render = render
- # Create the rollout generator to use for calls to `get_data()`.
- self._env_runner = _env_runner(
- worker, self.base_env, self.extra_batches.put, self.horizon,
- normalize_actions, clip_actions, multiple_episodes_in_batch,
- callbacks, self.perf_stats, soft_horizon, no_done_at_end,
- observation_fn, self.sample_collector, self.render)
- self.metrics_queue = queue.Queue()
- @override(SamplerInput)
- def get_data(self) -> SampleBatchType:
- while True:
- item = next(self._env_runner)
- if isinstance(item, RolloutMetrics):
- self.metrics_queue.put(item)
- else:
- return item
- @override(SamplerInput)
- def get_metrics(self) -> List[RolloutMetrics]:
- completed = []
- while True:
- try:
- completed.append(self.metrics_queue.get_nowait()._replace(
- perf_stats=self.perf_stats.get()))
- except queue.Empty:
- break
- return completed
- @override(SamplerInput)
- def get_extra_batches(self) -> List[SampleBatchType]:
- extra = []
- while True:
- try:
- extra.append(self.extra_batches.get_nowait())
- except queue.Empty:
- break
- return extra
- @DeveloperAPI
- class AsyncSampler(threading.Thread, SamplerInput):
- """Async SamplerInput that collects experiences in thread and queues them.
- Once started, experiences are continuously collected in the background
- and put into a Queue, from where they can be unqueued by the caller
- of `get_data()`.
- """
- def __init__(
- self,
- *,
- worker: "RolloutWorker",
- env: BaseEnv,
- clip_rewards: Union[bool, float],
- rollout_fragment_length: int,
- count_steps_by: str = "env_steps",
- callbacks: "DefaultCallbacks",
- horizon: Optional[int] = None,
- multiple_episodes_in_batch: bool = False,
- normalize_actions: bool = True,
- clip_actions: bool = False,
- soft_horizon: bool = False,
- no_done_at_end: bool = False,
- observation_fn: Optional["ObservationFunction"] = None,
- sample_collector_class: Optional[Type[SampleCollector]] = None,
- render: bool = False,
- blackhole_outputs: bool = False,
- # Obsolete.
- policies=None,
- policy_mapping_fn=None,
- preprocessors=None,
- obs_filters=None,
- tf_sess=None,
- ):
- """Initializes an AsyncSampler instance.
- Args:
- worker: The RolloutWorker that will use this Sampler for sampling.
- env: Any Env object. Will be converted into an RLlib BaseEnv.
- clip_rewards: True for +/-1.0 clipping,
- actual float value for +/- value clipping. False for no
- clipping.
- rollout_fragment_length: The length of a fragment to collect
- before building a SampleBatch from the data and resetting
- the SampleBatchBuilder object.
- count_steps_by: One of "env_steps" (default) or "agent_steps".
- Use "agent_steps", if you want rollout lengths to be counted
- by individual agent steps. In a multi-agent env,
- a single env_step contains one or more agent_steps, depending
- on how many agents are present at any given time in the
- ongoing episode.
- horizon: Hard-reset the Env after this many timesteps.
- multiple_episodes_in_batch: Whether to pack multiple
- episodes into each batch. This guarantees batches will be
- exactly `rollout_fragment_length` in size.
- normalize_actions: Whether to normalize actions to the
- action space's bounds.
- clip_actions: Whether to clip actions according to the
- given action_space's bounds.
- blackhole_outputs: Whether to collect samples, but then
- not further process or store them (throw away all samples).
- soft_horizon: If True, calculate bootstrapped values as if
- episode had ended, but don't physically reset the environment
- when the horizon is hit.
- no_done_at_end: Ignore the done=True at the end of the
- episode and instead record done=False.
- observation_fn: Optional multi-agent observation func to use for
- preprocessing observations.
- sample_collector_class: An optional SampleCollector sub-class to
- use to collect, store, and retrieve environment-, model-,
- and sampler data.
- render: Whether to try to render the environment after each step.
- """
- # All of the following arguments are deprecated. They will instead be
- # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
- if log_once("deprecated_async_sampler_args"):
- if policies is not None:
- deprecation_warning(old="policies")
- if policy_mapping_fn is not None:
- deprecation_warning(old="policy_mapping_fn")
- if preprocessors is not None:
- deprecation_warning(old="preprocessors")
- if obs_filters is not None:
- deprecation_warning(old="obs_filters")
- if tf_sess is not None:
- deprecation_warning(old="tf_sess")
- self.worker = worker
- for _, f in worker.filters.items():
- assert getattr(f, "is_concurrent", False), \
- "Observation Filter must support concurrent updates."
- self.base_env = BaseEnv.to_base_env(env)
- threading.Thread.__init__(self)
- self.queue = queue.Queue(5)
- self.extra_batches = queue.Queue()
- self.metrics_queue = queue.Queue()
- self.rollout_fragment_length = rollout_fragment_length
- self.horizon = horizon
- self.clip_rewards = clip_rewards
- self.daemon = True
- self.multiple_episodes_in_batch = multiple_episodes_in_batch
- self.callbacks = callbacks
- self.normalize_actions = normalize_actions
- self.clip_actions = clip_actions
- self.blackhole_outputs = blackhole_outputs
- self.soft_horizon = soft_horizon
- self.no_done_at_end = no_done_at_end
- self.perf_stats = _PerfStats()
- self.shutdown = False
- self.observation_fn = observation_fn
- self.render = render
- if not sample_collector_class:
- sample_collector_class = SimpleListCollector
- self.sample_collector = sample_collector_class(
- self.worker.policy_map,
- self.clip_rewards,
- self.callbacks,
- self.multiple_episodes_in_batch,
- self.rollout_fragment_length,
- count_steps_by=count_steps_by)
- @override(threading.Thread)
- def run(self):
- try:
- self._run()
- except BaseException as e:
- self.queue.put(e)
- raise e
- def _run(self):
- if self.blackhole_outputs:
- queue_putter = (lambda x: None)
- extra_batches_putter = (lambda x: None)
- else:
- queue_putter = self.queue.put
- extra_batches_putter = (
- lambda x: self.extra_batches.put(x, timeout=600.0))
- env_runner = _env_runner(
- self.worker, self.base_env, extra_batches_putter, self.horizon,
- self.normalize_actions, self.clip_actions,
- self.multiple_episodes_in_batch, self.callbacks, self.perf_stats,
- self.soft_horizon, self.no_done_at_end, self.observation_fn,
- self.sample_collector, self.render)
- while not self.shutdown:
- # The timeout variable exists because apparently, if one worker
- # dies, the other workers won't die with it, unless the timeout is
- # set to some large number. This is an empirical observation.
- item = next(env_runner)
- if isinstance(item, RolloutMetrics):
- self.metrics_queue.put(item)
- else:
- queue_putter(item)
- @override(SamplerInput)
- def get_data(self) -> SampleBatchType:
- if not self.is_alive():
- raise RuntimeError("Sampling thread has died")
- rollout = self.queue.get(timeout=600.0)
- # Propagate errors.
- if isinstance(rollout, BaseException):
- raise rollout
- return rollout
- @override(SamplerInput)
- def get_metrics(self) -> List[RolloutMetrics]:
- completed = []
- while True:
- try:
- completed.append(self.metrics_queue.get_nowait()._replace(
- perf_stats=self.perf_stats.get()))
- except queue.Empty:
- break
- return completed
- @override(SamplerInput)
- def get_extra_batches(self) -> List[SampleBatchType]:
- extra = []
- while True:
- try:
- extra.append(self.extra_batches.get_nowait())
- except queue.Empty:
- break
- return extra
- def _env_runner(
- worker: "RolloutWorker",
- base_env: BaseEnv,
- extra_batch_callback: Callable[[SampleBatchType], None],
- horizon: Optional[int],
- normalize_actions: bool,
- clip_actions: bool,
- multiple_episodes_in_batch: bool,
- callbacks: "DefaultCallbacks",
- perf_stats: _PerfStats,
- soft_horizon: bool,
- no_done_at_end: bool,
- observation_fn: "ObservationFunction",
- sample_collector: Optional[SampleCollector] = None,
- render: bool = None,
- ) -> Iterator[SampleBatchType]:
- """This implements the common experience collection logic.
- Args:
- worker: Reference to the current rollout worker.
- base_env: Env implementing BaseEnv.
- extra_batch_callback: function to send extra batch data to.
- horizon: Horizon of the episode.
- multiple_episodes_in_batch: Whether to pack multiple
- episodes into each batch. This guarantees batches will be exactly
- `rollout_fragment_length` in size.
- normalize_actions: Whether to normalize actions to the action
- space's bounds.
- clip_actions: Whether to clip actions to the space range.
- callbacks: User callbacks to run on episode events.
- perf_stats: Record perf stats into this object.
- soft_horizon: Calculate rewards but don't reset the
- environment when the horizon is hit.
- no_done_at_end: Ignore the done=True at the end of the episode
- and instead record done=False.
- observation_fn: Optional multi-agent
- observation func to use for preprocessing observations.
- sample_collector: An optional
- SampleCollector object to use.
- render: Whether to try to render the environment after each
- step.
- Yields:
- Object containing state, action, reward, terminal condition,
- and other fields as dictated by `policy`.
- """
- # May be populated with used for image rendering
- simple_image_viewer: Optional["SimpleImageViewer"] = None
- # Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore
- # error and continue with max_episode_steps=None.
- max_episode_steps = None
- try:
- max_episode_steps = base_env.get_sub_environments()[
- 0].spec.max_episode_steps
- except Exception:
- pass
- # Trainer has a given `horizon` setting.
- if horizon:
- # `horizon` is larger than env's limit.
- if max_episode_steps and horizon > max_episode_steps:
- # Try to override the env's own max-step setting with our horizon.
- # If this won't work, throw an error.
- try:
- base_env.get_sub_environments()[
- 0].spec.max_episode_steps = horizon
- base_env.get_sub_environments()[0]._max_episode_steps = horizon
- except Exception:
- raise ValueError(
- "Your `horizon` setting ({}) is larger than the Env's own "
- "timestep limit ({}), which seems to be unsettable! Try "
- "to increase the Env's built-in limit to be at least as "
- "large as your wanted `horizon`.".format(
- horizon, max_episode_steps))
- # Otherwise, set Trainer's horizon to env's max-steps.
- elif max_episode_steps:
- horizon = max_episode_steps
- logger.debug(
- "No episode horizon specified, setting it to Env's limit ({}).".
- format(max_episode_steps))
- # No horizon/max_episode_steps -> Episodes may be infinitely long.
- else:
- horizon = float("inf")
- logger.debug("No episode horizon specified, assuming inf.")
- # Pool of batch builders, which can be shared across episodes to pack
- # trajectory data.
- batch_builder_pool: List[MultiAgentSampleBatchBuilder] = []
- def get_batch_builder():
- if batch_builder_pool:
- return batch_builder_pool.pop()
- else:
- return None
- def new_episode(env_id):
- episode = Episode(
- worker.policy_map,
- worker.policy_mapping_fn,
- get_batch_builder,
- extra_batch_callback,
- env_id=env_id,
- worker=worker,
- )
- # Call each policy's Exploration.on_episode_start method.
- # Note: This may break the exploration (e.g. ParameterNoise) of
- # policies in the `policy_map` that have not been recently used
- # (and are therefore stashed to disk). However, we certainly do not
- # want to loop through all (even stashed) policies here as that
- # would counter the purpose of the LRU policy caching.
- for p in worker.policy_map.cache.values():
- if getattr(p, "exploration", None) is not None:
- p.exploration.on_episode_start(
- policy=p,
- environment=base_env,
- episode=episode,
- tf_sess=p.get_session())
- callbacks.on_episode_start(
- worker=worker,
- base_env=base_env,
- policies=worker.policy_map,
- episode=episode,
- env_index=env_id,
- )
- return episode
- active_episodes: Dict[EnvID, Episode] = \
- NewEpisodeDefaultDict(new_episode)
- while True:
- perf_stats.iters += 1
- t0 = time.time()
- # Get observations from all ready agents.
- # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
- unfiltered_obs, rewards, dones, infos, off_policy_actions = \
- base_env.poll()
- perf_stats.env_wait_time += time.time() - t0
- if log_once("env_returns"):
- logger.info("Raw obs from env: {}".format(
- summarize(unfiltered_obs)))
- logger.info("Info return from env: {}".format(summarize(infos)))
- # Process observations and prepare for policy evaluation.
- t1 = time.time()
- # types: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]],
- # List[Union[RolloutMetrics, SampleBatchType]]
- active_envs, to_eval, outputs = \
- _process_observations(
- worker=worker,
- base_env=base_env,
- active_episodes=active_episodes,
- unfiltered_obs=unfiltered_obs,
- rewards=rewards,
- dones=dones,
- infos=infos,
- horizon=horizon,
- multiple_episodes_in_batch=multiple_episodes_in_batch,
- callbacks=callbacks,
- soft_horizon=soft_horizon,
- no_done_at_end=no_done_at_end,
- observation_fn=observation_fn,
- sample_collector=sample_collector,
- )
- perf_stats.raw_obs_processing_time += time.time() - t1
- for o in outputs:
- yield o
- # Do batched policy eval (accross vectorized envs).
- t2 = time.time()
- # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
- eval_results = _do_policy_eval(
- to_eval=to_eval,
- policies=worker.policy_map,
- sample_collector=sample_collector,
- active_episodes=active_episodes,
- )
- perf_stats.inference_time += time.time() - t2
- # Process results and update episode state.
- t3 = time.time()
- actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
- _process_policy_eval_results(
- to_eval=to_eval,
- eval_results=eval_results,
- active_episodes=active_episodes,
- active_envs=active_envs,
- off_policy_actions=off_policy_actions,
- policies=worker.policy_map,
- normalize_actions=normalize_actions,
- clip_actions=clip_actions,
- )
- perf_stats.action_processing_time += time.time() - t3
- # Return computed actions to ready envs. We also send to envs that have
- # taken off-policy actions; those envs are free to ignore the action.
- t4 = time.time()
- base_env.send_actions(actions_to_send)
- perf_stats.env_wait_time += time.time() - t4
- # Try to render the env, if required.
- if render:
- t5 = time.time()
- # Render can either return an RGB image (uint8 [w x h x 3] numpy
- # array) or take care of rendering itself (returning True).
- rendered = base_env.try_render()
- # Rendering returned an image -> Display it in a SimpleImageViewer.
- if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
- # ImageViewer not defined yet, try to create one.
- if simple_image_viewer is None:
- try:
- from gym.envs.classic_control.rendering import \
- SimpleImageViewer
- simple_image_viewer = SimpleImageViewer()
- except (ImportError, ModuleNotFoundError):
- render = False # disable rendering
- logger.warning(
- "Could not import gym.envs.classic_control."
- "rendering! Try `pip install gym[all]`.")
- if simple_image_viewer:
- simple_image_viewer.imshow(rendered)
- elif rendered not in [True, False, None]:
- raise ValueError(
- "The env's ({base_env}) `try_render()` method returned an"
- " unsupported value! Make sure you either return a "
- "uint8/w x h x 3 (RGB) image or handle rendering in a "
- "window and then return `True`.")
- perf_stats.env_render_time += time.time() - t5
- def _process_observations(
- *,
- worker: "RolloutWorker",
- base_env: BaseEnv,
- active_episodes: Dict[EnvID, Episode],
- unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
- rewards: Dict[EnvID, Dict[AgentID, float]],
- dones: Dict[EnvID, Dict[AgentID, bool]],
- infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
- horizon: int,
- multiple_episodes_in_batch: bool,
- callbacks: "DefaultCallbacks",
- soft_horizon: bool,
- no_done_at_end: bool,
- observation_fn: "ObservationFunction",
- sample_collector: SampleCollector,
- ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
- RolloutMetrics, SampleBatchType]]]:
- """Record new data from the environment and prepare for policy evaluation.
- Args:
- worker: Reference to the current rollout worker.
- base_env: Env implementing BaseEnv.
- active_episodes: Mapping from
- episode ID to currently ongoing Episode object.
- unfiltered_obs: Doubly keyed dict of env-ids -> agent ids
- -> unfiltered observation tensor, returned by a `BaseEnv.poll()`
- call.
- rewards: Doubly keyed dict of env-ids -> agent ids ->
- rewards tensor, returned by a `BaseEnv.poll()` call.
- dones: Doubly keyed dict of env-ids -> agent ids ->
- boolean done flags, returned by a `BaseEnv.poll()` call.
- infos: Doubly keyed dict of env-ids -> agent ids ->
- info dicts, returned by a `BaseEnv.poll()` call.
- horizon: Horizon of the episode.
- multiple_episodes_in_batch: Whether to pack multiple
- episodes into each batch. This guarantees batches will be exactly
- `rollout_fragment_length` in size.
- callbacks: User callbacks to run on episode events.
- soft_horizon: Calculate rewards but don't reset the
- environment when the horizon is hit.
- no_done_at_end: Ignore the done=True at the end of the episode
- and instead record done=False.
- observation_fn: Optional multi-agent
- observation func to use for preprocessing observations.
- sample_collector: The SampleCollector object
- used to store and retrieve environment samples.
- Returns:
- Tuple consisting of 1) active_envs: Set of non-terminated env ids.
- 2) to_eval: Map of policy_id to list of agent PolicyEvalData.
- 3) outputs: List of metrics and samples to return from the sampler.
- """
- # Output objects.
- active_envs: Set[EnvID] = set()
- to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
- outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
- # For each (vectorized) sub-environment.
- # types: EnvID, Dict[AgentID, EnvObsType]
- for env_id, all_agents_obs in unfiltered_obs.items():
- is_new_episode: bool = env_id not in active_episodes
- episode: Episode = active_episodes[env_id]
- if not is_new_episode:
- sample_collector.episode_step(episode)
- episode._add_agent_rewards(rewards[env_id])
- # Check episode termination conditions.
- if dones[env_id]["__all__"] or episode.length >= horizon:
- hit_horizon = (episode.length >= horizon
- and not dones[env_id]["__all__"])
- all_agents_done = True
- atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
- base_env)
- if atari_metrics is not None:
- for m in atari_metrics:
- outputs.append(
- m._replace(custom_metrics=episode.custom_metrics))
- else:
- outputs.append(
- RolloutMetrics(episode.length, episode.total_reward,
- dict(episode.agent_rewards),
- episode.custom_metrics, {},
- episode.hist_data, episode.media))
- # Check whether we have to create a fake-last observation
- # for some agents (the environment is not required to do so if
- # dones[__all__]=True).
- for ag_id in episode.get_agents():
- if not episode.last_done_for(
- ag_id) and ag_id not in all_agents_obs:
- # Create a fake (all-0s) observation.
- obs_sp = worker.policy_map[episode.policy_for(
- ag_id)].observation_space
- obs_sp = getattr(obs_sp, "original_space", obs_sp)
- all_agents_obs[ag_id] = tree.map_structure(
- np.zeros_like, obs_sp.sample())
- else:
- hit_horizon = False
- all_agents_done = False
- active_envs.add(env_id)
- # Custom observation function is applied before preprocessing.
- if observation_fn:
- all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
- agent_obs=all_agents_obs,
- worker=worker,
- base_env=base_env,
- policies=worker.policy_map,
- episode=episode)
- if not isinstance(all_agents_obs, dict):
- raise ValueError(
- "observe() must return a dict of agent observations")
- # For each agent in the environment.
- # types: AgentID, EnvObsType
- for agent_id, raw_obs in all_agents_obs.items():
- assert agent_id != "__all__"
- last_observation: EnvObsType = episode.last_observation_for(
- agent_id)
- agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
- # A new agent (initial obs) is already done -> Skip entirely.
- if last_observation is None and agent_done:
- continue
- policy_id: PolicyID = episode.policy_for(agent_id)
- preprocessor = _get_or_raise(worker.preprocessors, policy_id)
- prep_obs: EnvObsType = raw_obs
- if preprocessor is not None:
- prep_obs = preprocessor.transform(raw_obs)
- if log_once("prep_obs"):
- logger.info("Preprocessed obs: {}".format(
- summarize(prep_obs)))
- filtered_obs: EnvObsType = _get_or_raise(worker.filters,
- policy_id)(prep_obs)
- if log_once("filtered_obs"):
- logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
- episode._set_last_observation(agent_id, filtered_obs)
- episode._set_last_raw_obs(agent_id, raw_obs)
- episode._set_last_done(agent_id, agent_done)
- # Infos from the environment.
- agent_infos = infos[env_id].get(agent_id, {})
- episode._set_last_info(agent_id, agent_infos)
- # Record transition info if applicable.
- if last_observation is None:
- sample_collector.add_init_obs(episode, agent_id, env_id,
- policy_id, episode.length - 1,
- filtered_obs)
- elif agent_infos is None or agent_infos.get(
- "training_enabled", True):
- # Add actions, rewards, next-obs to collectors.
- values_dict = {
- SampleBatch.T: episode.length - 1,
- SampleBatch.ENV_ID: env_id,
- SampleBatch.AGENT_INDEX: episode._agent_index(agent_id),
- # Action (slot 0) taken at timestep t.
- SampleBatch.ACTIONS: episode.last_action_for(agent_id),
- # Reward received after taking a at timestep t.
- SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
- # After taking action=a, did we reach terminal?
- SampleBatch.DONES: (False
- if (no_done_at_end
- or (hit_horizon and soft_horizon))
- else agent_done),
- # Next observation.
- SampleBatch.NEXT_OBS: filtered_obs,
- }
- # Add extra-action-fetches (policy-inference infos) to
- # collectors.
- pol = worker.policy_map[policy_id]
- for key, value in episode.last_extra_action_outs_for(
- agent_id).items():
- if key in pol.view_requirements:
- values_dict[key] = value
- # Env infos for this agent.
- if "infos" in pol.view_requirements:
- values_dict["infos"] = agent_infos
- sample_collector.add_action_reward_next_obs(
- episode.episode_id, agent_id, env_id, policy_id,
- agent_done, values_dict)
- if not agent_done:
- item = PolicyEvalData(
- env_id, agent_id, filtered_obs, agent_infos, None
- if last_observation is None else
- episode.rnn_state_for(agent_id), None
- if last_observation is None else
- episode.last_action_for(agent_id), rewards[env_id].get(
- agent_id, 0.0))
- to_eval[policy_id].append(item)
- # Invoke the `on_episode_step` callback after the step is logged
- # to the episode.
- # Exception: The very first env.poll() call causes the env to get reset
- # (no step taken yet, just a single starting observation logged).
- # We need to skip this callback in this case.
- if episode.length > 0:
- callbacks.on_episode_step(
- worker=worker,
- base_env=base_env,
- policies=worker.policy_map,
- episode=episode,
- env_index=env_id)
- # Episode is done for all agents (dones[__all__] == True)
- # or we hit the horizon.
- if all_agents_done:
- is_done = dones[env_id]["__all__"]
- check_dones = is_done and not no_done_at_end
- # If, we are not allowed to pack the next episode into the same
- # SampleBatch (batch_mode=complete_episodes) -> Build the
- # MultiAgentBatch from a single episode and add it to "outputs".
- # Otherwise, just postprocess and continue collecting across
- # episodes.
- ma_sample_batch = sample_collector.postprocess_episode(
- episode,
- is_done=is_done or (hit_horizon and not soft_horizon),
- check_dones=check_dones,
- build=not multiple_episodes_in_batch)
- if ma_sample_batch:
- outputs.append(ma_sample_batch)
- # Call each (in-memory) policy's Exploration.on_episode_end
- # method.
- # Note: This may break the exploration (e.g. ParameterNoise) of
- # policies in the `policy_map` that have not been recently used
- # (and are therefore stashed to disk). However, we certainly do not
- # want to loop through all (even stashed) policies here as that
- # would counter the purpose of the LRU policy caching.
- for p in worker.policy_map.cache.values():
- if getattr(p, "exploration", None) is not None:
- p.exploration.on_episode_end(
- policy=p,
- environment=base_env,
- episode=episode,
- tf_sess=p.get_session())
- # Call custom on_episode_end callback.
- callbacks.on_episode_end(
- worker=worker,
- base_env=base_env,
- policies=worker.policy_map,
- episode=episode,
- env_index=env_id,
- )
- # Horizon hit and we have a soft horizon (no hard env reset).
- if hit_horizon and soft_horizon:
- episode.soft_reset()
- resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs
- else:
- del active_episodes[env_id]
- resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
- env_id)
- # Reset not supported, drop this env from the ready list.
- if resetted_obs is None:
- if horizon != float("inf"):
- raise ValueError(
- "Setting episode horizon requires reset() support "
- "from the environment.")
- # Creates a new episode if this is not async return.
- # If reset is async, we will get its result in some future poll.
- elif resetted_obs != ASYNC_RESET_RETURN:
- new_episode: Episode = active_episodes[env_id]
- if observation_fn:
- resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
- agent_obs=resetted_obs,
- worker=worker,
- base_env=base_env,
- policies=worker.policy_map,
- episode=new_episode)
- # types: AgentID, EnvObsType
- for agent_id, raw_obs in resetted_obs.items():
- policy_id: PolicyID = new_episode.policy_for(agent_id)
- preproccessor = _get_or_raise(worker.preprocessors,
- policy_id)
- prep_obs: EnvObsType = raw_obs
- if preproccessor is not None:
- prep_obs = preproccessor.transform(raw_obs)
- filtered_obs: EnvObsType = _get_or_raise(
- worker.filters, policy_id)(prep_obs)
- new_episode._set_last_raw_obs(agent_id, raw_obs)
- new_episode._set_last_observation(agent_id, filtered_obs)
- # Add initial obs to buffer.
- sample_collector.add_init_obs(
- new_episode, agent_id, env_id, policy_id,
- new_episode.length - 1, filtered_obs)
- item = PolicyEvalData(
- env_id, agent_id, filtered_obs,
- episode.last_info_for(agent_id) or {},
- episode.rnn_state_for(agent_id), None, 0.0)
- to_eval[policy_id].append(item)
- # Try to build something.
- if multiple_episodes_in_batch:
- sample_batches = \
- sample_collector.try_build_truncated_episode_multi_agent_batch()
- if sample_batches:
- outputs.extend(sample_batches)
- return active_envs, to_eval, outputs
- def _do_policy_eval(
- *,
- to_eval: Dict[PolicyID, List[PolicyEvalData]],
- policies: PolicyMap,
- sample_collector,
- active_episodes: Dict[EnvID, Episode],
- ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
- """Call compute_actions on collected episode/model data to get next action.
- Args:
- to_eval: Mapping of policy IDs to lists of PolicyEvalData objects
- (items in these lists will be the batch's items for the model
- forward pass).
- policies: Mapping from policy ID to Policy obj.
- sample_collector: The SampleCollector object to use.
- active_episodes: Mapping of EnvID to its currently active episode.
- Returns:
- Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
- """
- eval_results: Dict[PolicyID, TensorStructType] = {}
- if log_once("compute_actions_input"):
- logger.info("Inputs to compute_actions():\n\n{}\n".format(
- summarize(to_eval)))
- for policy_id, eval_data in to_eval.items():
- # In case the policyID has been removed from this worker, we need to
- # re-assign policy_id and re-lookup the Policy object to use.
- try:
- policy: Policy = _get_or_raise(policies, policy_id)
- except ValueError:
- # Important: Get the policy_mapping_fn from the active
- # Episode as the policy_mapping_fn from the worker may
- # have already been changed (mapping fn stay constant
- # within one episode).
- episode = active_episodes[eval_data[0].env_id]
- policy_id = episode.policy_mapping_fn(
- eval_data[0].agent_id, episode, worker=episode.worker)
- policy: Policy = _get_or_raise(policies, policy_id)
- input_dict = sample_collector.get_inference_input_dict(policy_id)
- eval_results[policy_id] = \
- policy.compute_actions_from_input_dict(
- input_dict,
- timestep=policy.global_timestep,
- episodes=[active_episodes[t.env_id] for t in eval_data])
- if log_once("compute_actions_result"):
- logger.info("Outputs of compute_actions():\n\n{}\n".format(
- summarize(eval_results)))
- return eval_results
- def _process_policy_eval_results(
- *,
- to_eval: Dict[PolicyID, List[PolicyEvalData]],
- eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
- dict]],
- active_episodes: Dict[EnvID, Episode],
- active_envs: Set[int],
- off_policy_actions: MultiEnvDict,
- policies: Dict[PolicyID, Policy],
- normalize_actions: bool,
- clip_actions: bool,
- ) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
- """Process the output of policy neural network evaluation.
- Records policy evaluation results into the given episode objects and
- returns replies to send back to agents in the env.
- Args:
- to_eval: Mapping of policy IDs to lists of PolicyEvalData objects.
- eval_results: Mapping of policy IDs to list of
- actions, rnn-out states, extra-action-fetches dicts.
- active_episodes: Mapping from episode ID to currently ongoing
- Episode object.
- active_envs: Set of non-terminated env ids.
- off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
- off-policy-action, returned by a `BaseEnv.poll()` call.
- policies: Mapping from policy ID to Policy.
- normalize_actions: Whether to normalize actions to the action
- space's bounds.
- clip_actions: Whether to clip actions to the action space's bounds.
- Returns:
- Nested dict of env id -> agent id -> actions to be sent to
- Env (np.ndarrays).
- """
- actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
- defaultdict(dict)
- # types: int
- for env_id in active_envs:
- actions_to_send[env_id] = {} # at minimum send empty dict
- # types: PolicyID, List[PolicyEvalData]
- for policy_id, eval_data in to_eval.items():
- actions: TensorStructType = eval_results[policy_id][0]
- actions = convert_to_numpy(actions)
- rnn_out_cols: StateBatch = eval_results[policy_id][1]
- extra_action_out_cols: dict = eval_results[policy_id][2]
- # In case actions is a list (representing the 0th dim of a batch of
- # primitive actions), try converting it first.
- if isinstance(actions, list):
- actions = np.array(actions)
- # Store RNN state ins/outs and extra-action fetches to episode.
- for f_i, column in enumerate(rnn_out_cols):
- extra_action_out_cols["state_out_{}".format(f_i)] = column
- policy: Policy = _get_or_raise(policies, policy_id)
- # Split action-component batches into single action rows.
- actions: List[EnvActionType] = unbatch(actions)
- # types: int, EnvActionType
- for i, action in enumerate(actions):
- # Normalize, if necessary.
- if normalize_actions:
- action_to_send = unsquash_action(action,
- policy.action_space_struct)
- # Clip, if necessary.
- elif clip_actions:
- action_to_send = clip_action(action,
- policy.action_space_struct)
- else:
- action_to_send = action
- env_id: int = eval_data[i].env_id
- agent_id: AgentID = eval_data[i].agent_id
- episode: Episode = active_episodes[env_id]
- episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
- episode._set_last_extra_action_outs(
- agent_id, {k: v[i]
- for k, v in extra_action_out_cols.items()})
- if env_id in off_policy_actions and \
- agent_id in off_policy_actions[env_id]:
- episode._set_last_action(agent_id,
- off_policy_actions[env_id][agent_id])
- else:
- episode._set_last_action(agent_id, action)
- assert agent_id not in actions_to_send[env_id]
- actions_to_send[env_id][agent_id] = action_to_send
- return actions_to_send
- def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
- """Atari games have multiple logical episodes, one per life.
- However, for metrics reporting we count full episodes, all lives included.
- """
- sub_environments = base_env.get_sub_environments()
- if not sub_environments:
- return None
- atari_out = []
- for sub_env in sub_environments:
- monitor = get_wrapper_by_cls(sub_env, MonitorEnv)
- if not monitor:
- return None
- for eps_rew, eps_len in monitor.next_episode_results():
- atari_out.append(RolloutMetrics(eps_len, eps_rew))
- return atari_out
- def _to_column_format(rnn_state_rows: List[List[Any]]) -> StateBatch:
- num_cols = len(rnn_state_rows[0])
- return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
- def _get_or_raise(mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]],
- policy_id: PolicyID) -> Union[Policy, Preprocessor, Filter]:
- """Returns an object under key `policy_id` in `mapping`.
- Args:
- mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
- mapping dict from policy id (str) to actual object (Policy,
- Preprocessor, etc.).
- policy_id (str): The policy ID to lookup.
- Returns:
- Union[Policy, Preprocessor, Filter]: The found object.
- Raises:
- ValueError: If `policy_id` cannot be found in `mapping`.
- """
- if policy_id not in mapping:
- raise ValueError(
- "Could not find policy for agent: PolicyID `{}` not found "
- "in policy map, whose keys are `{}`.".format(
- policy_id, mapping.keys()))
- return mapping[policy_id]
|