12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209 |
- from collections import defaultdict
- import concurrent
- import copy
- from datetime import datetime
- import functools
- import gymnasium as gym
- import importlib
- import json
- import logging
- import numpy as np
- import os
- from packaging import version
- import pkg_resources
- import re
- import tempfile
- import time
- import tree # pip install dm_tree
- from typing import (
- Callable,
- Container,
- DefaultDict,
- Dict,
- List,
- Optional,
- Set,
- Tuple,
- Type,
- Union,
- )
- import ray
- from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
- from ray.actor import ActorHandle
- from ray.train import Checkpoint
- import ray.cloudpickle as pickle
- from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
- from ray.rllib.algorithms.registry import ALGORITHMS_CLASS_TO_NAME as ALL_ALGORITHMS
- from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
- from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
- from ray.rllib.env.env_context import EnvContext
- from ray.rllib.env.utils import _gym_env_creator
- from ray.rllib.evaluation.episode import Episode
- from ray.rllib.evaluation.metrics import (
- collect_episodes,
- collect_metrics,
- summarize_episodes,
- )
- from ray.rllib.evaluation.rollout_worker import RolloutWorker
- from ray.rllib.evaluation.worker_set import WorkerSet
- from ray.rllib.execution.common import (
- STEPS_TRAINED_THIS_ITER_COUNTER, # TODO: Backward compatibility.
- )
- from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
- from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
- from ray.rllib.offline import get_dataset_and_shards
- from ray.rllib.offline.estimators import (
- OffPolicyEstimator,
- ImportanceSampling,
- WeightedImportanceSampling,
- DirectMethod,
- DoublyRobust,
- )
- from ray.rllib.offline.offline_evaluator import OfflineEvaluator
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, concat_samples
- from ray.rllib.utils import deep_update, FilterManager
- from ray.rllib.utils.annotations import (
- DeveloperAPI,
- ExperimentalAPI,
- OverrideToImplementCustomLogic,
- OverrideToImplementCustomLogic_CallToSuperRecommended,
- PublicAPI,
- override,
- )
- from ray.rllib.utils.checkpoints import (
- CHECKPOINT_VERSION,
- CHECKPOINT_VERSION_LEARNER,
- get_checkpoint_info,
- try_import_msgpack,
- )
- from ray.rllib.utils.debug import update_global_seed_if_necessary
- from ray.rllib.utils.deprecation import (
- DEPRECATED_VALUE,
- Deprecated,
- deprecation_warning,
- )
- from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.from_config import from_config
- from ray.rllib.utils.metrics import (
- NUM_AGENT_STEPS_SAMPLED,
- NUM_AGENT_STEPS_SAMPLED_THIS_ITER,
- NUM_AGENT_STEPS_TRAINED,
- NUM_ENV_STEPS_SAMPLED,
- NUM_ENV_STEPS_SAMPLED_THIS_ITER,
- NUM_ENV_STEPS_TRAINED,
- SYNCH_WORKER_WEIGHTS_TIMER,
- TRAINING_ITERATION_TIMER,
- SAMPLE_TIMER,
- )
- from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
- from ray.rllib.utils.policy import validate_policy_id
- from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer, ReplayBuffer
- from ray.rllib.utils.serialization import deserialize_type, NOT_SERIALIZABLE
- from ray.rllib.utils.spaces import space_utils
- from ray.rllib.utils.typing import (
- AgentConnectorDataType,
- AgentID,
- AlgorithmConfigDict,
- EnvCreator,
- EnvInfoDict,
- EnvType,
- EpisodeID,
- PartialAlgorithmConfigDict,
- PolicyID,
- PolicyState,
- ResultDict,
- SampleBatchType,
- TensorStructType,
- TensorType,
- )
- from ray.tune.execution.placement_groups import PlacementGroupFactory
- from ray.tune.experiment.trial import ExportFormat
- from ray.tune.logger import Logger, UnifiedLogger
- from ray.tune.registry import ENV_CREATOR, _global_registry
- from ray.tune.resources import Resources
- from ray.tune.result import DEFAULT_RESULTS_DIR
- from ray.tune.trainable import Trainable
- from ray.util import log_once
- from ray.util.timer import _Timer
- from ray.tune.registry import get_trainable_cls
- try:
- from ray.rllib.extensions import AlgorithmBase
- except ImportError:
- class AlgorithmBase:
- @staticmethod
- def _get_learner_bundles(cf: AlgorithmConfig) -> List[Dict[str, int]]:
- """Selects the right resource bundles for learner workers based off of cf.
- Args:
- cf: The algorithm config.
- Returns:
- A list of resource bundles for the learner workers.
- """
- if cf.num_learner_workers > 0:
- if cf.num_gpus_per_learner_worker:
- learner_bundles = [
- {"GPU": cf.num_learner_workers * cf.num_gpus_per_learner_worker}
- ]
- elif cf.num_cpus_per_learner_worker:
- learner_bundles = [
- {
- "CPU": cf.num_cpus_per_learner_worker
- * cf.num_learner_workers,
- }
- ]
- else:
- learner_bundles = [
- {
- # sampling and training is not done concurrently when local is
- # used, so pick the max.
- "CPU": max(
- cf.num_cpus_per_learner_worker, cf.num_cpus_for_local_worker
- ),
- "GPU": cf.num_gpus_per_learner_worker,
- }
- ]
- return learner_bundles
- tf1, tf, tfv = try_import_tf()
- logger = logging.getLogger(__name__)
- @Deprecated(
- new="config = AlgorithmConfig().update_from_dict({'a': 1, 'b': 2}); ... ; "
- "print(config.lr) -> 0.001; if config.a > 0: [do something];",
- error=True,
- )
- def with_common_config(*args, **kwargs):
- pass
- @PublicAPI
- class Algorithm(Trainable, AlgorithmBase):
- """An RLlib algorithm responsible for optimizing one or more Policies.
- Algorithms contain a WorkerSet under `self.workers`. A WorkerSet is
- normally composed of a single local worker
- (self.workers.local_worker()), used to compute and apply learning updates,
- and optionally one or more remote workers used to generate environment
- samples in parallel.
- WorkerSet is fault tolerant and elastic. It tracks health states for all
- the managed remote worker actors. As a result, Algorithm should never
- access the underlying actor handles directly. Instead, always access them
- via all the foreach APIs with assigned IDs of the underlying workers.
- Each worker (remotes or local) contains a PolicyMap, which itself
- may contain either one policy for single-agent training or one or more
- policies for multi-agent training. Policies are synchronized
- automatically from time to time using ray.remote calls. The exact
- synchronization logic depends on the specific algorithm used,
- but this usually happens from local worker to all remote workers and
- after each training update.
- You can write your own Algorithm classes by sub-classing from `Algorithm`
- or any of its built-in sub-classes.
- This allows you to override the `training_step` method to implement
- your own algorithm logic. You can find the different built-in
- algorithms' `training_step()` methods in their respective main .py files,
- e.g. rllib.algorithms.dqn.dqn.py or rllib.algorithms.impala.impala.py.
- The most important API methods a Algorithm exposes are `train()`,
- `evaluate()`, `save()` and `restore()`.
- """
- # Whether to allow unknown top-level config keys.
- _allow_unknown_configs = False
- # List of top-level keys with value=dict, for which new sub-keys are
- # allowed to be added to the value dict.
- _allow_unknown_subkeys = [
- "tf_session_args",
- "local_tf_session_args",
- "env_config",
- "model",
- "optimizer",
- "custom_resources_per_worker",
- "evaluation_config",
- "exploration_config",
- "replay_buffer_config",
- "extra_python_environs_for_worker",
- "input_config",
- "output_config",
- ]
- # List of top level keys with value=dict, for which we always override the
- # entire value (dict), iff the "type" key in that value dict changes.
- _override_all_subkeys_if_type_changes = [
- "exploration_config",
- "replay_buffer_config",
- ]
- # List of keys that are always fully overridden if present in any dict or sub-dict
- _override_all_key_list = ["off_policy_estimation_methods", "policies"]
- _progress_metrics = (
- "num_env_steps_sampled",
- "num_env_steps_trained",
- "episodes_total",
- "sampler_results/episode_len_mean",
- "sampler_results/episode_reward_mean",
- "evaluation/sampler_results/episode_reward_mean",
- )
- @staticmethod
- def from_checkpoint(
- checkpoint: Union[str, Checkpoint],
- policy_ids: Optional[Container[PolicyID]] = None,
- policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
- policies_to_train: Optional[
- Union[
- Container[PolicyID],
- Callable[[PolicyID, Optional[SampleBatchType]], bool],
- ]
- ] = None,
- ) -> "Algorithm":
- """Creates a new algorithm instance from a given checkpoint.
- Note: This method must remain backward compatible from 2.0.0 on.
- Args:
- checkpoint: The path (str) to the checkpoint directory to use
- or an AIR Checkpoint instance to restore from.
- policy_ids: Optional list of PolicyIDs to recover. This allows users to
- restore an Algorithm with only a subset of the originally present
- Policies.
- policy_mapping_fn: An optional (updated) policy mapping function
- to use from here on.
- policies_to_train: An optional list of policy IDs to be trained
- or a callable taking PolicyID and SampleBatchType and
- returning a bool (trainable or not?).
- If None, will keep the existing setup in place. Policies,
- whose IDs are not in the list (or for which the callable
- returns False) will not be updated.
- Returns:
- The instantiated Algorithm.
- """
- checkpoint_info = get_checkpoint_info(checkpoint)
- # Not possible for (v0.1) (algo class and config information missing
- # or very hard to retrieve).
- if checkpoint_info["checkpoint_version"] == version.Version("0.1"):
- raise ValueError(
- "Cannot restore a v0 checkpoint using `Algorithm.from_checkpoint()`!"
- "In this case, do the following:\n"
- "1) Create a new Algorithm object using your original config.\n"
- "2) Call the `restore()` method of this algo object passing it"
- " your checkpoint dir or AIR Checkpoint object."
- )
- elif checkpoint_info["checkpoint_version"] < version.Version("1.0"):
- raise ValueError(
- "`checkpoint_info['checkpoint_version']` in `Algorithm.from_checkpoint"
- "()` must be 1.0 or later! You are using a checkpoint with "
- f"version v{checkpoint_info['checkpoint_version']}."
- )
- # This is a msgpack checkpoint.
- if checkpoint_info["format"] == "msgpack":
- # User did not provide unserializable function with this call
- # (`policy_mapping_fn`). Note that if `policies_to_train` is None, it
- # defaults to training all policies (so it's ok to not provide this here).
- if policy_mapping_fn is None:
- # Only DEFAULT_POLICY_ID present in this algorithm, provide default
- # implementations of these two functions.
- if checkpoint_info["policy_ids"] == {DEFAULT_POLICY_ID}:
- policy_mapping_fn = AlgorithmConfig.DEFAULT_POLICY_MAPPING_FN
- # Provide meaningful error message.
- else:
- raise ValueError(
- "You are trying to restore a multi-agent algorithm from a "
- "`msgpack` formatted checkpoint, which do NOT store the "
- "`policy_mapping_fn` or `policies_to_train` "
- "functions! Make sure that when using the "
- "`Algorithm.from_checkpoint()` utility, you also pass the "
- "args: `policy_mapping_fn` and `policies_to_train` with your "
- "call. You might leave `policies_to_train=None` in case "
- "you would like to train all policies anyways."
- )
- state = Algorithm._checkpoint_info_to_algorithm_state(
- checkpoint_info=checkpoint_info,
- policy_ids=policy_ids,
- policy_mapping_fn=policy_mapping_fn,
- policies_to_train=policies_to_train,
- )
- return Algorithm.from_state(state)
- @staticmethod
- def from_state(state: Dict) -> "Algorithm":
- """Recovers an Algorithm from a state object.
- The `state` of an instantiated Algorithm can be retrieved by calling its
- `get_state` method. It contains all information necessary
- to create the Algorithm from scratch. No access to the original code (e.g.
- configs, knowledge of the Algorithm's class, etc..) is needed.
- Args:
- state: The state to recover a new Algorithm instance from.
- Returns:
- A new Algorithm instance.
- """
- algorithm_class: Type[Algorithm] = state.get("algorithm_class")
- if algorithm_class is None:
- raise ValueError(
- "No `algorithm_class` key was found in given `state`! "
- "Cannot create new Algorithm."
- )
- # algo_class = get_trainable_cls(algo_class_name)
- # Create the new algo.
- config = state.get("config")
- if not config:
- raise ValueError("No `config` found in given Algorithm state!")
- new_algo = algorithm_class(config=config)
- # Set the new algo's state.
- new_algo.__setstate__(state)
- # Return the new algo.
- return new_algo
- @PublicAPI
- def __init__(
- self,
- config: Optional[AlgorithmConfig] = None,
- env=None, # deprecated arg
- logger_creator: Optional[Callable[[], Logger]] = None,
- **kwargs,
- ):
- """Initializes an Algorithm instance.
- Args:
- config: Algorithm-specific configuration object.
- logger_creator: Callable that creates a ray.tune.Logger
- object. If unspecified, a default logger is created.
- **kwargs: Arguments passed to the Trainable base class.
- """
- config = config or self.get_default_config()
- # Translate possible dict into an AlgorithmConfig object, as well as,
- # resolving generic config objects into specific ones (e.g. passing
- # an `AlgorithmConfig` super-class instance into a PPO constructor,
- # which normally would expect a PPOConfig object).
- if isinstance(config, dict):
- default_config = self.get_default_config()
- # `self.get_default_config()` also returned a dict ->
- # Last resort: Create core AlgorithmConfig from merged dicts.
- if isinstance(default_config, dict):
- config = AlgorithmConfig.from_dict(
- config_dict=self.merge_algorithm_configs(
- default_config, config, True
- )
- )
- # Default config is an AlgorithmConfig -> update its properties
- # from the given config dict.
- else:
- config = default_config.update_from_dict(config)
- else:
- default_config = self.get_default_config()
- # Given AlgorithmConfig is not of the same type as the default config:
- # This could be the case e.g. if the user is building an algo from a
- # generic AlgorithmConfig() object.
- if not isinstance(config, type(default_config)):
- config = default_config.update_from_dict(config.to_dict())
- # In case this algo is using a generic config (with no algo_class set), set it
- # here.
- if config.algo_class is None:
- config.algo_class = type(self)
- if env is not None:
- deprecation_warning(
- old=f"algo = Algorithm(env='{env}', ...)",
- new=f"algo = AlgorithmConfig().environment('{env}').build()",
- error=False,
- )
- config.environment(env)
- # Validate and freeze our AlgorithmConfig object (no more changes possible).
- config.validate()
- config.freeze()
- # Convert `env` provided in config into a concrete env creator callable, which
- # takes an EnvContext (config dict) as arg and returning an RLlib supported Env
- # type (e.g. a gym.Env).
- self._env_id, self.env_creator = self._get_env_id_and_creator(
- config.env, config
- )
- env_descr = (
- self._env_id.__name__ if isinstance(self._env_id, type) else self._env_id
- )
- # Placeholder for a local replay buffer instance.
- self.local_replay_buffer = None
- # Create a default logger creator if no logger_creator is specified
- if logger_creator is None:
- # Default logdir prefix containing the agent's name and the
- # env id.
- timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
- env_descr_for_dir = re.sub("[/\\\\]", "-", str(env_descr))
- logdir_prefix = f"{str(self)}_{env_descr_for_dir}_{timestr}"
- if not os.path.exists(DEFAULT_RESULTS_DIR):
- # Possible race condition if dir is created several times on
- # rollout workers
- os.makedirs(DEFAULT_RESULTS_DIR, exist_ok=True)
- logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
- # Allow users to more precisely configure the created logger
- # via "logger_config.type".
- if config.logger_config and "type" in config.logger_config:
- def default_logger_creator(config):
- """Creates a custom logger with the default prefix."""
- cfg = config["logger_config"].copy()
- cls = cfg.pop("type")
- # Provide default for logdir, in case the user does
- # not specify this in the "logger_config" dict.
- logdir_ = cfg.pop("logdir", logdir)
- return from_config(cls=cls, _args=[cfg], logdir=logdir_)
- # If no `type` given, use tune's UnifiedLogger as last resort.
- else:
- def default_logger_creator(config):
- """Creates a Unified logger with the default prefix."""
- return UnifiedLogger(config, logdir, loggers=None)
- logger_creator = default_logger_creator
- # Metrics-related properties.
- self._timers = defaultdict(_Timer)
- self._counters = defaultdict(int)
- self._episode_history = []
- self._episodes_to_be_collected = []
- # The fully qualified AlgorithmConfig used for evaluation
- # (or None if evaluation not setup).
- self.evaluation_config: Optional[AlgorithmConfig] = None
- # Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
- self.evaluation_workers: Optional[WorkerSet] = None
- # Initialize common evaluation_metrics to nan, before they become
- # available. We want to make sure the metrics are always present
- # (although their values may be nan), so that Tune does not complain
- # when we use these as stopping criteria.
- self.evaluation_metrics = {
- # TODO: Don't dump sampler results into top-level.
- "evaluation": {
- "episode_reward_max": np.nan,
- "episode_reward_min": np.nan,
- "episode_reward_mean": np.nan,
- "sampler_results": {
- "episode_reward_max": np.nan,
- "episode_reward_min": np.nan,
- "episode_reward_mean": np.nan,
- },
- },
- }
- super().__init__(
- config=config,
- logger_creator=logger_creator,
- **kwargs,
- )
- # Check, whether `training_iteration` is still a tune.Trainable property
- # and has not been overridden by the user in the attempt to implement the
- # algos logic (this should be done now inside `training_step`).
- try:
- assert isinstance(self.training_iteration, int)
- except AssertionError:
- raise AssertionError(
- "Your Algorithm's `training_iteration` seems to be overridden by your "
- "custom training logic! To solve this problem, simply rename your "
- "`self.training_iteration()` method into `self.training_step`."
- )
- @OverrideToImplementCustomLogic
- @classmethod
- def get_default_config(cls) -> AlgorithmConfig:
- return AlgorithmConfig()
- @OverrideToImplementCustomLogic
- def _remote_worker_ids_for_metrics(self) -> List[int]:
- """Returns a list of remote worker IDs to fetch metrics from.
- Specific Algorithm implementations can override this method to
- use a subset of the workers for metrics collection.
- Returns:
- List of remote worker IDs to fetch metrics from.
- """
- return self.workers.healthy_worker_ids()
- @OverrideToImplementCustomLogic_CallToSuperRecommended
- @override(Trainable)
- def setup(self, config: AlgorithmConfig) -> None:
- # Setup our config: Merge the user-supplied config dict (which could
- # be a partial config dict) with the class' default.
- if not isinstance(config, AlgorithmConfig):
- assert isinstance(config, PartialAlgorithmConfigDict)
- config_obj = self.get_default_config()
- if not isinstance(config_obj, AlgorithmConfig):
- assert isinstance(config, PartialAlgorithmConfigDict)
- config_obj = AlgorithmConfig().from_dict(config_obj)
- config_obj.update_from_dict(config)
- config_obj.env = self._env_id
- self.config = config_obj
- # Set Algorithm's seed after we have - if necessary - enabled
- # tf eager-execution.
- update_global_seed_if_necessary(self.config.framework_str, self.config.seed)
- self._record_usage(self.config)
- # Create the callbacks object.
- self.callbacks = self.config.callbacks_class()
- if self.config.log_level in ["WARN", "ERROR"]:
- logger.info(
- f"Current log_level is {self.config.log_level}. For more information, "
- "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
- "-vv flags."
- )
- if self.config.log_level:
- logging.getLogger("ray.rllib").setLevel(self.config.log_level)
- # Create local replay buffer if necessary.
- self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
- self.config
- )
- # Create a dict, mapping ActorHandles to sets of open remote
- # requests (object refs). This way, we keep track, of which actors
- # inside this Algorithm (e.g. a remote RolloutWorker) have
- # already been sent how many (e.g. `sample()`) requests.
- self.remote_requests_in_flight: DefaultDict[
- ActorHandle, Set[ray.ObjectRef]
- ] = defaultdict(set)
- self.workers: Optional[WorkerSet] = None
- self.train_exec_impl = None
- # Offline RL settings.
- input_evaluation = self.config.get("input_evaluation")
- if input_evaluation is not None and input_evaluation is not DEPRECATED_VALUE:
- ope_dict = {str(ope): {"type": ope} for ope in input_evaluation}
- deprecation_warning(
- old="config.input_evaluation={}".format(input_evaluation),
- new="config.evaluation(evaluation_config=config.overrides("
- f"off_policy_estimation_methods={ope_dict}"
- "))",
- error=True,
- help="Running OPE during training is not recommended.",
- )
- self.config.off_policy_estimation_methods = ope_dict
- # Deprecated way of implementing Algorithm sub-classes (or "templates"
- # via the `build_trainer` utility function).
- # Instead, sub-classes should override the Trainable's `setup()`
- # method and call super().setup() from within that override at some
- # point.
- # Old design: Override `Algorithm._init`.
- _init = False
- try:
- self._init(self.config, self.env_creator)
- _init = True
- # New design: Override `Algorithm.setup()` (as indented by tune.Trainable)
- # and do or don't call `super().setup()` from within your override.
- # By default, `super().setup()` will create both worker sets:
- # "rollout workers" for collecting samples for training and - if
- # applicable - "evaluation workers" for evaluation runs in between or
- # parallel to training.
- # TODO: Deprecate `_init()` and remove this try/except block.
- except NotImplementedError:
- pass
- # Only if user did not override `_init()`:
- if _init is False:
- # Create a set of env runner actors via a WorkerSet.
- self.workers = WorkerSet(
- env_creator=self.env_creator,
- validate_env=self.validate_env,
- default_policy_class=self.get_default_policy_class(self.config),
- config=self.config,
- num_workers=self.config.num_rollout_workers,
- local_worker=True,
- logdir=self.logdir,
- )
- # TODO (avnishn): Remove the execution plan API by q1 2023
- # Function defining one single training iteration's behavior.
- if self.config._disable_execution_plan_api:
- # Ensure remote workers are initially in sync with the local worker.
- self.workers.sync_weights()
- # LocalIterator-creating "execution plan".
- # Only call this once here to create `self.train_exec_impl`,
- # which is a ray.util.iter.LocalIterator that will be `next`'d
- # on each training iteration.
- else:
- self.train_exec_impl = self.execution_plan(
- self.workers, self.config, **self._kwargs_for_execution_plan()
- )
- # Compile, validate, and freeze an evaluation config.
- self.evaluation_config = self.config.get_evaluation_config_object()
- self.evaluation_config.validate()
- self.evaluation_config.freeze()
- # Evaluation WorkerSet setup.
- # User would like to setup a separate evaluation worker set.
- # Note: We skip workerset creation if we need to do offline evaluation
- if self._should_create_evaluation_rollout_workers(self.evaluation_config):
- _, env_creator = self._get_env_id_and_creator(
- self.evaluation_config.env, self.evaluation_config
- )
- # Create a separate evaluation worker set for evaluation.
- # If evaluation_num_workers=0, use the evaluation set's local
- # worker for evaluation, otherwise, use its remote workers
- # (parallelized evaluation).
- self.evaluation_workers: WorkerSet = WorkerSet(
- env_creator=env_creator,
- validate_env=None,
- default_policy_class=self.get_default_policy_class(self.config),
- config=self.evaluation_config,
- num_workers=self.config.evaluation_num_workers,
- logdir=self.logdir,
- )
- if self.config.enable_async_evaluation:
- self._evaluation_weights_seq_number = 0
- self.evaluation_dataset = None
- if (
- self.evaluation_config.off_policy_estimation_methods
- and not self.evaluation_config.ope_split_batch_by_episode
- ):
- # the num worker is set to 0 to avoid creating shards. The dataset will not
- # be repartioned to num_workers blocks.
- logger.info("Creating evaluation dataset ...")
- self.evaluation_dataset, _ = get_dataset_and_shards(
- self.evaluation_config, num_workers=0
- )
- logger.info("Evaluation dataset created")
- self.reward_estimators: Dict[str, OffPolicyEstimator] = {}
- ope_types = {
- "is": ImportanceSampling,
- "wis": WeightedImportanceSampling,
- "dm": DirectMethod,
- "dr": DoublyRobust,
- }
- for name, method_config in self.config.off_policy_estimation_methods.items():
- method_type = method_config.pop("type")
- if method_type in ope_types:
- deprecation_warning(
- old=method_type,
- new=str(ope_types[method_type]),
- error=True,
- )
- method_type = ope_types[method_type]
- elif isinstance(method_type, str):
- logger.log(0, "Trying to import from string: " + method_type)
- mod, obj = method_type.rsplit(".", 1)
- mod = importlib.import_module(mod)
- method_type = getattr(mod, obj)
- if isinstance(method_type, type) and issubclass(
- method_type, OfflineEvaluator
- ):
- # TODO(kourosh) : Add an integration test for all these
- # offline evaluators.
- policy = self.get_policy()
- if issubclass(method_type, OffPolicyEstimator):
- method_config["gamma"] = self.config.gamma
- self.reward_estimators[name] = method_type(policy, **method_config)
- else:
- raise ValueError(
- f"Unknown off_policy_estimation type: {method_type}! Must be "
- "either a class path or a sub-class of ray.rllib."
- "offline.offline_evaluator::OfflineEvaluator"
- )
- # TODO (Rohan138): Refactor this and remove deprecated methods
- # Need to add back method_type in case Algorithm is restored from checkpoint
- method_config["type"] = method_type
- self.learner_group = None
- if self.config._enable_learner_api:
- # TODO (Kourosh): This is an interim solution where policies and modules
- # co-exist. In this world we have both policy_map and MARLModule that need
- # to be consistent with one another. To make a consistent parity between
- # the two we need to loop through the policy modules and create a simple
- # MARLModule from the RLModule within each policy.
- local_worker = self.workers.local_worker()
- policy_dict, _ = self.config.get_multi_agent_setup(
- env=local_worker.env,
- spaces=getattr(local_worker, "spaces", None),
- )
- # TODO (Sven): Unify the inference of the MARLModuleSpec. Right now,
- # we get this from the RolloutWorker's `marl_module_spec` property.
- # However, this is hacky (information leak) and should not remain this
- # way. For other EnvRunner classes (that don't have this property),
- # Algorithm should infer this itself.
- if hasattr(local_worker, "marl_module_spec"):
- module_spec = local_worker.marl_module_spec
- else:
- module_spec = self.config.get_marl_module_spec(policy_dict=policy_dict)
- learner_group_config = self.config.get_learner_group_config(module_spec)
- self.learner_group = learner_group_config.build()
- # check if there are modules to load from the module_spec
- rl_module_ckpt_dirs = {}
- marl_module_ckpt_dir = module_spec.load_state_path
- modules_to_load = module_spec.modules_to_load
- for module_id, sub_module_spec in module_spec.module_specs.items():
- if sub_module_spec.load_state_path:
- rl_module_ckpt_dirs[module_id] = sub_module_spec.load_state_path
- if marl_module_ckpt_dir or rl_module_ckpt_dirs:
- self.learner_group.load_module_state(
- marl_module_ckpt_dir=marl_module_ckpt_dir,
- modules_to_load=modules_to_load,
- rl_module_ckpt_dirs=rl_module_ckpt_dirs,
- )
- # sync the weights from the learner group to the rollout workers
- weights = self.learner_group.get_weights()
- local_worker.set_weights(weights)
- self.workers.sync_weights()
- # Run `on_algorithm_init` callback after initialization is done.
- self.callbacks.on_algorithm_init(algorithm=self)
- # TODO: Deprecated: In your sub-classes of Algorithm, override `setup()`
- # directly and call super().setup() from within it if you would like the
- # default setup behavior plus some own setup logic.
- # If you don't need the env/workers/config/etc.. setup for you by super,
- # simply do not call super().setup() from your overridden method.
- def _init(self, config: AlgorithmConfigDict, env_creator: EnvCreator) -> None:
- raise NotImplementedError
- @OverrideToImplementCustomLogic
- @classmethod
- def get_default_policy_class(
- cls,
- config: AlgorithmConfig,
- ) -> Optional[Type[Policy]]:
- """Returns a default Policy class to use, given a config.
- This class will be used by an Algorithm in case
- the policy class is not provided by the user in any single- or
- multi-agent PolicySpec.
- Note: This method is ignored when the RLModule API is enabled.
- """
- return None
- @override(Trainable)
- def step(self) -> ResultDict:
- """Implements the main `Algorithm.train()` logic.
- Takes n attempts to perform a single training step. Thereby
- catches RayErrors resulting from worker failures. After n attempts,
- fails gracefully.
- Override this method in your Algorithm sub-classes if you would like to
- handle worker failures yourself.
- Otherwise, override only `training_step()` to implement the core
- algorithm logic.
- Returns:
- The results dict with stats/infos on sampling, training,
- and - if required - evaluation.
- """
- # Do we have to run `self.evaluate()` this iteration?
- # `self.iteration` gets incremented after this function returns,
- # meaning that e. g. the first time this function is called,
- # self.iteration will be 0.
- evaluate_this_iter = (
- self.config.evaluation_interval is not None
- and (self.iteration + 1) % self.config.evaluation_interval == 0
- )
- # Results dict for training (and if appolicable: evaluation).
- results: ResultDict = {}
- # Parallel eval + training: Kick off evaluation-loop and parallel train() call.
- if evaluate_this_iter and self.config.evaluation_parallel_to_training:
- (
- results,
- train_iter_ctx,
- ) = self._run_one_training_iteration_and_evaluation_in_parallel()
- # - No evaluation necessary, just run the next training iteration.
- # - We have to evaluate in this training iteration, but no parallelism ->
- # evaluate after the training iteration is entirely done.
- else:
- results, train_iter_ctx = self._run_one_training_iteration()
- # Sequential: Train (already done above), then evaluate.
- if evaluate_this_iter and not self.config.evaluation_parallel_to_training:
- results.update(self._run_one_evaluation(train_future=None))
- # Attach latest available evaluation results to train results,
- # if necessary.
- if not evaluate_this_iter and self.config.always_attach_evaluation_results:
- assert isinstance(
- self.evaluation_metrics, dict
- ), "Algorithm.evaluate() needs to return a dict."
- results.update(self.evaluation_metrics)
- if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
- # Sync filters on workers.
- self._sync_filters_if_needed(
- central_worker=self.workers.local_worker(),
- workers=self.workers,
- config=self.config,
- )
- # TODO (avnishn): Remove the execution plan API by q1 2023
- # Collect worker metrics and add combine them with `results`.
- if self.config._disable_execution_plan_api:
- episodes_this_iter = collect_episodes(
- self.workers,
- self._remote_worker_ids_for_metrics(),
- timeout_seconds=self.config.metrics_episode_collection_timeout_s,
- )
- results = self._compile_iteration_results(
- episodes_this_iter=episodes_this_iter,
- step_ctx=train_iter_ctx,
- iteration_results=results,
- )
- # Check `env_task_fn` for possible update of the env's task.
- if self.config.env_task_fn is not None:
- if not callable(self.config.env_task_fn):
- raise ValueError(
- "`env_task_fn` must be None or a callable taking "
- "[train_results, env, env_ctx] as args!"
- )
- def fn(env, env_context, task_fn):
- new_task = task_fn(results, env, env_context)
- cur_task = env.get_task()
- if cur_task != new_task:
- env.set_task(new_task)
- fn = functools.partial(fn, task_fn=self.config.env_task_fn)
- self.workers.foreach_env_with_context(fn)
- return results
- @PublicAPI
- def evaluate(
- self,
- duration_fn: Optional[Callable[[int], int]] = None,
- ) -> dict:
- """Evaluates current policy under `evaluation_config` settings.
- Args:
- duration_fn: An optional callable taking the already run
- num episodes as only arg and returning the number of
- episodes left to run. It's used to find out whether
- evaluation should continue.
- """
- # Call the `_before_evaluate` hook.
- self._before_evaluate()
- if self.evaluation_dataset is not None:
- return {"evaluation": self._run_offline_evaluation()}
- # Sync weights to the evaluation WorkerSet.
- if self.evaluation_workers is not None:
- self.evaluation_workers.sync_weights(
- from_worker_or_learner_group=self.workers.local_worker()
- )
- self._sync_filters_if_needed(
- central_worker=self.workers.local_worker(),
- workers=self.evaluation_workers,
- config=self.evaluation_config,
- )
- self.callbacks.on_evaluate_start(algorithm=self)
- if self.config.custom_evaluation_function:
- logger.info(
- "Running custom eval function {}".format(
- self.config.custom_evaluation_function
- )
- )
- metrics = self.config.custom_evaluation_function(
- self, self.evaluation_workers
- )
- if not metrics or not isinstance(metrics, dict):
- raise ValueError(
- "Custom eval function must return "
- "dict of metrics, got {}.".format(metrics)
- )
- else:
- if (
- self.evaluation_workers is None
- and self.workers.local_worker().input_reader is None
- ):
- raise ValueError(
- "Cannot evaluate w/o an evaluation worker set in "
- "the Algorithm or w/o an env on the local worker!\n"
- "Try one of the following:\n1) Set "
- "`evaluation_interval` >= 0 to force creating a "
- "separate evaluation worker set.\n2) Set "
- "`create_env_on_driver=True` to force the local "
- "(non-eval) worker to have an environment to "
- "evaluate on."
- )
- # How many episodes/timesteps do we need to run?
- # In "auto" mode (only for parallel eval + training): Run as long
- # as training lasts.
- unit = self.config.evaluation_duration_unit
- eval_cfg = self.evaluation_config
- rollout = eval_cfg.rollout_fragment_length
- num_envs = eval_cfg.num_envs_per_worker
- auto = self.config.evaluation_duration == "auto"
- duration = (
- self.config.evaluation_duration
- if not auto
- else (self.config.evaluation_num_workers or 1)
- * (1 if unit == "episodes" else rollout)
- )
- agent_steps_this_iter = 0
- env_steps_this_iter = 0
- # Default done-function returns True, whenever num episodes
- # have been completed.
- if duration_fn is None:
- def duration_fn(num_units_done):
- return duration - num_units_done
- logger.info(f"Evaluating current state of {self} for {duration} {unit}.")
- metrics = None
- all_batches = []
- # No evaluation worker set ->
- # Do evaluation using the local worker. Expect error due to the
- # local worker not having an env.
- if self.evaluation_workers is None:
- # If unit=episodes -> Run n times `sample()` (each sample
- # produces exactly 1 episode).
- # If unit=ts -> Run 1 `sample()` b/c the
- # `rollout_fragment_length` is exactly the desired ts.
- iters = duration if unit == "episodes" else 1
- for _ in range(iters):
- batch = self.workers.local_worker().sample()
- agent_steps_this_iter += batch.agent_steps()
- env_steps_this_iter += batch.env_steps()
- if self.reward_estimators:
- all_batches.append(batch)
- metrics = collect_metrics(
- self.workers,
- keep_custom_metrics=eval_cfg.keep_per_episode_custom_metrics,
- timeout_seconds=eval_cfg.metrics_episode_collection_timeout_s,
- )
- # Evaluation worker set only has local worker.
- elif self.evaluation_workers.num_remote_workers() == 0:
- # If unit=episodes -> Run n times `sample()` (each sample
- # produces exactly 1 episode).
- # If unit=ts -> Run 1 `sample()` b/c the
- # `rollout_fragment_length` is exactly the desired ts.
- iters = duration if unit == "episodes" else 1
- for _ in range(iters):
- batch = self.evaluation_workers.local_worker().sample()
- agent_steps_this_iter += batch.agent_steps()
- env_steps_this_iter += batch.env_steps()
- if self.reward_estimators:
- all_batches.append(batch)
- # Evaluation worker set has n remote workers.
- elif self.evaluation_workers.num_healthy_remote_workers() > 0:
- # How many episodes have we run (across all eval workers)?
- num_units_done = 0
- _round = 0
- # In case all of the remote evaluation workers die during a round
- # of evaluation, we need to stop.
- while True and self.evaluation_workers.num_healthy_remote_workers() > 0:
- units_left_to_do = duration_fn(num_units_done)
- if units_left_to_do <= 0:
- break
- _round += 1
- unit_per_remote_worker = (
- 1 if unit == "episodes" else rollout * num_envs
- )
- # Select proper number of evaluation workers for this round.
- selected_eval_worker_ids = [
- worker_id
- for i, worker_id in enumerate(
- self.evaluation_workers.healthy_worker_ids()
- )
- if i * unit_per_remote_worker < units_left_to_do
- ]
- batches = self.evaluation_workers.foreach_worker(
- func=lambda w: w.sample(),
- local_worker=False,
- remote_worker_ids=selected_eval_worker_ids,
- timeout_seconds=self.config.evaluation_sample_timeout_s,
- )
- if len(batches) != len(selected_eval_worker_ids):
- logger.warning(
- "Calling `sample()` on your remote evaluation worker(s) "
- "resulted in a timeout (after the configured "
- f"{self.config.evaluation_sample_timeout_s} seconds)! "
- "Try to set `evaluation_sample_timeout_s` in your config"
- " to a larger value."
- + (
- " If your episodes don't terminate easily, you may "
- "also want to set `evaluation_duration_unit` to "
- "'timesteps' (instead of 'episodes')."
- if unit == "episodes"
- else ""
- )
- )
- break
- _agent_steps = sum(b.agent_steps() for b in batches)
- _env_steps = sum(b.env_steps() for b in batches)
- # 1 episode per returned batch.
- if unit == "episodes":
- num_units_done += len(batches)
- # Make sure all batches are exactly one episode.
- for ma_batch in batches:
- ma_batch = ma_batch.as_multi_agent()
- for batch in ma_batch.policy_batches.values():
- assert batch.is_terminated_or_truncated()
- # n timesteps per returned batch.
- else:
- num_units_done += (
- _agent_steps
- if self.config.count_steps_by == "agent_steps"
- else _env_steps
- )
- if self.reward_estimators:
- # TODO: (kourosh) This approach will cause an OOM issue when
- # the dataset gets huge (should be ok for now).
- all_batches.extend(batches)
- agent_steps_this_iter += _agent_steps
- env_steps_this_iter += _env_steps
- logger.info(
- f"Ran round {_round} of non-parallel evaluation "
- f"({num_units_done}/{duration if not auto else '?'} "
- f"{unit} done)"
- )
- else:
- # Can't find a good way to run this evaluation.
- # Wait for next iteration.
- pass
- if metrics is None:
- metrics = collect_metrics(
- self.evaluation_workers,
- keep_custom_metrics=self.config.keep_per_episode_custom_metrics,
- timeout_seconds=eval_cfg.metrics_episode_collection_timeout_s,
- )
- # TODO: Don't dump sampler results into top-level.
- if not self.config.custom_evaluation_function:
- metrics = dict({"sampler_results": metrics}, **metrics)
- metrics[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps_this_iter
- metrics[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps_this_iter
- # TODO: Remove this key at some point. Here for backward compatibility.
- metrics["timesteps_this_iter"] = env_steps_this_iter
- # Compute off-policy estimates
- estimates = defaultdict(list)
- # for each batch run the estimator's fwd pass
- for name, estimator in self.reward_estimators.items():
- for batch in all_batches:
- estimate_result = estimator.estimate(
- batch,
- split_batch_by_episode=self.config.ope_split_batch_by_episode,
- )
- estimates[name].append(estimate_result)
- # collate estimates from all batches
- if estimates:
- metrics["off_policy_estimator"] = {}
- for name, estimate_list in estimates.items():
- avg_estimate = tree.map_structure(
- lambda *x: np.mean(x, axis=0), *estimate_list
- )
- metrics["off_policy_estimator"][name] = avg_estimate
- # Evaluation does not run for every step.
- # Save evaluation metrics on Algorithm, so it can be attached to
- # subsequent step results as latest evaluation result.
- self.evaluation_metrics = {"evaluation": metrics}
- # Trigger `on_evaluate_end` callback.
- self.callbacks.on_evaluate_end(
- algorithm=self, evaluation_metrics=self.evaluation_metrics
- )
- # Also return the results here for convenience.
- return self.evaluation_metrics
- @ExperimentalAPI
- def _evaluate_async(
- self,
- duration_fn: Optional[Callable[[int], int]] = None,
- ) -> dict:
- """Evaluates current policy under `evaluation_config` settings.
- Uses the AsyncParallelRequests manager to send frequent `sample.remote()`
- requests to the evaluation RolloutWorkers and collect the results of these
- calls. Handles worker failures (or slowdowns) gracefully due to the asynch'ness
- and the fact that other eval RolloutWorkers can thus cover the workload.
- Important Note: This will replace the current `self.evaluate()` method as the
- default in the future.
- Args:
- duration_fn: An optional callable taking the already run
- num episodes as only arg and returning the number of
- episodes left to run. It's used to find out whether
- evaluation should continue.
- """
- # How many episodes/timesteps do we need to run?
- # In "auto" mode (only for parallel eval + training): Run as long
- # as training lasts.
- unit = self.config.evaluation_duration_unit
- eval_cfg = self.evaluation_config
- rollout = eval_cfg.rollout_fragment_length
- num_envs = eval_cfg.num_envs_per_worker
- auto = self.config.evaluation_duration == "auto"
- duration = (
- self.config.evaluation_duration
- if not auto
- else (self.config.evaluation_num_workers or 1)
- * (1 if unit == "episodes" else rollout)
- )
- # Call the `_before_evaluate` hook.
- self._before_evaluate()
- # TODO(Jun): Implement solution via connectors.
- self._sync_filters_if_needed(
- central_worker=self.workers.local_worker(),
- workers=self.evaluation_workers,
- config=eval_cfg,
- )
- if self.config.custom_evaluation_function:
- raise ValueError(
- "`config.custom_evaluation_function` not supported in combination "
- "with `enable_async_evaluation=True` config setting!"
- )
- if self.evaluation_workers is None and (
- self.workers.local_worker().input_reader is None
- or self.config.evaluation_num_workers == 0
- ):
- raise ValueError(
- "Evaluation w/o eval workers (calling Algorithm.evaluate() w/o "
- "evaluation specifically set up) OR evaluation without input reader "
- "OR evaluation with only a local evaluation worker "
- "(`evaluation_num_workers=0`) not supported in combination "
- "with `enable_async_evaluation=True` config setting!"
- )
- agent_steps_this_iter = 0
- env_steps_this_iter = 0
- logger.info(f"Evaluating current state of {self} for {duration} {unit}.")
- all_batches = []
- # Default done-function returns True, whenever num episodes
- # have been completed.
- if duration_fn is None:
- def duration_fn(num_units_done):
- return duration - num_units_done
- # Put weights only once into object store and use same object
- # ref to synch to all workers.
- self._evaluation_weights_seq_number += 1
- weights_ref = ray.put(self.workers.local_worker().get_weights())
- weights_seq_no = self._evaluation_weights_seq_number
- def remote_fn(worker):
- # Pass in seq-no so that eval workers may ignore this call if no update has
- # happened since the last call to `remote_fn` (sample).
- worker.set_weights(
- weights=ray.get(weights_ref), weights_seq_no=weights_seq_no
- )
- batch = worker.sample()
- metrics = worker.get_metrics()
- return batch, metrics, weights_seq_no
- rollout_metrics = []
- # How many episodes have we run (across all eval workers)?
- num_units_done = 0
- _round = 0
- while self.evaluation_workers.num_healthy_remote_workers() > 0:
- units_left_to_do = duration_fn(num_units_done)
- if units_left_to_do <= 0:
- break
- _round += 1
- # Get ready evaluation results and metrics asynchronously.
- self.evaluation_workers.foreach_worker_async(
- func=remote_fn,
- healthy_only=True,
- )
- eval_results = self.evaluation_workers.fetch_ready_async_reqs()
- batches = []
- i = 0
- for _, result in eval_results:
- batch, metrics, seq_no = result
- # Ignore results, if the weights seq-number does not match (is
- # from a previous evaluation step) OR if we have already reached
- # the configured duration (e.g. number of episodes to evaluate
- # for).
- if seq_no == self._evaluation_weights_seq_number and (
- i * (1 if unit == "episodes" else rollout * num_envs)
- < units_left_to_do
- ):
- batches.append(batch)
- rollout_metrics.extend(metrics)
- i += 1
- _agent_steps = sum(b.agent_steps() for b in batches)
- _env_steps = sum(b.env_steps() for b in batches)
- # 1 episode per returned batch.
- if unit == "episodes":
- num_units_done += len(batches)
- # Make sure all batches are exactly one episode.
- for ma_batch in batches:
- ma_batch = ma_batch.as_multi_agent()
- for batch in ma_batch.policy_batches.values():
- assert batch.is_terminated_or_truncated()
- # n timesteps per returned batch.
- else:
- num_units_done += (
- _agent_steps
- if self.config.count_steps_by == "agent_steps"
- else _env_steps
- )
- if self.reward_estimators:
- all_batches.extend(batches)
- agent_steps_this_iter += _agent_steps
- env_steps_this_iter += _env_steps
- logger.info(
- f"Ran round {_round} of parallel evaluation "
- f"({num_units_done}/{duration if not auto else '?'} "
- f"{unit} done)"
- )
- sampler_results = summarize_episodes(
- rollout_metrics,
- keep_custom_metrics=eval_cfg["keep_per_episode_custom_metrics"],
- )
- # TODO: Don't dump sampler results into top-level.
- metrics = dict({"sampler_results": sampler_results}, **sampler_results)
- metrics[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps_this_iter
- metrics[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps_this_iter
- # TODO: Remove this key at some point. Here for backward compatibility.
- metrics["timesteps_this_iter"] = env_steps_this_iter
- if self.reward_estimators:
- # Compute off-policy estimates
- metrics["off_policy_estimator"] = {}
- total_batch = concat_samples(all_batches)
- for name, estimator in self.reward_estimators.items():
- estimates = estimator.estimate(total_batch)
- metrics["off_policy_estimator"][name] = estimates
- # Evaluation does not run for every step.
- # Save evaluation metrics on Algorithm, so it can be attached to
- # subsequent step results as latest evaluation result.
- self.evaluation_metrics = {"evaluation": metrics}
- # Trigger `on_evaluate_end` callback.
- self.callbacks.on_evaluate_end(
- algorithm=self, evaluation_metrics=self.evaluation_metrics
- )
- # Return evaluation results.
- return self.evaluation_metrics
- @OverrideToImplementCustomLogic
- @DeveloperAPI
- def restore_workers(self, workers: WorkerSet):
- """Try to restore failed workers if necessary.
- Algorithms that use custom RolloutWorkers may override this method to
- disable default, and create custom restoration logics.
- Args:
- workers: The WorkerSet to restore. This may be Rollout or Evaluation
- workers.
- """
- # If `workers` is None, or
- # 1. `workers` (WorkerSet) does not have a local worker, and
- # 2. `self.workers` (WorkerSet used for training) does not have a local worker
- # -> we don't have a local worker to get state from, so we can't recover
- # remote worker in this case.
- if not workers or (
- not workers.local_worker() and not self.workers.local_worker()
- ):
- return
- # This is really cheap, since probe_unhealthy_workers() is a no-op
- # if there are no unhealthy workers.
- restored = workers.probe_unhealthy_workers()
- if restored:
- from_worker = workers.local_worker() or self.workers.local_worker()
- # Get the state of the correct (reference) worker. E.g. The local worker
- # of the main WorkerSet.
- state_ref = ray.put(from_worker.get_state())
- # By default, entire local worker state is synced after restoration
- # to bring these workers up to date.
- workers.foreach_worker(
- func=lambda w: w.set_state(ray.get(state_ref)),
- remote_worker_ids=restored,
- # Don't update the local_worker, b/c it's the one we are synching from.
- local_worker=False,
- timeout_seconds=self.config.worker_restore_timeout_s,
- # Bring back actor after successful state syncing.
- mark_healthy=True,
- )
- @OverrideToImplementCustomLogic
- @DeveloperAPI
- def training_step(self) -> ResultDict:
- """Default single iteration logic of an algorithm.
- - Collect on-policy samples (SampleBatches) in parallel using the
- Algorithm's RolloutWorkers (@ray.remote).
- - Concatenate collected SampleBatches into one train batch.
- - Note that we may have more than one policy in the multi-agent case:
- Call the different policies' `learn_on_batch` (simple optimizer) OR
- `load_batch_into_buffer` + `learn_on_loaded_batch` (multi-GPU
- optimizer) methods to calculate loss and update the model(s).
- - Return all collected metrics for the iteration.
- Returns:
- The results dict from executing the training iteration.
- """
- # Collect SampleBatches from sample workers until we have a full batch.
- with self._timers[SAMPLE_TIMER]:
- if self.config.count_steps_by == "agent_steps":
- train_batch = synchronous_parallel_sample(
- worker_set=self.workers,
- max_agent_steps=self.config.train_batch_size,
- )
- else:
- train_batch = synchronous_parallel_sample(
- worker_set=self.workers, max_env_steps=self.config.train_batch_size
- )
- train_batch = train_batch.as_multi_agent()
- self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
- self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
- # Only train if train_batch is not empty.
- # In an extreme situation, all rollout workers die during the
- # synchronous_parallel_sample() call above.
- # In which case, we should skip training, wait a little bit, then probe again.
- train_results = {}
- if train_batch.agent_steps() > 0:
- # Use simple optimizer (only for multi-agent or tf-eager; all other
- # cases should use the multi-GPU optimizer, even if only using 1 GPU).
- # TODO: (sven) rename MultiGPUOptimizer into something more
- # meaningful.
- if self.config._enable_learner_api:
- is_module_trainable = self.workers.local_worker().is_policy_to_train
- self.learner_group.set_is_module_trainable(is_module_trainable)
- train_results = self.learner_group.update(train_batch)
- elif self.config.get("simple_optimizer") is True:
- train_results = train_one_step(self, train_batch)
- else:
- train_results = multi_gpu_train_one_step(self, train_batch)
- else:
- # Wait 1 sec before probing again via weight syncing.
- time.sleep(1)
- # Update weights and global_vars - after learning on the local worker - on all
- # remote workers (only those policies that were actually trained).
- global_vars = {
- "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
- }
- with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
- # TODO (Avnish): Implement this on learner_group.get_weights().
- # TODO (Kourosh): figure out how we are going to sync MARLModule
- # weights to MARLModule weights under the policy_map objects?
- from_worker_or_trainer = None
- if self.config._enable_learner_api:
- from_worker_or_trainer = self.learner_group
- self.workers.sync_weights(
- from_worker_or_learner_group=from_worker_or_trainer,
- policies=list(train_results.keys()),
- global_vars=global_vars,
- )
- return train_results
- @staticmethod
- def execution_plan(workers, config, **kwargs):
- raise NotImplementedError(
- "It is no longer supported to use the `Algorithm.execution_plan()` API!"
- " Set `_disable_execution_plan_api=True` in your config and override the "
- "`Algorithm.training_step()` method with your algo's custom "
- "execution logic instead."
- )
- @PublicAPI
- def compute_single_action(
- self,
- observation: Optional[TensorStructType] = None,
- state: Optional[List[TensorStructType]] = None,
- *,
- prev_action: Optional[TensorStructType] = None,
- prev_reward: Optional[float] = None,
- info: Optional[EnvInfoDict] = None,
- input_dict: Optional[SampleBatch] = None,
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- full_fetch: bool = False,
- explore: Optional[bool] = None,
- timestep: Optional[int] = None,
- episode: Optional[Episode] = None,
- unsquash_action: Optional[bool] = None,
- clip_action: Optional[bool] = None,
- # Kwargs placeholder for future compatibility.
- **kwargs,
- ) -> Union[
- TensorStructType,
- Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]],
- ]:
- """Computes an action for the specified policy on the local worker.
- Note that you can also access the policy object through
- self.get_policy(policy_id) and call compute_single_action() on it
- directly.
- Args:
- observation: Single (unbatched) observation from the
- environment.
- state: List of all RNN hidden (single, unbatched) state tensors.
- prev_action: Single (unbatched) previous action value.
- prev_reward: Single (unbatched) previous reward value.
- info: Env info dict, if any.
- input_dict: An optional SampleBatch that holds all the values
- for: obs, state, prev_action, and prev_reward, plus maybe
- custom defined views of the current env trajectory. Note
- that only one of `obs` or `input_dict` must be non-None.
- policy_id: Policy to query (only applies to multi-agent).
- Default: "default_policy".
- full_fetch: Whether to return extra action fetch results.
- This is always set to True if `state` is specified.
- explore: Whether to apply exploration to the action.
- Default: None -> use self.config.explore.
- timestep: The current (sampling) time step.
- episode: This provides access to all of the internal episodes'
- state, which may be useful for model-based or multi-agent
- algorithms.
- unsquash_action: Should actions be unsquashed according to the
- env's/Policy's action space? If None, use the value of
- self.config.normalize_actions.
- clip_action: Should actions be clipped according to the
- env's/Policy's action space? If None, use the value of
- self.config.clip_actions.
- Keyword Args:
- kwargs: forward compatibility placeholder
- Returns:
- The computed action if full_fetch=False, or a tuple of a) the
- full output of policy.compute_actions() if full_fetch=True
- or we have an RNN-based Policy.
- Raises:
- KeyError: If the `policy_id` cannot be found in this Algorithm's local
- worker.
- """
- # `unsquash_action` is None: Use value of config['normalize_actions'].
- if unsquash_action is None:
- unsquash_action = self.config.normalize_actions
- # `clip_action` is None: Use value of config['clip_actions'].
- elif clip_action is None:
- clip_action = self.config.clip_actions
- # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state`
- # are all None.
- err_msg = (
- "Provide either `input_dict` OR [`observation`, ...] as "
- "args to `Algorithm.compute_single_action()`!"
- )
- if input_dict is not None:
- assert (
- observation is None
- and prev_action is None
- and prev_reward is None
- and state is None
- ), err_msg
- observation = input_dict[SampleBatch.OBS]
- else:
- assert observation is not None, err_msg
- # Get the policy to compute the action for (in the multi-agent case,
- # Algorithm may hold >1 policies).
- policy = self.get_policy(policy_id)
- if policy is None:
- raise KeyError(
- f"PolicyID '{policy_id}' not found in PolicyMap of the "
- f"Algorithm's local worker!"
- )
- local_worker = self.workers.local_worker()
- if not self.config.get("enable_connectors"):
- # Check the preprocessor and preprocess, if necessary.
- pp = local_worker.preprocessors[policy_id]
- if pp and type(pp).__name__ != "NoPreprocessor":
- observation = pp.transform(observation)
- observation = local_worker.filters[policy_id](observation, update=False)
- else:
- # Just preprocess observations, similar to how it used to be done before.
- pp = policy.agent_connectors[ObsPreprocessorConnector]
- # convert the observation to array if possible
- if not isinstance(observation, (np.ndarray, dict, tuple)):
- try:
- observation = np.asarray(observation)
- except Exception:
- raise ValueError(
- f"Observation type {type(observation)} cannot be converted to "
- f"np.ndarray."
- )
- if pp:
- assert len(pp) == 1, "Only one preprocessor should be in the pipeline"
- pp = pp[0]
- if not pp.is_identity():
- # Note(Kourosh): This call will leave the policy's connector
- # in eval mode. would that be a problem?
- pp.in_eval()
- if observation is not None:
- _input_dict = {SampleBatch.OBS: observation}
- elif input_dict is not None:
- _input_dict = {SampleBatch.OBS: input_dict[SampleBatch.OBS]}
- else:
- raise ValueError(
- "Either observation or input_dict must be provided."
- )
- # TODO (Kourosh): Create a new util method for algorithm that
- # computes actions based on raw inputs from env and can keep track
- # of its own internal state.
- acd = AgentConnectorDataType("0", "0", _input_dict)
- # make sure the state is reset since we are only applying the
- # preprocessor
- pp.reset(env_id="0")
- ac_o = pp([acd])[0]
- observation = ac_o.data[SampleBatch.OBS]
- # Input-dict.
- if input_dict is not None:
- input_dict[SampleBatch.OBS] = observation
- action, state, extra = policy.compute_single_action(
- input_dict=input_dict,
- explore=explore,
- timestep=timestep,
- episode=episode,
- )
- # Individual args.
- else:
- action, state, extra = policy.compute_single_action(
- obs=observation,
- state=state,
- prev_action=prev_action,
- prev_reward=prev_reward,
- info=info,
- explore=explore,
- timestep=timestep,
- episode=episode,
- )
- # If we work in normalized action space (normalize_actions=True),
- # we re-translate here into the env's action space.
- if unsquash_action:
- action = space_utils.unsquash_action(action, policy.action_space_struct)
- # Clip, according to env's action space.
- elif clip_action:
- action = space_utils.clip_action(action, policy.action_space_struct)
- # Return 3-Tuple: Action, states, and extra-action fetches.
- if state or full_fetch:
- return action, state, extra
- # Ensure backward compatibility.
- else:
- return action
- @PublicAPI
- def compute_actions(
- self,
- observations: TensorStructType,
- state: Optional[List[TensorStructType]] = None,
- *,
- prev_action: Optional[TensorStructType] = None,
- prev_reward: Optional[TensorStructType] = None,
- info: Optional[EnvInfoDict] = None,
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- full_fetch: bool = False,
- explore: Optional[bool] = None,
- timestep: Optional[int] = None,
- episodes: Optional[List[Episode]] = None,
- unsquash_actions: Optional[bool] = None,
- clip_actions: Optional[bool] = None,
- **kwargs,
- ):
- """Computes an action for the specified policy on the local Worker.
- Note that you can also access the policy object through
- self.get_policy(policy_id) and call compute_actions() on it directly.
- Args:
- observation: Observation from the environment.
- state: RNN hidden state, if any. If state is not None,
- then all of compute_single_action(...) is returned
- (computed action, rnn state(s), logits dictionary).
- Otherwise compute_single_action(...)[0] is returned
- (computed action).
- prev_action: Previous action value, if any.
- prev_reward: Previous reward, if any.
- info: Env info dict, if any.
- policy_id: Policy to query (only applies to multi-agent).
- full_fetch: Whether to return extra action fetch results.
- This is always set to True if RNN state is specified.
- explore: Whether to pick an exploitation or exploration
- action (default: None -> use self.config.explore).
- timestep: The current (sampling) time step.
- episodes: This provides access to all of the internal episodes'
- state, which may be useful for model-based or multi-agent
- algorithms.
- unsquash_actions: Should actions be unsquashed according
- to the env's/Policy's action space? If None, use
- self.config.normalize_actions.
- clip_actions: Should actions be clipped according to the
- env's/Policy's action space? If None, use
- self.config.clip_actions.
- Keyword Args:
- kwargs: forward compatibility placeholder
- Returns:
- The computed action if full_fetch=False, or a tuple consisting of
- the full output of policy.compute_actions_from_input_dict() if
- full_fetch=True or we have an RNN-based Policy.
- """
- # `unsquash_actions` is None: Use value of config['normalize_actions'].
- if unsquash_actions is None:
- unsquash_actions = self.config.normalize_actions
- # `clip_actions` is None: Use value of config['clip_actions'].
- elif clip_actions is None:
- clip_actions = self.config.clip_actions
- # Preprocess obs and states.
- state_defined = state is not None
- policy = self.get_policy(policy_id)
- filtered_obs, filtered_state = [], []
- for agent_id, ob in observations.items():
- worker = self.workers.local_worker()
- preprocessed = worker.preprocessors[policy_id].transform(ob)
- filtered = worker.filters[policy_id](preprocessed, update=False)
- filtered_obs.append(filtered)
- if state is None:
- continue
- elif agent_id in state:
- filtered_state.append(state[agent_id])
- else:
- filtered_state.append(policy.get_initial_state())
- # Batch obs and states
- obs_batch = np.stack(filtered_obs)
- if state is None:
- state = []
- else:
- state = list(zip(*filtered_state))
- state = [np.stack(s) for s in state]
- input_dict = {SampleBatch.OBS: obs_batch}
- # prev_action and prev_reward can be None, np.ndarray, or tensor-like structure.
- # Explicitly check for None here to avoid the error message "The truth value of
- # an array with more than one element is ambiguous.", when np arrays are passed
- # as arguments.
- if prev_action is not None:
- input_dict[SampleBatch.PREV_ACTIONS] = prev_action
- if prev_reward is not None:
- input_dict[SampleBatch.PREV_REWARDS] = prev_reward
- if info:
- input_dict[SampleBatch.INFOS] = info
- for i, s in enumerate(state):
- input_dict[f"state_in_{i}"] = s
- # Batch compute actions
- actions, states, infos = policy.compute_actions_from_input_dict(
- input_dict=input_dict,
- explore=explore,
- timestep=timestep,
- episodes=episodes,
- )
- # Unbatch actions for the environment into a multi-agent dict.
- single_actions = space_utils.unbatch(actions)
- actions = {}
- for key, a in zip(observations, single_actions):
- # If we work in normalized action space (normalize_actions=True),
- # we re-translate here into the env's action space.
- if unsquash_actions:
- a = space_utils.unsquash_action(a, policy.action_space_struct)
- # Clip, according to env's action space.
- elif clip_actions:
- a = space_utils.clip_action(a, policy.action_space_struct)
- actions[key] = a
- # Unbatch states into a multi-agent dict.
- unbatched_states = {}
- for idx, agent_id in enumerate(observations):
- unbatched_states[agent_id] = [s[idx] for s in states]
- # Return only actions or full tuple
- if state_defined or full_fetch:
- return actions, unbatched_states, infos
- else:
- return actions
- @PublicAPI
- def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy:
- """Return policy for the specified id, or None.
- Args:
- policy_id: ID of the policy to return.
- """
- return self.workers.local_worker().get_policy(policy_id)
- @PublicAPI
- def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict:
- """Return a dictionary of policy ids to weights.
- Args:
- policies: Optional list of policies to return weights for,
- or None for all policies.
- """
- return self.workers.local_worker().get_weights(policies)
- @PublicAPI
- def set_weights(self, weights: Dict[PolicyID, dict]):
- """Set policy weights by policy id.
- Args:
- weights: Map of policy ids to weights to set.
- """
- self.workers.local_worker().set_weights(weights)
- @PublicAPI
- def add_policy(
- self,
- policy_id: PolicyID,
- policy_cls: Optional[Type[Policy]] = None,
- policy: Optional[Policy] = None,
- *,
- observation_space: Optional[gym.spaces.Space] = None,
- action_space: Optional[gym.spaces.Space] = None,
- config: Optional[Union[AlgorithmConfig, PartialAlgorithmConfigDict]] = None,
- policy_state: Optional[PolicyState] = None,
- policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
- policies_to_train: Optional[
- Union[
- Container[PolicyID],
- Callable[[PolicyID, Optional[SampleBatchType]], bool],
- ]
- ] = None,
- evaluation_workers: bool = True,
- module_spec: Optional[SingleAgentRLModuleSpec] = None,
- ) -> Optional[Policy]:
- """Adds a new policy to this Algorithm.
- Args:
- policy_id: ID of the policy to add.
- IMPORTANT: Must not contain characters that
- are also not allowed in Unix/Win filesystems, such as: `<>:"/|?*`,
- or a dot, space or backslash at the end of the ID.
- policy_cls: The Policy class to use for constructing the new Policy.
- Note: Only one of `policy_cls` or `policy` must be provided.
- policy: The Policy instance to add to this algorithm. If not None, the
- given Policy object will be directly inserted into the Algorithm's
- local worker and clones of that Policy will be created on all remote
- workers as well as all evaluation workers.
- Note: Only one of `policy_cls` or `policy` must be provided.
- observation_space: The observation space of the policy to add.
- If None, try to infer this space from the environment.
- action_space: The action space of the policy to add.
- If None, try to infer this space from the environment.
- config: The config object or overrides for the policy to add.
- policy_state: Optional state dict to apply to the new
- policy instance, right after its construction.
- policy_mapping_fn: An optional (updated) policy mapping function
- to use from here on. Note that already ongoing episodes will
- not change their mapping but will use the old mapping till
- the end of the episode.
- policies_to_train: An optional list of policy IDs to be trained
- or a callable taking PolicyID and SampleBatchType and
- returning a bool (trainable or not?).
- If None, will keep the existing setup in place. Policies,
- whose IDs are not in the list (or for which the callable
- returns False) will not be updated.
- evaluation_workers: Whether to add the new policy also
- to the evaluation WorkerSet.
- module_spec: In the new RLModule API we need to pass in the module_spec for
- the new module that is supposed to be added. Knowing the policy spec is
- not sufficient.
- Returns:
- The newly added policy (the copy that got added to the local
- worker). If `workers` was provided, None is returned.
- """
- validate_policy_id(policy_id, error=True)
- self.workers.add_policy(
- policy_id,
- policy_cls,
- policy,
- observation_space=observation_space,
- action_space=action_space,
- config=config,
- policy_state=policy_state,
- policy_mapping_fn=policy_mapping_fn,
- policies_to_train=policies_to_train,
- module_spec=module_spec,
- )
- # If learner API is enabled, we need to also add the underlying module
- # to the learner group.
- if self.config._enable_learner_api:
- policy = self.get_policy(policy_id)
- module = policy.model
- self.learner_group.add_module(
- module_id=policy_id,
- module_spec=SingleAgentRLModuleSpec.from_module(module),
- )
- weights = policy.get_weights()
- self.learner_group.set_weights({policy_id: weights})
- # Add to evaluation workers, if necessary.
- if evaluation_workers is True and self.evaluation_workers is not None:
- self.evaluation_workers.add_policy(
- policy_id,
- policy_cls,
- policy,
- observation_space=observation_space,
- action_space=action_space,
- config=config,
- policy_state=policy_state,
- policy_mapping_fn=policy_mapping_fn,
- policies_to_train=policies_to_train,
- module_spec=module_spec,
- )
- # Return newly added policy (from the local rollout worker).
- return self.get_policy(policy_id)
- @PublicAPI
- def remove_policy(
- self,
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- *,
- policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
- policies_to_train: Optional[
- Union[
- Container[PolicyID],
- Callable[[PolicyID, Optional[SampleBatchType]], bool],
- ]
- ] = None,
- evaluation_workers: bool = True,
- ) -> None:
- """Removes a new policy from this Algorithm.
- Args:
- policy_id: ID of the policy to be removed.
- policy_mapping_fn: An optional (updated) policy mapping function
- to use from here on. Note that already ongoing episodes will
- not change their mapping but will use the old mapping till
- the end of the episode.
- policies_to_train: An optional list of policy IDs to be trained
- or a callable taking PolicyID and SampleBatchType and
- returning a bool (trainable or not?).
- If None, will keep the existing setup in place. Policies,
- whose IDs are not in the list (or for which the callable
- returns False) will not be updated.
- evaluation_workers: Whether to also remove the policy from the
- evaluation WorkerSet.
- """
- def fn(worker):
- worker.remove_policy(
- policy_id=policy_id,
- policy_mapping_fn=policy_mapping_fn,
- policies_to_train=policies_to_train,
- )
- self.workers.foreach_worker(fn, local_worker=True, healthy_only=True)
- if evaluation_workers and self.evaluation_workers is not None:
- self.evaluation_workers.foreach_worker(
- fn,
- local_worker=True,
- healthy_only=True,
- )
- @DeveloperAPI
- def export_policy_model(
- self,
- export_dir: str,
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- onnx: Optional[int] = None,
- ) -> None:
- """Exports policy model with given policy_id to a local directory.
- Args:
- export_dir: Writable local directory.
- policy_id: Optional policy id to export.
- onnx: If given, will export model in ONNX format. The
- value of this parameter set the ONNX OpSet version to use.
- If None, the output format will be DL framework specific.
- Example:
- >>> from ray.rllib.algorithms.ppo import PPO
- >>> # Use an Algorithm from RLlib or define your own.
- >>> algo = PPO(...) # doctest: +SKIP
- >>> for _ in range(10): # doctest: +SKIP
- >>> algo.train() # doctest: +SKIP
- >>> algo.export_policy_model("/tmp/dir") # doctest: +SKIP
- >>> algo.export_policy_model("/tmp/dir/onnx", onnx=1) # doctest: +SKIP
- """
- self.get_policy(policy_id).export_model(export_dir, onnx)
- @DeveloperAPI
- def export_policy_checkpoint(
- self,
- export_dir: str,
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- ) -> None:
- """Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.
- Args:
- export_dir: Writable local directory to store the AIR Checkpoint
- information into.
- policy_id: Optional policy ID to export. If not provided, will export
- "default_policy". If `policy_id` does not exist in this Algorithm,
- will raise a KeyError.
- Raises:
- KeyError if `policy_id` cannot be found in this Algorithm.
- Example:
- >>> from ray.rllib.algorithms.ppo import PPO
- >>> # Use an Algorithm from RLlib or define your own.
- >>> algo = PPO(...) # doctest: +SKIP
- >>> for _ in range(10): # doctest: +SKIP
- >>> algo.train() # doctest: +SKIP
- >>> algo.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP
- """
- policy = self.get_policy(policy_id)
- if policy is None:
- raise KeyError(f"Policy with ID {policy_id} not found in Algorithm!")
- policy.export_checkpoint(export_dir)
- @DeveloperAPI
- def import_policy_model_from_h5(
- self,
- import_file: str,
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- ) -> None:
- """Imports a policy's model with given policy_id from a local h5 file.
- Args:
- import_file: The h5 file to import from.
- policy_id: Optional policy id to import into.
- Example:
- >>> from ray.rllib.algorithms.ppo import PPO
- >>> algo = PPO(...) # doctest: +SKIP
- >>> algo.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP
- >>> for _ in range(10): # doctest: +SKIP
- >>> algo.train() # doctest: +SKIP
- """
- self.get_policy(policy_id).import_model_from_h5(import_file)
- # Sync new weights to remote workers.
- self._sync_weights_to_workers(worker_set=self.workers)
- @override(Trainable)
- def save_checkpoint(self, checkpoint_dir: str) -> None:
- """Exports checkpoint to a local directory.
- The structure of an Algorithm checkpoint dir will be as follows::
- policies/
- pol_1/
- policy_state.pkl
- pol_2/
- policy_state.pkl
- learner/
- learner_state.json
- module_state/
- module_1/
- ...
- optimizer_state/
- optimizers_module_1/
- ...
- rllib_checkpoint.json
- algorithm_state.pkl
- Note: `rllib_checkpoint.json` contains a "version" key (e.g. with value 0.1)
- helping RLlib to remain backward compatible wrt. restoring from checkpoints from
- Ray 2.0 onwards.
- Args:
- checkpoint_dir: The directory where the checkpoint files will be stored.
- """
- state = self.__getstate__()
- # Extract policy states from worker state (Policies get their own
- # checkpoint sub-dirs).
- policy_states = {}
- if "worker" in state and "policy_states" in state["worker"]:
- policy_states = state["worker"].pop("policy_states", {})
- # Add RLlib checkpoint version.
- if self.config._enable_learner_api:
- state["checkpoint_version"] = CHECKPOINT_VERSION_LEARNER
- else:
- state["checkpoint_version"] = CHECKPOINT_VERSION
- # Write state (w/o policies) to disk.
- state_file = os.path.join(checkpoint_dir, "algorithm_state.pkl")
- with open(state_file, "wb") as f:
- pickle.dump(state, f)
- # Write rllib_checkpoint.json.
- with open(os.path.join(checkpoint_dir, "rllib_checkpoint.json"), "w") as f:
- json.dump(
- {
- "type": "Algorithm",
- "checkpoint_version": str(state["checkpoint_version"]),
- "format": "cloudpickle",
- "state_file": state_file,
- "policy_ids": list(policy_states.keys()),
- "ray_version": ray.__version__,
- "ray_commit": ray.__commit__,
- },
- f,
- )
- # Write individual policies to disk, each in their own sub-directory.
- for pid, policy_state in policy_states.items():
- # From here on, disallow policyIDs that would not work as directory names.
- validate_policy_id(pid, error=True)
- policy_dir = os.path.join(checkpoint_dir, "policies", pid)
- os.makedirs(policy_dir, exist_ok=True)
- policy = self.get_policy(pid)
- policy.export_checkpoint(policy_dir, policy_state=policy_state)
- # if we are using the learner API, save the learner group state
- if self.config._enable_learner_api:
- learner_state_dir = os.path.join(checkpoint_dir, "learner")
- self.learner_group.save_state(learner_state_dir)
- @override(Trainable)
- def load_checkpoint(self, checkpoint_dir: str) -> None:
- # Checkpoint is provided as a local directory.
- # Restore from the checkpoint file or dir.
- checkpoint_info = get_checkpoint_info(checkpoint_dir)
- checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(checkpoint_info)
- self.__setstate__(checkpoint_data)
- if self.config._enable_learner_api:
- learner_state_dir = os.path.join(checkpoint_dir, "learner")
- self.learner_group.load_state(learner_state_dir)
- @override(Trainable)
- def log_result(self, result: ResultDict) -> None:
- # Log after the callback is invoked, so that the user has a chance
- # to mutate the result.
- # TODO: Remove `algorithm` arg at some point to fully deprecate the old
- # signature.
- self.callbacks.on_train_result(algorithm=self, result=result)
- # Then log according to Trainable's logging logic.
- Trainable.log_result(self, result)
- @override(Trainable)
- def cleanup(self) -> None:
- # Stop all workers.
- if hasattr(self, "workers") and self.workers is not None:
- self.workers.stop()
- if hasattr(self, "evaluation_workers") and self.evaluation_workers is not None:
- self.evaluation_workers.stop()
- @OverrideToImplementCustomLogic
- @classmethod
- @override(Trainable)
- def default_resource_request(
- cls, config: Union[AlgorithmConfig, PartialAlgorithmConfigDict]
- ) -> Union[Resources, PlacementGroupFactory]:
- # Default logic for RLlib Algorithms:
- # Create one bundle per individual worker (local or remote).
- # Use `num_cpus_for_local_worker` and `num_gpus` for the local worker and
- # `num_cpus_per_worker` and `num_gpus_per_worker` for the remote
- # workers to determine their CPU/GPU resource needs.
- # Convenience config handles.
- cf = cls.get_default_config().update_from_dict(config)
- cf.validate()
- cf.freeze()
- # get evaluation config
- eval_cf = cf.get_evaluation_config_object()
- eval_cf.validate()
- eval_cf.freeze()
- # resources for the driver of this trainable
- if cf._enable_learner_api:
- if cf.num_learner_workers == 0:
- # in this case local_worker only does sampling and training is done on
- # local learner worker
- driver = cls._get_learner_bundles(cf)[0]
- else:
- # in this case local_worker only does sampling and training is done on
- # remote learner workers
- driver = {"CPU": cf.num_cpus_for_local_worker, "GPU": 0}
- else:
- driver = {
- "CPU": cf.num_cpus_for_local_worker,
- "GPU": 0 if cf._fake_gpus else cf.num_gpus,
- }
- # resources for remote rollout env samplers
- rollout_bundles = [
- {
- "CPU": cf.num_cpus_per_worker,
- "GPU": cf.num_gpus_per_worker,
- **cf.custom_resources_per_worker,
- }
- for _ in range(cf.num_rollout_workers)
- ]
- # resources for remote evaluation env samplers or datasets (if any)
- if cls._should_create_evaluation_rollout_workers(eval_cf):
- # Evaluation workers.
- # Note: The local eval worker is located on the driver CPU.
- evaluation_bundles = [
- {
- "CPU": eval_cf.num_cpus_per_worker,
- "GPU": eval_cf.num_gpus_per_worker,
- **eval_cf.custom_resources_per_worker,
- }
- for _ in range(eval_cf.evaluation_num_workers)
- ]
- else:
- # resources for offline dataset readers during evaluation
- # Note (Kourosh): we should not claim extra workers for
- # training on the offline dataset, since rollout workers have already
- # claimed it.
- # Another Note (Kourosh): dataset reader will not use placement groups so
- # whatever we specify here won't matter because dataset won't even use it.
- # Disclaimer: using ray dataset in tune may cause deadlock when multiple
- # tune trials get scheduled on the same node and do not leave any spare
- # resources for dataset operations. The workaround is to limit the
- # max_concurrent trials so that some spare cpus are left for dataset
- # operations. This behavior should get fixed by the dataset team. more info
- # found here:
- # https://docs.ray.io/en/master/data/dataset-internals.html#datasets-tune
- evaluation_bundles = []
- # resources for remote learner workers
- learner_bundles = []
- if cf._enable_learner_api and cf.num_learner_workers > 0:
- learner_bundles = cls._get_learner_bundles(cf)
- bundles = [driver] + rollout_bundles + evaluation_bundles + learner_bundles
- # Return PlacementGroupFactory containing all needed resources
- # (already properly defined as device bundles).
- return PlacementGroupFactory(
- bundles=bundles,
- strategy=config.get("placement_strategy", "PACK"),
- )
- @DeveloperAPI
- def _before_evaluate(self):
- """Pre-evaluation callback."""
- pass
- @staticmethod
- def _get_env_id_and_creator(
- env_specifier: Union[str, EnvType, None], config: AlgorithmConfig
- ) -> Tuple[Optional[str], EnvCreator]:
- """Returns env_id and creator callable given original env id from config.
- Args:
- env_specifier: An env class, an already tune registered env ID, a known
- gym env name, or None (if no env is used).
- config: The AlgorithmConfig object.
- Returns:
- Tuple consisting of a) env ID string and b) env creator callable.
- """
- # Environment is specified via a string.
- if isinstance(env_specifier, str):
- # An already registered env.
- if _global_registry.contains(ENV_CREATOR, env_specifier):
- return env_specifier, _global_registry.get(ENV_CREATOR, env_specifier)
- # A class path specifier.
- elif "." in env_specifier:
- def env_creator_from_classpath(env_context):
- try:
- env_obj = from_config(env_specifier, env_context)
- except ValueError:
- raise EnvError(
- ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_specifier)
- )
- return env_obj
- return env_specifier, env_creator_from_classpath
- # Try gym/PyBullet/Vizdoom.
- else:
- return env_specifier, functools.partial(
- _gym_env_creator, env_descriptor=env_specifier
- )
- elif isinstance(env_specifier, type):
- env_id = env_specifier # .__name__
- if config["remote_worker_envs"]:
- # Check gym version (0.22 or higher?).
- # If > 0.21, can't perform auto-wrapping of the given class as this
- # would lead to a pickle error.
- gym_version = pkg_resources.get_distribution("gym").version
- if version.parse(gym_version) >= version.parse("0.22"):
- raise ValueError(
- "Cannot specify a gym.Env class via `config.env` while setting "
- "`config.remote_worker_env=True` AND your gym version is >= "
- "0.22! Try installing an older version of gym or set `config."
- "remote_worker_env=False`."
- )
- @ray.remote(num_cpus=1)
- class _wrapper(env_specifier):
- # Add convenience `_get_spaces` and `_is_multi_agent`
- # methods:
- def _get_spaces(self):
- return self.observation_space, self.action_space
- def _is_multi_agent(self):
- from ray.rllib.env.multi_agent_env import MultiAgentEnv
- return isinstance(self, MultiAgentEnv)
- return env_id, lambda cfg: _wrapper.remote(cfg)
- # gym.Env-subclass: Also go through our RLlib gym-creator.
- elif issubclass(env_specifier, gym.Env):
- return env_id, functools.partial(
- _gym_env_creator,
- env_descriptor=env_specifier,
- auto_wrap_old_gym_envs=config.get("auto_wrap_old_gym_envs", True),
- )
- # All other env classes: Call c'tor directly.
- else:
- return env_id, lambda cfg: env_specifier(cfg)
- # No env -> Env creator always returns None.
- elif env_specifier is None:
- return None, lambda env_config: None
- else:
- raise ValueError(
- "{} is an invalid env specifier. ".format(env_specifier)
- + "You can specify a custom env as either a class "
- '(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
- )
- def _sync_filters_if_needed(
- self,
- *,
- central_worker: RolloutWorker,
- workers: WorkerSet,
- config: AlgorithmConfig,
- ) -> None:
- """Synchronizes the filter stats from `workers` to `central_worker`.
- .. and broadcasts the central_worker's filter stats back to all `workers`
- (if configured).
- Args:
- central_worker: The worker to sync/aggregate all `workers`' filter stats to
- and from which to (possibly) broadcast the updated filter stats back to
- `workers`.
- workers: The WorkerSet, whose workers' filter stats should be used for
- aggregation on `central_worker` and which (possibly) get updated
- from `central_worker` after the sync.
- config: The algorithm config instance. This is used to determine, whether
- syncing from `workers` should happen at all and whether broadcasting
- back to `workers` (after possible syncing) should happen.
- """
- if central_worker and config.observation_filter != "NoFilter":
- FilterManager.synchronize(
- central_worker.filters,
- workers,
- update_remote=config.update_worker_filter_stats,
- timeout_seconds=config.sync_filters_on_rollout_workers_timeout_s,
- use_remote_data_for_update=config.use_worker_filter_stats,
- )
- @DeveloperAPI
- def _sync_weights_to_workers(
- self,
- *,
- worker_set: WorkerSet,
- ) -> None:
- """Sync "main" weights to given WorkerSet or list of workers."""
- # Broadcast the new policy weights to all remote workers in worker_set.
- logger.info("Synchronizing weights to workers.")
- worker_set.sync_weights()
- @classmethod
- @override(Trainable)
- def resource_help(cls, config: Union[AlgorithmConfig, AlgorithmConfigDict]) -> str:
- return (
- "\n\nYou can adjust the resource requests of RLlib Algorithms by calling "
- "`AlgorithmConfig.resources("
- "num_gpus=.., num_cpus_per_worker=.., num_gpus_per_worker=.., ..)` or "
- "`AgorithmConfig.rollouts(num_rollout_workers=..)`. See "
- "the `ray.rllib.algorithms.algorithm_config.AlgorithmConfig` classes "
- "(each Algorithm has its own subclass of this class) for more info.\n\n"
- f"The config of this Algorithm is: {config}"
- )
- @override(Trainable)
- def get_auto_filled_metrics(
- self,
- now: Optional[datetime] = None,
- time_this_iter: Optional[float] = None,
- timestamp: Optional[int] = None,
- debug_metrics_only: bool = False,
- ) -> dict:
- # Override this method to make sure, the `config` key of the returned results
- # contains the proper Tune config dict (instead of an AlgorithmConfig object).
- auto_filled = super().get_auto_filled_metrics(
- now, time_this_iter, timestamp, debug_metrics_only
- )
- if "config" not in auto_filled:
- raise KeyError("`config` key not found in auto-filled results dict!")
- # If `config` key is no dict (but AlgorithmConfig object) ->
- # make sure, it's a dict to not break Tune APIs.
- if not isinstance(auto_filled["config"], dict):
- assert isinstance(auto_filled["config"], AlgorithmConfig)
- auto_filled["config"] = auto_filled["config"].to_dict()
- return auto_filled
- @classmethod
- def merge_algorithm_configs(
- cls,
- config1: AlgorithmConfigDict,
- config2: PartialAlgorithmConfigDict,
- _allow_unknown_configs: Optional[bool] = None,
- ) -> AlgorithmConfigDict:
- """Merges a complete Algorithm config dict with a partial override dict.
- Respects nested structures within the config dicts. The values in the
- partial override dict take priority.
- Args:
- config1: The complete Algorithm's dict to be merged (overridden)
- with `config2`.
- config2: The partial override config dict to merge on top of
- `config1`.
- _allow_unknown_configs: If True, keys in `config2` that don't exist
- in `config1` are allowed and will be added to the final config.
- Returns:
- The merged full algorithm config dict.
- """
- config1 = copy.deepcopy(config1)
- if "callbacks" in config2 and type(config2["callbacks"]) is dict:
- deprecation_warning(
- "callbacks dict interface",
- "a class extending rllib.algorithms.callbacks.DefaultCallbacks; "
- "see `rllib/examples/custom_metrics_and_callbacks.py` for an example.",
- error=True,
- )
- if _allow_unknown_configs is None:
- _allow_unknown_configs = cls._allow_unknown_configs
- return deep_update(
- config1,
- config2,
- _allow_unknown_configs,
- cls._allow_unknown_subkeys,
- cls._override_all_subkeys_if_type_changes,
- cls._override_all_key_list,
- )
- @staticmethod
- @ExperimentalAPI
- def validate_env(env: EnvType, env_context: EnvContext) -> None:
- """Env validator function for this Algorithm class.
- Override this in child classes to define custom validation
- behavior.
- Args:
- env: The (sub-)environment to validate. This is normally a
- single sub-environment (e.g. a gym.Env) within a vectorized
- setup.
- env_context: The EnvContext to configure the environment.
- Raises:
- Exception in case something is wrong with the given environment.
- """
- pass
- @override(Trainable)
- def _export_model(
- self, export_formats: List[str], export_dir: str
- ) -> Dict[str, str]:
- ExportFormat.validate(export_formats)
- exported = {}
- if ExportFormat.CHECKPOINT in export_formats:
- path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
- self.export_policy_checkpoint(path)
- exported[ExportFormat.CHECKPOINT] = path
- if ExportFormat.MODEL in export_formats:
- path = os.path.join(export_dir, ExportFormat.MODEL)
- self.export_policy_model(path)
- exported[ExportFormat.MODEL] = path
- if ExportFormat.ONNX in export_formats:
- path = os.path.join(export_dir, ExportFormat.ONNX)
- self.export_policy_model(path, onnx=int(os.getenv("ONNX_OPSET", "11")))
- exported[ExportFormat.ONNX] = path
- return exported
- def import_model(self, import_file: str):
- """Imports a model from import_file.
- Note: Currently, only h5 files are supported.
- Args:
- import_file: The file to import the model from.
- Returns:
- A dict that maps ExportFormats to successfully exported models.
- """
- # Check for existence.
- if not os.path.exists(import_file):
- raise FileNotFoundError(
- "`import_file` '{}' does not exist! Can't import Model.".format(
- import_file
- )
- )
- # Get the format of the given file.
- import_format = "h5" # TODO(sven): Support checkpoint loading.
- ExportFormat.validate([import_format])
- if import_format != ExportFormat.H5:
- raise NotImplementedError
- else:
- return self.import_policy_model_from_h5(import_file)
- @PublicAPI
- def __getstate__(self) -> Dict:
- """Returns current state of Algorithm, sufficient to restore it from scratch.
- Returns:
- The current state dict of this Algorithm, which can be used to sufficiently
- restore the algorithm from scratch without any other information.
- """
- # Add config to state so complete Algorithm can be reproduced w/o it.
- state = {
- "algorithm_class": type(self),
- "config": self.config,
- }
- if hasattr(self, "workers"):
- state["worker"] = self.workers.local_worker().get_state()
- # TODO: Experimental functionality: Store contents of replay buffer
- # to checkpoint, only if user has configured this.
- if self.local_replay_buffer is not None and self.config.get(
- "store_buffer_in_checkpoints"
- ):
- state["local_replay_buffer"] = self.local_replay_buffer.get_state()
- if self.train_exec_impl is not None:
- state["train_exec_impl"] = self.train_exec_impl.shared_metrics.get().save()
- else:
- state["counters"] = self._counters
- state["training_iteration"] = self.training_iteration
- return state
- @PublicAPI
- def __setstate__(self, state) -> None:
- """Sets the algorithm to the provided state.
- Args:
- state: The state dict to restore this Algorithm instance to. `state` may
- have been returned by a call to an Algorithm's `__getstate__()` method.
- """
- # TODO (sven): Validate that our config and the config in state are compatible.
- # For example, the model architectures may differ.
- # Also, what should the behavior be if e.g. some training parameter
- # (e.g. lr) changed?
- if hasattr(self, "workers") and "worker" in state:
- self.workers.local_worker().set_state(state["worker"])
- remote_state = ray.put(state["worker"])
- self.workers.foreach_worker(
- lambda w: w.set_state(ray.get(remote_state)),
- local_worker=False,
- healthy_only=False,
- )
- if self.evaluation_workers:
- # If evaluation workers are used, also restore the policies
- # there in case they are used for evaluation purpose.
- self.evaluation_workers.foreach_worker(
- lambda w: w.set_state(ray.get(remote_state)),
- healthy_only=False,
- )
- # If necessary, restore replay data as well.
- if self.local_replay_buffer is not None:
- # TODO: Experimental functionality: Restore contents of replay
- # buffer from checkpoint, only if user has configured this.
- if self.config.get("store_buffer_in_checkpoints"):
- if "local_replay_buffer" in state:
- self.local_replay_buffer.set_state(state["local_replay_buffer"])
- else:
- logger.warning(
- "`store_buffer_in_checkpoints` is True, but no replay "
- "data found in state!"
- )
- elif "local_replay_buffer" in state and log_once(
- "no_store_buffer_in_checkpoints_but_data_found"
- ):
- logger.warning(
- "`store_buffer_in_checkpoints` is False, but some replay "
- "data found in state!"
- )
- if self.train_exec_impl is not None:
- self.train_exec_impl.shared_metrics.get().restore(state["train_exec_impl"])
- elif "counters" in state:
- self._counters = state["counters"]
- if "training_iteration" in state:
- self._iteration = state["training_iteration"]
- @staticmethod
- def _checkpoint_info_to_algorithm_state(
- checkpoint_info: dict,
- policy_ids: Optional[Container[PolicyID]] = None,
- policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
- policies_to_train: Optional[
- Union[
- Container[PolicyID],
- Callable[[PolicyID, Optional[SampleBatchType]], bool],
- ]
- ] = None,
- ) -> Dict:
- """Converts a checkpoint info or object to a proper Algorithm state dict.
- The returned state dict can be used inside self.__setstate__().
- Args:
- checkpoint_info: A checkpoint info dict as returned by
- `ray.rllib.utils.checkpoints.get_checkpoint_info(
- [checkpoint dir or AIR Checkpoint])`.
- policy_ids: Optional list/set of PolicyIDs. If not None, only those policies
- listed here will be included in the returned state. Note that
- state items such as filters, the `is_policy_to_train` function, as
- well as the multi-agent `policy_ids` dict will be adjusted as well,
- based on this arg.
- policy_mapping_fn: An optional (updated) policy mapping function
- to include in the returned state.
- policies_to_train: An optional list of policy IDs to be trained
- or a callable taking PolicyID and SampleBatchType and
- returning a bool (trainable or not?) to include in the returned state.
- Returns:
- The state dict usable within the `self.__setstate__()` method.
- """
- if checkpoint_info["type"] != "Algorithm":
- raise ValueError(
- "`checkpoint` arg passed to "
- "`Algorithm._checkpoint_info_to_algorithm_state()` must be an "
- f"Algorithm checkpoint (but is {checkpoint_info['type']})!"
- )
- msgpack = None
- if checkpoint_info.get("format") == "msgpack":
- msgpack = try_import_msgpack(error=True)
- with open(checkpoint_info["state_file"], "rb") as f:
- if msgpack is not None:
- state = msgpack.load(f)
- else:
- state = pickle.load(f)
- # New checkpoint format: Policies are in separate sub-dirs.
- # Note: Algorithms like ES/ARS don't have a WorkerSet, so we just return
- # the plain state here.
- if (
- checkpoint_info["checkpoint_version"] > version.Version("0.1")
- and state.get("worker") is not None
- ):
- worker_state = state["worker"]
- # Retrieve the set of all required policy IDs.
- policy_ids = set(
- policy_ids if policy_ids is not None else worker_state["policy_ids"]
- )
- # Remove those policies entirely from filters that are not in
- # `policy_ids`.
- worker_state["filters"] = {
- pid: filter
- for pid, filter in worker_state["filters"].items()
- if pid in policy_ids
- }
- # Get Algorithm class.
- if isinstance(state["algorithm_class"], str):
- # Try deserializing from a full classpath.
- # Or as a last resort: Tune registered algorithm name.
- state["algorithm_class"] = deserialize_type(
- state["algorithm_class"]
- ) or get_trainable_cls(state["algorithm_class"])
- # Compile actual config object.
- default_config = state["algorithm_class"].get_default_config()
- if isinstance(default_config, AlgorithmConfig):
- new_config = default_config.update_from_dict(state["config"])
- else:
- new_config = Algorithm.merge_algorithm_configs(
- default_config, state["config"]
- )
- # Remove policies from multiagent dict that are not in `policy_ids`.
- new_policies = new_config.policies
- if isinstance(new_policies, (set, list, tuple)):
- new_policies = {pid for pid in new_policies if pid in policy_ids}
- else:
- new_policies = {
- pid: spec for pid, spec in new_policies.items() if pid in policy_ids
- }
- new_config.multi_agent(
- policies=new_policies,
- policies_to_train=policies_to_train,
- **(
- {"policy_mapping_fn": policy_mapping_fn}
- if policy_mapping_fn is not None
- else {}
- ),
- )
- state["config"] = new_config
- # Prepare local `worker` state to add policies' states into it,
- # read from separate policy checkpoint files.
- worker_state["policy_states"] = {}
- for pid in policy_ids:
- policy_state_file = os.path.join(
- checkpoint_info["checkpoint_dir"],
- "policies",
- pid,
- "policy_state."
- + ("msgpck" if checkpoint_info["format"] == "msgpack" else "pkl"),
- )
- if not os.path.isfile(policy_state_file):
- raise ValueError(
- "Given checkpoint does not seem to be valid! No policy "
- f"state file found for PID={pid}. "
- f"The file not found is: {policy_state_file}."
- )
- with open(policy_state_file, "rb") as f:
- if msgpack is not None:
- worker_state["policy_states"][pid] = msgpack.load(f)
- else:
- worker_state["policy_states"][pid] = pickle.load(f)
- # These two functions are never serialized in a msgpack checkpoint (which
- # does not store code, unlike a cloudpickle checkpoint). Hence the user has
- # to provide them with the `Algorithm.from_checkpoint()` call.
- if policy_mapping_fn is not None:
- worker_state["policy_mapping_fn"] = policy_mapping_fn
- if (
- policies_to_train is not None
- # `policies_to_train` might be left None in case all policies should be
- # trained.
- or worker_state["is_policy_to_train"] == NOT_SERIALIZABLE
- ):
- worker_state["is_policy_to_train"] = policies_to_train
- return state
- @DeveloperAPI
- def _create_local_replay_buffer_if_necessary(
- self, config: PartialAlgorithmConfigDict
- ) -> Optional[MultiAgentReplayBuffer]:
- """Create a MultiAgentReplayBuffer instance if necessary.
- Args:
- config: Algorithm-specific configuration data.
- Returns:
- MultiAgentReplayBuffer instance based on algorithm config.
- None, if local replay buffer is not needed.
- """
- if not config.get("replay_buffer_config") or config["replay_buffer_config"].get(
- "no_local_replay_buffer"
- ):
- return
- return from_config(ReplayBuffer, config["replay_buffer_config"])
- @DeveloperAPI
- def _kwargs_for_execution_plan(self):
- kwargs = {}
- if self.local_replay_buffer is not None:
- kwargs["local_replay_buffer"] = self.local_replay_buffer
- return kwargs
- def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
- """Runs one training iteration (self.iteration will be +1 after this).
- Calls `self.training_step()` repeatedly until the minimum time (sec),
- sample- or training steps have been reached.
- Returns:
- The results dict from the training iteration.
- """
- # In case we are training (in a thread) parallel to evaluation,
- # we may have to re-enable eager mode here (gets disabled in the
- # thread).
- if self.config.get("framework") == "tf2" and not tf.executing_eagerly():
- tf1.enable_eager_execution()
- results = None
- # Create a step context ...
- with TrainIterCtx(algo=self) as train_iter_ctx:
- # .. so we can query it whether we should stop the iteration loop (e.g.
- # when we have reached `min_time_s_per_iteration`).
- while not train_iter_ctx.should_stop(results):
- # Try to train one step.
- # TODO (avnishn): Remove the execution plan API by q1 2023
- with self._timers[TRAINING_ITERATION_TIMER]:
- if self.config._disable_execution_plan_api:
- results = self.training_step()
- else:
- results = next(self.train_exec_impl)
- # With training step done. Try to bring failed workers back.
- self.restore_workers(self.workers)
- return results, train_iter_ctx
- def _run_one_evaluation(
- self,
- train_future: Optional[concurrent.futures.ThreadPoolExecutor] = None,
- ) -> ResultDict:
- """Runs evaluation step via `self.evaluate()` and handling worker failures.
- Args:
- train_future: In case, we are training and avaluating in parallel,
- this arg carries the currently running ThreadPoolExecutor
- object that runs the training iteration
- Returns:
- The results dict from the evaluation call.
- """
- eval_func_to_use = (
- self._evaluate_async
- if self.config.enable_async_evaluation
- else self.evaluate
- )
- if self.config.evaluation_duration == "auto":
- assert (
- train_future is not None and self.config.evaluation_parallel_to_training
- )
- unit = self.config.evaluation_duration_unit
- eval_results = eval_func_to_use(
- duration_fn=functools.partial(
- self._automatic_evaluation_duration_fn,
- unit,
- self.config.evaluation_num_workers,
- self.evaluation_config,
- train_future,
- )
- )
- # Run `self.evaluate()` only once per training iteration.
- else:
- eval_results = eval_func_to_use()
- if self.evaluation_workers is not None:
- # After evaluation, do a round of health check to see if any of
- # the failed workers are back.
- self.restore_workers(self.evaluation_workers)
- # Add number of healthy evaluation workers after this iteration.
- eval_results["evaluation"][
- "num_healthy_workers"
- ] = self.evaluation_workers.num_healthy_remote_workers()
- eval_results["evaluation"][
- "num_in_flight_async_reqs"
- ] = self.evaluation_workers.num_in_flight_async_reqs()
- eval_results["evaluation"][
- "num_remote_worker_restarts"
- ] = self.evaluation_workers.num_remote_worker_restarts()
- return eval_results
- def _run_one_training_iteration_and_evaluation_in_parallel(
- self,
- ) -> Tuple[ResultDict, "TrainIterCtx"]:
- """Runs one training iteration and one evaluation step in parallel.
- First starts the training iteration (via `self._run_one_training_iteration()`)
- within a ThreadPoolExecutor, then runs the evaluation step in parallel.
- In auto-duration mode (config.evaluation_duration=auto), makes sure the
- evaluation step takes roughly the same time as the training iteration.
- Returns:
- The accumulated training and evaluation results.
- """
- with concurrent.futures.ThreadPoolExecutor() as executor:
- train_future = executor.submit(lambda: self._run_one_training_iteration())
- # Pass the train_future into `self._run_one_evaluation()` to allow it
- # to run exactly as long as the training iteration takes in case
- # evaluation_duration=auto.
- results = self._run_one_evaluation(train_future)
- # Collect the training results from the future.
- train_results, train_iter_ctx = train_future.result()
- results.update(train_results)
- return results, train_iter_ctx
- def _run_offline_evaluation(self):
- """Runs offline evaluation via `OfflineEvaluator.estimate_on_dataset()` API.
- This method will be used when `evaluation_dataset` is provided.
- Note: This will only work if the policy is a single agent policy.
- Returns:
- The results dict from the offline evaluation call.
- """
- assert len(self.workers.local_worker().policy_map) == 1
- parallelism = self.evaluation_config.evaluation_num_workers or 1
- offline_eval_results = {"off_policy_estimator": {}}
- for evaluator_name, offline_evaluator in self.reward_estimators.items():
- offline_eval_results["off_policy_estimator"][
- evaluator_name
- ] = offline_evaluator.estimate_on_dataset(
- self.evaluation_dataset,
- n_parallelism=parallelism,
- )
- return offline_eval_results
- @classmethod
- def _should_create_evaluation_rollout_workers(cls, eval_config: "AlgorithmConfig"):
- """Determines whether we need to create evaluation workers.
- Returns False if we need to run offline evaluation
- (with ope.estimate_on_dastaset API) or when local worker is to be used for
- evaluation. Note: We only use estimate_on_dataset API with bandits for now.
- That is when ope_split_batch_by_episode is False.
- TODO: In future we will do the same for episodic RL OPE.
- """
- run_offline_evaluation = (
- eval_config.off_policy_estimation_methods
- and not eval_config.ope_split_batch_by_episode
- )
- return not run_offline_evaluation and (
- eval_config.evaluation_num_workers > 0 or eval_config.evaluation_interval
- )
- @staticmethod
- def _automatic_evaluation_duration_fn(
- unit, num_eval_workers, eval_cfg, train_future, num_units_done
- ):
- # Training is done and we already ran at least one
- # evaluation -> Nothing left to run.
- if num_units_done > 0 and train_future.done():
- return 0
- # Count by episodes. -> Run n more
- # (n=num eval workers).
- elif unit == "episodes":
- return num_eval_workers
- # Count by timesteps. -> Run n*m*p more
- # (n=num eval workers; m=rollout fragment length;
- # p=num-envs-per-worker).
- else:
- return (
- num_eval_workers
- * eval_cfg["rollout_fragment_length"]
- * eval_cfg["num_envs_per_worker"]
- )
- def _compile_iteration_results(
- self, *, episodes_this_iter, step_ctx, iteration_results=None
- ):
- # Return dict.
- results: ResultDict = {}
- iteration_results = iteration_results or {}
- # Evaluation results.
- if "evaluation" in iteration_results:
- results["evaluation"] = iteration_results.pop("evaluation")
- # Custom metrics and episode media.
- results["custom_metrics"] = iteration_results.pop("custom_metrics", {})
- results["episode_media"] = iteration_results.pop("episode_media", {})
- # Learner info.
- results["info"] = {LEARNER_INFO: iteration_results}
- # Calculate how many (if any) of older, historical episodes we have to add to
- # `episodes_this_iter` in order to reach the required smoothing window.
- episodes_for_metrics = episodes_this_iter[:]
- missing = self.config.metrics_num_episodes_for_smoothing - len(
- episodes_this_iter
- )
- # We have to add some older episodes to reach the smoothing window size.
- if missing > 0:
- episodes_for_metrics = self._episode_history[-missing:] + episodes_this_iter
- assert (
- len(episodes_for_metrics)
- <= self.config.metrics_num_episodes_for_smoothing
- )
- # Note that when there are more than `metrics_num_episodes_for_smoothing`
- # episodes in `episodes_for_metrics`, leave them as-is. In this case, we'll
- # compute the stats over that larger number.
- # Add new episodes to our history and make sure it doesn't grow larger than
- # needed.
- self._episode_history.extend(episodes_this_iter)
- self._episode_history = self._episode_history[
- -self.config.metrics_num_episodes_for_smoothing :
- ]
- results["sampler_results"] = summarize_episodes(
- episodes_for_metrics,
- episodes_this_iter,
- self.config.keep_per_episode_custom_metrics,
- )
- # TODO: Don't dump sampler results into top-level.
- results.update(results["sampler_results"])
- results["num_healthy_workers"] = self.workers.num_healthy_remote_workers()
- results["num_in_flight_async_reqs"] = self.workers.num_in_flight_async_reqs()
- results[
- "num_remote_worker_restarts"
- ] = self.workers.num_remote_worker_restarts()
- # Train-steps- and env/agent-steps this iteration.
- for c in [
- NUM_AGENT_STEPS_SAMPLED,
- NUM_AGENT_STEPS_TRAINED,
- NUM_ENV_STEPS_SAMPLED,
- NUM_ENV_STEPS_TRAINED,
- ]:
- results[c] = self._counters[c]
- time_taken_sec = step_ctx.get_time_taken_sec()
- if self.config.count_steps_by == "agent_steps":
- results[NUM_AGENT_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled
- results[NUM_AGENT_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
- results[NUM_AGENT_STEPS_SAMPLED + "_throughput_per_sec"] = (
- step_ctx.sampled / time_taken_sec
- )
- results[NUM_AGENT_STEPS_TRAINED + "_throughput_per_sec"] = (
- step_ctx.trained / time_taken_sec
- )
- # TODO: For CQL and other algos, count by trained steps.
- results["timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED]
- else:
- results[NUM_ENV_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled
- results[NUM_ENV_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
- results[NUM_ENV_STEPS_SAMPLED + "_throughput_per_sec"] = (
- step_ctx.sampled / time_taken_sec
- )
- results[NUM_ENV_STEPS_TRAINED + "_throughput_per_sec"] = (
- step_ctx.trained / time_taken_sec
- )
- # TODO: For CQL and other algos, count by trained steps.
- results["timesteps_total"] = self._counters[NUM_ENV_STEPS_SAMPLED]
- # TODO: Backward compatibility.
- results[STEPS_TRAINED_THIS_ITER_COUNTER] = step_ctx.trained
- results["agent_timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED]
- # Process timer results.
- timers = {}
- for k, timer in self._timers.items():
- timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
- if timer.has_units_processed():
- timers["{}_throughput".format(k)] = round(timer.mean_throughput, 3)
- results["timers"] = timers
- # Process counter results.
- counters = {}
- for k, counter in self._counters.items():
- counters[k] = counter
- results["counters"] = counters
- # TODO: Backward compatibility.
- results["info"].update(counters)
- return results
- def __repr__(self):
- return type(self).__name__
- def _record_usage(self, config):
- """Record the framework and algorithm used.
- Args:
- config: Algorithm config dict.
- """
- record_extra_usage_tag(TagKey.RLLIB_FRAMEWORK, config["framework"])
- record_extra_usage_tag(TagKey.RLLIB_NUM_WORKERS, str(config["num_workers"]))
- alg = self.__class__.__name__
- # We do not want to collect user defined algorithm names.
- if alg not in ALL_ALGORITHMS:
- alg = "USER_DEFINED"
- record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg)
- @Deprecated(new="AlgorithmConfig.validate()", error=True)
- def validate_config(self, config):
- pass
- # TODO: Create a dict that throw a deprecation warning once we have fully moved
- # to AlgorithmConfig() objects (some algos still missing).
- COMMON_CONFIG: AlgorithmConfigDict = AlgorithmConfig(Algorithm).to_dict()
- class TrainIterCtx:
- def __init__(self, algo: Algorithm):
- self.algo = algo
- self.time_start = None
- self.time_stop = None
- def __enter__(self):
- # Before first call to `step()`, `results` is expected to be None ->
- # Start with self.failures=-1 -> set to 0 before the very first call
- # to `self.step()`.
- self.failures = -1
- self.time_start = time.time()
- self.sampled = 0
- self.trained = 0
- self.init_env_steps_sampled = self.algo._counters[NUM_ENV_STEPS_SAMPLED]
- self.init_env_steps_trained = self.algo._counters[NUM_ENV_STEPS_TRAINED]
- self.init_agent_steps_sampled = self.algo._counters[NUM_AGENT_STEPS_SAMPLED]
- self.init_agent_steps_trained = self.algo._counters[NUM_AGENT_STEPS_TRAINED]
- self.failure_tolerance = self.algo.config[
- "num_consecutive_worker_failures_tolerance"
- ]
- return self
- def __exit__(self, *args):
- self.time_stop = time.time()
- def get_time_taken_sec(self) -> float:
- """Returns the time we spent in the context in seconds."""
- return self.time_stop - self.time_start
- def should_stop(self, results):
- # Before first call to `step()`.
- if results is None:
- # Fail after n retries.
- self.failures += 1
- if self.failures > self.failure_tolerance:
- raise RuntimeError(
- "More than `num_consecutive_worker_failures_tolerance="
- f"{self.failure_tolerance}` consecutive worker failures! "
- "Exiting."
- )
- # Continue to very first `step()` call or retry `step()` after
- # a (tolerable) failure.
- return False
- # Stopping criteria.
- elif self.algo.config._disable_execution_plan_api:
- if self.algo.config.count_steps_by == "agent_steps":
- self.sampled = (
- self.algo._counters[NUM_AGENT_STEPS_SAMPLED]
- - self.init_agent_steps_sampled
- )
- self.trained = (
- self.algo._counters[NUM_AGENT_STEPS_TRAINED]
- - self.init_agent_steps_trained
- )
- else:
- self.sampled = (
- self.algo._counters[NUM_ENV_STEPS_SAMPLED]
- - self.init_env_steps_sampled
- )
- self.trained = (
- self.algo._counters[NUM_ENV_STEPS_TRAINED]
- - self.init_env_steps_trained
- )
- min_t = self.algo.config["min_time_s_per_iteration"]
- min_sample_ts = self.algo.config["min_sample_timesteps_per_iteration"]
- min_train_ts = self.algo.config["min_train_timesteps_per_iteration"]
- # Repeat if not enough time has passed or if not enough
- # env|train timesteps have been processed (or these min
- # values are not provided by the user).
- if (
- (not min_t or time.time() - self.time_start >= min_t)
- and (not min_sample_ts or self.sampled >= min_sample_ts)
- and (not min_train_ts or self.trained >= min_train_ts)
- ):
- return True
- else:
- return False
- # No errors (we got results != None) -> Return True
- # (meaning: yes, should stop -> no further step attempts).
- else:
- return True
|