algorithm.py 134 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209
  1. from collections import defaultdict
  2. import concurrent
  3. import copy
  4. from datetime import datetime
  5. import functools
  6. import gymnasium as gym
  7. import importlib
  8. import json
  9. import logging
  10. import numpy as np
  11. import os
  12. from packaging import version
  13. import pkg_resources
  14. import re
  15. import tempfile
  16. import time
  17. import tree # pip install dm_tree
  18. from typing import (
  19. Callable,
  20. Container,
  21. DefaultDict,
  22. Dict,
  23. List,
  24. Optional,
  25. Set,
  26. Tuple,
  27. Type,
  28. Union,
  29. )
  30. import ray
  31. from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
  32. from ray.actor import ActorHandle
  33. from ray.train import Checkpoint
  34. import ray.cloudpickle as pickle
  35. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  36. from ray.rllib.algorithms.registry import ALGORITHMS_CLASS_TO_NAME as ALL_ALGORITHMS
  37. from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
  38. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  39. from ray.rllib.env.env_context import EnvContext
  40. from ray.rllib.env.utils import _gym_env_creator
  41. from ray.rllib.evaluation.episode import Episode
  42. from ray.rllib.evaluation.metrics import (
  43. collect_episodes,
  44. collect_metrics,
  45. summarize_episodes,
  46. )
  47. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  48. from ray.rllib.evaluation.worker_set import WorkerSet
  49. from ray.rllib.execution.common import (
  50. STEPS_TRAINED_THIS_ITER_COUNTER, # TODO: Backward compatibility.
  51. )
  52. from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
  53. from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
  54. from ray.rllib.offline import get_dataset_and_shards
  55. from ray.rllib.offline.estimators import (
  56. OffPolicyEstimator,
  57. ImportanceSampling,
  58. WeightedImportanceSampling,
  59. DirectMethod,
  60. DoublyRobust,
  61. )
  62. from ray.rllib.offline.offline_evaluator import OfflineEvaluator
  63. from ray.rllib.policy.policy import Policy
  64. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, concat_samples
  65. from ray.rllib.utils import deep_update, FilterManager
  66. from ray.rllib.utils.annotations import (
  67. DeveloperAPI,
  68. ExperimentalAPI,
  69. OverrideToImplementCustomLogic,
  70. OverrideToImplementCustomLogic_CallToSuperRecommended,
  71. PublicAPI,
  72. override,
  73. )
  74. from ray.rllib.utils.checkpoints import (
  75. CHECKPOINT_VERSION,
  76. CHECKPOINT_VERSION_LEARNER,
  77. get_checkpoint_info,
  78. try_import_msgpack,
  79. )
  80. from ray.rllib.utils.debug import update_global_seed_if_necessary
  81. from ray.rllib.utils.deprecation import (
  82. DEPRECATED_VALUE,
  83. Deprecated,
  84. deprecation_warning,
  85. )
  86. from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
  87. from ray.rllib.utils.framework import try_import_tf
  88. from ray.rllib.utils.from_config import from_config
  89. from ray.rllib.utils.metrics import (
  90. NUM_AGENT_STEPS_SAMPLED,
  91. NUM_AGENT_STEPS_SAMPLED_THIS_ITER,
  92. NUM_AGENT_STEPS_TRAINED,
  93. NUM_ENV_STEPS_SAMPLED,
  94. NUM_ENV_STEPS_SAMPLED_THIS_ITER,
  95. NUM_ENV_STEPS_TRAINED,
  96. SYNCH_WORKER_WEIGHTS_TIMER,
  97. TRAINING_ITERATION_TIMER,
  98. SAMPLE_TIMER,
  99. )
  100. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
  101. from ray.rllib.utils.policy import validate_policy_id
  102. from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer, ReplayBuffer
  103. from ray.rllib.utils.serialization import deserialize_type, NOT_SERIALIZABLE
  104. from ray.rllib.utils.spaces import space_utils
  105. from ray.rllib.utils.typing import (
  106. AgentConnectorDataType,
  107. AgentID,
  108. AlgorithmConfigDict,
  109. EnvCreator,
  110. EnvInfoDict,
  111. EnvType,
  112. EpisodeID,
  113. PartialAlgorithmConfigDict,
  114. PolicyID,
  115. PolicyState,
  116. ResultDict,
  117. SampleBatchType,
  118. TensorStructType,
  119. TensorType,
  120. )
  121. from ray.tune.execution.placement_groups import PlacementGroupFactory
  122. from ray.tune.experiment.trial import ExportFormat
  123. from ray.tune.logger import Logger, UnifiedLogger
  124. from ray.tune.registry import ENV_CREATOR, _global_registry
  125. from ray.tune.resources import Resources
  126. from ray.tune.result import DEFAULT_RESULTS_DIR
  127. from ray.tune.trainable import Trainable
  128. from ray.util import log_once
  129. from ray.util.timer import _Timer
  130. from ray.tune.registry import get_trainable_cls
  131. try:
  132. from ray.rllib.extensions import AlgorithmBase
  133. except ImportError:
  134. class AlgorithmBase:
  135. @staticmethod
  136. def _get_learner_bundles(cf: AlgorithmConfig) -> List[Dict[str, int]]:
  137. """Selects the right resource bundles for learner workers based off of cf.
  138. Args:
  139. cf: The algorithm config.
  140. Returns:
  141. A list of resource bundles for the learner workers.
  142. """
  143. if cf.num_learner_workers > 0:
  144. if cf.num_gpus_per_learner_worker:
  145. learner_bundles = [
  146. {"GPU": cf.num_learner_workers * cf.num_gpus_per_learner_worker}
  147. ]
  148. elif cf.num_cpus_per_learner_worker:
  149. learner_bundles = [
  150. {
  151. "CPU": cf.num_cpus_per_learner_worker
  152. * cf.num_learner_workers,
  153. }
  154. ]
  155. else:
  156. learner_bundles = [
  157. {
  158. # sampling and training is not done concurrently when local is
  159. # used, so pick the max.
  160. "CPU": max(
  161. cf.num_cpus_per_learner_worker, cf.num_cpus_for_local_worker
  162. ),
  163. "GPU": cf.num_gpus_per_learner_worker,
  164. }
  165. ]
  166. return learner_bundles
  167. tf1, tf, tfv = try_import_tf()
  168. logger = logging.getLogger(__name__)
  169. @Deprecated(
  170. new="config = AlgorithmConfig().update_from_dict({'a': 1, 'b': 2}); ... ; "
  171. "print(config.lr) -> 0.001; if config.a > 0: [do something];",
  172. error=True,
  173. )
  174. def with_common_config(*args, **kwargs):
  175. pass
  176. @PublicAPI
  177. class Algorithm(Trainable, AlgorithmBase):
  178. """An RLlib algorithm responsible for optimizing one or more Policies.
  179. Algorithms contain a WorkerSet under `self.workers`. A WorkerSet is
  180. normally composed of a single local worker
  181. (self.workers.local_worker()), used to compute and apply learning updates,
  182. and optionally one or more remote workers used to generate environment
  183. samples in parallel.
  184. WorkerSet is fault tolerant and elastic. It tracks health states for all
  185. the managed remote worker actors. As a result, Algorithm should never
  186. access the underlying actor handles directly. Instead, always access them
  187. via all the foreach APIs with assigned IDs of the underlying workers.
  188. Each worker (remotes or local) contains a PolicyMap, which itself
  189. may contain either one policy for single-agent training or one or more
  190. policies for multi-agent training. Policies are synchronized
  191. automatically from time to time using ray.remote calls. The exact
  192. synchronization logic depends on the specific algorithm used,
  193. but this usually happens from local worker to all remote workers and
  194. after each training update.
  195. You can write your own Algorithm classes by sub-classing from `Algorithm`
  196. or any of its built-in sub-classes.
  197. This allows you to override the `training_step` method to implement
  198. your own algorithm logic. You can find the different built-in
  199. algorithms' `training_step()` methods in their respective main .py files,
  200. e.g. rllib.algorithms.dqn.dqn.py or rllib.algorithms.impala.impala.py.
  201. The most important API methods a Algorithm exposes are `train()`,
  202. `evaluate()`, `save()` and `restore()`.
  203. """
  204. # Whether to allow unknown top-level config keys.
  205. _allow_unknown_configs = False
  206. # List of top-level keys with value=dict, for which new sub-keys are
  207. # allowed to be added to the value dict.
  208. _allow_unknown_subkeys = [
  209. "tf_session_args",
  210. "local_tf_session_args",
  211. "env_config",
  212. "model",
  213. "optimizer",
  214. "custom_resources_per_worker",
  215. "evaluation_config",
  216. "exploration_config",
  217. "replay_buffer_config",
  218. "extra_python_environs_for_worker",
  219. "input_config",
  220. "output_config",
  221. ]
  222. # List of top level keys with value=dict, for which we always override the
  223. # entire value (dict), iff the "type" key in that value dict changes.
  224. _override_all_subkeys_if_type_changes = [
  225. "exploration_config",
  226. "replay_buffer_config",
  227. ]
  228. # List of keys that are always fully overridden if present in any dict or sub-dict
  229. _override_all_key_list = ["off_policy_estimation_methods", "policies"]
  230. _progress_metrics = (
  231. "num_env_steps_sampled",
  232. "num_env_steps_trained",
  233. "episodes_total",
  234. "sampler_results/episode_len_mean",
  235. "sampler_results/episode_reward_mean",
  236. "evaluation/sampler_results/episode_reward_mean",
  237. )
  238. @staticmethod
  239. def from_checkpoint(
  240. checkpoint: Union[str, Checkpoint],
  241. policy_ids: Optional[Container[PolicyID]] = None,
  242. policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
  243. policies_to_train: Optional[
  244. Union[
  245. Container[PolicyID],
  246. Callable[[PolicyID, Optional[SampleBatchType]], bool],
  247. ]
  248. ] = None,
  249. ) -> "Algorithm":
  250. """Creates a new algorithm instance from a given checkpoint.
  251. Note: This method must remain backward compatible from 2.0.0 on.
  252. Args:
  253. checkpoint: The path (str) to the checkpoint directory to use
  254. or an AIR Checkpoint instance to restore from.
  255. policy_ids: Optional list of PolicyIDs to recover. This allows users to
  256. restore an Algorithm with only a subset of the originally present
  257. Policies.
  258. policy_mapping_fn: An optional (updated) policy mapping function
  259. to use from here on.
  260. policies_to_train: An optional list of policy IDs to be trained
  261. or a callable taking PolicyID and SampleBatchType and
  262. returning a bool (trainable or not?).
  263. If None, will keep the existing setup in place. Policies,
  264. whose IDs are not in the list (or for which the callable
  265. returns False) will not be updated.
  266. Returns:
  267. The instantiated Algorithm.
  268. """
  269. checkpoint_info = get_checkpoint_info(checkpoint)
  270. # Not possible for (v0.1) (algo class and config information missing
  271. # or very hard to retrieve).
  272. if checkpoint_info["checkpoint_version"] == version.Version("0.1"):
  273. raise ValueError(
  274. "Cannot restore a v0 checkpoint using `Algorithm.from_checkpoint()`!"
  275. "In this case, do the following:\n"
  276. "1) Create a new Algorithm object using your original config.\n"
  277. "2) Call the `restore()` method of this algo object passing it"
  278. " your checkpoint dir or AIR Checkpoint object."
  279. )
  280. elif checkpoint_info["checkpoint_version"] < version.Version("1.0"):
  281. raise ValueError(
  282. "`checkpoint_info['checkpoint_version']` in `Algorithm.from_checkpoint"
  283. "()` must be 1.0 or later! You are using a checkpoint with "
  284. f"version v{checkpoint_info['checkpoint_version']}."
  285. )
  286. # This is a msgpack checkpoint.
  287. if checkpoint_info["format"] == "msgpack":
  288. # User did not provide unserializable function with this call
  289. # (`policy_mapping_fn`). Note that if `policies_to_train` is None, it
  290. # defaults to training all policies (so it's ok to not provide this here).
  291. if policy_mapping_fn is None:
  292. # Only DEFAULT_POLICY_ID present in this algorithm, provide default
  293. # implementations of these two functions.
  294. if checkpoint_info["policy_ids"] == {DEFAULT_POLICY_ID}:
  295. policy_mapping_fn = AlgorithmConfig.DEFAULT_POLICY_MAPPING_FN
  296. # Provide meaningful error message.
  297. else:
  298. raise ValueError(
  299. "You are trying to restore a multi-agent algorithm from a "
  300. "`msgpack` formatted checkpoint, which do NOT store the "
  301. "`policy_mapping_fn` or `policies_to_train` "
  302. "functions! Make sure that when using the "
  303. "`Algorithm.from_checkpoint()` utility, you also pass the "
  304. "args: `policy_mapping_fn` and `policies_to_train` with your "
  305. "call. You might leave `policies_to_train=None` in case "
  306. "you would like to train all policies anyways."
  307. )
  308. state = Algorithm._checkpoint_info_to_algorithm_state(
  309. checkpoint_info=checkpoint_info,
  310. policy_ids=policy_ids,
  311. policy_mapping_fn=policy_mapping_fn,
  312. policies_to_train=policies_to_train,
  313. )
  314. return Algorithm.from_state(state)
  315. @staticmethod
  316. def from_state(state: Dict) -> "Algorithm":
  317. """Recovers an Algorithm from a state object.
  318. The `state` of an instantiated Algorithm can be retrieved by calling its
  319. `get_state` method. It contains all information necessary
  320. to create the Algorithm from scratch. No access to the original code (e.g.
  321. configs, knowledge of the Algorithm's class, etc..) is needed.
  322. Args:
  323. state: The state to recover a new Algorithm instance from.
  324. Returns:
  325. A new Algorithm instance.
  326. """
  327. algorithm_class: Type[Algorithm] = state.get("algorithm_class")
  328. if algorithm_class is None:
  329. raise ValueError(
  330. "No `algorithm_class` key was found in given `state`! "
  331. "Cannot create new Algorithm."
  332. )
  333. # algo_class = get_trainable_cls(algo_class_name)
  334. # Create the new algo.
  335. config = state.get("config")
  336. if not config:
  337. raise ValueError("No `config` found in given Algorithm state!")
  338. new_algo = algorithm_class(config=config)
  339. # Set the new algo's state.
  340. new_algo.__setstate__(state)
  341. # Return the new algo.
  342. return new_algo
  343. @PublicAPI
  344. def __init__(
  345. self,
  346. config: Optional[AlgorithmConfig] = None,
  347. env=None, # deprecated arg
  348. logger_creator: Optional[Callable[[], Logger]] = None,
  349. **kwargs,
  350. ):
  351. """Initializes an Algorithm instance.
  352. Args:
  353. config: Algorithm-specific configuration object.
  354. logger_creator: Callable that creates a ray.tune.Logger
  355. object. If unspecified, a default logger is created.
  356. **kwargs: Arguments passed to the Trainable base class.
  357. """
  358. config = config or self.get_default_config()
  359. # Translate possible dict into an AlgorithmConfig object, as well as,
  360. # resolving generic config objects into specific ones (e.g. passing
  361. # an `AlgorithmConfig` super-class instance into a PPO constructor,
  362. # which normally would expect a PPOConfig object).
  363. if isinstance(config, dict):
  364. default_config = self.get_default_config()
  365. # `self.get_default_config()` also returned a dict ->
  366. # Last resort: Create core AlgorithmConfig from merged dicts.
  367. if isinstance(default_config, dict):
  368. config = AlgorithmConfig.from_dict(
  369. config_dict=self.merge_algorithm_configs(
  370. default_config, config, True
  371. )
  372. )
  373. # Default config is an AlgorithmConfig -> update its properties
  374. # from the given config dict.
  375. else:
  376. config = default_config.update_from_dict(config)
  377. else:
  378. default_config = self.get_default_config()
  379. # Given AlgorithmConfig is not of the same type as the default config:
  380. # This could be the case e.g. if the user is building an algo from a
  381. # generic AlgorithmConfig() object.
  382. if not isinstance(config, type(default_config)):
  383. config = default_config.update_from_dict(config.to_dict())
  384. # In case this algo is using a generic config (with no algo_class set), set it
  385. # here.
  386. if config.algo_class is None:
  387. config.algo_class = type(self)
  388. if env is not None:
  389. deprecation_warning(
  390. old=f"algo = Algorithm(env='{env}', ...)",
  391. new=f"algo = AlgorithmConfig().environment('{env}').build()",
  392. error=False,
  393. )
  394. config.environment(env)
  395. # Validate and freeze our AlgorithmConfig object (no more changes possible).
  396. config.validate()
  397. config.freeze()
  398. # Convert `env` provided in config into a concrete env creator callable, which
  399. # takes an EnvContext (config dict) as arg and returning an RLlib supported Env
  400. # type (e.g. a gym.Env).
  401. self._env_id, self.env_creator = self._get_env_id_and_creator(
  402. config.env, config
  403. )
  404. env_descr = (
  405. self._env_id.__name__ if isinstance(self._env_id, type) else self._env_id
  406. )
  407. # Placeholder for a local replay buffer instance.
  408. self.local_replay_buffer = None
  409. # Create a default logger creator if no logger_creator is specified
  410. if logger_creator is None:
  411. # Default logdir prefix containing the agent's name and the
  412. # env id.
  413. timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
  414. env_descr_for_dir = re.sub("[/\\\\]", "-", str(env_descr))
  415. logdir_prefix = f"{str(self)}_{env_descr_for_dir}_{timestr}"
  416. if not os.path.exists(DEFAULT_RESULTS_DIR):
  417. # Possible race condition if dir is created several times on
  418. # rollout workers
  419. os.makedirs(DEFAULT_RESULTS_DIR, exist_ok=True)
  420. logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
  421. # Allow users to more precisely configure the created logger
  422. # via "logger_config.type".
  423. if config.logger_config and "type" in config.logger_config:
  424. def default_logger_creator(config):
  425. """Creates a custom logger with the default prefix."""
  426. cfg = config["logger_config"].copy()
  427. cls = cfg.pop("type")
  428. # Provide default for logdir, in case the user does
  429. # not specify this in the "logger_config" dict.
  430. logdir_ = cfg.pop("logdir", logdir)
  431. return from_config(cls=cls, _args=[cfg], logdir=logdir_)
  432. # If no `type` given, use tune's UnifiedLogger as last resort.
  433. else:
  434. def default_logger_creator(config):
  435. """Creates a Unified logger with the default prefix."""
  436. return UnifiedLogger(config, logdir, loggers=None)
  437. logger_creator = default_logger_creator
  438. # Metrics-related properties.
  439. self._timers = defaultdict(_Timer)
  440. self._counters = defaultdict(int)
  441. self._episode_history = []
  442. self._episodes_to_be_collected = []
  443. # The fully qualified AlgorithmConfig used for evaluation
  444. # (or None if evaluation not setup).
  445. self.evaluation_config: Optional[AlgorithmConfig] = None
  446. # Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
  447. self.evaluation_workers: Optional[WorkerSet] = None
  448. # Initialize common evaluation_metrics to nan, before they become
  449. # available. We want to make sure the metrics are always present
  450. # (although their values may be nan), so that Tune does not complain
  451. # when we use these as stopping criteria.
  452. self.evaluation_metrics = {
  453. # TODO: Don't dump sampler results into top-level.
  454. "evaluation": {
  455. "episode_reward_max": np.nan,
  456. "episode_reward_min": np.nan,
  457. "episode_reward_mean": np.nan,
  458. "sampler_results": {
  459. "episode_reward_max": np.nan,
  460. "episode_reward_min": np.nan,
  461. "episode_reward_mean": np.nan,
  462. },
  463. },
  464. }
  465. super().__init__(
  466. config=config,
  467. logger_creator=logger_creator,
  468. **kwargs,
  469. )
  470. # Check, whether `training_iteration` is still a tune.Trainable property
  471. # and has not been overridden by the user in the attempt to implement the
  472. # algos logic (this should be done now inside `training_step`).
  473. try:
  474. assert isinstance(self.training_iteration, int)
  475. except AssertionError:
  476. raise AssertionError(
  477. "Your Algorithm's `training_iteration` seems to be overridden by your "
  478. "custom training logic! To solve this problem, simply rename your "
  479. "`self.training_iteration()` method into `self.training_step`."
  480. )
  481. @OverrideToImplementCustomLogic
  482. @classmethod
  483. def get_default_config(cls) -> AlgorithmConfig:
  484. return AlgorithmConfig()
  485. @OverrideToImplementCustomLogic
  486. def _remote_worker_ids_for_metrics(self) -> List[int]:
  487. """Returns a list of remote worker IDs to fetch metrics from.
  488. Specific Algorithm implementations can override this method to
  489. use a subset of the workers for metrics collection.
  490. Returns:
  491. List of remote worker IDs to fetch metrics from.
  492. """
  493. return self.workers.healthy_worker_ids()
  494. @OverrideToImplementCustomLogic_CallToSuperRecommended
  495. @override(Trainable)
  496. def setup(self, config: AlgorithmConfig) -> None:
  497. # Setup our config: Merge the user-supplied config dict (which could
  498. # be a partial config dict) with the class' default.
  499. if not isinstance(config, AlgorithmConfig):
  500. assert isinstance(config, PartialAlgorithmConfigDict)
  501. config_obj = self.get_default_config()
  502. if not isinstance(config_obj, AlgorithmConfig):
  503. assert isinstance(config, PartialAlgorithmConfigDict)
  504. config_obj = AlgorithmConfig().from_dict(config_obj)
  505. config_obj.update_from_dict(config)
  506. config_obj.env = self._env_id
  507. self.config = config_obj
  508. # Set Algorithm's seed after we have - if necessary - enabled
  509. # tf eager-execution.
  510. update_global_seed_if_necessary(self.config.framework_str, self.config.seed)
  511. self._record_usage(self.config)
  512. # Create the callbacks object.
  513. self.callbacks = self.config.callbacks_class()
  514. if self.config.log_level in ["WARN", "ERROR"]:
  515. logger.info(
  516. f"Current log_level is {self.config.log_level}. For more information, "
  517. "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
  518. "-vv flags."
  519. )
  520. if self.config.log_level:
  521. logging.getLogger("ray.rllib").setLevel(self.config.log_level)
  522. # Create local replay buffer if necessary.
  523. self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
  524. self.config
  525. )
  526. # Create a dict, mapping ActorHandles to sets of open remote
  527. # requests (object refs). This way, we keep track, of which actors
  528. # inside this Algorithm (e.g. a remote RolloutWorker) have
  529. # already been sent how many (e.g. `sample()`) requests.
  530. self.remote_requests_in_flight: DefaultDict[
  531. ActorHandle, Set[ray.ObjectRef]
  532. ] = defaultdict(set)
  533. self.workers: Optional[WorkerSet] = None
  534. self.train_exec_impl = None
  535. # Offline RL settings.
  536. input_evaluation = self.config.get("input_evaluation")
  537. if input_evaluation is not None and input_evaluation is not DEPRECATED_VALUE:
  538. ope_dict = {str(ope): {"type": ope} for ope in input_evaluation}
  539. deprecation_warning(
  540. old="config.input_evaluation={}".format(input_evaluation),
  541. new="config.evaluation(evaluation_config=config.overrides("
  542. f"off_policy_estimation_methods={ope_dict}"
  543. "))",
  544. error=True,
  545. help="Running OPE during training is not recommended.",
  546. )
  547. self.config.off_policy_estimation_methods = ope_dict
  548. # Deprecated way of implementing Algorithm sub-classes (or "templates"
  549. # via the `build_trainer` utility function).
  550. # Instead, sub-classes should override the Trainable's `setup()`
  551. # method and call super().setup() from within that override at some
  552. # point.
  553. # Old design: Override `Algorithm._init`.
  554. _init = False
  555. try:
  556. self._init(self.config, self.env_creator)
  557. _init = True
  558. # New design: Override `Algorithm.setup()` (as indented by tune.Trainable)
  559. # and do or don't call `super().setup()` from within your override.
  560. # By default, `super().setup()` will create both worker sets:
  561. # "rollout workers" for collecting samples for training and - if
  562. # applicable - "evaluation workers" for evaluation runs in between or
  563. # parallel to training.
  564. # TODO: Deprecate `_init()` and remove this try/except block.
  565. except NotImplementedError:
  566. pass
  567. # Only if user did not override `_init()`:
  568. if _init is False:
  569. # Create a set of env runner actors via a WorkerSet.
  570. self.workers = WorkerSet(
  571. env_creator=self.env_creator,
  572. validate_env=self.validate_env,
  573. default_policy_class=self.get_default_policy_class(self.config),
  574. config=self.config,
  575. num_workers=self.config.num_rollout_workers,
  576. local_worker=True,
  577. logdir=self.logdir,
  578. )
  579. # TODO (avnishn): Remove the execution plan API by q1 2023
  580. # Function defining one single training iteration's behavior.
  581. if self.config._disable_execution_plan_api:
  582. # Ensure remote workers are initially in sync with the local worker.
  583. self.workers.sync_weights()
  584. # LocalIterator-creating "execution plan".
  585. # Only call this once here to create `self.train_exec_impl`,
  586. # which is a ray.util.iter.LocalIterator that will be `next`'d
  587. # on each training iteration.
  588. else:
  589. self.train_exec_impl = self.execution_plan(
  590. self.workers, self.config, **self._kwargs_for_execution_plan()
  591. )
  592. # Compile, validate, and freeze an evaluation config.
  593. self.evaluation_config = self.config.get_evaluation_config_object()
  594. self.evaluation_config.validate()
  595. self.evaluation_config.freeze()
  596. # Evaluation WorkerSet setup.
  597. # User would like to setup a separate evaluation worker set.
  598. # Note: We skip workerset creation if we need to do offline evaluation
  599. if self._should_create_evaluation_rollout_workers(self.evaluation_config):
  600. _, env_creator = self._get_env_id_and_creator(
  601. self.evaluation_config.env, self.evaluation_config
  602. )
  603. # Create a separate evaluation worker set for evaluation.
  604. # If evaluation_num_workers=0, use the evaluation set's local
  605. # worker for evaluation, otherwise, use its remote workers
  606. # (parallelized evaluation).
  607. self.evaluation_workers: WorkerSet = WorkerSet(
  608. env_creator=env_creator,
  609. validate_env=None,
  610. default_policy_class=self.get_default_policy_class(self.config),
  611. config=self.evaluation_config,
  612. num_workers=self.config.evaluation_num_workers,
  613. logdir=self.logdir,
  614. )
  615. if self.config.enable_async_evaluation:
  616. self._evaluation_weights_seq_number = 0
  617. self.evaluation_dataset = None
  618. if (
  619. self.evaluation_config.off_policy_estimation_methods
  620. and not self.evaluation_config.ope_split_batch_by_episode
  621. ):
  622. # the num worker is set to 0 to avoid creating shards. The dataset will not
  623. # be repartioned to num_workers blocks.
  624. logger.info("Creating evaluation dataset ...")
  625. self.evaluation_dataset, _ = get_dataset_and_shards(
  626. self.evaluation_config, num_workers=0
  627. )
  628. logger.info("Evaluation dataset created")
  629. self.reward_estimators: Dict[str, OffPolicyEstimator] = {}
  630. ope_types = {
  631. "is": ImportanceSampling,
  632. "wis": WeightedImportanceSampling,
  633. "dm": DirectMethod,
  634. "dr": DoublyRobust,
  635. }
  636. for name, method_config in self.config.off_policy_estimation_methods.items():
  637. method_type = method_config.pop("type")
  638. if method_type in ope_types:
  639. deprecation_warning(
  640. old=method_type,
  641. new=str(ope_types[method_type]),
  642. error=True,
  643. )
  644. method_type = ope_types[method_type]
  645. elif isinstance(method_type, str):
  646. logger.log(0, "Trying to import from string: " + method_type)
  647. mod, obj = method_type.rsplit(".", 1)
  648. mod = importlib.import_module(mod)
  649. method_type = getattr(mod, obj)
  650. if isinstance(method_type, type) and issubclass(
  651. method_type, OfflineEvaluator
  652. ):
  653. # TODO(kourosh) : Add an integration test for all these
  654. # offline evaluators.
  655. policy = self.get_policy()
  656. if issubclass(method_type, OffPolicyEstimator):
  657. method_config["gamma"] = self.config.gamma
  658. self.reward_estimators[name] = method_type(policy, **method_config)
  659. else:
  660. raise ValueError(
  661. f"Unknown off_policy_estimation type: {method_type}! Must be "
  662. "either a class path or a sub-class of ray.rllib."
  663. "offline.offline_evaluator::OfflineEvaluator"
  664. )
  665. # TODO (Rohan138): Refactor this and remove deprecated methods
  666. # Need to add back method_type in case Algorithm is restored from checkpoint
  667. method_config["type"] = method_type
  668. self.learner_group = None
  669. if self.config._enable_learner_api:
  670. # TODO (Kourosh): This is an interim solution where policies and modules
  671. # co-exist. In this world we have both policy_map and MARLModule that need
  672. # to be consistent with one another. To make a consistent parity between
  673. # the two we need to loop through the policy modules and create a simple
  674. # MARLModule from the RLModule within each policy.
  675. local_worker = self.workers.local_worker()
  676. policy_dict, _ = self.config.get_multi_agent_setup(
  677. env=local_worker.env,
  678. spaces=getattr(local_worker, "spaces", None),
  679. )
  680. # TODO (Sven): Unify the inference of the MARLModuleSpec. Right now,
  681. # we get this from the RolloutWorker's `marl_module_spec` property.
  682. # However, this is hacky (information leak) and should not remain this
  683. # way. For other EnvRunner classes (that don't have this property),
  684. # Algorithm should infer this itself.
  685. if hasattr(local_worker, "marl_module_spec"):
  686. module_spec = local_worker.marl_module_spec
  687. else:
  688. module_spec = self.config.get_marl_module_spec(policy_dict=policy_dict)
  689. learner_group_config = self.config.get_learner_group_config(module_spec)
  690. self.learner_group = learner_group_config.build()
  691. # check if there are modules to load from the module_spec
  692. rl_module_ckpt_dirs = {}
  693. marl_module_ckpt_dir = module_spec.load_state_path
  694. modules_to_load = module_spec.modules_to_load
  695. for module_id, sub_module_spec in module_spec.module_specs.items():
  696. if sub_module_spec.load_state_path:
  697. rl_module_ckpt_dirs[module_id] = sub_module_spec.load_state_path
  698. if marl_module_ckpt_dir or rl_module_ckpt_dirs:
  699. self.learner_group.load_module_state(
  700. marl_module_ckpt_dir=marl_module_ckpt_dir,
  701. modules_to_load=modules_to_load,
  702. rl_module_ckpt_dirs=rl_module_ckpt_dirs,
  703. )
  704. # sync the weights from the learner group to the rollout workers
  705. weights = self.learner_group.get_weights()
  706. local_worker.set_weights(weights)
  707. self.workers.sync_weights()
  708. # Run `on_algorithm_init` callback after initialization is done.
  709. self.callbacks.on_algorithm_init(algorithm=self)
  710. # TODO: Deprecated: In your sub-classes of Algorithm, override `setup()`
  711. # directly and call super().setup() from within it if you would like the
  712. # default setup behavior plus some own setup logic.
  713. # If you don't need the env/workers/config/etc.. setup for you by super,
  714. # simply do not call super().setup() from your overridden method.
  715. def _init(self, config: AlgorithmConfigDict, env_creator: EnvCreator) -> None:
  716. raise NotImplementedError
  717. @OverrideToImplementCustomLogic
  718. @classmethod
  719. def get_default_policy_class(
  720. cls,
  721. config: AlgorithmConfig,
  722. ) -> Optional[Type[Policy]]:
  723. """Returns a default Policy class to use, given a config.
  724. This class will be used by an Algorithm in case
  725. the policy class is not provided by the user in any single- or
  726. multi-agent PolicySpec.
  727. Note: This method is ignored when the RLModule API is enabled.
  728. """
  729. return None
  730. @override(Trainable)
  731. def step(self) -> ResultDict:
  732. """Implements the main `Algorithm.train()` logic.
  733. Takes n attempts to perform a single training step. Thereby
  734. catches RayErrors resulting from worker failures. After n attempts,
  735. fails gracefully.
  736. Override this method in your Algorithm sub-classes if you would like to
  737. handle worker failures yourself.
  738. Otherwise, override only `training_step()` to implement the core
  739. algorithm logic.
  740. Returns:
  741. The results dict with stats/infos on sampling, training,
  742. and - if required - evaluation.
  743. """
  744. # Do we have to run `self.evaluate()` this iteration?
  745. # `self.iteration` gets incremented after this function returns,
  746. # meaning that e. g. the first time this function is called,
  747. # self.iteration will be 0.
  748. evaluate_this_iter = (
  749. self.config.evaluation_interval is not None
  750. and (self.iteration + 1) % self.config.evaluation_interval == 0
  751. )
  752. # Results dict for training (and if appolicable: evaluation).
  753. results: ResultDict = {}
  754. # Parallel eval + training: Kick off evaluation-loop and parallel train() call.
  755. if evaluate_this_iter and self.config.evaluation_parallel_to_training:
  756. (
  757. results,
  758. train_iter_ctx,
  759. ) = self._run_one_training_iteration_and_evaluation_in_parallel()
  760. # - No evaluation necessary, just run the next training iteration.
  761. # - We have to evaluate in this training iteration, but no parallelism ->
  762. # evaluate after the training iteration is entirely done.
  763. else:
  764. results, train_iter_ctx = self._run_one_training_iteration()
  765. # Sequential: Train (already done above), then evaluate.
  766. if evaluate_this_iter and not self.config.evaluation_parallel_to_training:
  767. results.update(self._run_one_evaluation(train_future=None))
  768. # Attach latest available evaluation results to train results,
  769. # if necessary.
  770. if not evaluate_this_iter and self.config.always_attach_evaluation_results:
  771. assert isinstance(
  772. self.evaluation_metrics, dict
  773. ), "Algorithm.evaluate() needs to return a dict."
  774. results.update(self.evaluation_metrics)
  775. if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
  776. # Sync filters on workers.
  777. self._sync_filters_if_needed(
  778. central_worker=self.workers.local_worker(),
  779. workers=self.workers,
  780. config=self.config,
  781. )
  782. # TODO (avnishn): Remove the execution plan API by q1 2023
  783. # Collect worker metrics and add combine them with `results`.
  784. if self.config._disable_execution_plan_api:
  785. episodes_this_iter = collect_episodes(
  786. self.workers,
  787. self._remote_worker_ids_for_metrics(),
  788. timeout_seconds=self.config.metrics_episode_collection_timeout_s,
  789. )
  790. results = self._compile_iteration_results(
  791. episodes_this_iter=episodes_this_iter,
  792. step_ctx=train_iter_ctx,
  793. iteration_results=results,
  794. )
  795. # Check `env_task_fn` for possible update of the env's task.
  796. if self.config.env_task_fn is not None:
  797. if not callable(self.config.env_task_fn):
  798. raise ValueError(
  799. "`env_task_fn` must be None or a callable taking "
  800. "[train_results, env, env_ctx] as args!"
  801. )
  802. def fn(env, env_context, task_fn):
  803. new_task = task_fn(results, env, env_context)
  804. cur_task = env.get_task()
  805. if cur_task != new_task:
  806. env.set_task(new_task)
  807. fn = functools.partial(fn, task_fn=self.config.env_task_fn)
  808. self.workers.foreach_env_with_context(fn)
  809. return results
  810. @PublicAPI
  811. def evaluate(
  812. self,
  813. duration_fn: Optional[Callable[[int], int]] = None,
  814. ) -> dict:
  815. """Evaluates current policy under `evaluation_config` settings.
  816. Args:
  817. duration_fn: An optional callable taking the already run
  818. num episodes as only arg and returning the number of
  819. episodes left to run. It's used to find out whether
  820. evaluation should continue.
  821. """
  822. # Call the `_before_evaluate` hook.
  823. self._before_evaluate()
  824. if self.evaluation_dataset is not None:
  825. return {"evaluation": self._run_offline_evaluation()}
  826. # Sync weights to the evaluation WorkerSet.
  827. if self.evaluation_workers is not None:
  828. self.evaluation_workers.sync_weights(
  829. from_worker_or_learner_group=self.workers.local_worker()
  830. )
  831. self._sync_filters_if_needed(
  832. central_worker=self.workers.local_worker(),
  833. workers=self.evaluation_workers,
  834. config=self.evaluation_config,
  835. )
  836. self.callbacks.on_evaluate_start(algorithm=self)
  837. if self.config.custom_evaluation_function:
  838. logger.info(
  839. "Running custom eval function {}".format(
  840. self.config.custom_evaluation_function
  841. )
  842. )
  843. metrics = self.config.custom_evaluation_function(
  844. self, self.evaluation_workers
  845. )
  846. if not metrics or not isinstance(metrics, dict):
  847. raise ValueError(
  848. "Custom eval function must return "
  849. "dict of metrics, got {}.".format(metrics)
  850. )
  851. else:
  852. if (
  853. self.evaluation_workers is None
  854. and self.workers.local_worker().input_reader is None
  855. ):
  856. raise ValueError(
  857. "Cannot evaluate w/o an evaluation worker set in "
  858. "the Algorithm or w/o an env on the local worker!\n"
  859. "Try one of the following:\n1) Set "
  860. "`evaluation_interval` >= 0 to force creating a "
  861. "separate evaluation worker set.\n2) Set "
  862. "`create_env_on_driver=True` to force the local "
  863. "(non-eval) worker to have an environment to "
  864. "evaluate on."
  865. )
  866. # How many episodes/timesteps do we need to run?
  867. # In "auto" mode (only for parallel eval + training): Run as long
  868. # as training lasts.
  869. unit = self.config.evaluation_duration_unit
  870. eval_cfg = self.evaluation_config
  871. rollout = eval_cfg.rollout_fragment_length
  872. num_envs = eval_cfg.num_envs_per_worker
  873. auto = self.config.evaluation_duration == "auto"
  874. duration = (
  875. self.config.evaluation_duration
  876. if not auto
  877. else (self.config.evaluation_num_workers or 1)
  878. * (1 if unit == "episodes" else rollout)
  879. )
  880. agent_steps_this_iter = 0
  881. env_steps_this_iter = 0
  882. # Default done-function returns True, whenever num episodes
  883. # have been completed.
  884. if duration_fn is None:
  885. def duration_fn(num_units_done):
  886. return duration - num_units_done
  887. logger.info(f"Evaluating current state of {self} for {duration} {unit}.")
  888. metrics = None
  889. all_batches = []
  890. # No evaluation worker set ->
  891. # Do evaluation using the local worker. Expect error due to the
  892. # local worker not having an env.
  893. if self.evaluation_workers is None:
  894. # If unit=episodes -> Run n times `sample()` (each sample
  895. # produces exactly 1 episode).
  896. # If unit=ts -> Run 1 `sample()` b/c the
  897. # `rollout_fragment_length` is exactly the desired ts.
  898. iters = duration if unit == "episodes" else 1
  899. for _ in range(iters):
  900. batch = self.workers.local_worker().sample()
  901. agent_steps_this_iter += batch.agent_steps()
  902. env_steps_this_iter += batch.env_steps()
  903. if self.reward_estimators:
  904. all_batches.append(batch)
  905. metrics = collect_metrics(
  906. self.workers,
  907. keep_custom_metrics=eval_cfg.keep_per_episode_custom_metrics,
  908. timeout_seconds=eval_cfg.metrics_episode_collection_timeout_s,
  909. )
  910. # Evaluation worker set only has local worker.
  911. elif self.evaluation_workers.num_remote_workers() == 0:
  912. # If unit=episodes -> Run n times `sample()` (each sample
  913. # produces exactly 1 episode).
  914. # If unit=ts -> Run 1 `sample()` b/c the
  915. # `rollout_fragment_length` is exactly the desired ts.
  916. iters = duration if unit == "episodes" else 1
  917. for _ in range(iters):
  918. batch = self.evaluation_workers.local_worker().sample()
  919. agent_steps_this_iter += batch.agent_steps()
  920. env_steps_this_iter += batch.env_steps()
  921. if self.reward_estimators:
  922. all_batches.append(batch)
  923. # Evaluation worker set has n remote workers.
  924. elif self.evaluation_workers.num_healthy_remote_workers() > 0:
  925. # How many episodes have we run (across all eval workers)?
  926. num_units_done = 0
  927. _round = 0
  928. # In case all of the remote evaluation workers die during a round
  929. # of evaluation, we need to stop.
  930. while True and self.evaluation_workers.num_healthy_remote_workers() > 0:
  931. units_left_to_do = duration_fn(num_units_done)
  932. if units_left_to_do <= 0:
  933. break
  934. _round += 1
  935. unit_per_remote_worker = (
  936. 1 if unit == "episodes" else rollout * num_envs
  937. )
  938. # Select proper number of evaluation workers for this round.
  939. selected_eval_worker_ids = [
  940. worker_id
  941. for i, worker_id in enumerate(
  942. self.evaluation_workers.healthy_worker_ids()
  943. )
  944. if i * unit_per_remote_worker < units_left_to_do
  945. ]
  946. batches = self.evaluation_workers.foreach_worker(
  947. func=lambda w: w.sample(),
  948. local_worker=False,
  949. remote_worker_ids=selected_eval_worker_ids,
  950. timeout_seconds=self.config.evaluation_sample_timeout_s,
  951. )
  952. if len(batches) != len(selected_eval_worker_ids):
  953. logger.warning(
  954. "Calling `sample()` on your remote evaluation worker(s) "
  955. "resulted in a timeout (after the configured "
  956. f"{self.config.evaluation_sample_timeout_s} seconds)! "
  957. "Try to set `evaluation_sample_timeout_s` in your config"
  958. " to a larger value."
  959. + (
  960. " If your episodes don't terminate easily, you may "
  961. "also want to set `evaluation_duration_unit` to "
  962. "'timesteps' (instead of 'episodes')."
  963. if unit == "episodes"
  964. else ""
  965. )
  966. )
  967. break
  968. _agent_steps = sum(b.agent_steps() for b in batches)
  969. _env_steps = sum(b.env_steps() for b in batches)
  970. # 1 episode per returned batch.
  971. if unit == "episodes":
  972. num_units_done += len(batches)
  973. # Make sure all batches are exactly one episode.
  974. for ma_batch in batches:
  975. ma_batch = ma_batch.as_multi_agent()
  976. for batch in ma_batch.policy_batches.values():
  977. assert batch.is_terminated_or_truncated()
  978. # n timesteps per returned batch.
  979. else:
  980. num_units_done += (
  981. _agent_steps
  982. if self.config.count_steps_by == "agent_steps"
  983. else _env_steps
  984. )
  985. if self.reward_estimators:
  986. # TODO: (kourosh) This approach will cause an OOM issue when
  987. # the dataset gets huge (should be ok for now).
  988. all_batches.extend(batches)
  989. agent_steps_this_iter += _agent_steps
  990. env_steps_this_iter += _env_steps
  991. logger.info(
  992. f"Ran round {_round} of non-parallel evaluation "
  993. f"({num_units_done}/{duration if not auto else '?'} "
  994. f"{unit} done)"
  995. )
  996. else:
  997. # Can't find a good way to run this evaluation.
  998. # Wait for next iteration.
  999. pass
  1000. if metrics is None:
  1001. metrics = collect_metrics(
  1002. self.evaluation_workers,
  1003. keep_custom_metrics=self.config.keep_per_episode_custom_metrics,
  1004. timeout_seconds=eval_cfg.metrics_episode_collection_timeout_s,
  1005. )
  1006. # TODO: Don't dump sampler results into top-level.
  1007. if not self.config.custom_evaluation_function:
  1008. metrics = dict({"sampler_results": metrics}, **metrics)
  1009. metrics[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps_this_iter
  1010. metrics[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps_this_iter
  1011. # TODO: Remove this key at some point. Here for backward compatibility.
  1012. metrics["timesteps_this_iter"] = env_steps_this_iter
  1013. # Compute off-policy estimates
  1014. estimates = defaultdict(list)
  1015. # for each batch run the estimator's fwd pass
  1016. for name, estimator in self.reward_estimators.items():
  1017. for batch in all_batches:
  1018. estimate_result = estimator.estimate(
  1019. batch,
  1020. split_batch_by_episode=self.config.ope_split_batch_by_episode,
  1021. )
  1022. estimates[name].append(estimate_result)
  1023. # collate estimates from all batches
  1024. if estimates:
  1025. metrics["off_policy_estimator"] = {}
  1026. for name, estimate_list in estimates.items():
  1027. avg_estimate = tree.map_structure(
  1028. lambda *x: np.mean(x, axis=0), *estimate_list
  1029. )
  1030. metrics["off_policy_estimator"][name] = avg_estimate
  1031. # Evaluation does not run for every step.
  1032. # Save evaluation metrics on Algorithm, so it can be attached to
  1033. # subsequent step results as latest evaluation result.
  1034. self.evaluation_metrics = {"evaluation": metrics}
  1035. # Trigger `on_evaluate_end` callback.
  1036. self.callbacks.on_evaluate_end(
  1037. algorithm=self, evaluation_metrics=self.evaluation_metrics
  1038. )
  1039. # Also return the results here for convenience.
  1040. return self.evaluation_metrics
  1041. @ExperimentalAPI
  1042. def _evaluate_async(
  1043. self,
  1044. duration_fn: Optional[Callable[[int], int]] = None,
  1045. ) -> dict:
  1046. """Evaluates current policy under `evaluation_config` settings.
  1047. Uses the AsyncParallelRequests manager to send frequent `sample.remote()`
  1048. requests to the evaluation RolloutWorkers and collect the results of these
  1049. calls. Handles worker failures (or slowdowns) gracefully due to the asynch'ness
  1050. and the fact that other eval RolloutWorkers can thus cover the workload.
  1051. Important Note: This will replace the current `self.evaluate()` method as the
  1052. default in the future.
  1053. Args:
  1054. duration_fn: An optional callable taking the already run
  1055. num episodes as only arg and returning the number of
  1056. episodes left to run. It's used to find out whether
  1057. evaluation should continue.
  1058. """
  1059. # How many episodes/timesteps do we need to run?
  1060. # In "auto" mode (only for parallel eval + training): Run as long
  1061. # as training lasts.
  1062. unit = self.config.evaluation_duration_unit
  1063. eval_cfg = self.evaluation_config
  1064. rollout = eval_cfg.rollout_fragment_length
  1065. num_envs = eval_cfg.num_envs_per_worker
  1066. auto = self.config.evaluation_duration == "auto"
  1067. duration = (
  1068. self.config.evaluation_duration
  1069. if not auto
  1070. else (self.config.evaluation_num_workers or 1)
  1071. * (1 if unit == "episodes" else rollout)
  1072. )
  1073. # Call the `_before_evaluate` hook.
  1074. self._before_evaluate()
  1075. # TODO(Jun): Implement solution via connectors.
  1076. self._sync_filters_if_needed(
  1077. central_worker=self.workers.local_worker(),
  1078. workers=self.evaluation_workers,
  1079. config=eval_cfg,
  1080. )
  1081. if self.config.custom_evaluation_function:
  1082. raise ValueError(
  1083. "`config.custom_evaluation_function` not supported in combination "
  1084. "with `enable_async_evaluation=True` config setting!"
  1085. )
  1086. if self.evaluation_workers is None and (
  1087. self.workers.local_worker().input_reader is None
  1088. or self.config.evaluation_num_workers == 0
  1089. ):
  1090. raise ValueError(
  1091. "Evaluation w/o eval workers (calling Algorithm.evaluate() w/o "
  1092. "evaluation specifically set up) OR evaluation without input reader "
  1093. "OR evaluation with only a local evaluation worker "
  1094. "(`evaluation_num_workers=0`) not supported in combination "
  1095. "with `enable_async_evaluation=True` config setting!"
  1096. )
  1097. agent_steps_this_iter = 0
  1098. env_steps_this_iter = 0
  1099. logger.info(f"Evaluating current state of {self} for {duration} {unit}.")
  1100. all_batches = []
  1101. # Default done-function returns True, whenever num episodes
  1102. # have been completed.
  1103. if duration_fn is None:
  1104. def duration_fn(num_units_done):
  1105. return duration - num_units_done
  1106. # Put weights only once into object store and use same object
  1107. # ref to synch to all workers.
  1108. self._evaluation_weights_seq_number += 1
  1109. weights_ref = ray.put(self.workers.local_worker().get_weights())
  1110. weights_seq_no = self._evaluation_weights_seq_number
  1111. def remote_fn(worker):
  1112. # Pass in seq-no so that eval workers may ignore this call if no update has
  1113. # happened since the last call to `remote_fn` (sample).
  1114. worker.set_weights(
  1115. weights=ray.get(weights_ref), weights_seq_no=weights_seq_no
  1116. )
  1117. batch = worker.sample()
  1118. metrics = worker.get_metrics()
  1119. return batch, metrics, weights_seq_no
  1120. rollout_metrics = []
  1121. # How many episodes have we run (across all eval workers)?
  1122. num_units_done = 0
  1123. _round = 0
  1124. while self.evaluation_workers.num_healthy_remote_workers() > 0:
  1125. units_left_to_do = duration_fn(num_units_done)
  1126. if units_left_to_do <= 0:
  1127. break
  1128. _round += 1
  1129. # Get ready evaluation results and metrics asynchronously.
  1130. self.evaluation_workers.foreach_worker_async(
  1131. func=remote_fn,
  1132. healthy_only=True,
  1133. )
  1134. eval_results = self.evaluation_workers.fetch_ready_async_reqs()
  1135. batches = []
  1136. i = 0
  1137. for _, result in eval_results:
  1138. batch, metrics, seq_no = result
  1139. # Ignore results, if the weights seq-number does not match (is
  1140. # from a previous evaluation step) OR if we have already reached
  1141. # the configured duration (e.g. number of episodes to evaluate
  1142. # for).
  1143. if seq_no == self._evaluation_weights_seq_number and (
  1144. i * (1 if unit == "episodes" else rollout * num_envs)
  1145. < units_left_to_do
  1146. ):
  1147. batches.append(batch)
  1148. rollout_metrics.extend(metrics)
  1149. i += 1
  1150. _agent_steps = sum(b.agent_steps() for b in batches)
  1151. _env_steps = sum(b.env_steps() for b in batches)
  1152. # 1 episode per returned batch.
  1153. if unit == "episodes":
  1154. num_units_done += len(batches)
  1155. # Make sure all batches are exactly one episode.
  1156. for ma_batch in batches:
  1157. ma_batch = ma_batch.as_multi_agent()
  1158. for batch in ma_batch.policy_batches.values():
  1159. assert batch.is_terminated_or_truncated()
  1160. # n timesteps per returned batch.
  1161. else:
  1162. num_units_done += (
  1163. _agent_steps
  1164. if self.config.count_steps_by == "agent_steps"
  1165. else _env_steps
  1166. )
  1167. if self.reward_estimators:
  1168. all_batches.extend(batches)
  1169. agent_steps_this_iter += _agent_steps
  1170. env_steps_this_iter += _env_steps
  1171. logger.info(
  1172. f"Ran round {_round} of parallel evaluation "
  1173. f"({num_units_done}/{duration if not auto else '?'} "
  1174. f"{unit} done)"
  1175. )
  1176. sampler_results = summarize_episodes(
  1177. rollout_metrics,
  1178. keep_custom_metrics=eval_cfg["keep_per_episode_custom_metrics"],
  1179. )
  1180. # TODO: Don't dump sampler results into top-level.
  1181. metrics = dict({"sampler_results": sampler_results}, **sampler_results)
  1182. metrics[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps_this_iter
  1183. metrics[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps_this_iter
  1184. # TODO: Remove this key at some point. Here for backward compatibility.
  1185. metrics["timesteps_this_iter"] = env_steps_this_iter
  1186. if self.reward_estimators:
  1187. # Compute off-policy estimates
  1188. metrics["off_policy_estimator"] = {}
  1189. total_batch = concat_samples(all_batches)
  1190. for name, estimator in self.reward_estimators.items():
  1191. estimates = estimator.estimate(total_batch)
  1192. metrics["off_policy_estimator"][name] = estimates
  1193. # Evaluation does not run for every step.
  1194. # Save evaluation metrics on Algorithm, so it can be attached to
  1195. # subsequent step results as latest evaluation result.
  1196. self.evaluation_metrics = {"evaluation": metrics}
  1197. # Trigger `on_evaluate_end` callback.
  1198. self.callbacks.on_evaluate_end(
  1199. algorithm=self, evaluation_metrics=self.evaluation_metrics
  1200. )
  1201. # Return evaluation results.
  1202. return self.evaluation_metrics
  1203. @OverrideToImplementCustomLogic
  1204. @DeveloperAPI
  1205. def restore_workers(self, workers: WorkerSet):
  1206. """Try to restore failed workers if necessary.
  1207. Algorithms that use custom RolloutWorkers may override this method to
  1208. disable default, and create custom restoration logics.
  1209. Args:
  1210. workers: The WorkerSet to restore. This may be Rollout or Evaluation
  1211. workers.
  1212. """
  1213. # If `workers` is None, or
  1214. # 1. `workers` (WorkerSet) does not have a local worker, and
  1215. # 2. `self.workers` (WorkerSet used for training) does not have a local worker
  1216. # -> we don't have a local worker to get state from, so we can't recover
  1217. # remote worker in this case.
  1218. if not workers or (
  1219. not workers.local_worker() and not self.workers.local_worker()
  1220. ):
  1221. return
  1222. # This is really cheap, since probe_unhealthy_workers() is a no-op
  1223. # if there are no unhealthy workers.
  1224. restored = workers.probe_unhealthy_workers()
  1225. if restored:
  1226. from_worker = workers.local_worker() or self.workers.local_worker()
  1227. # Get the state of the correct (reference) worker. E.g. The local worker
  1228. # of the main WorkerSet.
  1229. state_ref = ray.put(from_worker.get_state())
  1230. # By default, entire local worker state is synced after restoration
  1231. # to bring these workers up to date.
  1232. workers.foreach_worker(
  1233. func=lambda w: w.set_state(ray.get(state_ref)),
  1234. remote_worker_ids=restored,
  1235. # Don't update the local_worker, b/c it's the one we are synching from.
  1236. local_worker=False,
  1237. timeout_seconds=self.config.worker_restore_timeout_s,
  1238. # Bring back actor after successful state syncing.
  1239. mark_healthy=True,
  1240. )
  1241. @OverrideToImplementCustomLogic
  1242. @DeveloperAPI
  1243. def training_step(self) -> ResultDict:
  1244. """Default single iteration logic of an algorithm.
  1245. - Collect on-policy samples (SampleBatches) in parallel using the
  1246. Algorithm's RolloutWorkers (@ray.remote).
  1247. - Concatenate collected SampleBatches into one train batch.
  1248. - Note that we may have more than one policy in the multi-agent case:
  1249. Call the different policies' `learn_on_batch` (simple optimizer) OR
  1250. `load_batch_into_buffer` + `learn_on_loaded_batch` (multi-GPU
  1251. optimizer) methods to calculate loss and update the model(s).
  1252. - Return all collected metrics for the iteration.
  1253. Returns:
  1254. The results dict from executing the training iteration.
  1255. """
  1256. # Collect SampleBatches from sample workers until we have a full batch.
  1257. with self._timers[SAMPLE_TIMER]:
  1258. if self.config.count_steps_by == "agent_steps":
  1259. train_batch = synchronous_parallel_sample(
  1260. worker_set=self.workers,
  1261. max_agent_steps=self.config.train_batch_size,
  1262. )
  1263. else:
  1264. train_batch = synchronous_parallel_sample(
  1265. worker_set=self.workers, max_env_steps=self.config.train_batch_size
  1266. )
  1267. train_batch = train_batch.as_multi_agent()
  1268. self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
  1269. self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
  1270. # Only train if train_batch is not empty.
  1271. # In an extreme situation, all rollout workers die during the
  1272. # synchronous_parallel_sample() call above.
  1273. # In which case, we should skip training, wait a little bit, then probe again.
  1274. train_results = {}
  1275. if train_batch.agent_steps() > 0:
  1276. # Use simple optimizer (only for multi-agent or tf-eager; all other
  1277. # cases should use the multi-GPU optimizer, even if only using 1 GPU).
  1278. # TODO: (sven) rename MultiGPUOptimizer into something more
  1279. # meaningful.
  1280. if self.config._enable_learner_api:
  1281. is_module_trainable = self.workers.local_worker().is_policy_to_train
  1282. self.learner_group.set_is_module_trainable(is_module_trainable)
  1283. train_results = self.learner_group.update(train_batch)
  1284. elif self.config.get("simple_optimizer") is True:
  1285. train_results = train_one_step(self, train_batch)
  1286. else:
  1287. train_results = multi_gpu_train_one_step(self, train_batch)
  1288. else:
  1289. # Wait 1 sec before probing again via weight syncing.
  1290. time.sleep(1)
  1291. # Update weights and global_vars - after learning on the local worker - on all
  1292. # remote workers (only those policies that were actually trained).
  1293. global_vars = {
  1294. "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
  1295. }
  1296. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
  1297. # TODO (Avnish): Implement this on learner_group.get_weights().
  1298. # TODO (Kourosh): figure out how we are going to sync MARLModule
  1299. # weights to MARLModule weights under the policy_map objects?
  1300. from_worker_or_trainer = None
  1301. if self.config._enable_learner_api:
  1302. from_worker_or_trainer = self.learner_group
  1303. self.workers.sync_weights(
  1304. from_worker_or_learner_group=from_worker_or_trainer,
  1305. policies=list(train_results.keys()),
  1306. global_vars=global_vars,
  1307. )
  1308. return train_results
  1309. @staticmethod
  1310. def execution_plan(workers, config, **kwargs):
  1311. raise NotImplementedError(
  1312. "It is no longer supported to use the `Algorithm.execution_plan()` API!"
  1313. " Set `_disable_execution_plan_api=True` in your config and override the "
  1314. "`Algorithm.training_step()` method with your algo's custom "
  1315. "execution logic instead."
  1316. )
  1317. @PublicAPI
  1318. def compute_single_action(
  1319. self,
  1320. observation: Optional[TensorStructType] = None,
  1321. state: Optional[List[TensorStructType]] = None,
  1322. *,
  1323. prev_action: Optional[TensorStructType] = None,
  1324. prev_reward: Optional[float] = None,
  1325. info: Optional[EnvInfoDict] = None,
  1326. input_dict: Optional[SampleBatch] = None,
  1327. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1328. full_fetch: bool = False,
  1329. explore: Optional[bool] = None,
  1330. timestep: Optional[int] = None,
  1331. episode: Optional[Episode] = None,
  1332. unsquash_action: Optional[bool] = None,
  1333. clip_action: Optional[bool] = None,
  1334. # Kwargs placeholder for future compatibility.
  1335. **kwargs,
  1336. ) -> Union[
  1337. TensorStructType,
  1338. Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]],
  1339. ]:
  1340. """Computes an action for the specified policy on the local worker.
  1341. Note that you can also access the policy object through
  1342. self.get_policy(policy_id) and call compute_single_action() on it
  1343. directly.
  1344. Args:
  1345. observation: Single (unbatched) observation from the
  1346. environment.
  1347. state: List of all RNN hidden (single, unbatched) state tensors.
  1348. prev_action: Single (unbatched) previous action value.
  1349. prev_reward: Single (unbatched) previous reward value.
  1350. info: Env info dict, if any.
  1351. input_dict: An optional SampleBatch that holds all the values
  1352. for: obs, state, prev_action, and prev_reward, plus maybe
  1353. custom defined views of the current env trajectory. Note
  1354. that only one of `obs` or `input_dict` must be non-None.
  1355. policy_id: Policy to query (only applies to multi-agent).
  1356. Default: "default_policy".
  1357. full_fetch: Whether to return extra action fetch results.
  1358. This is always set to True if `state` is specified.
  1359. explore: Whether to apply exploration to the action.
  1360. Default: None -> use self.config.explore.
  1361. timestep: The current (sampling) time step.
  1362. episode: This provides access to all of the internal episodes'
  1363. state, which may be useful for model-based or multi-agent
  1364. algorithms.
  1365. unsquash_action: Should actions be unsquashed according to the
  1366. env's/Policy's action space? If None, use the value of
  1367. self.config.normalize_actions.
  1368. clip_action: Should actions be clipped according to the
  1369. env's/Policy's action space? If None, use the value of
  1370. self.config.clip_actions.
  1371. Keyword Args:
  1372. kwargs: forward compatibility placeholder
  1373. Returns:
  1374. The computed action if full_fetch=False, or a tuple of a) the
  1375. full output of policy.compute_actions() if full_fetch=True
  1376. or we have an RNN-based Policy.
  1377. Raises:
  1378. KeyError: If the `policy_id` cannot be found in this Algorithm's local
  1379. worker.
  1380. """
  1381. # `unsquash_action` is None: Use value of config['normalize_actions'].
  1382. if unsquash_action is None:
  1383. unsquash_action = self.config.normalize_actions
  1384. # `clip_action` is None: Use value of config['clip_actions'].
  1385. elif clip_action is None:
  1386. clip_action = self.config.clip_actions
  1387. # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state`
  1388. # are all None.
  1389. err_msg = (
  1390. "Provide either `input_dict` OR [`observation`, ...] as "
  1391. "args to `Algorithm.compute_single_action()`!"
  1392. )
  1393. if input_dict is not None:
  1394. assert (
  1395. observation is None
  1396. and prev_action is None
  1397. and prev_reward is None
  1398. and state is None
  1399. ), err_msg
  1400. observation = input_dict[SampleBatch.OBS]
  1401. else:
  1402. assert observation is not None, err_msg
  1403. # Get the policy to compute the action for (in the multi-agent case,
  1404. # Algorithm may hold >1 policies).
  1405. policy = self.get_policy(policy_id)
  1406. if policy is None:
  1407. raise KeyError(
  1408. f"PolicyID '{policy_id}' not found in PolicyMap of the "
  1409. f"Algorithm's local worker!"
  1410. )
  1411. local_worker = self.workers.local_worker()
  1412. if not self.config.get("enable_connectors"):
  1413. # Check the preprocessor and preprocess, if necessary.
  1414. pp = local_worker.preprocessors[policy_id]
  1415. if pp and type(pp).__name__ != "NoPreprocessor":
  1416. observation = pp.transform(observation)
  1417. observation = local_worker.filters[policy_id](observation, update=False)
  1418. else:
  1419. # Just preprocess observations, similar to how it used to be done before.
  1420. pp = policy.agent_connectors[ObsPreprocessorConnector]
  1421. # convert the observation to array if possible
  1422. if not isinstance(observation, (np.ndarray, dict, tuple)):
  1423. try:
  1424. observation = np.asarray(observation)
  1425. except Exception:
  1426. raise ValueError(
  1427. f"Observation type {type(observation)} cannot be converted to "
  1428. f"np.ndarray."
  1429. )
  1430. if pp:
  1431. assert len(pp) == 1, "Only one preprocessor should be in the pipeline"
  1432. pp = pp[0]
  1433. if not pp.is_identity():
  1434. # Note(Kourosh): This call will leave the policy's connector
  1435. # in eval mode. would that be a problem?
  1436. pp.in_eval()
  1437. if observation is not None:
  1438. _input_dict = {SampleBatch.OBS: observation}
  1439. elif input_dict is not None:
  1440. _input_dict = {SampleBatch.OBS: input_dict[SampleBatch.OBS]}
  1441. else:
  1442. raise ValueError(
  1443. "Either observation or input_dict must be provided."
  1444. )
  1445. # TODO (Kourosh): Create a new util method for algorithm that
  1446. # computes actions based on raw inputs from env and can keep track
  1447. # of its own internal state.
  1448. acd = AgentConnectorDataType("0", "0", _input_dict)
  1449. # make sure the state is reset since we are only applying the
  1450. # preprocessor
  1451. pp.reset(env_id="0")
  1452. ac_o = pp([acd])[0]
  1453. observation = ac_o.data[SampleBatch.OBS]
  1454. # Input-dict.
  1455. if input_dict is not None:
  1456. input_dict[SampleBatch.OBS] = observation
  1457. action, state, extra = policy.compute_single_action(
  1458. input_dict=input_dict,
  1459. explore=explore,
  1460. timestep=timestep,
  1461. episode=episode,
  1462. )
  1463. # Individual args.
  1464. else:
  1465. action, state, extra = policy.compute_single_action(
  1466. obs=observation,
  1467. state=state,
  1468. prev_action=prev_action,
  1469. prev_reward=prev_reward,
  1470. info=info,
  1471. explore=explore,
  1472. timestep=timestep,
  1473. episode=episode,
  1474. )
  1475. # If we work in normalized action space (normalize_actions=True),
  1476. # we re-translate here into the env's action space.
  1477. if unsquash_action:
  1478. action = space_utils.unsquash_action(action, policy.action_space_struct)
  1479. # Clip, according to env's action space.
  1480. elif clip_action:
  1481. action = space_utils.clip_action(action, policy.action_space_struct)
  1482. # Return 3-Tuple: Action, states, and extra-action fetches.
  1483. if state or full_fetch:
  1484. return action, state, extra
  1485. # Ensure backward compatibility.
  1486. else:
  1487. return action
  1488. @PublicAPI
  1489. def compute_actions(
  1490. self,
  1491. observations: TensorStructType,
  1492. state: Optional[List[TensorStructType]] = None,
  1493. *,
  1494. prev_action: Optional[TensorStructType] = None,
  1495. prev_reward: Optional[TensorStructType] = None,
  1496. info: Optional[EnvInfoDict] = None,
  1497. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1498. full_fetch: bool = False,
  1499. explore: Optional[bool] = None,
  1500. timestep: Optional[int] = None,
  1501. episodes: Optional[List[Episode]] = None,
  1502. unsquash_actions: Optional[bool] = None,
  1503. clip_actions: Optional[bool] = None,
  1504. **kwargs,
  1505. ):
  1506. """Computes an action for the specified policy on the local Worker.
  1507. Note that you can also access the policy object through
  1508. self.get_policy(policy_id) and call compute_actions() on it directly.
  1509. Args:
  1510. observation: Observation from the environment.
  1511. state: RNN hidden state, if any. If state is not None,
  1512. then all of compute_single_action(...) is returned
  1513. (computed action, rnn state(s), logits dictionary).
  1514. Otherwise compute_single_action(...)[0] is returned
  1515. (computed action).
  1516. prev_action: Previous action value, if any.
  1517. prev_reward: Previous reward, if any.
  1518. info: Env info dict, if any.
  1519. policy_id: Policy to query (only applies to multi-agent).
  1520. full_fetch: Whether to return extra action fetch results.
  1521. This is always set to True if RNN state is specified.
  1522. explore: Whether to pick an exploitation or exploration
  1523. action (default: None -> use self.config.explore).
  1524. timestep: The current (sampling) time step.
  1525. episodes: This provides access to all of the internal episodes'
  1526. state, which may be useful for model-based or multi-agent
  1527. algorithms.
  1528. unsquash_actions: Should actions be unsquashed according
  1529. to the env's/Policy's action space? If None, use
  1530. self.config.normalize_actions.
  1531. clip_actions: Should actions be clipped according to the
  1532. env's/Policy's action space? If None, use
  1533. self.config.clip_actions.
  1534. Keyword Args:
  1535. kwargs: forward compatibility placeholder
  1536. Returns:
  1537. The computed action if full_fetch=False, or a tuple consisting of
  1538. the full output of policy.compute_actions_from_input_dict() if
  1539. full_fetch=True or we have an RNN-based Policy.
  1540. """
  1541. # `unsquash_actions` is None: Use value of config['normalize_actions'].
  1542. if unsquash_actions is None:
  1543. unsquash_actions = self.config.normalize_actions
  1544. # `clip_actions` is None: Use value of config['clip_actions'].
  1545. elif clip_actions is None:
  1546. clip_actions = self.config.clip_actions
  1547. # Preprocess obs and states.
  1548. state_defined = state is not None
  1549. policy = self.get_policy(policy_id)
  1550. filtered_obs, filtered_state = [], []
  1551. for agent_id, ob in observations.items():
  1552. worker = self.workers.local_worker()
  1553. preprocessed = worker.preprocessors[policy_id].transform(ob)
  1554. filtered = worker.filters[policy_id](preprocessed, update=False)
  1555. filtered_obs.append(filtered)
  1556. if state is None:
  1557. continue
  1558. elif agent_id in state:
  1559. filtered_state.append(state[agent_id])
  1560. else:
  1561. filtered_state.append(policy.get_initial_state())
  1562. # Batch obs and states
  1563. obs_batch = np.stack(filtered_obs)
  1564. if state is None:
  1565. state = []
  1566. else:
  1567. state = list(zip(*filtered_state))
  1568. state = [np.stack(s) for s in state]
  1569. input_dict = {SampleBatch.OBS: obs_batch}
  1570. # prev_action and prev_reward can be None, np.ndarray, or tensor-like structure.
  1571. # Explicitly check for None here to avoid the error message "The truth value of
  1572. # an array with more than one element is ambiguous.", when np arrays are passed
  1573. # as arguments.
  1574. if prev_action is not None:
  1575. input_dict[SampleBatch.PREV_ACTIONS] = prev_action
  1576. if prev_reward is not None:
  1577. input_dict[SampleBatch.PREV_REWARDS] = prev_reward
  1578. if info:
  1579. input_dict[SampleBatch.INFOS] = info
  1580. for i, s in enumerate(state):
  1581. input_dict[f"state_in_{i}"] = s
  1582. # Batch compute actions
  1583. actions, states, infos = policy.compute_actions_from_input_dict(
  1584. input_dict=input_dict,
  1585. explore=explore,
  1586. timestep=timestep,
  1587. episodes=episodes,
  1588. )
  1589. # Unbatch actions for the environment into a multi-agent dict.
  1590. single_actions = space_utils.unbatch(actions)
  1591. actions = {}
  1592. for key, a in zip(observations, single_actions):
  1593. # If we work in normalized action space (normalize_actions=True),
  1594. # we re-translate here into the env's action space.
  1595. if unsquash_actions:
  1596. a = space_utils.unsquash_action(a, policy.action_space_struct)
  1597. # Clip, according to env's action space.
  1598. elif clip_actions:
  1599. a = space_utils.clip_action(a, policy.action_space_struct)
  1600. actions[key] = a
  1601. # Unbatch states into a multi-agent dict.
  1602. unbatched_states = {}
  1603. for idx, agent_id in enumerate(observations):
  1604. unbatched_states[agent_id] = [s[idx] for s in states]
  1605. # Return only actions or full tuple
  1606. if state_defined or full_fetch:
  1607. return actions, unbatched_states, infos
  1608. else:
  1609. return actions
  1610. @PublicAPI
  1611. def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy:
  1612. """Return policy for the specified id, or None.
  1613. Args:
  1614. policy_id: ID of the policy to return.
  1615. """
  1616. return self.workers.local_worker().get_policy(policy_id)
  1617. @PublicAPI
  1618. def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict:
  1619. """Return a dictionary of policy ids to weights.
  1620. Args:
  1621. policies: Optional list of policies to return weights for,
  1622. or None for all policies.
  1623. """
  1624. return self.workers.local_worker().get_weights(policies)
  1625. @PublicAPI
  1626. def set_weights(self, weights: Dict[PolicyID, dict]):
  1627. """Set policy weights by policy id.
  1628. Args:
  1629. weights: Map of policy ids to weights to set.
  1630. """
  1631. self.workers.local_worker().set_weights(weights)
  1632. @PublicAPI
  1633. def add_policy(
  1634. self,
  1635. policy_id: PolicyID,
  1636. policy_cls: Optional[Type[Policy]] = None,
  1637. policy: Optional[Policy] = None,
  1638. *,
  1639. observation_space: Optional[gym.spaces.Space] = None,
  1640. action_space: Optional[gym.spaces.Space] = None,
  1641. config: Optional[Union[AlgorithmConfig, PartialAlgorithmConfigDict]] = None,
  1642. policy_state: Optional[PolicyState] = None,
  1643. policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
  1644. policies_to_train: Optional[
  1645. Union[
  1646. Container[PolicyID],
  1647. Callable[[PolicyID, Optional[SampleBatchType]], bool],
  1648. ]
  1649. ] = None,
  1650. evaluation_workers: bool = True,
  1651. module_spec: Optional[SingleAgentRLModuleSpec] = None,
  1652. ) -> Optional[Policy]:
  1653. """Adds a new policy to this Algorithm.
  1654. Args:
  1655. policy_id: ID of the policy to add.
  1656. IMPORTANT: Must not contain characters that
  1657. are also not allowed in Unix/Win filesystems, such as: `<>:"/|?*`,
  1658. or a dot, space or backslash at the end of the ID.
  1659. policy_cls: The Policy class to use for constructing the new Policy.
  1660. Note: Only one of `policy_cls` or `policy` must be provided.
  1661. policy: The Policy instance to add to this algorithm. If not None, the
  1662. given Policy object will be directly inserted into the Algorithm's
  1663. local worker and clones of that Policy will be created on all remote
  1664. workers as well as all evaluation workers.
  1665. Note: Only one of `policy_cls` or `policy` must be provided.
  1666. observation_space: The observation space of the policy to add.
  1667. If None, try to infer this space from the environment.
  1668. action_space: The action space of the policy to add.
  1669. If None, try to infer this space from the environment.
  1670. config: The config object or overrides for the policy to add.
  1671. policy_state: Optional state dict to apply to the new
  1672. policy instance, right after its construction.
  1673. policy_mapping_fn: An optional (updated) policy mapping function
  1674. to use from here on. Note that already ongoing episodes will
  1675. not change their mapping but will use the old mapping till
  1676. the end of the episode.
  1677. policies_to_train: An optional list of policy IDs to be trained
  1678. or a callable taking PolicyID and SampleBatchType and
  1679. returning a bool (trainable or not?).
  1680. If None, will keep the existing setup in place. Policies,
  1681. whose IDs are not in the list (or for which the callable
  1682. returns False) will not be updated.
  1683. evaluation_workers: Whether to add the new policy also
  1684. to the evaluation WorkerSet.
  1685. module_spec: In the new RLModule API we need to pass in the module_spec for
  1686. the new module that is supposed to be added. Knowing the policy spec is
  1687. not sufficient.
  1688. Returns:
  1689. The newly added policy (the copy that got added to the local
  1690. worker). If `workers` was provided, None is returned.
  1691. """
  1692. validate_policy_id(policy_id, error=True)
  1693. self.workers.add_policy(
  1694. policy_id,
  1695. policy_cls,
  1696. policy,
  1697. observation_space=observation_space,
  1698. action_space=action_space,
  1699. config=config,
  1700. policy_state=policy_state,
  1701. policy_mapping_fn=policy_mapping_fn,
  1702. policies_to_train=policies_to_train,
  1703. module_spec=module_spec,
  1704. )
  1705. # If learner API is enabled, we need to also add the underlying module
  1706. # to the learner group.
  1707. if self.config._enable_learner_api:
  1708. policy = self.get_policy(policy_id)
  1709. module = policy.model
  1710. self.learner_group.add_module(
  1711. module_id=policy_id,
  1712. module_spec=SingleAgentRLModuleSpec.from_module(module),
  1713. )
  1714. weights = policy.get_weights()
  1715. self.learner_group.set_weights({policy_id: weights})
  1716. # Add to evaluation workers, if necessary.
  1717. if evaluation_workers is True and self.evaluation_workers is not None:
  1718. self.evaluation_workers.add_policy(
  1719. policy_id,
  1720. policy_cls,
  1721. policy,
  1722. observation_space=observation_space,
  1723. action_space=action_space,
  1724. config=config,
  1725. policy_state=policy_state,
  1726. policy_mapping_fn=policy_mapping_fn,
  1727. policies_to_train=policies_to_train,
  1728. module_spec=module_spec,
  1729. )
  1730. # Return newly added policy (from the local rollout worker).
  1731. return self.get_policy(policy_id)
  1732. @PublicAPI
  1733. def remove_policy(
  1734. self,
  1735. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1736. *,
  1737. policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
  1738. policies_to_train: Optional[
  1739. Union[
  1740. Container[PolicyID],
  1741. Callable[[PolicyID, Optional[SampleBatchType]], bool],
  1742. ]
  1743. ] = None,
  1744. evaluation_workers: bool = True,
  1745. ) -> None:
  1746. """Removes a new policy from this Algorithm.
  1747. Args:
  1748. policy_id: ID of the policy to be removed.
  1749. policy_mapping_fn: An optional (updated) policy mapping function
  1750. to use from here on. Note that already ongoing episodes will
  1751. not change their mapping but will use the old mapping till
  1752. the end of the episode.
  1753. policies_to_train: An optional list of policy IDs to be trained
  1754. or a callable taking PolicyID and SampleBatchType and
  1755. returning a bool (trainable or not?).
  1756. If None, will keep the existing setup in place. Policies,
  1757. whose IDs are not in the list (or for which the callable
  1758. returns False) will not be updated.
  1759. evaluation_workers: Whether to also remove the policy from the
  1760. evaluation WorkerSet.
  1761. """
  1762. def fn(worker):
  1763. worker.remove_policy(
  1764. policy_id=policy_id,
  1765. policy_mapping_fn=policy_mapping_fn,
  1766. policies_to_train=policies_to_train,
  1767. )
  1768. self.workers.foreach_worker(fn, local_worker=True, healthy_only=True)
  1769. if evaluation_workers and self.evaluation_workers is not None:
  1770. self.evaluation_workers.foreach_worker(
  1771. fn,
  1772. local_worker=True,
  1773. healthy_only=True,
  1774. )
  1775. @DeveloperAPI
  1776. def export_policy_model(
  1777. self,
  1778. export_dir: str,
  1779. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1780. onnx: Optional[int] = None,
  1781. ) -> None:
  1782. """Exports policy model with given policy_id to a local directory.
  1783. Args:
  1784. export_dir: Writable local directory.
  1785. policy_id: Optional policy id to export.
  1786. onnx: If given, will export model in ONNX format. The
  1787. value of this parameter set the ONNX OpSet version to use.
  1788. If None, the output format will be DL framework specific.
  1789. Example:
  1790. >>> from ray.rllib.algorithms.ppo import PPO
  1791. >>> # Use an Algorithm from RLlib or define your own.
  1792. >>> algo = PPO(...) # doctest: +SKIP
  1793. >>> for _ in range(10): # doctest: +SKIP
  1794. >>> algo.train() # doctest: +SKIP
  1795. >>> algo.export_policy_model("/tmp/dir") # doctest: +SKIP
  1796. >>> algo.export_policy_model("/tmp/dir/onnx", onnx=1) # doctest: +SKIP
  1797. """
  1798. self.get_policy(policy_id).export_model(export_dir, onnx)
  1799. @DeveloperAPI
  1800. def export_policy_checkpoint(
  1801. self,
  1802. export_dir: str,
  1803. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1804. ) -> None:
  1805. """Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.
  1806. Args:
  1807. export_dir: Writable local directory to store the AIR Checkpoint
  1808. information into.
  1809. policy_id: Optional policy ID to export. If not provided, will export
  1810. "default_policy". If `policy_id` does not exist in this Algorithm,
  1811. will raise a KeyError.
  1812. Raises:
  1813. KeyError if `policy_id` cannot be found in this Algorithm.
  1814. Example:
  1815. >>> from ray.rllib.algorithms.ppo import PPO
  1816. >>> # Use an Algorithm from RLlib or define your own.
  1817. >>> algo = PPO(...) # doctest: +SKIP
  1818. >>> for _ in range(10): # doctest: +SKIP
  1819. >>> algo.train() # doctest: +SKIP
  1820. >>> algo.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP
  1821. """
  1822. policy = self.get_policy(policy_id)
  1823. if policy is None:
  1824. raise KeyError(f"Policy with ID {policy_id} not found in Algorithm!")
  1825. policy.export_checkpoint(export_dir)
  1826. @DeveloperAPI
  1827. def import_policy_model_from_h5(
  1828. self,
  1829. import_file: str,
  1830. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1831. ) -> None:
  1832. """Imports a policy's model with given policy_id from a local h5 file.
  1833. Args:
  1834. import_file: The h5 file to import from.
  1835. policy_id: Optional policy id to import into.
  1836. Example:
  1837. >>> from ray.rllib.algorithms.ppo import PPO
  1838. >>> algo = PPO(...) # doctest: +SKIP
  1839. >>> algo.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP
  1840. >>> for _ in range(10): # doctest: +SKIP
  1841. >>> algo.train() # doctest: +SKIP
  1842. """
  1843. self.get_policy(policy_id).import_model_from_h5(import_file)
  1844. # Sync new weights to remote workers.
  1845. self._sync_weights_to_workers(worker_set=self.workers)
  1846. @override(Trainable)
  1847. def save_checkpoint(self, checkpoint_dir: str) -> None:
  1848. """Exports checkpoint to a local directory.
  1849. The structure of an Algorithm checkpoint dir will be as follows::
  1850. policies/
  1851. pol_1/
  1852. policy_state.pkl
  1853. pol_2/
  1854. policy_state.pkl
  1855. learner/
  1856. learner_state.json
  1857. module_state/
  1858. module_1/
  1859. ...
  1860. optimizer_state/
  1861. optimizers_module_1/
  1862. ...
  1863. rllib_checkpoint.json
  1864. algorithm_state.pkl
  1865. Note: `rllib_checkpoint.json` contains a "version" key (e.g. with value 0.1)
  1866. helping RLlib to remain backward compatible wrt. restoring from checkpoints from
  1867. Ray 2.0 onwards.
  1868. Args:
  1869. checkpoint_dir: The directory where the checkpoint files will be stored.
  1870. """
  1871. state = self.__getstate__()
  1872. # Extract policy states from worker state (Policies get their own
  1873. # checkpoint sub-dirs).
  1874. policy_states = {}
  1875. if "worker" in state and "policy_states" in state["worker"]:
  1876. policy_states = state["worker"].pop("policy_states", {})
  1877. # Add RLlib checkpoint version.
  1878. if self.config._enable_learner_api:
  1879. state["checkpoint_version"] = CHECKPOINT_VERSION_LEARNER
  1880. else:
  1881. state["checkpoint_version"] = CHECKPOINT_VERSION
  1882. # Write state (w/o policies) to disk.
  1883. state_file = os.path.join(checkpoint_dir, "algorithm_state.pkl")
  1884. with open(state_file, "wb") as f:
  1885. pickle.dump(state, f)
  1886. # Write rllib_checkpoint.json.
  1887. with open(os.path.join(checkpoint_dir, "rllib_checkpoint.json"), "w") as f:
  1888. json.dump(
  1889. {
  1890. "type": "Algorithm",
  1891. "checkpoint_version": str(state["checkpoint_version"]),
  1892. "format": "cloudpickle",
  1893. "state_file": state_file,
  1894. "policy_ids": list(policy_states.keys()),
  1895. "ray_version": ray.__version__,
  1896. "ray_commit": ray.__commit__,
  1897. },
  1898. f,
  1899. )
  1900. # Write individual policies to disk, each in their own sub-directory.
  1901. for pid, policy_state in policy_states.items():
  1902. # From here on, disallow policyIDs that would not work as directory names.
  1903. validate_policy_id(pid, error=True)
  1904. policy_dir = os.path.join(checkpoint_dir, "policies", pid)
  1905. os.makedirs(policy_dir, exist_ok=True)
  1906. policy = self.get_policy(pid)
  1907. policy.export_checkpoint(policy_dir, policy_state=policy_state)
  1908. # if we are using the learner API, save the learner group state
  1909. if self.config._enable_learner_api:
  1910. learner_state_dir = os.path.join(checkpoint_dir, "learner")
  1911. self.learner_group.save_state(learner_state_dir)
  1912. @override(Trainable)
  1913. def load_checkpoint(self, checkpoint_dir: str) -> None:
  1914. # Checkpoint is provided as a local directory.
  1915. # Restore from the checkpoint file or dir.
  1916. checkpoint_info = get_checkpoint_info(checkpoint_dir)
  1917. checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(checkpoint_info)
  1918. self.__setstate__(checkpoint_data)
  1919. if self.config._enable_learner_api:
  1920. learner_state_dir = os.path.join(checkpoint_dir, "learner")
  1921. self.learner_group.load_state(learner_state_dir)
  1922. @override(Trainable)
  1923. def log_result(self, result: ResultDict) -> None:
  1924. # Log after the callback is invoked, so that the user has a chance
  1925. # to mutate the result.
  1926. # TODO: Remove `algorithm` arg at some point to fully deprecate the old
  1927. # signature.
  1928. self.callbacks.on_train_result(algorithm=self, result=result)
  1929. # Then log according to Trainable's logging logic.
  1930. Trainable.log_result(self, result)
  1931. @override(Trainable)
  1932. def cleanup(self) -> None:
  1933. # Stop all workers.
  1934. if hasattr(self, "workers") and self.workers is not None:
  1935. self.workers.stop()
  1936. if hasattr(self, "evaluation_workers") and self.evaluation_workers is not None:
  1937. self.evaluation_workers.stop()
  1938. @OverrideToImplementCustomLogic
  1939. @classmethod
  1940. @override(Trainable)
  1941. def default_resource_request(
  1942. cls, config: Union[AlgorithmConfig, PartialAlgorithmConfigDict]
  1943. ) -> Union[Resources, PlacementGroupFactory]:
  1944. # Default logic for RLlib Algorithms:
  1945. # Create one bundle per individual worker (local or remote).
  1946. # Use `num_cpus_for_local_worker` and `num_gpus` for the local worker and
  1947. # `num_cpus_per_worker` and `num_gpus_per_worker` for the remote
  1948. # workers to determine their CPU/GPU resource needs.
  1949. # Convenience config handles.
  1950. cf = cls.get_default_config().update_from_dict(config)
  1951. cf.validate()
  1952. cf.freeze()
  1953. # get evaluation config
  1954. eval_cf = cf.get_evaluation_config_object()
  1955. eval_cf.validate()
  1956. eval_cf.freeze()
  1957. # resources for the driver of this trainable
  1958. if cf._enable_learner_api:
  1959. if cf.num_learner_workers == 0:
  1960. # in this case local_worker only does sampling and training is done on
  1961. # local learner worker
  1962. driver = cls._get_learner_bundles(cf)[0]
  1963. else:
  1964. # in this case local_worker only does sampling and training is done on
  1965. # remote learner workers
  1966. driver = {"CPU": cf.num_cpus_for_local_worker, "GPU": 0}
  1967. else:
  1968. driver = {
  1969. "CPU": cf.num_cpus_for_local_worker,
  1970. "GPU": 0 if cf._fake_gpus else cf.num_gpus,
  1971. }
  1972. # resources for remote rollout env samplers
  1973. rollout_bundles = [
  1974. {
  1975. "CPU": cf.num_cpus_per_worker,
  1976. "GPU": cf.num_gpus_per_worker,
  1977. **cf.custom_resources_per_worker,
  1978. }
  1979. for _ in range(cf.num_rollout_workers)
  1980. ]
  1981. # resources for remote evaluation env samplers or datasets (if any)
  1982. if cls._should_create_evaluation_rollout_workers(eval_cf):
  1983. # Evaluation workers.
  1984. # Note: The local eval worker is located on the driver CPU.
  1985. evaluation_bundles = [
  1986. {
  1987. "CPU": eval_cf.num_cpus_per_worker,
  1988. "GPU": eval_cf.num_gpus_per_worker,
  1989. **eval_cf.custom_resources_per_worker,
  1990. }
  1991. for _ in range(eval_cf.evaluation_num_workers)
  1992. ]
  1993. else:
  1994. # resources for offline dataset readers during evaluation
  1995. # Note (Kourosh): we should not claim extra workers for
  1996. # training on the offline dataset, since rollout workers have already
  1997. # claimed it.
  1998. # Another Note (Kourosh): dataset reader will not use placement groups so
  1999. # whatever we specify here won't matter because dataset won't even use it.
  2000. # Disclaimer: using ray dataset in tune may cause deadlock when multiple
  2001. # tune trials get scheduled on the same node and do not leave any spare
  2002. # resources for dataset operations. The workaround is to limit the
  2003. # max_concurrent trials so that some spare cpus are left for dataset
  2004. # operations. This behavior should get fixed by the dataset team. more info
  2005. # found here:
  2006. # https://docs.ray.io/en/master/data/dataset-internals.html#datasets-tune
  2007. evaluation_bundles = []
  2008. # resources for remote learner workers
  2009. learner_bundles = []
  2010. if cf._enable_learner_api and cf.num_learner_workers > 0:
  2011. learner_bundles = cls._get_learner_bundles(cf)
  2012. bundles = [driver] + rollout_bundles + evaluation_bundles + learner_bundles
  2013. # Return PlacementGroupFactory containing all needed resources
  2014. # (already properly defined as device bundles).
  2015. return PlacementGroupFactory(
  2016. bundles=bundles,
  2017. strategy=config.get("placement_strategy", "PACK"),
  2018. )
  2019. @DeveloperAPI
  2020. def _before_evaluate(self):
  2021. """Pre-evaluation callback."""
  2022. pass
  2023. @staticmethod
  2024. def _get_env_id_and_creator(
  2025. env_specifier: Union[str, EnvType, None], config: AlgorithmConfig
  2026. ) -> Tuple[Optional[str], EnvCreator]:
  2027. """Returns env_id and creator callable given original env id from config.
  2028. Args:
  2029. env_specifier: An env class, an already tune registered env ID, a known
  2030. gym env name, or None (if no env is used).
  2031. config: The AlgorithmConfig object.
  2032. Returns:
  2033. Tuple consisting of a) env ID string and b) env creator callable.
  2034. """
  2035. # Environment is specified via a string.
  2036. if isinstance(env_specifier, str):
  2037. # An already registered env.
  2038. if _global_registry.contains(ENV_CREATOR, env_specifier):
  2039. return env_specifier, _global_registry.get(ENV_CREATOR, env_specifier)
  2040. # A class path specifier.
  2041. elif "." in env_specifier:
  2042. def env_creator_from_classpath(env_context):
  2043. try:
  2044. env_obj = from_config(env_specifier, env_context)
  2045. except ValueError:
  2046. raise EnvError(
  2047. ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_specifier)
  2048. )
  2049. return env_obj
  2050. return env_specifier, env_creator_from_classpath
  2051. # Try gym/PyBullet/Vizdoom.
  2052. else:
  2053. return env_specifier, functools.partial(
  2054. _gym_env_creator, env_descriptor=env_specifier
  2055. )
  2056. elif isinstance(env_specifier, type):
  2057. env_id = env_specifier # .__name__
  2058. if config["remote_worker_envs"]:
  2059. # Check gym version (0.22 or higher?).
  2060. # If > 0.21, can't perform auto-wrapping of the given class as this
  2061. # would lead to a pickle error.
  2062. gym_version = pkg_resources.get_distribution("gym").version
  2063. if version.parse(gym_version) >= version.parse("0.22"):
  2064. raise ValueError(
  2065. "Cannot specify a gym.Env class via `config.env` while setting "
  2066. "`config.remote_worker_env=True` AND your gym version is >= "
  2067. "0.22! Try installing an older version of gym or set `config."
  2068. "remote_worker_env=False`."
  2069. )
  2070. @ray.remote(num_cpus=1)
  2071. class _wrapper(env_specifier):
  2072. # Add convenience `_get_spaces` and `_is_multi_agent`
  2073. # methods:
  2074. def _get_spaces(self):
  2075. return self.observation_space, self.action_space
  2076. def _is_multi_agent(self):
  2077. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  2078. return isinstance(self, MultiAgentEnv)
  2079. return env_id, lambda cfg: _wrapper.remote(cfg)
  2080. # gym.Env-subclass: Also go through our RLlib gym-creator.
  2081. elif issubclass(env_specifier, gym.Env):
  2082. return env_id, functools.partial(
  2083. _gym_env_creator,
  2084. env_descriptor=env_specifier,
  2085. auto_wrap_old_gym_envs=config.get("auto_wrap_old_gym_envs", True),
  2086. )
  2087. # All other env classes: Call c'tor directly.
  2088. else:
  2089. return env_id, lambda cfg: env_specifier(cfg)
  2090. # No env -> Env creator always returns None.
  2091. elif env_specifier is None:
  2092. return None, lambda env_config: None
  2093. else:
  2094. raise ValueError(
  2095. "{} is an invalid env specifier. ".format(env_specifier)
  2096. + "You can specify a custom env as either a class "
  2097. '(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
  2098. )
  2099. def _sync_filters_if_needed(
  2100. self,
  2101. *,
  2102. central_worker: RolloutWorker,
  2103. workers: WorkerSet,
  2104. config: AlgorithmConfig,
  2105. ) -> None:
  2106. """Synchronizes the filter stats from `workers` to `central_worker`.
  2107. .. and broadcasts the central_worker's filter stats back to all `workers`
  2108. (if configured).
  2109. Args:
  2110. central_worker: The worker to sync/aggregate all `workers`' filter stats to
  2111. and from which to (possibly) broadcast the updated filter stats back to
  2112. `workers`.
  2113. workers: The WorkerSet, whose workers' filter stats should be used for
  2114. aggregation on `central_worker` and which (possibly) get updated
  2115. from `central_worker` after the sync.
  2116. config: The algorithm config instance. This is used to determine, whether
  2117. syncing from `workers` should happen at all and whether broadcasting
  2118. back to `workers` (after possible syncing) should happen.
  2119. """
  2120. if central_worker and config.observation_filter != "NoFilter":
  2121. FilterManager.synchronize(
  2122. central_worker.filters,
  2123. workers,
  2124. update_remote=config.update_worker_filter_stats,
  2125. timeout_seconds=config.sync_filters_on_rollout_workers_timeout_s,
  2126. use_remote_data_for_update=config.use_worker_filter_stats,
  2127. )
  2128. @DeveloperAPI
  2129. def _sync_weights_to_workers(
  2130. self,
  2131. *,
  2132. worker_set: WorkerSet,
  2133. ) -> None:
  2134. """Sync "main" weights to given WorkerSet or list of workers."""
  2135. # Broadcast the new policy weights to all remote workers in worker_set.
  2136. logger.info("Synchronizing weights to workers.")
  2137. worker_set.sync_weights()
  2138. @classmethod
  2139. @override(Trainable)
  2140. def resource_help(cls, config: Union[AlgorithmConfig, AlgorithmConfigDict]) -> str:
  2141. return (
  2142. "\n\nYou can adjust the resource requests of RLlib Algorithms by calling "
  2143. "`AlgorithmConfig.resources("
  2144. "num_gpus=.., num_cpus_per_worker=.., num_gpus_per_worker=.., ..)` or "
  2145. "`AgorithmConfig.rollouts(num_rollout_workers=..)`. See "
  2146. "the `ray.rllib.algorithms.algorithm_config.AlgorithmConfig` classes "
  2147. "(each Algorithm has its own subclass of this class) for more info.\n\n"
  2148. f"The config of this Algorithm is: {config}"
  2149. )
  2150. @override(Trainable)
  2151. def get_auto_filled_metrics(
  2152. self,
  2153. now: Optional[datetime] = None,
  2154. time_this_iter: Optional[float] = None,
  2155. timestamp: Optional[int] = None,
  2156. debug_metrics_only: bool = False,
  2157. ) -> dict:
  2158. # Override this method to make sure, the `config` key of the returned results
  2159. # contains the proper Tune config dict (instead of an AlgorithmConfig object).
  2160. auto_filled = super().get_auto_filled_metrics(
  2161. now, time_this_iter, timestamp, debug_metrics_only
  2162. )
  2163. if "config" not in auto_filled:
  2164. raise KeyError("`config` key not found in auto-filled results dict!")
  2165. # If `config` key is no dict (but AlgorithmConfig object) ->
  2166. # make sure, it's a dict to not break Tune APIs.
  2167. if not isinstance(auto_filled["config"], dict):
  2168. assert isinstance(auto_filled["config"], AlgorithmConfig)
  2169. auto_filled["config"] = auto_filled["config"].to_dict()
  2170. return auto_filled
  2171. @classmethod
  2172. def merge_algorithm_configs(
  2173. cls,
  2174. config1: AlgorithmConfigDict,
  2175. config2: PartialAlgorithmConfigDict,
  2176. _allow_unknown_configs: Optional[bool] = None,
  2177. ) -> AlgorithmConfigDict:
  2178. """Merges a complete Algorithm config dict with a partial override dict.
  2179. Respects nested structures within the config dicts. The values in the
  2180. partial override dict take priority.
  2181. Args:
  2182. config1: The complete Algorithm's dict to be merged (overridden)
  2183. with `config2`.
  2184. config2: The partial override config dict to merge on top of
  2185. `config1`.
  2186. _allow_unknown_configs: If True, keys in `config2` that don't exist
  2187. in `config1` are allowed and will be added to the final config.
  2188. Returns:
  2189. The merged full algorithm config dict.
  2190. """
  2191. config1 = copy.deepcopy(config1)
  2192. if "callbacks" in config2 and type(config2["callbacks"]) is dict:
  2193. deprecation_warning(
  2194. "callbacks dict interface",
  2195. "a class extending rllib.algorithms.callbacks.DefaultCallbacks; "
  2196. "see `rllib/examples/custom_metrics_and_callbacks.py` for an example.",
  2197. error=True,
  2198. )
  2199. if _allow_unknown_configs is None:
  2200. _allow_unknown_configs = cls._allow_unknown_configs
  2201. return deep_update(
  2202. config1,
  2203. config2,
  2204. _allow_unknown_configs,
  2205. cls._allow_unknown_subkeys,
  2206. cls._override_all_subkeys_if_type_changes,
  2207. cls._override_all_key_list,
  2208. )
  2209. @staticmethod
  2210. @ExperimentalAPI
  2211. def validate_env(env: EnvType, env_context: EnvContext) -> None:
  2212. """Env validator function for this Algorithm class.
  2213. Override this in child classes to define custom validation
  2214. behavior.
  2215. Args:
  2216. env: The (sub-)environment to validate. This is normally a
  2217. single sub-environment (e.g. a gym.Env) within a vectorized
  2218. setup.
  2219. env_context: The EnvContext to configure the environment.
  2220. Raises:
  2221. Exception in case something is wrong with the given environment.
  2222. """
  2223. pass
  2224. @override(Trainable)
  2225. def _export_model(
  2226. self, export_formats: List[str], export_dir: str
  2227. ) -> Dict[str, str]:
  2228. ExportFormat.validate(export_formats)
  2229. exported = {}
  2230. if ExportFormat.CHECKPOINT in export_formats:
  2231. path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
  2232. self.export_policy_checkpoint(path)
  2233. exported[ExportFormat.CHECKPOINT] = path
  2234. if ExportFormat.MODEL in export_formats:
  2235. path = os.path.join(export_dir, ExportFormat.MODEL)
  2236. self.export_policy_model(path)
  2237. exported[ExportFormat.MODEL] = path
  2238. if ExportFormat.ONNX in export_formats:
  2239. path = os.path.join(export_dir, ExportFormat.ONNX)
  2240. self.export_policy_model(path, onnx=int(os.getenv("ONNX_OPSET", "11")))
  2241. exported[ExportFormat.ONNX] = path
  2242. return exported
  2243. def import_model(self, import_file: str):
  2244. """Imports a model from import_file.
  2245. Note: Currently, only h5 files are supported.
  2246. Args:
  2247. import_file: The file to import the model from.
  2248. Returns:
  2249. A dict that maps ExportFormats to successfully exported models.
  2250. """
  2251. # Check for existence.
  2252. if not os.path.exists(import_file):
  2253. raise FileNotFoundError(
  2254. "`import_file` '{}' does not exist! Can't import Model.".format(
  2255. import_file
  2256. )
  2257. )
  2258. # Get the format of the given file.
  2259. import_format = "h5" # TODO(sven): Support checkpoint loading.
  2260. ExportFormat.validate([import_format])
  2261. if import_format != ExportFormat.H5:
  2262. raise NotImplementedError
  2263. else:
  2264. return self.import_policy_model_from_h5(import_file)
  2265. @PublicAPI
  2266. def __getstate__(self) -> Dict:
  2267. """Returns current state of Algorithm, sufficient to restore it from scratch.
  2268. Returns:
  2269. The current state dict of this Algorithm, which can be used to sufficiently
  2270. restore the algorithm from scratch without any other information.
  2271. """
  2272. # Add config to state so complete Algorithm can be reproduced w/o it.
  2273. state = {
  2274. "algorithm_class": type(self),
  2275. "config": self.config,
  2276. }
  2277. if hasattr(self, "workers"):
  2278. state["worker"] = self.workers.local_worker().get_state()
  2279. # TODO: Experimental functionality: Store contents of replay buffer
  2280. # to checkpoint, only if user has configured this.
  2281. if self.local_replay_buffer is not None and self.config.get(
  2282. "store_buffer_in_checkpoints"
  2283. ):
  2284. state["local_replay_buffer"] = self.local_replay_buffer.get_state()
  2285. if self.train_exec_impl is not None:
  2286. state["train_exec_impl"] = self.train_exec_impl.shared_metrics.get().save()
  2287. else:
  2288. state["counters"] = self._counters
  2289. state["training_iteration"] = self.training_iteration
  2290. return state
  2291. @PublicAPI
  2292. def __setstate__(self, state) -> None:
  2293. """Sets the algorithm to the provided state.
  2294. Args:
  2295. state: The state dict to restore this Algorithm instance to. `state` may
  2296. have been returned by a call to an Algorithm's `__getstate__()` method.
  2297. """
  2298. # TODO (sven): Validate that our config and the config in state are compatible.
  2299. # For example, the model architectures may differ.
  2300. # Also, what should the behavior be if e.g. some training parameter
  2301. # (e.g. lr) changed?
  2302. if hasattr(self, "workers") and "worker" in state:
  2303. self.workers.local_worker().set_state(state["worker"])
  2304. remote_state = ray.put(state["worker"])
  2305. self.workers.foreach_worker(
  2306. lambda w: w.set_state(ray.get(remote_state)),
  2307. local_worker=False,
  2308. healthy_only=False,
  2309. )
  2310. if self.evaluation_workers:
  2311. # If evaluation workers are used, also restore the policies
  2312. # there in case they are used for evaluation purpose.
  2313. self.evaluation_workers.foreach_worker(
  2314. lambda w: w.set_state(ray.get(remote_state)),
  2315. healthy_only=False,
  2316. )
  2317. # If necessary, restore replay data as well.
  2318. if self.local_replay_buffer is not None:
  2319. # TODO: Experimental functionality: Restore contents of replay
  2320. # buffer from checkpoint, only if user has configured this.
  2321. if self.config.get("store_buffer_in_checkpoints"):
  2322. if "local_replay_buffer" in state:
  2323. self.local_replay_buffer.set_state(state["local_replay_buffer"])
  2324. else:
  2325. logger.warning(
  2326. "`store_buffer_in_checkpoints` is True, but no replay "
  2327. "data found in state!"
  2328. )
  2329. elif "local_replay_buffer" in state and log_once(
  2330. "no_store_buffer_in_checkpoints_but_data_found"
  2331. ):
  2332. logger.warning(
  2333. "`store_buffer_in_checkpoints` is False, but some replay "
  2334. "data found in state!"
  2335. )
  2336. if self.train_exec_impl is not None:
  2337. self.train_exec_impl.shared_metrics.get().restore(state["train_exec_impl"])
  2338. elif "counters" in state:
  2339. self._counters = state["counters"]
  2340. if "training_iteration" in state:
  2341. self._iteration = state["training_iteration"]
  2342. @staticmethod
  2343. def _checkpoint_info_to_algorithm_state(
  2344. checkpoint_info: dict,
  2345. policy_ids: Optional[Container[PolicyID]] = None,
  2346. policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
  2347. policies_to_train: Optional[
  2348. Union[
  2349. Container[PolicyID],
  2350. Callable[[PolicyID, Optional[SampleBatchType]], bool],
  2351. ]
  2352. ] = None,
  2353. ) -> Dict:
  2354. """Converts a checkpoint info or object to a proper Algorithm state dict.
  2355. The returned state dict can be used inside self.__setstate__().
  2356. Args:
  2357. checkpoint_info: A checkpoint info dict as returned by
  2358. `ray.rllib.utils.checkpoints.get_checkpoint_info(
  2359. [checkpoint dir or AIR Checkpoint])`.
  2360. policy_ids: Optional list/set of PolicyIDs. If not None, only those policies
  2361. listed here will be included in the returned state. Note that
  2362. state items such as filters, the `is_policy_to_train` function, as
  2363. well as the multi-agent `policy_ids` dict will be adjusted as well,
  2364. based on this arg.
  2365. policy_mapping_fn: An optional (updated) policy mapping function
  2366. to include in the returned state.
  2367. policies_to_train: An optional list of policy IDs to be trained
  2368. or a callable taking PolicyID and SampleBatchType and
  2369. returning a bool (trainable or not?) to include in the returned state.
  2370. Returns:
  2371. The state dict usable within the `self.__setstate__()` method.
  2372. """
  2373. if checkpoint_info["type"] != "Algorithm":
  2374. raise ValueError(
  2375. "`checkpoint` arg passed to "
  2376. "`Algorithm._checkpoint_info_to_algorithm_state()` must be an "
  2377. f"Algorithm checkpoint (but is {checkpoint_info['type']})!"
  2378. )
  2379. msgpack = None
  2380. if checkpoint_info.get("format") == "msgpack":
  2381. msgpack = try_import_msgpack(error=True)
  2382. with open(checkpoint_info["state_file"], "rb") as f:
  2383. if msgpack is not None:
  2384. state = msgpack.load(f)
  2385. else:
  2386. state = pickle.load(f)
  2387. # New checkpoint format: Policies are in separate sub-dirs.
  2388. # Note: Algorithms like ES/ARS don't have a WorkerSet, so we just return
  2389. # the plain state here.
  2390. if (
  2391. checkpoint_info["checkpoint_version"] > version.Version("0.1")
  2392. and state.get("worker") is not None
  2393. ):
  2394. worker_state = state["worker"]
  2395. # Retrieve the set of all required policy IDs.
  2396. policy_ids = set(
  2397. policy_ids if policy_ids is not None else worker_state["policy_ids"]
  2398. )
  2399. # Remove those policies entirely from filters that are not in
  2400. # `policy_ids`.
  2401. worker_state["filters"] = {
  2402. pid: filter
  2403. for pid, filter in worker_state["filters"].items()
  2404. if pid in policy_ids
  2405. }
  2406. # Get Algorithm class.
  2407. if isinstance(state["algorithm_class"], str):
  2408. # Try deserializing from a full classpath.
  2409. # Or as a last resort: Tune registered algorithm name.
  2410. state["algorithm_class"] = deserialize_type(
  2411. state["algorithm_class"]
  2412. ) or get_trainable_cls(state["algorithm_class"])
  2413. # Compile actual config object.
  2414. default_config = state["algorithm_class"].get_default_config()
  2415. if isinstance(default_config, AlgorithmConfig):
  2416. new_config = default_config.update_from_dict(state["config"])
  2417. else:
  2418. new_config = Algorithm.merge_algorithm_configs(
  2419. default_config, state["config"]
  2420. )
  2421. # Remove policies from multiagent dict that are not in `policy_ids`.
  2422. new_policies = new_config.policies
  2423. if isinstance(new_policies, (set, list, tuple)):
  2424. new_policies = {pid for pid in new_policies if pid in policy_ids}
  2425. else:
  2426. new_policies = {
  2427. pid: spec for pid, spec in new_policies.items() if pid in policy_ids
  2428. }
  2429. new_config.multi_agent(
  2430. policies=new_policies,
  2431. policies_to_train=policies_to_train,
  2432. **(
  2433. {"policy_mapping_fn": policy_mapping_fn}
  2434. if policy_mapping_fn is not None
  2435. else {}
  2436. ),
  2437. )
  2438. state["config"] = new_config
  2439. # Prepare local `worker` state to add policies' states into it,
  2440. # read from separate policy checkpoint files.
  2441. worker_state["policy_states"] = {}
  2442. for pid in policy_ids:
  2443. policy_state_file = os.path.join(
  2444. checkpoint_info["checkpoint_dir"],
  2445. "policies",
  2446. pid,
  2447. "policy_state."
  2448. + ("msgpck" if checkpoint_info["format"] == "msgpack" else "pkl"),
  2449. )
  2450. if not os.path.isfile(policy_state_file):
  2451. raise ValueError(
  2452. "Given checkpoint does not seem to be valid! No policy "
  2453. f"state file found for PID={pid}. "
  2454. f"The file not found is: {policy_state_file}."
  2455. )
  2456. with open(policy_state_file, "rb") as f:
  2457. if msgpack is not None:
  2458. worker_state["policy_states"][pid] = msgpack.load(f)
  2459. else:
  2460. worker_state["policy_states"][pid] = pickle.load(f)
  2461. # These two functions are never serialized in a msgpack checkpoint (which
  2462. # does not store code, unlike a cloudpickle checkpoint). Hence the user has
  2463. # to provide them with the `Algorithm.from_checkpoint()` call.
  2464. if policy_mapping_fn is not None:
  2465. worker_state["policy_mapping_fn"] = policy_mapping_fn
  2466. if (
  2467. policies_to_train is not None
  2468. # `policies_to_train` might be left None in case all policies should be
  2469. # trained.
  2470. or worker_state["is_policy_to_train"] == NOT_SERIALIZABLE
  2471. ):
  2472. worker_state["is_policy_to_train"] = policies_to_train
  2473. return state
  2474. @DeveloperAPI
  2475. def _create_local_replay_buffer_if_necessary(
  2476. self, config: PartialAlgorithmConfigDict
  2477. ) -> Optional[MultiAgentReplayBuffer]:
  2478. """Create a MultiAgentReplayBuffer instance if necessary.
  2479. Args:
  2480. config: Algorithm-specific configuration data.
  2481. Returns:
  2482. MultiAgentReplayBuffer instance based on algorithm config.
  2483. None, if local replay buffer is not needed.
  2484. """
  2485. if not config.get("replay_buffer_config") or config["replay_buffer_config"].get(
  2486. "no_local_replay_buffer"
  2487. ):
  2488. return
  2489. return from_config(ReplayBuffer, config["replay_buffer_config"])
  2490. @DeveloperAPI
  2491. def _kwargs_for_execution_plan(self):
  2492. kwargs = {}
  2493. if self.local_replay_buffer is not None:
  2494. kwargs["local_replay_buffer"] = self.local_replay_buffer
  2495. return kwargs
  2496. def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
  2497. """Runs one training iteration (self.iteration will be +1 after this).
  2498. Calls `self.training_step()` repeatedly until the minimum time (sec),
  2499. sample- or training steps have been reached.
  2500. Returns:
  2501. The results dict from the training iteration.
  2502. """
  2503. # In case we are training (in a thread) parallel to evaluation,
  2504. # we may have to re-enable eager mode here (gets disabled in the
  2505. # thread).
  2506. if self.config.get("framework") == "tf2" and not tf.executing_eagerly():
  2507. tf1.enable_eager_execution()
  2508. results = None
  2509. # Create a step context ...
  2510. with TrainIterCtx(algo=self) as train_iter_ctx:
  2511. # .. so we can query it whether we should stop the iteration loop (e.g.
  2512. # when we have reached `min_time_s_per_iteration`).
  2513. while not train_iter_ctx.should_stop(results):
  2514. # Try to train one step.
  2515. # TODO (avnishn): Remove the execution plan API by q1 2023
  2516. with self._timers[TRAINING_ITERATION_TIMER]:
  2517. if self.config._disable_execution_plan_api:
  2518. results = self.training_step()
  2519. else:
  2520. results = next(self.train_exec_impl)
  2521. # With training step done. Try to bring failed workers back.
  2522. self.restore_workers(self.workers)
  2523. return results, train_iter_ctx
  2524. def _run_one_evaluation(
  2525. self,
  2526. train_future: Optional[concurrent.futures.ThreadPoolExecutor] = None,
  2527. ) -> ResultDict:
  2528. """Runs evaluation step via `self.evaluate()` and handling worker failures.
  2529. Args:
  2530. train_future: In case, we are training and avaluating in parallel,
  2531. this arg carries the currently running ThreadPoolExecutor
  2532. object that runs the training iteration
  2533. Returns:
  2534. The results dict from the evaluation call.
  2535. """
  2536. eval_func_to_use = (
  2537. self._evaluate_async
  2538. if self.config.enable_async_evaluation
  2539. else self.evaluate
  2540. )
  2541. if self.config.evaluation_duration == "auto":
  2542. assert (
  2543. train_future is not None and self.config.evaluation_parallel_to_training
  2544. )
  2545. unit = self.config.evaluation_duration_unit
  2546. eval_results = eval_func_to_use(
  2547. duration_fn=functools.partial(
  2548. self._automatic_evaluation_duration_fn,
  2549. unit,
  2550. self.config.evaluation_num_workers,
  2551. self.evaluation_config,
  2552. train_future,
  2553. )
  2554. )
  2555. # Run `self.evaluate()` only once per training iteration.
  2556. else:
  2557. eval_results = eval_func_to_use()
  2558. if self.evaluation_workers is not None:
  2559. # After evaluation, do a round of health check to see if any of
  2560. # the failed workers are back.
  2561. self.restore_workers(self.evaluation_workers)
  2562. # Add number of healthy evaluation workers after this iteration.
  2563. eval_results["evaluation"][
  2564. "num_healthy_workers"
  2565. ] = self.evaluation_workers.num_healthy_remote_workers()
  2566. eval_results["evaluation"][
  2567. "num_in_flight_async_reqs"
  2568. ] = self.evaluation_workers.num_in_flight_async_reqs()
  2569. eval_results["evaluation"][
  2570. "num_remote_worker_restarts"
  2571. ] = self.evaluation_workers.num_remote_worker_restarts()
  2572. return eval_results
  2573. def _run_one_training_iteration_and_evaluation_in_parallel(
  2574. self,
  2575. ) -> Tuple[ResultDict, "TrainIterCtx"]:
  2576. """Runs one training iteration and one evaluation step in parallel.
  2577. First starts the training iteration (via `self._run_one_training_iteration()`)
  2578. within a ThreadPoolExecutor, then runs the evaluation step in parallel.
  2579. In auto-duration mode (config.evaluation_duration=auto), makes sure the
  2580. evaluation step takes roughly the same time as the training iteration.
  2581. Returns:
  2582. The accumulated training and evaluation results.
  2583. """
  2584. with concurrent.futures.ThreadPoolExecutor() as executor:
  2585. train_future = executor.submit(lambda: self._run_one_training_iteration())
  2586. # Pass the train_future into `self._run_one_evaluation()` to allow it
  2587. # to run exactly as long as the training iteration takes in case
  2588. # evaluation_duration=auto.
  2589. results = self._run_one_evaluation(train_future)
  2590. # Collect the training results from the future.
  2591. train_results, train_iter_ctx = train_future.result()
  2592. results.update(train_results)
  2593. return results, train_iter_ctx
  2594. def _run_offline_evaluation(self):
  2595. """Runs offline evaluation via `OfflineEvaluator.estimate_on_dataset()` API.
  2596. This method will be used when `evaluation_dataset` is provided.
  2597. Note: This will only work if the policy is a single agent policy.
  2598. Returns:
  2599. The results dict from the offline evaluation call.
  2600. """
  2601. assert len(self.workers.local_worker().policy_map) == 1
  2602. parallelism = self.evaluation_config.evaluation_num_workers or 1
  2603. offline_eval_results = {"off_policy_estimator": {}}
  2604. for evaluator_name, offline_evaluator in self.reward_estimators.items():
  2605. offline_eval_results["off_policy_estimator"][
  2606. evaluator_name
  2607. ] = offline_evaluator.estimate_on_dataset(
  2608. self.evaluation_dataset,
  2609. n_parallelism=parallelism,
  2610. )
  2611. return offline_eval_results
  2612. @classmethod
  2613. def _should_create_evaluation_rollout_workers(cls, eval_config: "AlgorithmConfig"):
  2614. """Determines whether we need to create evaluation workers.
  2615. Returns False if we need to run offline evaluation
  2616. (with ope.estimate_on_dastaset API) or when local worker is to be used for
  2617. evaluation. Note: We only use estimate_on_dataset API with bandits for now.
  2618. That is when ope_split_batch_by_episode is False.
  2619. TODO: In future we will do the same for episodic RL OPE.
  2620. """
  2621. run_offline_evaluation = (
  2622. eval_config.off_policy_estimation_methods
  2623. and not eval_config.ope_split_batch_by_episode
  2624. )
  2625. return not run_offline_evaluation and (
  2626. eval_config.evaluation_num_workers > 0 or eval_config.evaluation_interval
  2627. )
  2628. @staticmethod
  2629. def _automatic_evaluation_duration_fn(
  2630. unit, num_eval_workers, eval_cfg, train_future, num_units_done
  2631. ):
  2632. # Training is done and we already ran at least one
  2633. # evaluation -> Nothing left to run.
  2634. if num_units_done > 0 and train_future.done():
  2635. return 0
  2636. # Count by episodes. -> Run n more
  2637. # (n=num eval workers).
  2638. elif unit == "episodes":
  2639. return num_eval_workers
  2640. # Count by timesteps. -> Run n*m*p more
  2641. # (n=num eval workers; m=rollout fragment length;
  2642. # p=num-envs-per-worker).
  2643. else:
  2644. return (
  2645. num_eval_workers
  2646. * eval_cfg["rollout_fragment_length"]
  2647. * eval_cfg["num_envs_per_worker"]
  2648. )
  2649. def _compile_iteration_results(
  2650. self, *, episodes_this_iter, step_ctx, iteration_results=None
  2651. ):
  2652. # Return dict.
  2653. results: ResultDict = {}
  2654. iteration_results = iteration_results or {}
  2655. # Evaluation results.
  2656. if "evaluation" in iteration_results:
  2657. results["evaluation"] = iteration_results.pop("evaluation")
  2658. # Custom metrics and episode media.
  2659. results["custom_metrics"] = iteration_results.pop("custom_metrics", {})
  2660. results["episode_media"] = iteration_results.pop("episode_media", {})
  2661. # Learner info.
  2662. results["info"] = {LEARNER_INFO: iteration_results}
  2663. # Calculate how many (if any) of older, historical episodes we have to add to
  2664. # `episodes_this_iter` in order to reach the required smoothing window.
  2665. episodes_for_metrics = episodes_this_iter[:]
  2666. missing = self.config.metrics_num_episodes_for_smoothing - len(
  2667. episodes_this_iter
  2668. )
  2669. # We have to add some older episodes to reach the smoothing window size.
  2670. if missing > 0:
  2671. episodes_for_metrics = self._episode_history[-missing:] + episodes_this_iter
  2672. assert (
  2673. len(episodes_for_metrics)
  2674. <= self.config.metrics_num_episodes_for_smoothing
  2675. )
  2676. # Note that when there are more than `metrics_num_episodes_for_smoothing`
  2677. # episodes in `episodes_for_metrics`, leave them as-is. In this case, we'll
  2678. # compute the stats over that larger number.
  2679. # Add new episodes to our history and make sure it doesn't grow larger than
  2680. # needed.
  2681. self._episode_history.extend(episodes_this_iter)
  2682. self._episode_history = self._episode_history[
  2683. -self.config.metrics_num_episodes_for_smoothing :
  2684. ]
  2685. results["sampler_results"] = summarize_episodes(
  2686. episodes_for_metrics,
  2687. episodes_this_iter,
  2688. self.config.keep_per_episode_custom_metrics,
  2689. )
  2690. # TODO: Don't dump sampler results into top-level.
  2691. results.update(results["sampler_results"])
  2692. results["num_healthy_workers"] = self.workers.num_healthy_remote_workers()
  2693. results["num_in_flight_async_reqs"] = self.workers.num_in_flight_async_reqs()
  2694. results[
  2695. "num_remote_worker_restarts"
  2696. ] = self.workers.num_remote_worker_restarts()
  2697. # Train-steps- and env/agent-steps this iteration.
  2698. for c in [
  2699. NUM_AGENT_STEPS_SAMPLED,
  2700. NUM_AGENT_STEPS_TRAINED,
  2701. NUM_ENV_STEPS_SAMPLED,
  2702. NUM_ENV_STEPS_TRAINED,
  2703. ]:
  2704. results[c] = self._counters[c]
  2705. time_taken_sec = step_ctx.get_time_taken_sec()
  2706. if self.config.count_steps_by == "agent_steps":
  2707. results[NUM_AGENT_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled
  2708. results[NUM_AGENT_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
  2709. results[NUM_AGENT_STEPS_SAMPLED + "_throughput_per_sec"] = (
  2710. step_ctx.sampled / time_taken_sec
  2711. )
  2712. results[NUM_AGENT_STEPS_TRAINED + "_throughput_per_sec"] = (
  2713. step_ctx.trained / time_taken_sec
  2714. )
  2715. # TODO: For CQL and other algos, count by trained steps.
  2716. results["timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED]
  2717. else:
  2718. results[NUM_ENV_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled
  2719. results[NUM_ENV_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
  2720. results[NUM_ENV_STEPS_SAMPLED + "_throughput_per_sec"] = (
  2721. step_ctx.sampled / time_taken_sec
  2722. )
  2723. results[NUM_ENV_STEPS_TRAINED + "_throughput_per_sec"] = (
  2724. step_ctx.trained / time_taken_sec
  2725. )
  2726. # TODO: For CQL and other algos, count by trained steps.
  2727. results["timesteps_total"] = self._counters[NUM_ENV_STEPS_SAMPLED]
  2728. # TODO: Backward compatibility.
  2729. results[STEPS_TRAINED_THIS_ITER_COUNTER] = step_ctx.trained
  2730. results["agent_timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED]
  2731. # Process timer results.
  2732. timers = {}
  2733. for k, timer in self._timers.items():
  2734. timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
  2735. if timer.has_units_processed():
  2736. timers["{}_throughput".format(k)] = round(timer.mean_throughput, 3)
  2737. results["timers"] = timers
  2738. # Process counter results.
  2739. counters = {}
  2740. for k, counter in self._counters.items():
  2741. counters[k] = counter
  2742. results["counters"] = counters
  2743. # TODO: Backward compatibility.
  2744. results["info"].update(counters)
  2745. return results
  2746. def __repr__(self):
  2747. return type(self).__name__
  2748. def _record_usage(self, config):
  2749. """Record the framework and algorithm used.
  2750. Args:
  2751. config: Algorithm config dict.
  2752. """
  2753. record_extra_usage_tag(TagKey.RLLIB_FRAMEWORK, config["framework"])
  2754. record_extra_usage_tag(TagKey.RLLIB_NUM_WORKERS, str(config["num_workers"]))
  2755. alg = self.__class__.__name__
  2756. # We do not want to collect user defined algorithm names.
  2757. if alg not in ALL_ALGORITHMS:
  2758. alg = "USER_DEFINED"
  2759. record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg)
  2760. @Deprecated(new="AlgorithmConfig.validate()", error=True)
  2761. def validate_config(self, config):
  2762. pass
  2763. # TODO: Create a dict that throw a deprecation warning once we have fully moved
  2764. # to AlgorithmConfig() objects (some algos still missing).
  2765. COMMON_CONFIG: AlgorithmConfigDict = AlgorithmConfig(Algorithm).to_dict()
  2766. class TrainIterCtx:
  2767. def __init__(self, algo: Algorithm):
  2768. self.algo = algo
  2769. self.time_start = None
  2770. self.time_stop = None
  2771. def __enter__(self):
  2772. # Before first call to `step()`, `results` is expected to be None ->
  2773. # Start with self.failures=-1 -> set to 0 before the very first call
  2774. # to `self.step()`.
  2775. self.failures = -1
  2776. self.time_start = time.time()
  2777. self.sampled = 0
  2778. self.trained = 0
  2779. self.init_env_steps_sampled = self.algo._counters[NUM_ENV_STEPS_SAMPLED]
  2780. self.init_env_steps_trained = self.algo._counters[NUM_ENV_STEPS_TRAINED]
  2781. self.init_agent_steps_sampled = self.algo._counters[NUM_AGENT_STEPS_SAMPLED]
  2782. self.init_agent_steps_trained = self.algo._counters[NUM_AGENT_STEPS_TRAINED]
  2783. self.failure_tolerance = self.algo.config[
  2784. "num_consecutive_worker_failures_tolerance"
  2785. ]
  2786. return self
  2787. def __exit__(self, *args):
  2788. self.time_stop = time.time()
  2789. def get_time_taken_sec(self) -> float:
  2790. """Returns the time we spent in the context in seconds."""
  2791. return self.time_stop - self.time_start
  2792. def should_stop(self, results):
  2793. # Before first call to `step()`.
  2794. if results is None:
  2795. # Fail after n retries.
  2796. self.failures += 1
  2797. if self.failures > self.failure_tolerance:
  2798. raise RuntimeError(
  2799. "More than `num_consecutive_worker_failures_tolerance="
  2800. f"{self.failure_tolerance}` consecutive worker failures! "
  2801. "Exiting."
  2802. )
  2803. # Continue to very first `step()` call or retry `step()` after
  2804. # a (tolerable) failure.
  2805. return False
  2806. # Stopping criteria.
  2807. elif self.algo.config._disable_execution_plan_api:
  2808. if self.algo.config.count_steps_by == "agent_steps":
  2809. self.sampled = (
  2810. self.algo._counters[NUM_AGENT_STEPS_SAMPLED]
  2811. - self.init_agent_steps_sampled
  2812. )
  2813. self.trained = (
  2814. self.algo._counters[NUM_AGENT_STEPS_TRAINED]
  2815. - self.init_agent_steps_trained
  2816. )
  2817. else:
  2818. self.sampled = (
  2819. self.algo._counters[NUM_ENV_STEPS_SAMPLED]
  2820. - self.init_env_steps_sampled
  2821. )
  2822. self.trained = (
  2823. self.algo._counters[NUM_ENV_STEPS_TRAINED]
  2824. - self.init_env_steps_trained
  2825. )
  2826. min_t = self.algo.config["min_time_s_per_iteration"]
  2827. min_sample_ts = self.algo.config["min_sample_timesteps_per_iteration"]
  2828. min_train_ts = self.algo.config["min_train_timesteps_per_iteration"]
  2829. # Repeat if not enough time has passed or if not enough
  2830. # env|train timesteps have been processed (or these min
  2831. # values are not provided by the user).
  2832. if (
  2833. (not min_t or time.time() - self.time_start >= min_t)
  2834. and (not min_sample_ts or self.sampled >= min_sample_ts)
  2835. and (not min_train_ts or self.trained >= min_train_ts)
  2836. ):
  2837. return True
  2838. else:
  2839. return False
  2840. # No errors (we got results != None) -> Return True
  2841. # (meaning: yes, should stop -> no further step attempts).
  2842. else:
  2843. return True