12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054 |
- import copy
- import importlib.util
- import logging
- import os
- import platform
- import threading
- from collections import defaultdict
- from types import FunctionType
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Container,
- Dict,
- List,
- Optional,
- Set,
- Tuple,
- Type,
- Union,
- )
- import numpy as np
- import tree # pip install dm_tree
- from gymnasium.spaces import Discrete, MultiDiscrete, Space
- import ray
- from ray import ObjectRef
- from ray import cloudpickle as pickle
- from ray.rllib.connectors.util import (
- create_connectors_for_policy,
- maybe_get_filters_for_syncing,
- )
- from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
- from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
- from ray.rllib.env.env_context import EnvContext
- from ray.rllib.env.env_runner import EnvRunner
- from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
- from ray.rllib.env.multi_agent_env import MultiAgentEnv
- from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind
- from ray.rllib.evaluation.metrics import RolloutMetrics
- from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
- from ray.rllib.models import ModelCatalog
- from ray.rllib.models.preprocessors import Preprocessor
- from ray.rllib.offline import (
- D4RLReader,
- DatasetReader,
- DatasetWriter,
- InputReader,
- IOContext,
- JsonReader,
- JsonWriter,
- MixedInput,
- NoopOutput,
- OutputWriter,
- ShuffledInput,
- )
- from ray.rllib.policy.policy import Policy, PolicySpec
- from ray.rllib.policy.policy_map import PolicyMap
- from ray.rllib.policy.sample_batch import (
- DEFAULT_POLICY_ID,
- MultiAgentBatch,
- concat_samples,
- convert_ma_batch_to_sample_batch,
- )
- from ray.rllib.policy.torch_policy import TorchPolicy
- from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
- from ray.rllib.utils import check_env, force_list
- from ray.rllib.utils.annotations import DeveloperAPI, override
- from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
- from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
- from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
- from ray.rllib.utils.filter import Filter, NoFilter, get_filter
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- from ray.rllib.utils.from_config import from_config
- from ray.rllib.utils.policy import create_policy_for_framework, validate_policy_id
- from ray.rllib.utils.sgd import do_minibatch_sgd
- from ray.rllib.utils.tf_run_builder import _TFRunBuilder
- from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
- from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
- from ray.rllib.utils.typing import (
- AgentID,
- EnvCreator,
- EnvType,
- ModelGradients,
- ModelWeights,
- MultiAgentPolicyConfigDict,
- PartialAlgorithmConfigDict,
- PolicyID,
- PolicyState,
- SampleBatchType,
- T,
- )
- from ray.tune.registry import registry_contains_input, registry_get_input
- from ray.util.annotations import PublicAPI
- from ray.util.debug import disable_log_once_globally, enable_periodic_logging, log_once
- from ray.util.iter import ParallelIteratorWorker
- if TYPE_CHECKING:
- from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
- from ray.rllib.algorithms.callbacks import DefaultCallbacks # noqa
- from ray.rllib.evaluation.episode import Episode
- tf1, tf, tfv = try_import_tf()
- torch, _ = try_import_torch()
- logger = logging.getLogger(__name__)
- # Handle to the current rollout worker, which will be set to the most recently
- # created RolloutWorker in this process. This can be helpful to access in
- # custom env or policy classes for debugging or advanced use cases.
- _global_worker: Optional["RolloutWorker"] = None
- @DeveloperAPI
- def get_global_worker() -> "RolloutWorker":
- """Returns a handle to the active rollout worker in this process."""
- global _global_worker
- return _global_worker
- def _update_env_seed_if_necessary(
- env: EnvType, seed: int, worker_idx: int, vector_idx: int
- ):
- """Set a deterministic random seed on environment.
- NOTE: this may not work with remote environments (issue #18154).
- """
- if seed is None:
- return
- # A single RL job is unlikely to have more than 10K
- # rollout workers.
- max_num_envs_per_workers: int = 1000
- assert (
- worker_idx < max_num_envs_per_workers
- ), "Too many envs per worker. Random seeds may collide."
- computed_seed: int = worker_idx * max_num_envs_per_workers + vector_idx + seed
- # Gymnasium.env.
- # This will silently fail for most Farama-foundation gymnasium environments.
- # (they do nothing and return None per default)
- if not hasattr(env, "reset"):
- if log_once("env_has_no_reset_method"):
- logger.info(f"Env {env} doesn't have a `reset()` method. Cannot seed.")
- else:
- try:
- env.reset(seed=computed_seed)
- except Exception:
- logger.info(
- f"Env {env} doesn't support setting a seed via its `reset()` "
- "method! Implement this method as `reset(self, *, seed=None, "
- "options=None)` for it to abide to the correct API. Cannot seed."
- )
- @DeveloperAPI
- class RolloutWorker(ParallelIteratorWorker, EnvRunner):
- """Common experience collection class.
- This class wraps a policy instance and an environment class to
- collect experiences from the environment. You can create many replicas of
- this class as Ray actors to scale RL training.
- This class supports vectorized and multi-agent policy evaluation (e.g.,
- VectorEnv, MultiAgentEnv, etc.)
- Examples:
- >>> # Create a rollout worker and using it to collect experiences.
- >>> import gymnasium as gym
- >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
- >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
- >>> worker = RolloutWorker( # doctest: +SKIP
- ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
- ... default_policy_class=PGTF1Policy) # doctest: +SKIP
- >>> print(worker.sample()) # doctest: +SKIP
- SampleBatch({
- "obs": [[...]], "actions": [[...]], "rewards": [[...]],
- "terminateds": [[...]], "truncateds": [[...]], "new_obs": [[...]]})
- >>> # Creating a multi-agent rollout worker
- >>> from gymnasium.spaces import Discrete, Box
- >>> import random
- >>> MultiAgentTrafficGrid = ... # doctest: +SKIP
- >>> worker = RolloutWorker( # doctest: +SKIP
- ... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
- ... config=AlgorithmConfig().multi_agent(
- ... policies={ # doctest: +SKIP
- ... # Use an ensemble of two policies for car agents
- ... "car_policy1": # doctest: +SKIP
- ... (PGTFPolicy, Box(...), Discrete(...),
- ... AlgorithmConfig.overrides(gamma=0.99)),
- ... "car_policy2": # doctest: +SKIP
- ... (PGTFPolicy, Box(...), Discrete(...),
- ... AlgorithmConfig.overrides(gamma=0.95)),
- ... # Use a single shared policy for all traffic lights
- ... "traffic_light_policy":
- ... (PGTFPolicy, Box(...), Discrete(...), {}),
- ... },
- ... policy_mapping_fn=(
- ... lambda agent_id, episode, **kwargs:
- ... random.choice(["car_policy1", "car_policy2"])
- ... if agent_id.startswith("car_") else "traffic_light_policy"),
- ... ),
- .. )
- >>> print(worker.sample()) # doctest: +SKIP
- MultiAgentBatch({
- "car_policy1": SampleBatch(...),
- "car_policy2": SampleBatch(...),
- "traffic_light_policy": SampleBatch(...)})
- """
- def __init__(
- self,
- *,
- env_creator: EnvCreator,
- validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
- config: Optional["AlgorithmConfig"] = None,
- worker_index: int = 0,
- num_workers: Optional[int] = None,
- recreated_worker: bool = False,
- log_dir: Optional[str] = None,
- spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
- default_policy_class: Optional[Type[Policy]] = None,
- dataset_shards: Optional[List[ray.data.Dataset]] = None,
- # Deprecated: This is all specified in `config` anyways.
- tf_session_creator=DEPRECATED_VALUE, # Use config.tf_session_options instead.
- ):
- """Initializes a RolloutWorker instance.
- Args:
- env_creator: Function that returns a gym.Env given an EnvContext
- wrapped configuration.
- validate_env: Optional callable to validate the generated
- environment (only on worker=0).
- worker_index: For remote workers, this should be set to a
- non-zero and unique value. This index is passed to created envs
- through EnvContext so that envs can be configured per worker.
- recreated_worker: Whether this worker is a recreated one. Workers are
- recreated by an Algorithm (via WorkerSet) in case
- `recreate_failed_workers=True` and one of the original workers (or an
- already recreated one) has failed. They don't differ from original
- workers other than the value of this flag (`self.recreated_worker`).
- log_dir: Directory where logs can be placed.
- spaces: An optional space dict mapping policy IDs
- to (obs_space, action_space)-tuples. This is used in case no
- Env is created on this RolloutWorker.
- """
- # Deprecated args.
- if tf_session_creator != DEPRECATED_VALUE:
- deprecation_warning(
- old="RolloutWorker(.., tf_session_creator=.., ..)",
- new="config.framework(tf_session_args={..}); "
- "RolloutWorker(config=config, ..)",
- error=True,
- )
- self._original_kwargs: dict = locals().copy()
- del self._original_kwargs["self"]
- global _global_worker
- _global_worker = self
- from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
- # Default config needed?
- if config is None or isinstance(config, dict):
- config = AlgorithmConfig().update_from_dict(config or {})
- # Freeze config, so no one else can alter it from here on.
- config.freeze()
- # Set extra python env variables before calling super constructor.
- if config.extra_python_environs_for_driver and worker_index == 0:
- for key, value in config.extra_python_environs_for_driver.items():
- os.environ[key] = str(value)
- elif config.extra_python_environs_for_worker and worker_index > 0:
- for key, value in config.extra_python_environs_for_worker.items():
- os.environ[key] = str(value)
- def gen_rollouts():
- while True:
- yield self.sample()
- ParallelIteratorWorker.__init__(self, gen_rollouts, False)
- EnvRunner.__init__(self, config=config)
- self.num_workers = (
- num_workers if num_workers is not None else self.config.num_rollout_workers
- )
- # In case we are reading from distributed datasets, store the shards here
- # and pick our shard by our worker-index.
- self._ds_shards = dataset_shards
- self.worker_index: int = worker_index
- # Lock to be able to lock this entire worker
- # (via `self.lock()` and `self.unlock()`).
- # This might be crucial to prevent a race condition in case
- # `config.policy_states_are_swappable=True` and you are using an Algorithm
- # with a learner thread. In this case, the thread might update a policy
- # that is being swapped (during the update) by the Algorithm's
- # training_step's `RolloutWorker.get_weights()` call (to sync back the
- # new weights to all remote workers).
- self._lock = threading.Lock()
- if (
- tf1
- and (config.framework_str == "tf2" or config.enable_tf1_exec_eagerly)
- # This eager check is necessary for certain all-framework tests
- # that use tf's eager_mode() context generator.
- and not tf1.executing_eagerly()
- ):
- tf1.enable_eager_execution()
- if self.config.log_level:
- logging.getLogger("ray.rllib").setLevel(self.config.log_level)
- if self.worker_index > 1:
- disable_log_once_globally() # only need 1 worker to log
- elif self.config.log_level == "DEBUG":
- enable_periodic_logging()
- env_context = EnvContext(
- self.config.env_config,
- worker_index=self.worker_index,
- vector_index=0,
- num_workers=self.num_workers,
- remote=self.config.remote_worker_envs,
- recreated_worker=recreated_worker,
- )
- self.env_context = env_context
- self.config: AlgorithmConfig = config
- self.callbacks: DefaultCallbacks = self.config.callbacks_class()
- self.recreated_worker: bool = recreated_worker
- # Setup current policy_mapping_fn. Start with the one from the config, which
- # might be None in older checkpoints (nowadays AlgorithmConfig has a proper
- # default for this); Need to cover this situation via the backup lambda here.
- self.policy_mapping_fn = (
- lambda agent_id, episode, worker, **kw: DEFAULT_POLICY_ID
- )
- self.set_policy_mapping_fn(self.config.policy_mapping_fn)
- self.env_creator: EnvCreator = env_creator
- # Resolve possible auto-fragment length.
- configured_rollout_fragment_length = self.config.get_rollout_fragment_length(
- worker_index=self.worker_index
- )
- self.total_rollout_fragment_length: int = (
- configured_rollout_fragment_length * self.config.num_envs_per_worker
- )
- self.preprocessing_enabled: bool = not config._disable_preprocessor_api
- self.last_batch: Optional[SampleBatchType] = None
- self.global_vars: dict = {
- # TODO(sven): Make this per-policy!
- "timestep": 0,
- # Counter for performed gradient updates per policy in `self.policy_map`.
- # Allows for compiling metrics on the off-policy'ness of an update given
- # that the number of gradient updates of the sampling policies are known
- # to the learner (and can be compared to the learner version of the same
- # policy).
- "num_grad_updates_per_policy": defaultdict(int),
- }
- # If seed is provided, add worker index to it and 10k iff evaluation worker.
- self.seed = (
- None
- if self.config.seed is None
- else self.config.seed
- + self.worker_index
- + self.config.in_evaluation * 10000
- )
- # Update the global seed for numpy/random/tf-eager/torch if we are not
- # the local worker, otherwise, this was already done in the Algorithm
- # object itself.
- if self.worker_index > 0:
- update_global_seed_if_necessary(self.config.framework_str, self.seed)
- # A single environment provided by the user (via config.env). This may
- # also remain None.
- # 1) Create the env using the user provided env_creator. This may
- # return a gym.Env (incl. MultiAgentEnv), an already vectorized
- # VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
- # 2) Wrap - if applicable - with Atari/rendering wrappers.
- # 3) Seed the env, if necessary.
- # 4) Vectorize the existing single env by creating more clones of
- # this env and wrapping it with the RLlib BaseEnv class.
- self.env = self.make_sub_env_fn = None
- # Create a (single) env for this worker.
- if not (
- self.worker_index == 0
- and self.num_workers > 0
- and not self.config.create_env_on_local_worker
- ):
- # Run the `env_creator` function passing the EnvContext.
- self.env = env_creator(copy.deepcopy(self.env_context))
- clip_rewards = self.config.clip_rewards
- if self.env is not None:
- # Validate environment (general validation function).
- if not self.config.disable_env_checking:
- check_env(self.env, self.config)
- # Custom validation function given, typically a function attribute of the
- # Algorithm.
- if validate_env is not None:
- validate_env(self.env, self.env_context)
- # We can't auto-wrap a BaseEnv.
- if isinstance(self.env, (BaseEnv, ray.actor.ActorHandle)):
- def wrap(env):
- return env
- # Atari type env and "deepmind" preprocessor pref.
- elif is_atari(self.env) and self.config.preprocessor_pref == "deepmind":
- # Deepmind wrappers already handle all preprocessing.
- self.preprocessing_enabled = False
- # If clip_rewards not explicitly set to False, switch it
- # on here (clip between -1.0 and 1.0).
- if self.config.clip_rewards is None:
- clip_rewards = True
- # Framestacking is used.
- use_framestack = self.config.model.get("framestack") is True
- def wrap(env):
- env = wrap_deepmind(
- env,
- dim=self.config.model.get("dim"),
- framestack=use_framestack,
- noframeskip=self.config.env_config.get("frameskip", 0) == 1,
- )
- return env
- elif self.config.preprocessor_pref is None:
- # Only turn off preprocessing
- self.preprocessing_enabled = False
- def wrap(env):
- return env
- else:
- def wrap(env):
- return env
- # Wrap env through the correct wrapper.
- self.env: EnvType = wrap(self.env)
- # Ideally, we would use the same make_sub_env() function below
- # to create self.env, but wrap(env) and self.env has a cyclic
- # dependency on each other right now, so we would settle on
- # duplicating the random seed setting logic for now.
- _update_env_seed_if_necessary(self.env, self.seed, self.worker_index, 0)
- # Call custom callback function `on_sub_environment_created`.
- self.callbacks.on_sub_environment_created(
- worker=self,
- sub_environment=self.env,
- env_context=self.env_context,
- )
- self.make_sub_env_fn = self._get_make_sub_env_fn(
- env_creator, env_context, validate_env, wrap, self.seed
- )
- self.spaces = spaces
- self.default_policy_class = default_policy_class
- self.policy_dict, self.is_policy_to_train = self.config.get_multi_agent_setup(
- env=self.env,
- spaces=self.spaces,
- default_policy_class=self.default_policy_class,
- )
- self.policy_map: Optional[PolicyMap] = None
- # TODO(jungong) : clean up after non-connector env_runner is fully deprecated.
- self.preprocessors: Dict[PolicyID, Preprocessor] = None
- # Check available number of GPUs.
- num_gpus = (
- self.config.num_gpus
- if self.worker_index == 0
- else self.config.num_gpus_per_worker
- )
- # This is only for the old API where local_worker was responsible for learning
- if not self.config._enable_learner_api:
- # Error if we don't find enough GPUs.
- if (
- ray.is_initialized()
- and ray._private.worker._mode() != ray._private.worker.LOCAL_MODE
- and not config._fake_gpus
- ):
- devices = []
- if self.config.framework_str in ["tf2", "tf"]:
- devices = get_tf_gpu_devices()
- elif self.config.framework_str == "torch":
- devices = list(range(torch.cuda.device_count()))
- if len(devices) < num_gpus:
- raise RuntimeError(
- ERR_MSG_NO_GPUS.format(len(devices), devices)
- + HOWTO_CHANGE_CONFIG
- )
- # Warn, if running in local-mode and actual GPUs (not faked) are
- # requested.
- elif (
- ray.is_initialized()
- and ray._private.worker._mode() == ray._private.worker.LOCAL_MODE
- and num_gpus > 0
- and not self.config._fake_gpus
- ):
- logger.warning(
- "You are running ray with `local_mode=True`, but have "
- f"configured {num_gpus} GPUs to be used! In local mode, "
- f"Policies are placed on the CPU and the `num_gpus` setting "
- f"is ignored."
- )
- self.filters: Dict[PolicyID, Filter] = defaultdict(NoFilter)
- # if RLModule API is enabled, marl_module_spec holds the specs of the RLModules
- self.marl_module_spec = None
- self._update_policy_map(policy_dict=self.policy_dict)
- # Update Policy's view requirements from Model, only if Policy directly
- # inherited from base `Policy` class. At this point here, the Policy
- # must have it's Model (if any) defined and ready to output an initial
- # state.
- for pol in self.policy_map.values():
- if not pol._model_init_state_automatically_added and not pol.config.get(
- "_enable_rl_module_api", False
- ):
- pol._update_model_view_requirements_from_init_state()
- self.multiagent: bool = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
- if self.multiagent and self.env is not None:
- if not isinstance(
- self.env,
- (BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv, ray.actor.ActorHandle),
- ):
- raise ValueError(
- f"Have multiple policies {self.policy_map}, but the "
- f"env {self.env} is not a subclass of BaseEnv, "
- f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!"
- )
- if self.worker_index == 0:
- logger.info("Built filter map: {}".format(self.filters))
- # This RolloutWorker has no env.
- if self.env is None:
- self.async_env = None
- # Use a custom env-vectorizer and call it providing self.env.
- elif "custom_vector_env" in self.config:
- self.async_env = self.config.custom_vector_env(self.env)
- # Default: Vectorize self.env via the make_sub_env function. This adds
- # further clones of self.env and creates a RLlib BaseEnv (which is
- # vectorized under the hood).
- else:
- # Always use vector env for consistency even if num_envs_per_worker=1.
- self.async_env: BaseEnv = convert_to_base_env(
- self.env,
- make_env=self.make_sub_env_fn,
- num_envs=self.config.num_envs_per_worker,
- remote_envs=self.config.remote_worker_envs,
- remote_env_batch_wait_ms=self.config.remote_env_batch_wait_ms,
- worker=self,
- restart_failed_sub_environments=(
- self.config.restart_failed_sub_environments
- ),
- )
- # `truncate_episodes`: Allow a batch to contain more than one episode
- # (fragments) and always make the batch `rollout_fragment_length`
- # long.
- rollout_fragment_length_for_sampler = configured_rollout_fragment_length
- if self.config.batch_mode == "truncate_episodes":
- pack = True
- # `complete_episodes`: Never cut episodes and sampler will return
- # exactly one (complete) episode per poll.
- else:
- assert self.config.batch_mode == "complete_episodes"
- rollout_fragment_length_for_sampler = float("inf")
- pack = False
- # Create the IOContext for this worker.
- self.io_context: IOContext = IOContext(
- log_dir, self.config, self.worker_index, self
- )
- render = False
- if self.config.render_env is True and (
- self.num_workers == 0 or self.worker_index == 1
- ):
- render = True
- if self.env is None:
- self.sampler = None
- elif self.config.sample_async:
- self.sampler = AsyncSampler(
- worker=self,
- env=self.async_env,
- clip_rewards=clip_rewards,
- rollout_fragment_length=rollout_fragment_length_for_sampler,
- count_steps_by=self.config.count_steps_by,
- callbacks=self.callbacks,
- multiple_episodes_in_batch=pack,
- normalize_actions=self.config.normalize_actions,
- clip_actions=self.config.clip_actions,
- observation_fn=self.config.observation_fn,
- sample_collector_class=self.config.sample_collector,
- render=render,
- )
- # Start the Sampler thread.
- self.sampler.start()
- else:
- self.sampler = SyncSampler(
- worker=self,
- env=self.async_env,
- clip_rewards=clip_rewards,
- rollout_fragment_length=rollout_fragment_length_for_sampler,
- count_steps_by=self.config.count_steps_by,
- callbacks=self.callbacks,
- multiple_episodes_in_batch=pack,
- normalize_actions=self.config.normalize_actions,
- clip_actions=self.config.clip_actions,
- observation_fn=self.config.observation_fn,
- sample_collector_class=self.config.sample_collector,
- render=render,
- )
- self.input_reader: InputReader = self._get_input_creator_from_config()(
- self.io_context
- )
- self.output_writer: OutputWriter = self._get_output_creator_from_config()(
- self.io_context
- )
- # The current weights sequence number (version). May remain None for when
- # not tracking weights versions.
- self.weights_seq_no: Optional[int] = None
- logger.debug(
- "Created rollout worker with env {} ({}), policies {}".format(
- self.async_env, self.env, self.policy_map
- )
- )
- @override(EnvRunner)
- def assert_healthy(self):
- is_healthy = self.policy_map and self.input_reader and self.output_writer
- assert is_healthy, (
- f"RolloutWorker {self} (idx={self.worker_index}; "
- f"num_workers={self.num_workers}) not healthy!"
- )
- @override(EnvRunner)
- def sample(self, **kwargs) -> SampleBatchType:
- """Returns a batch of experience sampled from this worker.
- This method must be implemented by subclasses.
- Returns:
- A columnar batch of experiences (e.g., tensors) or a MultiAgentBatch.
- Examples:
- >>> import gymnasium as gym
- >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
- >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
- >>> worker = RolloutWorker( # doctest: +SKIP
- ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
- ... default_policy_class=PGTF1Policy, # doctest: +SKIP
- ... config=AlgorithmConfig(), # doctest: +SKIP
- ... )
- >>> print(worker.sample()) # doctest: +SKIP
- SampleBatch({"obs": [...], "action": [...], ...})
- """
- if self.config.fake_sampler and self.last_batch is not None:
- return self.last_batch
- elif self.input_reader is None:
- raise ValueError(
- "RolloutWorker has no `input_reader` object! "
- "Cannot call `sample()`. You can try setting "
- "`create_env_on_driver` to True."
- )
- if log_once("sample_start"):
- logger.info(
- "Generating sample batch of size {}".format(
- self.total_rollout_fragment_length
- )
- )
- batches = [self.input_reader.next()]
- steps_so_far = (
- batches[0].count
- if self.config.count_steps_by == "env_steps"
- else batches[0].agent_steps()
- )
- # In truncate_episodes mode, never pull more than 1 batch per env.
- # This avoids over-running the target batch size.
- if (
- self.config.batch_mode == "truncate_episodes"
- and not self.config.offline_sampling
- ):
- max_batches = self.config.num_envs_per_worker
- else:
- max_batches = float("inf")
- while steps_so_far < self.total_rollout_fragment_length and (
- len(batches) < max_batches
- ):
- batch = self.input_reader.next()
- steps_so_far += (
- batch.count
- if self.config.count_steps_by == "env_steps"
- else batch.agent_steps()
- )
- batches.append(batch)
- batch = concat_samples(batches)
- self.callbacks.on_sample_end(worker=self, samples=batch)
- # Always do writes prior to compression for consistency and to allow
- # for better compression inside the writer.
- self.output_writer.write(batch)
- if log_once("sample_end"):
- logger.info("Completed sample batch:\n\n{}\n".format(summarize(batch)))
- if self.config.compress_observations:
- batch.compress(bulk=self.config.compress_observations == "bulk")
- if self.config.fake_sampler:
- self.last_batch = batch
- return batch
- @ray.method(num_returns=2)
- def sample_with_count(self) -> Tuple[SampleBatchType, int]:
- """Same as sample() but returns the count as a separate value.
- Returns:
- A columnar batch of experiences (e.g., tensors) and the
- size of the collected batch.
- Examples:
- >>> import gymnasium as gym
- >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
- >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
- >>> worker = RolloutWorker( # doctest: +SKIP
- ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
- ... default_policy_class=PGTFPolicy) # doctest: +SKIP
- >>> print(worker.sample_with_count()) # doctest: +SKIP
- (SampleBatch({"obs": [...], "action": [...], ...}), 3)
- """
- batch = self.sample()
- return batch, batch.count
- def learn_on_batch(self, samples: SampleBatchType) -> Dict:
- """Update policies based on the given batch.
- This is the equivalent to apply_gradients(compute_gradients(samples)),
- but can be optimized to avoid pulling gradients into CPU memory.
- Args:
- samples: The SampleBatch or MultiAgentBatch to learn on.
- Returns:
- Dictionary of extra metadata from compute_gradients().
- Examples:
- >>> import gymnasium as gym
- >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
- >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
- >>> worker = RolloutWorker( # doctest: +SKIP
- ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
- ... default_policy_class=PGTF1Policy) # doctest: +SKIP
- >>> batch = worker.sample() # doctest: +SKIP
- >>> info = worker.learn_on_batch(samples) # doctest: +SKIP
- """
- if log_once("learn_on_batch"):
- logger.info(
- "Training on concatenated sample batches:\n\n{}\n".format(
- summarize(samples)
- )
- )
- info_out = {}
- if isinstance(samples, MultiAgentBatch):
- builders = {}
- to_fetch = {}
- for pid, batch in samples.policy_batches.items():
- if self.is_policy_to_train is not None and not self.is_policy_to_train(
- pid, samples
- ):
- continue
- # Decompress SampleBatch, in case some columns are compressed.
- batch.decompress_if_needed()
- policy = self.policy_map[pid]
- tf_session = policy.get_session()
- if tf_session and hasattr(policy, "_build_learn_on_batch"):
- builders[pid] = _TFRunBuilder(tf_session, "learn_on_batch")
- to_fetch[pid] = policy._build_learn_on_batch(builders[pid], batch)
- else:
- info_out[pid] = policy.learn_on_batch(batch)
- info_out.update({pid: builders[pid].get(v) for pid, v in to_fetch.items()})
- else:
- if self.is_policy_to_train is None or self.is_policy_to_train(
- DEFAULT_POLICY_ID, samples
- ):
- info_out.update(
- {
- DEFAULT_POLICY_ID: self.policy_map[
- DEFAULT_POLICY_ID
- ].learn_on_batch(samples)
- }
- )
- if log_once("learn_out"):
- logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
- return info_out
- def sample_and_learn(
- self,
- expected_batch_size: int,
- num_sgd_iter: int,
- sgd_minibatch_size: str,
- standardize_fields: List[str],
- ) -> Tuple[dict, int]:
- """Sample and batch and learn on it.
- This is typically used in combination with distributed allreduce.
- Args:
- expected_batch_size: Expected number of samples to learn on.
- num_sgd_iter: Number of SGD iterations.
- sgd_minibatch_size: SGD minibatch size.
- standardize_fields: List of sample fields to normalize.
- Returns:
- A tuple consisting of a dictionary of extra metadata returned from
- the policies' `learn_on_batch()` and the number of samples
- learned on.
- """
- batch = self.sample()
- assert batch.count == expected_batch_size, (
- "Batch size possibly out of sync between workers, expected:",
- expected_batch_size,
- "got:",
- batch.count,
- )
- logger.info(
- "Executing distributed minibatch SGD "
- "with epoch size {}, minibatch size {}".format(
- batch.count, sgd_minibatch_size
- )
- )
- info = do_minibatch_sgd(
- batch,
- self.policy_map,
- self,
- num_sgd_iter,
- sgd_minibatch_size,
- standardize_fields,
- )
- return info, batch.count
- def compute_gradients(
- self,
- samples: SampleBatchType,
- single_agent: bool = None,
- ) -> Tuple[ModelGradients, dict]:
- """Returns a gradient computed w.r.t the specified samples.
- Uses the Policy's/ies' compute_gradients method(s) to perform the
- calculations. Skips policies that are not trainable as per
- `self.is_policy_to_train()`.
- Args:
- samples: The SampleBatch or MultiAgentBatch to compute gradients
- for using this worker's trainable policies.
- Returns:
- In the single-agent case, a tuple consisting of ModelGradients and
- info dict of the worker's policy.
- In the multi-agent case, a tuple consisting of a dict mapping
- PolicyID to ModelGradients and a dict mapping PolicyID to extra
- metadata info.
- Note that the first return value (grads) can be applied as is to a
- compatible worker using the worker's `apply_gradients()` method.
- Examples:
- >>> import gymnasium as gym
- >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
- >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
- >>> worker = RolloutWorker( # doctest: +SKIP
- ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
- ... default_policy_class=PGTF1Policy) # doctest: +SKIP
- >>> batch = worker.sample() # doctest: +SKIP
- >>> grads, info = worker.compute_gradients(samples) # doctest: +SKIP
- """
- if log_once("compute_gradients"):
- logger.info("Compute gradients on:\n\n{}\n".format(summarize(samples)))
- if single_agent is True:
- samples = convert_ma_batch_to_sample_batch(samples)
- grad_out, info_out = self.policy_map[DEFAULT_POLICY_ID].compute_gradients(
- samples
- )
- info_out["batch_count"] = samples.count
- return grad_out, info_out
- # Treat everything as is multi-agent.
- samples = samples.as_multi_agent()
- # Calculate gradients for all policies.
- grad_out, info_out = {}, {}
- if self.config.framework_str == "tf":
- for pid, batch in samples.policy_batches.items():
- if self.is_policy_to_train is not None and not self.is_policy_to_train(
- pid, samples
- ):
- continue
- policy = self.policy_map[pid]
- builder = _TFRunBuilder(policy.get_session(), "compute_gradients")
- grad_out[pid], info_out[pid] = policy._build_compute_gradients(
- builder, batch
- )
- grad_out = {k: builder.get(v) for k, v in grad_out.items()}
- info_out = {k: builder.get(v) for k, v in info_out.items()}
- else:
- for pid, batch in samples.policy_batches.items():
- if self.is_policy_to_train is not None and not self.is_policy_to_train(
- pid, samples
- ):
- continue
- grad_out[pid], info_out[pid] = self.policy_map[pid].compute_gradients(
- batch
- )
- info_out["batch_count"] = samples.count
- if log_once("grad_out"):
- logger.info("Compute grad info:\n\n{}\n".format(summarize(info_out)))
- return grad_out, info_out
- def apply_gradients(
- self,
- grads: Union[ModelGradients, Dict[PolicyID, ModelGradients]],
- ) -> None:
- """Applies the given gradients to this worker's models.
- Uses the Policy's/ies' apply_gradients method(s) to perform the
- operations.
- Args:
- grads: Single ModelGradients (single-agent case) or a dict
- mapping PolicyIDs to the respective model gradients
- structs.
- Examples:
- >>> import gymnasium as gym
- >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
- >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
- >>> worker = RolloutWorker( # doctest: +SKIP
- ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
- ... default_policy_class=PGTF1Policy) # doctest: +SKIP
- >>> samples = worker.sample() # doctest: +SKIP
- >>> grads, info = worker.compute_gradients(samples) # doctest: +SKIP
- >>> worker.apply_gradients(grads) # doctest: +SKIP
- """
- if log_once("apply_gradients"):
- logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
- # Grads is a dict (mapping PolicyIDs to ModelGradients).
- # Multi-agent case.
- if isinstance(grads, dict):
- for pid, g in grads.items():
- if self.is_policy_to_train is None or self.is_policy_to_train(
- pid, None
- ):
- self.policy_map[pid].apply_gradients(g)
- # Grads is a ModelGradients type. Single-agent case.
- elif self.is_policy_to_train is None or self.is_policy_to_train(
- DEFAULT_POLICY_ID, None
- ):
- self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
- def get_metrics(self) -> List[RolloutMetrics]:
- """Returns the thus-far collected metrics from this worker's rollouts.
- Returns:
- List of RolloutMetrics collected thus-far.
- """
- # Get metrics from sampler (if any).
- if self.sampler is not None:
- out = self.sampler.get_metrics()
- else:
- out = []
- return out
- def foreach_env(self, func: Callable[[EnvType], T]) -> List[T]:
- """Calls the given function with each sub-environment as arg.
- Args:
- func: The function to call for each underlying
- sub-environment (as only arg).
- Returns:
- The list of return values of all calls to `func([env])`.
- """
- if self.async_env is None:
- return []
- envs = self.async_env.get_sub_environments()
- # Empty list (not implemented): Call function directly on the
- # BaseEnv.
- if not envs:
- return [func(self.async_env)]
- # Call function on all underlying (vectorized) sub environments.
- else:
- return [func(e) for e in envs]
- def foreach_env_with_context(
- self, func: Callable[[EnvType, EnvContext], T]
- ) -> List[T]:
- """Calls given function with each sub-env plus env_ctx as args.
- Args:
- func: The function to call for each underlying
- sub-environment and its EnvContext (as the args).
- Returns:
- The list of return values of all calls to `func([env, ctx])`.
- """
- if self.async_env is None:
- return []
- envs = self.async_env.get_sub_environments()
- # Empty list (not implemented): Call function directly on the
- # BaseEnv.
- if not envs:
- return [func(self.async_env, self.env_context)]
- # Call function on all underlying (vectorized) sub environments.
- else:
- ret = []
- for i, e in enumerate(envs):
- ctx = self.env_context.copy_with_overrides(vector_index=i)
- ret.append(func(e, ctx))
- return ret
- def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Optional[Policy]:
- """Return policy for the specified id, or None.
- Args:
- policy_id: ID of the policy to return. None for DEFAULT_POLICY_ID
- (in the single agent case).
- Returns:
- The policy under the given ID (or None if not found).
- """
- return self.policy_map.get(policy_id)
- def add_policy(
- self,
- policy_id: PolicyID,
- policy_cls: Optional[Type[Policy]] = None,
- policy: Optional[Policy] = None,
- *,
- observation_space: Optional[Space] = None,
- action_space: Optional[Space] = None,
- config: Optional[PartialAlgorithmConfigDict] = None,
- policy_state: Optional[PolicyState] = None,
- policy_mapping_fn: Optional[Callable[[AgentID, "Episode"], PolicyID]] = None,
- policies_to_train: Optional[
- Union[Container[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
- ] = None,
- module_spec: Optional[SingleAgentRLModuleSpec] = None,
- ) -> Policy:
- """Adds a new policy to this RolloutWorker.
- Args:
- policy_id: ID of the policy to add.
- policy_cls: The Policy class to use for constructing the new Policy.
- Note: Only one of `policy_cls` or `policy` must be provided.
- policy: The Policy instance to add to this algorithm.
- Note: Only one of `policy_cls` or `policy` must be provided.
- observation_space: The observation space of the policy to add.
- action_space: The action space of the policy to add.
- config: The config overrides for the policy to add.
- policy_state: Optional state dict to apply to the new
- policy instance, right after its construction.
- policy_mapping_fn: An optional (updated) policy mapping function
- to use from here on. Note that already ongoing episodes will
- not change their mapping but will use the old mapping till
- the end of the episode.
- policies_to_train: An optional container of policy IDs to be
- trained or a callable taking PolicyID and - optionally -
- SampleBatchType and returning a bool (trainable or not?).
- If None, will keep the existing setup in place.
- Policies, whose IDs are not in the list (or for which the
- callable returns False) will not be updated.
- module_spec: In the new RLModule API we need to pass in the module_spec for
- the new module that is supposed to be added. Knowing the policy spec is
- not sufficient.
- Returns:
- The newly added policy.
- Raises:
- ValueError: If both `policy_cls` AND `policy` are provided.
- KeyError: If the given `policy_id` already exists in this worker's
- PolicyMap.
- """
- validate_policy_id(policy_id, error=False)
- if module_spec is not None and not self.config._enable_rl_module_api:
- raise ValueError(
- "If you pass in module_spec to the policy, the RLModule API needs "
- "to be enabled."
- )
- if policy_id in self.policy_map:
- raise KeyError(
- f"Policy ID '{policy_id}' already exists in policy map! "
- "Make sure you use a Policy ID that has not been taken yet."
- " Policy IDs that are already in your policy map: "
- f"{list(self.policy_map.keys())}"
- )
- if (policy_cls is None) == (policy is None):
- raise ValueError(
- "Only one of `policy_cls` or `policy` must be provided to "
- "RolloutWorker.add_policy()!"
- )
- if policy is None:
- policy_dict_to_add, _ = self.config.get_multi_agent_setup(
- policies={
- policy_id: PolicySpec(
- policy_cls, observation_space, action_space, config
- )
- },
- env=self.env,
- spaces=self.spaces,
- default_policy_class=self.default_policy_class,
- )
- else:
- policy_dict_to_add = {
- policy_id: PolicySpec(
- type(policy),
- policy.observation_space,
- policy.action_space,
- policy.config,
- )
- }
- self.policy_dict.update(policy_dict_to_add)
- self._update_policy_map(
- policy_dict=policy_dict_to_add,
- policy=policy,
- policy_states={policy_id: policy_state},
- single_agent_rl_module_spec=module_spec,
- )
- self.set_policy_mapping_fn(policy_mapping_fn)
- if policies_to_train is not None:
- self.set_is_policy_to_train(policies_to_train)
- return self.policy_map[policy_id]
- def remove_policy(
- self,
- *,
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
- policies_to_train: Optional[
- Union[Container[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
- ] = None,
- ) -> None:
- """Removes a policy from this RolloutWorker.
- Args:
- policy_id: ID of the policy to be removed. None for
- DEFAULT_POLICY_ID.
- policy_mapping_fn: An optional (updated) policy mapping function
- to use from here on. Note that already ongoing episodes will
- not change their mapping but will use the old mapping till
- the end of the episode.
- policies_to_train: An optional container of policy IDs to be
- trained or a callable taking PolicyID and - optionally -
- SampleBatchType and returning a bool (trainable or not?).
- If None, will keep the existing setup in place.
- Policies, whose IDs are not in the list (or for which the
- callable returns False) will not be updated.
- """
- if policy_id not in self.policy_map:
- raise ValueError(f"Policy ID '{policy_id}' not in policy map!")
- del self.policy_map[policy_id]
- del self.preprocessors[policy_id]
- self.set_policy_mapping_fn(policy_mapping_fn)
- if policies_to_train is not None:
- self.set_is_policy_to_train(policies_to_train)
- def set_policy_mapping_fn(
- self,
- policy_mapping_fn: Optional[Callable[[AgentID, "Episode"], PolicyID]] = None,
- ) -> None:
- """Sets `self.policy_mapping_fn` to a new callable (if provided).
- Args:
- policy_mapping_fn: The new mapping function to use. If None,
- will keep the existing mapping function in place.
- """
- if policy_mapping_fn is not None:
- self.policy_mapping_fn = policy_mapping_fn
- if not callable(self.policy_mapping_fn):
- raise ValueError("`policy_mapping_fn` must be a callable!")
- def set_is_policy_to_train(
- self,
- is_policy_to_train: Union[
- Container[PolicyID], Callable[[PolicyID, Optional[SampleBatchType]], bool]
- ],
- ) -> None:
- """Sets `self.is_policy_to_train()` to a new callable.
- Args:
- is_policy_to_train: A container of policy IDs to be
- trained or a callable taking PolicyID and - optionally -
- SampleBatchType and returning a bool (trainable or not?).
- If None, will keep the existing setup in place.
- Policies, whose IDs are not in the list (or for which the
- callable returns False) will not be updated.
- """
- # If container given, construct a simple default callable returning True
- # if the PolicyID is found in the list/set of IDs.
- if not callable(is_policy_to_train):
- assert isinstance(is_policy_to_train, (list, set, tuple)), (
- "ERROR: `is_policy_to_train`must be a [list|set|tuple] or a "
- "callable taking PolicyID and SampleBatch and returning "
- "True|False (trainable or not?)."
- )
- pols = set(is_policy_to_train)
- def is_policy_to_train(pid, batch=None):
- return pid in pols
- self.is_policy_to_train = is_policy_to_train
- @PublicAPI(stability="alpha")
- def get_policies_to_train(
- self, batch: Optional[SampleBatchType] = None
- ) -> Set[PolicyID]:
- """Returns all policies-to-train, given an optional batch.
- Loops through all policies currently in `self.policy_map` and checks
- the return value of `self.is_policy_to_train(pid, batch)`.
- Args:
- batch: An optional SampleBatchType for the
- `self.is_policy_to_train(pid, [batch]?)` check.
- Returns:
- The set of currently trainable policy IDs, given the optional
- `batch`.
- """
- return {
- pid
- for pid in self.policy_map.keys()
- if self.is_policy_to_train is None or self.is_policy_to_train(pid, batch)
- }
- def for_policy(
- self,
- func: Callable[[Policy, Optional[Any]], T],
- policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID,
- **kwargs,
- ) -> T:
- """Calls the given function with the specified policy as first arg.
- Args:
- func: The function to call with the policy as first arg.
- policy_id: The PolicyID of the policy to call the function with.
- Keyword Args:
- kwargs: Additional kwargs to be passed to the call.
- Returns:
- The return value of the function call.
- """
- return func(self.policy_map[policy_id], **kwargs)
- def foreach_policy(
- self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs
- ) -> List[T]:
- """Calls the given function with each (policy, policy_id) tuple.
- Args:
- func: The function to call with each (policy, policy ID) tuple.
- Keyword Args:
- kwargs: Additional kwargs to be passed to the call.
- Returns:
- The list of return values of all calls to
- `func([policy, pid, **kwargs])`.
- """
- return [func(policy, pid, **kwargs) for pid, policy in self.policy_map.items()]
- def foreach_policy_to_train(
- self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs
- ) -> List[T]:
- """
- Calls the given function with each (policy, policy_id) tuple.
- Only those policies/IDs will be called on, for which
- `self.is_policy_to_train()` returns True.
- Args:
- func: The function to call with each (policy, policy ID) tuple,
- for only those policies that `self.is_policy_to_train`
- returns True.
- Keyword Args:
- kwargs: Additional kwargs to be passed to the call.
- Returns:
- The list of return values of all calls to
- `func([policy, pid, **kwargs])`.
- """
- return [
- # Make sure to only iterate over keys() and not items(). Iterating over
- # items will access policy_map elements even for pids that we do not need,
- # i.e. those that are not in policy_to_train. Access to policy_map elements
- # can cause disk access for policies that were offloaded to disk. Since
- # these policies will be skipped in the for-loop accessing them is
- # unnecessary, making subsequent disk access unnecessary.
- func(self.policy_map[pid], pid, **kwargs)
- for pid in self.policy_map.keys()
- if self.is_policy_to_train is None or self.is_policy_to_train(pid, None)
- ]
- def sync_filters(self, new_filters: dict) -> None:
- """Changes self's filter to given and rebases any accumulated delta.
- Args:
- new_filters: Filters with new state to update local copy.
- """
- assert all(k in new_filters for k in self.filters)
- for k in self.filters:
- self.filters[k].sync(new_filters[k])
- def get_filters(self, flush_after: bool = False) -> Dict:
- """Returns a snapshot of filters.
- Args:
- flush_after: Clears the filter buffer state.
- Returns:
- Dict for serializable filters
- """
- return_filters = {}
- for k, f in self.filters.items():
- return_filters[k] = f.as_serializable()
- if flush_after:
- f.reset_buffer()
- return return_filters
- @override(EnvRunner)
- def get_state(self) -> dict:
- filters = self.get_filters(flush_after=True)
- policy_states = {}
- for pid in self.policy_map.keys():
- # If required by the user, only capture policies that are actually
- # trainable. Otherwise, capture all policies (for saving to disk).
- if (
- not self.config.checkpoint_trainable_policies_only
- or self.is_policy_to_train is None
- or self.is_policy_to_train(pid)
- ):
- policy_states[pid] = self.policy_map[pid].get_state()
- return {
- # List all known policy IDs here for convenience. When an Algorithm gets
- # restored from a checkpoint, it will not have access to the list of
- # possible IDs as each policy is stored in its own sub-dir
- # (see "policy_states").
- "policy_ids": list(self.policy_map.keys()),
- # Note that this field will not be stored in the algorithm checkpoint's
- # state file, but each policy will get its own state file generated in
- # a sub-dir within the algo's checkpoint dir.
- "policy_states": policy_states,
- # Also store current mapping fn and which policies to train.
- "policy_mapping_fn": self.policy_mapping_fn,
- "is_policy_to_train": self.is_policy_to_train,
- # TODO: Filters will be replaced by connectors.
- "filters": filters,
- }
- @override(EnvRunner)
- def set_state(self, state: dict) -> None:
- # Backward compatibility (old checkpoints' states would have the local
- # worker state as a bytes object, not a dict).
- if isinstance(state, bytes):
- state = pickle.loads(state)
- # TODO: Once filters are handled by connectors, get rid of the "filters"
- # key in `state` entirely (will be part of the policies then).
- self.sync_filters(state["filters"])
- connector_enabled = self.config.enable_connectors
- # Support older checkpoint versions (< 1.0), in which the policy_map
- # was stored under the "state" key, not "policy_states".
- policy_states = (
- state["policy_states"] if "policy_states" in state else state["state"]
- )
- for pid, policy_state in policy_states.items():
- # If - for some reason - we have an invalid PolicyID in the state,
- # this might be from an older checkpoint (pre v1.0). Just warn here.
- validate_policy_id(pid, error=False)
- if pid not in self.policy_map:
- spec = policy_state.get("policy_spec", None)
- if spec is None:
- logger.warning(
- f"PolicyID '{pid}' was probably added on-the-fly (not"
- " part of the static `multagent.policies` config) and"
- " no PolicySpec objects found in the pickled policy "
- f"state. Will not add `{pid}`, but ignore it for now."
- )
- else:
- policy_spec = (
- PolicySpec.deserialize(spec)
- if connector_enabled or isinstance(spec, dict)
- else spec
- )
- self.add_policy(
- policy_id=pid,
- policy_cls=policy_spec.policy_class,
- observation_space=policy_spec.observation_space,
- action_space=policy_spec.action_space,
- config=policy_spec.config,
- )
- if pid in self.policy_map:
- self.policy_map[pid].set_state(policy_state)
- # Also restore mapping fn and which policies to train.
- if "policy_mapping_fn" in state:
- self.set_policy_mapping_fn(state["policy_mapping_fn"])
- if state.get("is_policy_to_train") is not None:
- self.set_is_policy_to_train(state["is_policy_to_train"])
- def get_weights(
- self,
- policies: Optional[Container[PolicyID]] = None,
- ) -> Dict[PolicyID, ModelWeights]:
- """Returns each policies' model weights of this worker.
- Args:
- policies: List of PolicyIDs to get the weights from.
- Use None for all policies.
- Returns:
- Dict mapping PolicyIDs to ModelWeights.
- Examples:
- >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
- >>> # Create a RolloutWorker.
- >>> worker = ... # doctest: +SKIP
- >>> weights = worker.get_weights() # doctest: +SKIP
- >>> print(weights) # doctest: +SKIP
- {"default_policy": {"layer1": array(...), "layer2": ...}}
- """
- if policies is None:
- policies = list(self.policy_map.keys())
- policies = force_list(policies)
- return {
- # Make sure to only iterate over keys() and not items(). Iterating over
- # items will access policy_map elements even for pids that we do not need,
- # i.e. those that are not in policies. Access to policy_map elements can
- # cause disk access for policies that were offloaded to disk. Since these
- # policies will be skipped in the for-loop accessing them is unnecessary,
- # making subsequent disk access unnecessary.
- pid: self.policy_map[pid].get_weights()
- for pid in self.policy_map.keys()
- if pid in policies
- }
- def set_weights(
- self,
- weights: Dict[PolicyID, ModelWeights],
- global_vars: Optional[Dict] = None,
- weights_seq_no: Optional[int] = None,
- ) -> None:
- """Sets each policies' model weights of this worker.
- Args:
- weights: Dict mapping PolicyIDs to the new weights to be used.
- global_vars: An optional global vars dict to set this
- worker to. If None, do not update the global_vars.
- weights_seq_no: If needed, a sequence number for the weights version
- can be passed into this method. If not None, will store this seq no
- (in self.weights_seq_no) and in future calls - if the seq no did not
- change wrt. the last call - will ignore the call to save on performance.
- Examples:
- >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
- >>> # Create a RolloutWorker.
- >>> worker = ... # doctest: +SKIP
- >>> weights = worker.get_weights() # doctest: +SKIP
- >>> # Set `global_vars` (timestep) as well.
- >>> worker.set_weights(weights, {"timestep": 42}) # doctest: +SKIP
- """
- # Only update our weights, if no seq no given OR given seq no is different
- # from ours.
- if weights_seq_no is None or weights_seq_no != self.weights_seq_no:
- # If per-policy weights are object refs, `ray.get()` them first.
- if weights and isinstance(next(iter(weights.values())), ObjectRef):
- actual_weights = ray.get(list(weights.values()))
- weights = {
- pid: actual_weights[i] for i, pid in enumerate(weights.keys())
- }
- for pid, w in weights.items():
- self.policy_map[pid].set_weights(w)
- self.weights_seq_no = weights_seq_no
- if global_vars:
- self.set_global_vars(global_vars)
- def get_global_vars(self) -> dict:
- """Returns the current `self.global_vars` dict of this RolloutWorker.
- Returns:
- The current `self.global_vars` dict of this RolloutWorker.
- Examples:
- >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
- >>> # Create a RolloutWorker.
- >>> worker = ... # doctest: +SKIP
- >>> global_vars = worker.get_global_vars() # doctest: +SKIP
- >>> print(global_vars) # doctest: +SKIP
- {"timestep": 424242}
- """
- return self.global_vars
- def set_global_vars(
- self,
- global_vars: dict,
- policy_ids: Optional[List[PolicyID]] = None,
- ) -> None:
- """Updates this worker's and all its policies' global vars.
- Updates are done using the dict's update method.
- Args:
- global_vars: The global_vars dict to update the `self.global_vars` dict
- from.
- policy_ids: Optional list of Policy IDs to update. If None, will update all
- policies on the to-be-updated workers.
- Examples:
- >>> worker = ... # doctest: +SKIP
- >>> global_vars = worker.set_global_vars( # doctest: +SKIP
- ... {"timestep": 4242})
- """
- # Handle per-policy values.
- global_vars_copy = global_vars.copy()
- gradient_updates_per_policy = global_vars_copy.pop(
- "num_grad_updates_per_policy", {}
- )
- self.global_vars["num_grad_updates_per_policy"].update(
- gradient_updates_per_policy
- )
- # Only update explicitly provided policies or those that that are being
- # trained, in order to avoid superfluous access of policies, which might have
- # been offloaded to the object store.
- # Important b/c global vars are constantly being updated.
- for pid in policy_ids if policy_ids is not None else self.policy_map.keys():
- if self.is_policy_to_train is None or self.is_policy_to_train(pid, None):
- self.policy_map[pid].on_global_var_update(
- dict(
- global_vars_copy,
- # If count is None, Policy won't update the counter.
- **{"num_grad_updates": gradient_updates_per_policy.get(pid)},
- )
- )
- # Update all other global vars.
- self.global_vars.update(global_vars_copy)
- @override(EnvRunner)
- def stop(self) -> None:
- """Releases all resources used by this RolloutWorker."""
- # If we have an env -> Release its resources.
- if self.env is not None:
- self.async_env.stop()
- # In case we have-an AsyncSampler, kill its sampling thread.
- if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
- self.sampler.shutdown = True
- # Close all policies' sessions (if tf static graph).
- for policy in self.policy_map.cache.values():
- sess = policy.get_session()
- # Closes the tf session, if any.
- if sess is not None:
- sess.close()
- def lock(self) -> None:
- """Locks this RolloutWorker via its own threading.Lock."""
- self._lock.acquire()
- def unlock(self) -> None:
- """Unlocks this RolloutWorker via its own threading.Lock."""
- self._lock.release()
- def setup_torch_data_parallel(
- self, url: str, world_rank: int, world_size: int, backend: str
- ) -> None:
- """Join a torch process group for distributed SGD."""
- logger.info(
- "Joining process group, url={}, world_rank={}, "
- "world_size={}, backend={}".format(url, world_rank, world_size, backend)
- )
- torch.distributed.init_process_group(
- backend=backend, init_method=url, rank=world_rank, world_size=world_size
- )
- for pid, policy in self.policy_map.items():
- if not isinstance(policy, (TorchPolicy, TorchPolicyV2)):
- raise ValueError(
- "This policy does not support torch distributed", policy
- )
- policy.distributed_world_size = world_size
- def creation_args(self) -> dict:
- """Returns the kwargs dict used to create this worker."""
- return self._original_kwargs
- @DeveloperAPI
- def get_host(self) -> str:
- """Returns the hostname of the process running this evaluator."""
- return platform.node()
- @DeveloperAPI
- def get_node_ip(self) -> str:
- """Returns the IP address of the node that this worker runs on."""
- return ray.util.get_node_ip_address()
- @DeveloperAPI
- def find_free_port(self) -> int:
- """Finds a free port on the node that this worker runs on."""
- from ray.air._internal.util import find_free_port
- return find_free_port()
- def _update_policy_map(
- self,
- *,
- policy_dict: MultiAgentPolicyConfigDict,
- policy: Optional[Policy] = None,
- policy_states: Optional[Dict[PolicyID, PolicyState]] = None,
- single_agent_rl_module_spec: Optional[SingleAgentRLModuleSpec] = None,
- ) -> None:
- """Updates the policy map (and other stuff) on this worker.
- It performs the following:
- 1. It updates the observation preprocessors and updates the policy_specs
- with the postprocessed observation_spaces.
- 2. It updates the policy_specs with the complete algorithm_config (merged
- with the policy_spec's config).
- 3. If needed it will update the self.marl_module_spec on this worker
- 3. It updates the policy map with the new policies
- 4. It updates the filter dict
- 5. It calls the on_create_policy() hook of the callbacks on the newly added
- policies.
- Args:
- policy_dict: The policy dict to update the policy map with.
- policy: The policy to update the policy map with.
- policy_states: The policy states to update the policy map with.
- single_agent_rl_module_spec: The SingleAgentRLModuleSpec to add to the
- MultiAgentRLModuleSpec. If None, the config's
- `get_default_rl_module_spec` method's output will be used to create
- the policy with.
- """
- # Update the input policy dict with the postprocessed observation spaces and
- # merge configs. Also updates the preprocessor dict.
- updated_policy_dict = self._get_complete_policy_specs_dict(policy_dict)
- # Use the updated policy dict to create the marl_module_spec if necessary
- if self.config._enable_rl_module_api:
- spec = self.config.get_marl_module_spec(
- policy_dict=updated_policy_dict,
- single_agent_rl_module_spec=single_agent_rl_module_spec,
- )
- if self.marl_module_spec is None:
- # this is the first time, so we should create the marl_module_spec
- self.marl_module_spec = spec
- else:
- # This is adding a new policy, so we need call add_modules on the
- # module_specs of returned spec.
- self.marl_module_spec.add_modules(spec.module_specs)
- # Add __marl_module_spec key into the config so that the policy can access
- # it.
- updated_policy_dict = self._update_policy_dict_with_marl_module(
- updated_policy_dict
- )
- # Builds the self.policy_map dict
- self._build_policy_map(
- policy_dict=updated_policy_dict,
- policy=policy,
- policy_states=policy_states,
- )
- # Initialize the filter dict
- self._update_filter_dict(updated_policy_dict)
- # Call callback policy init hooks (only if the added policy did not exist
- # before).
- if policy is None:
- self._call_callbacks_on_create_policy()
- if self.worker_index == 0:
- logger.info(f"Built policy map: {self.policy_map}")
- logger.info(f"Built preprocessor map: {self.preprocessors}")
- def _get_complete_policy_specs_dict(
- self, policy_dict: MultiAgentPolicyConfigDict
- ) -> MultiAgentPolicyConfigDict:
- """Processes the policy dict and creates a new copy with the processed attrs.
- This processes the observation_space and prepares them for passing to rl module
- construction. It also merges the policy configs with the algorithm config.
- During this processing, we will also construct the preprocessors dict.
- """
- from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
- updated_policy_dict = copy.deepcopy(policy_dict)
- # If our preprocessors dict does not exist yet, create it here.
- self.preprocessors = self.preprocessors or {}
- # Loop through given policy-dict and add each entry to our map.
- for name, policy_spec in sorted(updated_policy_dict.items()):
- logger.debug("Creating policy for {}".format(name))
- # Policy brings its own complete AlgorithmConfig -> Use it for this policy.
- if isinstance(policy_spec.config, AlgorithmConfig):
- merged_conf = policy_spec.config
- else:
- # Update the general config with the specific config
- # for this particular policy.
- merged_conf: "AlgorithmConfig" = self.config.copy(copy_frozen=False)
- merged_conf.update_from_dict(policy_spec.config or {})
- # Update num_workers and worker_index.
- merged_conf.worker_index = self.worker_index
- # Preprocessors.
- obs_space = policy_spec.observation_space
- # Initialize preprocessor for this policy to None.
- self.preprocessors[name] = None
- if self.preprocessing_enabled:
- # Policies should deal with preprocessed (automatically flattened)
- # observations if preprocessing is enabled.
- preprocessor = ModelCatalog.get_preprocessor_for_space(
- obs_space,
- merged_conf.model,
- include_multi_binary=self.config.get(
- "_enable_rl_module_api", False
- ),
- )
- # Original observation space should be accessible at
- # obs_space.original_space after this step.
- if preprocessor is not None:
- obs_space = preprocessor.observation_space
- if not merged_conf.enable_connectors:
- # If connectors are not enabled, rollout worker will handle
- # the running of these preprocessors.
- self.preprocessors[name] = preprocessor
- policy_spec.config = merged_conf
- policy_spec.observation_space = obs_space
- return updated_policy_dict
- def _update_policy_dict_with_marl_module(
- self, policy_dict: MultiAgentPolicyConfigDict
- ) -> MultiAgentPolicyConfigDict:
- for name, policy_spec in policy_dict.items():
- policy_spec.config["__marl_module_spec"] = self.marl_module_spec
- return policy_dict
- def _build_policy_map(
- self,
- *,
- policy_dict: MultiAgentPolicyConfigDict,
- policy: Optional[Policy] = None,
- policy_states: Optional[Dict[PolicyID, PolicyState]] = None,
- ) -> None:
- """Adds the given policy_dict to `self.policy_map`.
- Args:
- policy_dict: The MultiAgentPolicyConfigDict to be added to this
- worker's PolicyMap.
- policy: If the policy to add already exists, user can provide it here.
- policy_states: Optional dict from PolicyIDs to PolicyStates to
- restore the states of the policies being built.
- """
- # If our policy_map does not exist yet, create it here.
- self.policy_map = self.policy_map or PolicyMap(
- capacity=self.config.policy_map_capacity,
- policy_states_are_swappable=self.config.policy_states_are_swappable,
- )
- # Loop through given policy-dict and add each entry to our map.
- for name, policy_spec in sorted(policy_dict.items()):
- # Create the actual policy object.
- if policy is None:
- new_policy = create_policy_for_framework(
- policy_id=name,
- policy_class=get_tf_eager_cls_if_necessary(
- policy_spec.policy_class, policy_spec.config
- ),
- merged_config=policy_spec.config,
- observation_space=policy_spec.observation_space,
- action_space=policy_spec.action_space,
- worker_index=self.worker_index,
- seed=self.seed,
- )
- else:
- new_policy = policy
- # Maybe torch compile an RLModule.
- if self.config.get("_enable_rl_module_api", False) and self.config.get(
- "torch_compile_worker"
- ):
- if self.config.framework_str != "torch":
- raise ValueError("Attempting to compile a non-torch RLModule.")
- rl_module = getattr(new_policy, "model", None)
- if rl_module is not None:
- compile_config = self.config.get_torch_compile_worker_config()
- rl_module.compile(compile_config)
- self.policy_map[name] = new_policy
- restore_states = (policy_states or {}).get(name, None)
- # Set the state of the newly created policy before syncing filters, etc.
- if restore_states:
- new_policy.set_state(restore_states)
- def _update_filter_dict(self, policy_dict: MultiAgentPolicyConfigDict) -> None:
- """Updates the filter dict for the given policy_dict."""
- for name, policy_spec in sorted(policy_dict.items()):
- new_policy = self.policy_map[name]
- if policy_spec.config.enable_connectors:
- # Note(jungong) : We should only create new connectors for the
- # policy iff we are creating a new policy from scratch. i.e,
- # we should NOT create new connectors when we already have the
- # policy object created before this function call or have the
- # restoring states from the caller.
- # Also note that we cannot just check the existence of connectors
- # to decide whether we should create connectors because we may be
- # restoring a policy that has 0 connectors configured.
- if (
- new_policy.agent_connectors is None
- or new_policy.action_connectors is None
- ):
- # TODO(jungong) : revisit this. It will be nicer to create
- # connectors as the last step of Policy.__init__().
- create_connectors_for_policy(new_policy, policy_spec.config)
- maybe_get_filters_for_syncing(self, name)
- else:
- filter_shape = tree.map_structure(
- lambda s: (
- None
- if isinstance(s, (Discrete, MultiDiscrete)) # noqa
- else np.array(s.shape)
- ),
- new_policy.observation_space_struct,
- )
- self.filters[name] = get_filter(
- policy_spec.config.observation_filter,
- filter_shape,
- )
- def _call_callbacks_on_create_policy(self):
- """Calls the on_create_policy callback for each policy in the policy map."""
- for name, policy in self.policy_map.items():
- self.callbacks.on_create_policy(policy_id=name, policy=policy)
- def _get_input_creator_from_config(self):
- def valid_module(class_path):
- if (
- isinstance(class_path, str)
- and not os.path.isfile(class_path)
- and "." in class_path
- ):
- module_path, class_name = class_path.rsplit(".", 1)
- try:
- spec = importlib.util.find_spec(module_path)
- if spec is not None:
- return True
- except (ModuleNotFoundError, ValueError):
- print(
- f"module {module_path} not found while trying to get "
- f"input {class_path}"
- )
- return False
- # A callable returning an InputReader object to use.
- if isinstance(self.config.input_, FunctionType):
- return self.config.input_
- # Use RLlib's Sampler classes (SyncSampler or AsynchSampler, depending
- # on `config.sample_async` setting).
- elif self.config.input_ == "sampler":
- return lambda ioctx: ioctx.default_sampler_input()
- # Ray Dataset input -> Use `config.input_config` to construct DatasetReader.
- elif self.config.input_ == "dataset":
- assert self._ds_shards is not None
- # Input dataset shards should have already been prepared.
- # We just need to take the proper shard here.
- return lambda ioctx: DatasetReader(
- self._ds_shards[self.worker_index], ioctx
- )
- # Dict: Mix of different input methods with different ratios.
- elif isinstance(self.config.input_, dict):
- return lambda ioctx: ShuffledInput(
- MixedInput(self.config.input_, ioctx), self.config.shuffle_buffer_size
- )
- # A pre-registered input descriptor (str).
- elif isinstance(self.config.input_, str) and registry_contains_input(
- self.config.input_
- ):
- return registry_get_input(self.config.input_)
- # D4RL input.
- elif "d4rl" in self.config.input_:
- env_name = self.config.input_.split(".")[-1]
- return lambda ioctx: D4RLReader(env_name, ioctx)
- # Valid python module (class path) -> Create using `from_config`.
- elif valid_module(self.config.input_):
- return lambda ioctx: ShuffledInput(
- from_config(self.config.input_, ioctx=ioctx)
- )
- # JSON file or list of JSON files -> Use JsonReader (shuffled).
- else:
- return lambda ioctx: ShuffledInput(
- JsonReader(self.config.input_, ioctx), self.config.shuffle_buffer_size
- )
- def _get_output_creator_from_config(self):
- if isinstance(self.config.output, FunctionType):
- return self.config.output
- elif self.config.output is None:
- return lambda ioctx: NoopOutput()
- elif self.config.output == "dataset":
- return lambda ioctx: DatasetWriter(
- ioctx, compress_columns=self.config.output_compress_columns
- )
- elif self.config.output == "logdir":
- return lambda ioctx: JsonWriter(
- ioctx.log_dir,
- ioctx,
- max_file_size=self.config.output_max_file_size,
- compress_columns=self.config.output_compress_columns,
- )
- else:
- return lambda ioctx: JsonWriter(
- self.config.output,
- ioctx,
- max_file_size=self.config.output_max_file_size,
- compress_columns=self.config.output_compress_columns,
- )
- def _get_make_sub_env_fn(
- self, env_creator, env_context, validate_env, env_wrapper, seed
- ):
- config = self.config
- def _make_sub_env_local(vector_index):
- # Used to created additional environments during environment
- # vectorization.
- # Create the env context (config dict + meta-data) for
- # this particular sub-env within the vectorized one.
- env_ctx = env_context.copy_with_overrides(vector_index=vector_index)
- # Create the sub-env.
- env = env_creator(env_ctx)
- # Validate first.
- if not config.disable_env_checking:
- try:
- check_env(env, config)
- except Exception as e:
- logger.warning(
- "We've added a module for checking environments that "
- "are used in experiments. Your env may not be set up"
- "correctly. You can disable env checking for now by setting "
- "`disable_env_checking` to True in your experiment config "
- "dictionary. You can run the environment checking module "
- "standalone by calling ray.rllib.utils.check_env(env)."
- )
- raise e
- # Custom validation function given by user.
- if validate_env is not None:
- validate_env(env, env_ctx)
- # Use our wrapper, defined above.
- env = env_wrapper(env)
- # Make sure a deterministic random seed is set on
- # all the sub-environments if specified.
- _update_env_seed_if_necessary(
- env, seed, env_context.worker_index, vector_index
- )
- return env
- if not env_context.remote:
- def _make_sub_env_remote(vector_index):
- sub_env = _make_sub_env_local(vector_index)
- self.callbacks.on_sub_environment_created(
- worker=self,
- sub_environment=sub_env,
- env_context=env_context.copy_with_overrides(
- worker_index=env_context.worker_index,
- vector_index=vector_index,
- remote=False,
- ),
- )
- return sub_env
- return _make_sub_env_remote
- else:
- return _make_sub_env_local
|