trainer.py 126 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819
  1. from collections import defaultdict
  2. import concurrent
  3. import copy
  4. from datetime import datetime
  5. import functools
  6. import gym
  7. import logging
  8. import math
  9. import numpy as np
  10. import os
  11. import pickle
  12. import tempfile
  13. import time
  14. from typing import Callable, Container, DefaultDict, Dict, List, Optional, \
  15. Set, Tuple, Type, Union
  16. import ray
  17. from ray.actor import ActorHandle
  18. from ray.exceptions import RayError
  19. from ray.rllib.agents.callbacks import DefaultCallbacks
  20. from ray.rllib.env.env_context import EnvContext
  21. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  22. from ray.rllib.env.utils import gym_env_creator
  23. from ray.rllib.evaluation.collectors.simple_list_collector import \
  24. SimpleListCollector
  25. from ray.rllib.evaluation.episode import Episode
  26. from ray.rllib.evaluation.metrics import collect_episodes, collect_metrics, \
  27. summarize_episodes
  28. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  29. from ray.rllib.evaluation.worker_set import WorkerSet
  30. from ray.rllib.execution.metric_ops import StandardMetricsReporting
  31. from ray.rllib.execution.buffers.multi_agent_replay_buffer import \
  32. MultiAgentReplayBuffer
  33. from ray.rllib.execution.common import WORKER_UPDATE_TIMER
  34. from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts, \
  35. synchronous_parallel_sample
  36. from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep, \
  37. train_one_step, multi_gpu_train_one_step
  38. from ray.rllib.models import MODEL_DEFAULTS
  39. from ray.rllib.policy.policy import Policy
  40. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
  41. from ray.rllib.utils import deep_update, FilterManager, merge_dicts
  42. from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, \
  43. override, PublicAPI
  44. from ray.rllib.utils.debug import update_global_seed_if_necessary
  45. from ray.rllib.utils.deprecation import Deprecated, deprecation_warning, \
  46. DEPRECATED_VALUE
  47. from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR
  48. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  49. from ray.rllib.utils.from_config import from_config
  50. from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED, \
  51. NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_TRAINED, NUM_AGENT_STEPS_TRAINED
  52. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
  53. from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent
  54. from ray.rllib.utils.spaces import space_utils
  55. from ray.rllib.utils.typing import AgentID, EnvCreator, EnvInfoDict, EnvType, \
  56. EpisodeID, PartialTrainerConfigDict, PolicyID, PolicyState, ResultDict, \
  57. TensorStructType, TensorType, TrainerConfigDict
  58. from ray.tune.logger import Logger, UnifiedLogger
  59. from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
  60. from ray.tune.resources import Resources
  61. from ray.tune.result import DEFAULT_RESULTS_DIR
  62. from ray.tune.trainable import Trainable
  63. from ray.tune.trial import ExportFormat
  64. from ray.tune.utils.placement_groups import PlacementGroupFactory
  65. from ray.util import log_once
  66. from ray.util.timer import _Timer
  67. tf1, tf, tfv = try_import_tf()
  68. logger = logging.getLogger(__name__)
  69. # Max number of times to retry a worker failure. We shouldn't try too many
  70. # times in a row since that would indicate a persistent cluster issue.
  71. MAX_WORKER_FAILURE_RETRIES = 3
  72. # yapf: disable
  73. # __sphinx_doc_begin__
  74. COMMON_CONFIG: TrainerConfigDict = {
  75. # === Settings for Rollout Worker processes ===
  76. # Number of rollout worker actors to create for parallel sampling. Setting
  77. # this to 0 will force rollouts to be done in the trainer actor.
  78. "num_workers": 2,
  79. # Number of environments to evaluate vector-wise per worker. This enables
  80. # model inference batching, which can improve performance for inference
  81. # bottlenecked workloads.
  82. "num_envs_per_worker": 1,
  83. # When `num_workers` > 0, the driver (local_worker; worker-idx=0) does not
  84. # need an environment. This is because it doesn't have to sample (done by
  85. # remote_workers; worker_indices > 0) nor evaluate (done by evaluation
  86. # workers; see below).
  87. "create_env_on_driver": False,
  88. # Divide episodes into fragments of this many steps each during rollouts.
  89. # Sample batches of this size are collected from rollout workers and
  90. # combined into a larger batch of `train_batch_size` for learning.
  91. #
  92. # For example, given rollout_fragment_length=100 and train_batch_size=1000:
  93. # 1. RLlib collects 10 fragments of 100 steps each from rollout workers.
  94. # 2. These fragments are concatenated and we perform an epoch of SGD.
  95. #
  96. # When using multiple envs per worker, the fragment size is multiplied by
  97. # `num_envs_per_worker`. This is since we are collecting steps from
  98. # multiple envs in parallel. For example, if num_envs_per_worker=5, then
  99. # rollout workers will return experiences in chunks of 5*100 = 500 steps.
  100. #
  101. # The dataflow here can vary per algorithm. For example, PPO further
  102. # divides the train batch into minibatches for multi-epoch SGD.
  103. "rollout_fragment_length": 200,
  104. # How to build per-Sampler (RolloutWorker) batches, which are then
  105. # usually concat'd to form the train batch. Note that "steps" below can
  106. # mean different things (either env- or agent-steps) and depends on the
  107. # `count_steps_by` (multiagent) setting below.
  108. # truncate_episodes: Each produced batch (when calling
  109. # RolloutWorker.sample()) will contain exactly `rollout_fragment_length`
  110. # steps. This mode guarantees evenly sized batches, but increases
  111. # variance as the future return must now be estimated at truncation
  112. # boundaries.
  113. # complete_episodes: Each unroll happens exactly over one episode, from
  114. # beginning to end. Data collection will not stop unless the episode
  115. # terminates or a configured horizon (hard or soft) is hit.
  116. "batch_mode": "truncate_episodes",
  117. # === Settings for the Trainer process ===
  118. # Discount factor of the MDP.
  119. "gamma": 0.99,
  120. # The default learning rate.
  121. "lr": 0.0001,
  122. # Training batch size, if applicable. Should be >= rollout_fragment_length.
  123. # Samples batches will be concatenated together to a batch of this size,
  124. # which is then passed to SGD.
  125. "train_batch_size": 200,
  126. # Arguments to pass to the policy model. See models/catalog.py for a full
  127. # list of the available model options.
  128. "model": MODEL_DEFAULTS,
  129. # Arguments to pass to the policy optimizer. These vary by optimizer.
  130. "optimizer": {},
  131. # === Environment Settings ===
  132. # Number of steps after which the episode is forced to terminate. Defaults
  133. # to `env.spec.max_episode_steps` (if present) for Gym envs.
  134. "horizon": None,
  135. # Calculate rewards but don't reset the environment when the horizon is
  136. # hit. This allows value estimation and RNN state to span across logical
  137. # episodes denoted by horizon. This only has an effect if horizon != inf.
  138. "soft_horizon": False,
  139. # Don't set 'done' at the end of the episode.
  140. # In combination with `soft_horizon`, this works as follows:
  141. # - no_done_at_end=False soft_horizon=False:
  142. # Reset env and add `done=True` at end of each episode.
  143. # - no_done_at_end=True soft_horizon=False:
  144. # Reset env, but do NOT add `done=True` at end of the episode.
  145. # - no_done_at_end=False soft_horizon=True:
  146. # Do NOT reset env at horizon, but add `done=True` at the horizon
  147. # (pretending the episode has terminated).
  148. # - no_done_at_end=True soft_horizon=True:
  149. # Do NOT reset env at horizon and do NOT add `done=True` at the horizon.
  150. "no_done_at_end": False,
  151. # The environment specifier:
  152. # This can either be a tune-registered env, via
  153. # `tune.register_env([name], lambda env_ctx: [env object])`,
  154. # or a string specifier of an RLlib supported type. In the latter case,
  155. # RLlib will try to interpret the specifier as either an openAI gym env,
  156. # a PyBullet env, a ViZDoomGym env, or a fully qualified classpath to an
  157. # Env class, e.g. "ray.rllib.examples.env.random_env.RandomEnv".
  158. "env": None,
  159. # The observation- and action spaces for the Policies of this Trainer.
  160. # Use None for automatically inferring these from the given env.
  161. "observation_space": None,
  162. "action_space": None,
  163. # Arguments dict passed to the env creator as an EnvContext object (which
  164. # is a dict plus the properties: num_workers, worker_index, vector_index,
  165. # and remote).
  166. "env_config": {},
  167. # If using num_envs_per_worker > 1, whether to create those new envs in
  168. # remote processes instead of in the same worker. This adds overheads, but
  169. # can make sense if your envs can take much time to step / reset
  170. # (e.g., for StarCraft). Use this cautiously; overheads are significant.
  171. "remote_worker_envs": False,
  172. # Timeout that remote workers are waiting when polling environments.
  173. # 0 (continue when at least one env is ready) is a reasonable default,
  174. # but optimal value could be obtained by measuring your environment
  175. # step / reset and model inference perf.
  176. "remote_env_batch_wait_ms": 0,
  177. # A callable taking the last train results, the base env and the env
  178. # context as args and returning a new task to set the env to.
  179. # The env must be a `TaskSettableEnv` sub-class for this to work.
  180. # See `examples/curriculum_learning.py` for an example.
  181. "env_task_fn": None,
  182. # If True, try to render the environment on the local worker or on worker
  183. # 1 (if num_workers > 0). For vectorized envs, this usually means that only
  184. # the first sub-environment will be rendered.
  185. # In order for this to work, your env will have to implement the
  186. # `render()` method which either:
  187. # a) handles window generation and rendering itself (returning True) or
  188. # b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].
  189. "render_env": False,
  190. # If True, stores videos in this relative directory inside the default
  191. # output dir (~/ray_results/...). Alternatively, you can specify an
  192. # absolute path (str), in which the env recordings should be
  193. # stored instead.
  194. # Set to False for not recording anything.
  195. # Note: This setting replaces the deprecated `monitor` key.
  196. "record_env": False,
  197. # Whether to clip rewards during Policy's postprocessing.
  198. # None (default): Clip for Atari only (r=sign(r)).
  199. # True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0.
  200. # False: Never clip.
  201. # [float value]: Clip at -value and + value.
  202. # Tuple[value1, value2]: Clip at value1 and value2.
  203. "clip_rewards": None,
  204. # If True, RLlib will learn entirely inside a normalized action space
  205. # (0.0 centered with small stddev; only affecting Box components).
  206. # We will unsquash actions (and clip, just in case) to the bounds of
  207. # the env's action space before sending actions back to the env.
  208. "normalize_actions": True,
  209. # If True, RLlib will clip actions according to the env's bounds
  210. # before sending them back to the env.
  211. # TODO: (sven) This option should be obsoleted and always be False.
  212. "clip_actions": False,
  213. # Whether to use "rllib" or "deepmind" preprocessors by default
  214. # Set to None for using no preprocessor. In this case, the model will have
  215. # to handle possibly complex observations from the environment.
  216. "preprocessor_pref": "deepmind",
  217. # === Debug Settings ===
  218. # Set the ray.rllib.* log level for the agent process and its workers.
  219. # Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
  220. # periodically print out summaries of relevant internal dataflow (this is
  221. # also printed out once at startup at the INFO level). When using the
  222. # `rllib train` command, you can also use the `-v` and `-vv` flags as
  223. # shorthand for INFO and DEBUG.
  224. "log_level": "WARN",
  225. # Callbacks that will be run during various phases of training. See the
  226. # `DefaultCallbacks` class and `examples/custom_metrics_and_callbacks.py`
  227. # for more usage information.
  228. "callbacks": DefaultCallbacks,
  229. # Whether to attempt to continue training if a worker crashes. The number
  230. # of currently healthy workers is reported as the "num_healthy_workers"
  231. # metric.
  232. "ignore_worker_failures": False,
  233. # Log system resource metrics to results. This requires `psutil` to be
  234. # installed for sys stats, and `gputil` for GPU metrics.
  235. "log_sys_usage": True,
  236. # Use fake (infinite speed) sampler. For testing only.
  237. "fake_sampler": False,
  238. # === Deep Learning Framework Settings ===
  239. # tf: TensorFlow (static-graph)
  240. # tf2: TensorFlow 2.x (eager or traced, if eager_tracing=True)
  241. # tfe: TensorFlow eager (or traced, if eager_tracing=True)
  242. # torch: PyTorch
  243. "framework": "tf",
  244. # Enable tracing in eager mode. This greatly improves performance
  245. # (speedup ~2x), but makes it slightly harder to debug since Python
  246. # code won't be evaluated after the initial eager pass.
  247. # Only possible if framework=[tf2|tfe].
  248. "eager_tracing": False,
  249. # Maximum number of tf.function re-traces before a runtime error is raised.
  250. # This is to prevent unnoticed retraces of methods inside the
  251. # `..._eager_traced` Policy, which could slow down execution by a
  252. # factor of 4, without the user noticing what the root cause for this
  253. # slowdown could be.
  254. # Only necessary for framework=[tf2|tfe].
  255. # Set to None to ignore the re-trace count and never throw an error.
  256. "eager_max_retraces": 20,
  257. # === Exploration Settings ===
  258. # Default exploration behavior, iff `explore`=None is passed into
  259. # compute_action(s).
  260. # Set to False for no exploration behavior (e.g., for evaluation).
  261. "explore": True,
  262. # Provide a dict specifying the Exploration object's config.
  263. "exploration_config": {
  264. # The Exploration class to use. In the simplest case, this is the name
  265. # (str) of any class present in the `rllib.utils.exploration` package.
  266. # You can also provide the python class directly or the full location
  267. # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
  268. # EpsilonGreedy").
  269. "type": "StochasticSampling",
  270. # Add constructor kwargs here (if any).
  271. },
  272. # === Evaluation Settings ===
  273. # Evaluate with every `evaluation_interval` training iterations.
  274. # The evaluation stats will be reported under the "evaluation" metric key.
  275. # Note that for Ape-X metrics are already only reported for the lowest
  276. # epsilon workers (least random workers).
  277. # Set to None (or 0) for no evaluation.
  278. "evaluation_interval": None,
  279. # Duration for which to run evaluation each `evaluation_interval`.
  280. # The unit for the duration can be set via `evaluation_duration_unit` to
  281. # either "episodes" (default) or "timesteps".
  282. # If using multiple evaluation workers (evaluation_num_workers > 1),
  283. # the load to run will be split amongst these.
  284. # If the value is "auto":
  285. # - For `evaluation_parallel_to_training=True`: Will run as many
  286. # episodes/timesteps that fit into the (parallel) training step.
  287. # - For `evaluation_parallel_to_training=False`: Error.
  288. "evaluation_duration": 10,
  289. # The unit, with which to count the evaluation duration. Either "episodes"
  290. # (default) or "timesteps".
  291. "evaluation_duration_unit": "episodes",
  292. # Whether to run evaluation in parallel to a Trainer.train() call
  293. # using threading. Default=False.
  294. # E.g. evaluation_interval=2 -> For every other training iteration,
  295. # the Trainer.train() and Trainer.evaluate() calls run in parallel.
  296. # Note: This is experimental. Possible pitfalls could be race conditions
  297. # for weight synching at the beginning of the evaluation loop.
  298. "evaluation_parallel_to_training": False,
  299. # Internal flag that is set to True for evaluation workers.
  300. "in_evaluation": False,
  301. # Typical usage is to pass extra args to evaluation env creator
  302. # and to disable exploration by computing deterministic actions.
  303. # IMPORTANT NOTE: Policy gradient algorithms are able to find the optimal
  304. # policy, even if this is a stochastic one. Setting "explore=False" here
  305. # will result in the evaluation workers not using this optimal policy!
  306. "evaluation_config": {
  307. # Example: overriding env_config, exploration, etc:
  308. # "env_config": {...},
  309. # "explore": False
  310. },
  311. # Number of parallel workers to use for evaluation. Note that this is set
  312. # to zero by default, which means evaluation will be run in the trainer
  313. # process (only if evaluation_interval is not None). If you increase this,
  314. # it will increase the Ray resource usage of the trainer since evaluation
  315. # workers are created separately from rollout workers (used to sample data
  316. # for training).
  317. "evaluation_num_workers": 0,
  318. # Customize the evaluation method. This must be a function of signature
  319. # (trainer: Trainer, eval_workers: WorkerSet) -> metrics: dict. See the
  320. # Trainer.evaluate() method to see the default implementation.
  321. # The Trainer guarantees all eval workers have the latest policy state
  322. # before this function is called.
  323. "custom_eval_function": None,
  324. # Make sure the latest available evaluation results are always attached to
  325. # a step result dict.
  326. # This may be useful if Tune or some other meta controller needs access
  327. # to evaluation metrics all the time.
  328. "always_attach_evaluation_results": False,
  329. # === Advanced Rollout Settings ===
  330. # Use a background thread for sampling (slightly off-policy, usually not
  331. # advisable to turn on unless your env specifically requires it).
  332. "sample_async": False,
  333. # The SampleCollector class to be used to collect and retrieve
  334. # environment-, model-, and sampler data. Override the SampleCollector base
  335. # class to implement your own collection/buffering/retrieval logic.
  336. "sample_collector": SimpleListCollector,
  337. # Element-wise observation filter, either "NoFilter" or "MeanStdFilter".
  338. "observation_filter": "NoFilter",
  339. # Whether to synchronize the statistics of remote filters.
  340. "synchronize_filters": True,
  341. # Configures TF for single-process operation by default.
  342. "tf_session_args": {
  343. # note: overridden by `local_tf_session_args`
  344. "intra_op_parallelism_threads": 2,
  345. "inter_op_parallelism_threads": 2,
  346. "gpu_options": {
  347. "allow_growth": True,
  348. },
  349. "log_device_placement": False,
  350. "device_count": {
  351. "CPU": 1
  352. },
  353. # Required by multi-GPU (num_gpus > 1).
  354. "allow_soft_placement": True,
  355. },
  356. # Override the following tf session args on the local worker
  357. "local_tf_session_args": {
  358. # Allow a higher level of parallelism by default, but not unlimited
  359. # since that can cause crashes with many concurrent drivers.
  360. "intra_op_parallelism_threads": 8,
  361. "inter_op_parallelism_threads": 8,
  362. },
  363. # Whether to LZ4 compress individual observations.
  364. "compress_observations": False,
  365. # Wait for metric batches for at most this many seconds. Those that
  366. # have not returned in time will be collected in the next train iteration.
  367. "metrics_episode_collection_timeout_s": 180,
  368. # Smooth metrics over this many episodes.
  369. "metrics_num_episodes_for_smoothing": 100,
  370. # Minimum time interval to run one `train()` call for:
  371. # If - after one `step_attempt()`, this time limit has not been reached,
  372. # will perform n more `step_attempt()` calls until this minimum time has
  373. # been consumed. Set to None or 0 for no minimum time.
  374. "min_time_s_per_reporting": None,
  375. # Minimum train/sample timesteps to optimize for per `train()` call.
  376. # This value does not affect learning, only the length of train iterations.
  377. # If - after one `step_attempt()`, the timestep counts (sampling or
  378. # training) have not been reached, will perform n more `step_attempt()`
  379. # calls until the minimum timesteps have been executed.
  380. # Set to None or 0 for no minimum timesteps.
  381. "min_train_timesteps_per_reporting": None,
  382. "min_sample_timesteps_per_reporting": None,
  383. # This argument, in conjunction with worker_index, sets the random seed of
  384. # each worker, so that identically configured trials will have identical
  385. # results. This makes experiments reproducible.
  386. "seed": None,
  387. # Any extra python env vars to set in the trainer process, e.g.,
  388. # {"OMP_NUM_THREADS": "16"}
  389. "extra_python_environs_for_driver": {},
  390. # The extra python environments need to set for worker processes.
  391. "extra_python_environs_for_worker": {},
  392. # === Resource Settings ===
  393. # Number of GPUs to allocate to the trainer process. Note that not all
  394. # algorithms can take advantage of trainer GPUs. Support for multi-GPU
  395. # is currently only available for tf-[PPO/IMPALA/DQN/PG].
  396. # This can be fractional (e.g., 0.3 GPUs).
  397. "num_gpus": 0,
  398. # Set to True for debugging (multi-)?GPU funcitonality on a CPU machine.
  399. # GPU towers will be simulated by graphs located on CPUs in this case.
  400. # Use `num_gpus` to test for different numbers of fake GPUs.
  401. "_fake_gpus": False,
  402. # Number of CPUs to allocate per worker.
  403. "num_cpus_per_worker": 1,
  404. # Number of GPUs to allocate per worker. This can be fractional. This is
  405. # usually needed only if your env itself requires a GPU (i.e., it is a
  406. # GPU-intensive video game), or model inference is unusually expensive.
  407. "num_gpus_per_worker": 0,
  408. # Any custom Ray resources to allocate per worker.
  409. "custom_resources_per_worker": {},
  410. # Number of CPUs to allocate for the trainer. Note: this only takes effect
  411. # when running in Tune. Otherwise, the trainer runs in the main program.
  412. "num_cpus_for_driver": 1,
  413. # The strategy for the placement group factory returned by
  414. # `Trainer.default_resource_request()`. A PlacementGroup defines, which
  415. # devices (resources) should always be co-located on the same node.
  416. # For example, a Trainer with 2 rollout workers, running with
  417. # num_gpus=1 will request a placement group with the bundles:
  418. # [{"gpu": 1, "cpu": 1}, {"cpu": 1}, {"cpu": 1}], where the first bundle is
  419. # for the driver and the other 2 bundles are for the two workers.
  420. # These bundles can now be "placed" on the same or different
  421. # nodes depending on the value of `placement_strategy`:
  422. # "PACK": Packs bundles into as few nodes as possible.
  423. # "SPREAD": Places bundles across distinct nodes as even as possible.
  424. # "STRICT_PACK": Packs bundles into one node. The group is not allowed
  425. # to span multiple nodes.
  426. # "STRICT_SPREAD": Packs bundles across distinct nodes.
  427. "placement_strategy": "PACK",
  428. # === Offline Datasets ===
  429. # Specify how to generate experiences:
  430. # - "sampler": Generate experiences via online (env) simulation (default).
  431. # - A local directory or file glob expression (e.g., "/tmp/*.json").
  432. # - A list of individual file paths/URIs (e.g., ["/tmp/1.json",
  433. # "s3://bucket/2.json"]).
  434. # - A dict with string keys and sampling probabilities as values (e.g.,
  435. # {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
  436. # - A callable that takes an `IOContext` object as only arg and returns a
  437. # ray.rllib.offline.InputReader.
  438. # - A string key that indexes a callable with tune.registry.register_input
  439. "input": "sampler",
  440. # Arguments accessible from the IOContext for configuring custom input
  441. "input_config": {},
  442. # True, if the actions in a given offline "input" are already normalized
  443. # (between -1.0 and 1.0). This is usually the case when the offline
  444. # file has been generated by another RLlib algorithm (e.g. PPO or SAC),
  445. # while "normalize_actions" was set to True.
  446. "actions_in_input_normalized": False,
  447. # Specify how to evaluate the current policy. This only has an effect when
  448. # reading offline experiences ("input" is not "sampler").
  449. # Available options:
  450. # - "wis": the weighted step-wise importance sampling estimator.
  451. # - "is": the step-wise importance sampling estimator.
  452. # - "simulation": run the environment in the background, but use
  453. # this data for evaluation only and not for learning.
  454. "input_evaluation": ["is", "wis"],
  455. # Whether to run postprocess_trajectory() on the trajectory fragments from
  456. # offline inputs. Note that postprocessing will be done using the *current*
  457. # policy, not the *behavior* policy, which is typically undesirable for
  458. # on-policy algorithms.
  459. "postprocess_inputs": False,
  460. # If positive, input batches will be shuffled via a sliding window buffer
  461. # of this number of batches. Use this if the input data is not in random
  462. # enough order. Input is delayed until the shuffle buffer is filled.
  463. "shuffle_buffer_size": 0,
  464. # Specify where experiences should be saved:
  465. # - None: don't save any experiences
  466. # - "logdir" to save to the agent log dir
  467. # - a path/URI to save to a custom output directory (e.g., "s3://bucket/")
  468. # - a function that returns a rllib.offline.OutputWriter
  469. "output": None,
  470. # What sample batch columns to LZ4 compress in the output data.
  471. "output_compress_columns": ["obs", "new_obs"],
  472. # Max output file size before rolling over to a new file.
  473. "output_max_file_size": 64 * 1024 * 1024,
  474. # === Settings for Multi-Agent Environments ===
  475. "multiagent": {
  476. # Map of type MultiAgentPolicyConfigDict from policy ids to tuples
  477. # of (policy_cls, obs_space, act_space, config). This defines the
  478. # observation and action spaces of the policies and any extra config.
  479. "policies": {},
  480. # Keep this many policies in the "policy_map" (before writing
  481. # least-recently used ones to disk/S3).
  482. "policy_map_capacity": 100,
  483. # Where to store overflowing (least-recently used) policies?
  484. # Could be a directory (str) or an S3 location. None for using
  485. # the default output dir.
  486. "policy_map_cache": None,
  487. # Function mapping agent ids to policy ids.
  488. "policy_mapping_fn": None,
  489. # Optional list of policies to train, or None for all policies.
  490. "policies_to_train": None,
  491. # Optional function that can be used to enhance the local agent
  492. # observations to include more state.
  493. # See rllib/evaluation/observation_function.py for more info.
  494. "observation_fn": None,
  495. # When replay_mode=lockstep, RLlib will replay all the agent
  496. # transitions at a particular timestep together in a batch. This allows
  497. # the policy to implement differentiable shared computations between
  498. # agents it controls at that timestep. When replay_mode=independent,
  499. # transitions are replayed independently per policy.
  500. "replay_mode": "independent",
  501. # Which metric to use as the "batch size" when building a
  502. # MultiAgentBatch. The two supported values are:
  503. # env_steps: Count each time the env is "stepped" (no matter how many
  504. # multi-agent actions are passed/how many multi-agent observations
  505. # have been returned in the previous step).
  506. # agent_steps: Count each individual agent step as one step.
  507. "count_steps_by": "env_steps",
  508. },
  509. # === Logger ===
  510. # Define logger-specific configuration to be used inside Logger
  511. # Default value None allows overwriting with nested dicts
  512. "logger_config": None,
  513. # === API deprecations/simplifications/changes ===
  514. # Experimental flag.
  515. # If True, TFPolicy will handle more than one loss/optimizer.
  516. # Set this to True, if you would like to return more than
  517. # one loss term from your `loss_fn` and an equal number of optimizers
  518. # from your `optimizer_fn`.
  519. # In the future, the default for this will be True.
  520. "_tf_policy_handles_more_than_one_loss": False,
  521. # Experimental flag.
  522. # If True, no (observation) preprocessor will be created and
  523. # observations will arrive in model as they are returned by the env.
  524. # In the future, the default for this will be True.
  525. "_disable_preprocessor_api": False,
  526. # Experimental flag.
  527. # If True, RLlib will no longer flatten the policy-computed actions into
  528. # a single tensor (for storage in SampleCollectors/output files/etc..),
  529. # but leave (possibly nested) actions as-is. Disabling flattening affects:
  530. # - SampleCollectors: Have to store possibly nested action structs.
  531. # - Models that have the previous action(s) as part of their input.
  532. # - Algorithms reading from offline files (incl. action information).
  533. "_disable_action_flattening": False,
  534. # Experimental flag.
  535. # If True, the execution plan API will not be used. Instead,
  536. # a Trainer's `training_iteration` method will be called as-is each
  537. # training iteration.
  538. "_disable_execution_plan_api": False,
  539. # === Deprecated keys ===
  540. # Uses the sync samples optimizer instead of the multi-gpu one. This is
  541. # usually slower, but you might want to try it if you run into issues with
  542. # the default optimizer.
  543. # This will be set automatically from now on.
  544. "simple_optimizer": DEPRECATED_VALUE,
  545. # Whether to write episode stats and videos to the agent log dir. This is
  546. # typically located in ~/ray_results.
  547. "monitor": DEPRECATED_VALUE,
  548. # Replaced by `evaluation_duration=10` and
  549. # `evaluation_duration_unit=episodes`.
  550. "evaluation_num_episodes": DEPRECATED_VALUE,
  551. # Use `metrics_num_episodes_for_smoothing` instead.
  552. "metrics_smoothing_episodes": DEPRECATED_VALUE,
  553. # Use `min_[env|train]_timesteps_per_reporting` instead.
  554. "timesteps_per_iteration": 0,
  555. # Use `min_time_s_per_reporting` instead.
  556. "min_iter_time_s": DEPRECATED_VALUE,
  557. # Use `metrics_episode_collection_timeout_s` instead.
  558. "collect_metrics_timeout": DEPRECATED_VALUE,
  559. }
  560. # __sphinx_doc_end__
  561. # yapf: enable
  562. @DeveloperAPI
  563. def with_common_config(
  564. extra_config: PartialTrainerConfigDict) -> TrainerConfigDict:
  565. """Returns the given config dict merged with common agent confs.
  566. Args:
  567. extra_config (PartialTrainerConfigDict): A user defined partial config
  568. which will get merged with COMMON_CONFIG and returned.
  569. Returns:
  570. TrainerConfigDict: The merged config dict resulting of COMMON_CONFIG
  571. plus `extra_config`.
  572. """
  573. return Trainer.merge_trainer_configs(
  574. COMMON_CONFIG, extra_config, _allow_unknown_configs=True)
  575. @PublicAPI
  576. class Trainer(Trainable):
  577. """An RLlib algorithm responsible for optimizing one or more Policies.
  578. Trainers contain a WorkerSet under `self.workers`. A WorkerSet is
  579. normally composed of a single local worker
  580. (self.workers.local_worker()), used to compute and apply learning updates,
  581. and optionally one or more remote workers (self.workers.remote_workers()),
  582. used to generate environment samples in parallel.
  583. Each worker (remotes or local) contains a PolicyMap, which itself
  584. may contain either one policy for single-agent training or one or more
  585. policies for multi-agent training. Policies are synchronized
  586. automatically from time to time using ray.remote calls. The exact
  587. synchronization logic depends on the specific algorithm (Trainer) used,
  588. but this usually happens from local worker to all remote workers and
  589. after each training update.
  590. You can write your own Trainer classes by sub-classing from `Trainer`
  591. or any of its built-in sub-classes.
  592. This allows you to override the `execution_plan` method to implement
  593. your own algorithm logic. You can find the different built-in
  594. algorithms' execution plans in their respective main py files,
  595. e.g. rllib.agents.dqn.dqn.py or rllib.agents.impala.impala.py.
  596. The most important API methods a Trainer exposes are `train()`,
  597. `evaluate()`, `save()` and `restore()`. Trainer objects retain internal
  598. model state between calls to train(), so you should create a new
  599. Trainer instance for each training session.
  600. """
  601. # Whether to allow unknown top-level config keys.
  602. _allow_unknown_configs = False
  603. # List of top-level keys with value=dict, for which new sub-keys are
  604. # allowed to be added to the value dict.
  605. _allow_unknown_subkeys = [
  606. "tf_session_args", "local_tf_session_args", "env_config", "model",
  607. "optimizer", "multiagent", "custom_resources_per_worker",
  608. "evaluation_config", "exploration_config",
  609. "extra_python_environs_for_driver", "extra_python_environs_for_worker",
  610. "input_config"
  611. ]
  612. # List of top level keys with value=dict, for which we always override the
  613. # entire value (dict), iff the "type" key in that value dict changes.
  614. _override_all_subkeys_if_type_changes = ["exploration_config"]
  615. # TODO: Deprecate. Instead, override `Trainer.get_default_config()`.
  616. _default_config = COMMON_CONFIG
  617. @PublicAPI
  618. def __init__(self,
  619. config: Optional[PartialTrainerConfigDict] = None,
  620. env: Optional[Union[str, EnvType]] = None,
  621. logger_creator: Optional[Callable[[], Logger]] = None,
  622. remote_checkpoint_dir: Optional[str] = None,
  623. sync_function_tpl: Optional[str] = None):
  624. """Initializes a Trainer instance.
  625. Args:
  626. config: Algorithm-specific configuration dict.
  627. env: Name of the environment to use (e.g. a gym-registered str),
  628. a full class path (e.g.
  629. "ray.rllib.examples.env.random_env.RandomEnv"), or an Env
  630. class directly. Note that this arg can also be specified via
  631. the "env" key in `config`.
  632. logger_creator: Callable that creates a ray.tune.Logger
  633. object. If unspecified, a default logger is created.
  634. """
  635. # User provided (partial) config (this may be w/o the default
  636. # Trainer's `COMMON_CONFIG` (see above)). Will get merged with
  637. # COMMON_CONFIG in self.setup().
  638. config = config or {}
  639. # Convert `env` provided in config into a string:
  640. # - If `env` is a string: `self._env_id` = `env`.
  641. # - If `env` is a class: `self._env_id` = `env.__name__` -> Already
  642. # register it with a auto-generated env creator.
  643. # - If `env` is None: `self._env_id` is None.
  644. self._env_id: Optional[str] = self._register_if_needed(
  645. env or config.get("env"), config)
  646. # The env creator callable, taking an EnvContext (config dict)
  647. # as arg and returning an RLlib supported Env type (e.g. a gym.Env).
  648. self.env_creator: EnvCreator = None
  649. # Placeholder for a local replay buffer instance.
  650. self.local_replay_buffer = None
  651. # Create a default logger creator if no logger_creator is specified
  652. if logger_creator is None:
  653. # Default logdir prefix containing the agent's name and the
  654. # env id.
  655. timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
  656. logdir_prefix = "{}_{}_{}".format(str(self), self._env_id, timestr)
  657. if not os.path.exists(DEFAULT_RESULTS_DIR):
  658. os.makedirs(DEFAULT_RESULTS_DIR)
  659. logdir = tempfile.mkdtemp(
  660. prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
  661. # Allow users to more precisely configure the created logger
  662. # via "logger_config.type".
  663. if config.get(
  664. "logger_config") and "type" in config["logger_config"]:
  665. def default_logger_creator(config):
  666. """Creates a custom logger with the default prefix."""
  667. cfg = config["logger_config"].copy()
  668. cls = cfg.pop("type")
  669. # Provide default for logdir, in case the user does
  670. # not specify this in the "logger_config" dict.
  671. logdir_ = cfg.pop("logdir", logdir)
  672. return from_config(cls=cls, _args=[cfg], logdir=logdir_)
  673. # If no `type` given, use tune's UnifiedLogger as last resort.
  674. else:
  675. def default_logger_creator(config):
  676. """Creates a Unified logger with the default prefix."""
  677. return UnifiedLogger(config, logdir, loggers=None)
  678. logger_creator = default_logger_creator
  679. # Metrics-related properties.
  680. self._timers = defaultdict(_Timer)
  681. self._counters = defaultdict(int)
  682. self._episode_history = []
  683. self._episodes_to_be_collected = []
  684. # Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
  685. self.evaluation_workers: Optional[WorkerSet] = None
  686. # Initialize common evaluation_metrics to nan, before they become
  687. # available. We want to make sure the metrics are always present
  688. # (although their values may be nan), so that Tune does not complain
  689. # when we use these as stopping criteria.
  690. self.evaluation_metrics = {
  691. "evaluation": {
  692. "episode_reward_max": np.nan,
  693. "episode_reward_min": np.nan,
  694. "episode_reward_mean": np.nan,
  695. }
  696. }
  697. super().__init__(config, logger_creator, remote_checkpoint_dir,
  698. sync_function_tpl)
  699. @ExperimentalAPI
  700. @classmethod
  701. def get_default_config(cls) -> TrainerConfigDict:
  702. return cls._default_config or COMMON_CONFIG
  703. @override(Trainable)
  704. def setup(self, config: PartialTrainerConfigDict):
  705. # Setup our config: Merge the user-supplied config (which could
  706. # be a partial config dict with the class' default).
  707. self.config = self.merge_trainer_configs(
  708. self.get_default_config(), config, self._allow_unknown_configs)
  709. self.config["env"] = self._env_id
  710. # Validate the framework settings in config.
  711. self.validate_framework(self.config)
  712. # Setup the self.env_creator callable (to be passed
  713. # e.g. to RolloutWorkers' c'tors).
  714. self.env_creator = self._get_env_creator_from_env_id(self._env_id)
  715. # Set Trainer's seed after we have - if necessary - enabled
  716. # tf eager-execution.
  717. update_global_seed_if_necessary(self.config["framework"],
  718. self.config["seed"])
  719. self.validate_config(self.config)
  720. self.callbacks = self.config["callbacks"]()
  721. log_level = self.config.get("log_level")
  722. if log_level in ["WARN", "ERROR"]:
  723. logger.info("Current log_level is {}. For more information, "
  724. "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
  725. "-vv flags.".format(log_level))
  726. if self.config.get("log_level"):
  727. logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
  728. # Create local replay buffer if necessary.
  729. self.local_replay_buffer = (
  730. self._create_local_replay_buffer_if_necessary(self.config))
  731. # Create a dict, mapping ActorHandles to sets of open remote
  732. # requests (object refs). This way, we keep track, of which actors
  733. # inside this Trainer (e.g. a remote RolloutWorker) have
  734. # already been sent how many (e.g. `sample()`) requests.
  735. self.remote_requests_in_flight: \
  736. DefaultDict[ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)
  737. self.workers: Optional[WorkerSet] = None
  738. self.train_exec_impl = None
  739. # Deprecated way of implementing Trainer sub-classes (or "templates"
  740. # via the `build_trainer` utility function).
  741. # Instead, sub-classes should override the Trainable's `setup()`
  742. # method and call super().setup() from within that override at some
  743. # point.
  744. # Old design: Override `Trainer._init` (or use `build_trainer()`, which
  745. # will do this for you).
  746. try:
  747. self._init(self.config, self.env_creator)
  748. # New design: Override `Trainable.setup()` (as indented by Trainable)
  749. # and do or don't call super().setup() from within your override.
  750. # By default, `super().setup()` will create both worker sets:
  751. # "rollout workers" for collecting samples for training and - if
  752. # applicable - "evaluation workers" for evaluation runs in between or
  753. # parallel to training.
  754. # TODO: Deprecate `_init()` and remove this try/except block.
  755. except NotImplementedError:
  756. # Only if user did not override `_init()`:
  757. # - Create rollout workers here automatically.
  758. # - Run the execution plan to create the local iterator to `next()`
  759. # in each training iteration.
  760. # This matches the behavior of using `build_trainer()`, which
  761. # should no longer be used.
  762. self.workers = self._make_workers(
  763. env_creator=self.env_creator,
  764. validate_env=self.validate_env,
  765. policy_class=self.get_default_policy_class(self.config),
  766. config=self.config,
  767. num_workers=self.config["num_workers"])
  768. # Function defining one single training iteration's behavior.
  769. if self.config["_disable_execution_plan_api"]:
  770. # Ensure remote workers are initially in sync with the
  771. # local worker.
  772. self.workers.sync_weights()
  773. # LocalIterator-creating "execution plan".
  774. # Only call this once here to create `self.train_exec_impl`,
  775. # which is a ray.util.iter.LocalIterator that will be `next`'d
  776. # on each training iteration.
  777. else:
  778. self.train_exec_impl = self.execution_plan(
  779. self.workers, self.config,
  780. **self._kwargs_for_execution_plan())
  781. # Now that workers have been created, update our policies
  782. # dict in config[multiagent] (with the correct original/
  783. # unpreprocessed spaces).
  784. self.config["multiagent"]["policies"] = \
  785. self.workers.local_worker().policy_dict
  786. # Evaluation WorkerSet setup.
  787. # User would like to setup a separate evaluation worker set.
  788. # Update with evaluation settings:
  789. user_eval_config = copy.deepcopy(self.config["evaluation_config"])
  790. # Assert that user has not unset "in_evaluation".
  791. assert "in_evaluation" not in user_eval_config or \
  792. user_eval_config["in_evaluation"] is True
  793. # Merge user-provided eval config with the base config. This makes sure
  794. # the eval config is always complete, no matter whether we have eval
  795. # workers or perform evaluation on the (non-eval) local worker.
  796. eval_config = merge_dicts(self.config, user_eval_config)
  797. self.config["evaluation_config"] = eval_config
  798. if self.config.get("evaluation_num_workers", 0) > 0 or \
  799. self.config.get("evaluation_interval"):
  800. logger.debug(f"Using evaluation_config: {user_eval_config}.")
  801. # Validate evaluation config.
  802. self.validate_config(eval_config)
  803. # Set the `in_evaluation` flag.
  804. eval_config["in_evaluation"] = True
  805. # Evaluation duration unit: episodes.
  806. # Switch on `complete_episode` rollouts. Also, make sure
  807. # rollout fragments are short so we never have more than one
  808. # episode in one rollout.
  809. if eval_config["evaluation_duration_unit"] == "episodes":
  810. eval_config.update({
  811. "batch_mode": "complete_episodes",
  812. "rollout_fragment_length": 1,
  813. })
  814. # Evaluation duration unit: timesteps.
  815. # - Set `batch_mode=truncate_episodes` so we don't perform rollouts
  816. # strictly along episode borders.
  817. # Set `rollout_fragment_length` such that desired steps are divided
  818. # equally amongst workers or - in "auto" duration mode - set it
  819. # to a reasonably small number (10), such that a single `sample()`
  820. # call doesn't take too much time so we can stop evaluation as soon
  821. # as possible after the train step is completed.
  822. else:
  823. eval_config.update({
  824. "batch_mode": "truncate_episodes",
  825. "rollout_fragment_length": 10
  826. if self.config["evaluation_duration"] == "auto" else int(
  827. math.ceil(
  828. self.config["evaluation_duration"] /
  829. (self.config["evaluation_num_workers"] or 1))),
  830. })
  831. self.config["evaluation_config"] = eval_config
  832. # Create a separate evaluation worker set for evaluation.
  833. # If evaluation_num_workers=0, use the evaluation set's local
  834. # worker for evaluation, otherwise, use its remote workers
  835. # (parallelized evaluation).
  836. self.evaluation_workers: WorkerSet = self._make_workers(
  837. env_creator=self.env_creator,
  838. validate_env=None,
  839. policy_class=self.get_default_policy_class(self.config),
  840. config=eval_config,
  841. num_workers=self.config["evaluation_num_workers"],
  842. # Don't even create a local worker if num_workers > 0.
  843. local_worker=False,
  844. )
  845. # TODO: Deprecated: In your sub-classes of Trainer, override `setup()`
  846. # directly and call super().setup() from within it if you would like the
  847. # default setup behavior plus some own setup logic.
  848. # If you don't need the env/workers/config/etc.. setup for you by super,
  849. # simply do not call super().setup() from your overridden method.
  850. def _init(self, config: TrainerConfigDict,
  851. env_creator: EnvCreator) -> None:
  852. raise NotImplementedError
  853. @ExperimentalAPI
  854. def get_default_policy_class(self, config: TrainerConfigDict) -> \
  855. Type[Policy]:
  856. """Returns a default Policy class to use, given a config.
  857. This class will be used inside RolloutWorkers' PolicyMaps in case
  858. the policy class is not provided by the user in any single- or
  859. multi-agent PolicySpec.
  860. This method is experimental and currently only used, iff the Trainer
  861. class was not created using the `build_trainer` utility and if
  862. the Trainer sub-class does not override `_init()` and create it's
  863. own WorkerSet in `_init()`.
  864. """
  865. return getattr(self, "_policy_class", None)
  866. @override(Trainable)
  867. def step(self) -> ResultDict:
  868. """Implements the main `Trainer.train()` logic.
  869. Takes n attempts to perform a single training step. Thereby
  870. catches RayErrors resulting from worker failures. After n attempts,
  871. fails gracefully.
  872. Override this method in your Trainer sub-classes if you would like to
  873. handle worker failures yourself. Otherwise, override
  874. `self.step_attempt()` to keep the n attempts (catch worker failures).
  875. Returns:
  876. The results dict with stats/infos on sampling, training,
  877. and - if required - evaluation.
  878. """
  879. step_attempt_results = None
  880. with self._step_context() as step_ctx:
  881. while not step_ctx.should_stop(step_attempt_results):
  882. # Try to train one step.
  883. try:
  884. step_attempt_results = self.step_attempt()
  885. # @ray.remote RolloutWorker failure.
  886. except RayError as e:
  887. # Try to recover w/o the failed worker.
  888. if self.config["ignore_worker_failures"]:
  889. logger.exception(
  890. "Error in train call, attempting to recover")
  891. self.try_recover_from_step_attempt()
  892. # Error out.
  893. else:
  894. logger.warning(
  895. "Worker crashed during call to `step_attempt()`. "
  896. "To try to continue training without the failed "
  897. "worker, set `ignore_worker_failures=True`.")
  898. raise e
  899. # Any other exception.
  900. except Exception as e:
  901. # Allow logs messages to propagate.
  902. time.sleep(0.5)
  903. raise e
  904. result = step_attempt_results
  905. if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
  906. # Sync filters on workers.
  907. self._sync_filters_if_needed(self.workers)
  908. # Collect worker metrics.
  909. if self.config["_disable_execution_plan_api"]:
  910. result = self._compile_step_results(
  911. step_ctx=step_ctx,
  912. step_attempt_results=step_attempt_results,
  913. )
  914. return result
  915. @ExperimentalAPI
  916. def step_attempt(self) -> ResultDict:
  917. """Attempts a single training step, including evaluation, if required.
  918. Override this method in your Trainer sub-classes if you would like to
  919. keep the n step-attempts logic (catch worker failures) in place or
  920. override `step()` directly if you would like to handle worker
  921. failures yourself.
  922. Returns:
  923. The results dict with stats/infos on sampling, training,
  924. and - if required - evaluation.
  925. """
  926. def auto_duration_fn(unit, num_eval_workers, eval_cfg, num_units_done):
  927. # Training is done and we already ran at least one
  928. # evaluation -> Nothing left to run.
  929. if num_units_done > 0 and \
  930. train_future.done():
  931. return 0
  932. # Count by episodes. -> Run n more
  933. # (n=num eval workers).
  934. elif unit == "episodes":
  935. return num_eval_workers
  936. # Count by timesteps. -> Run n*m*p more
  937. # (n=num eval workers; m=rollout fragment length;
  938. # p=num-envs-per-worker).
  939. else:
  940. return num_eval_workers * \
  941. eval_cfg["rollout_fragment_length"] * \
  942. eval_cfg["num_envs_per_worker"]
  943. # self._iteration gets incremented after this function returns,
  944. # meaning that e. g. the first time this function is called,
  945. # self._iteration will be 0.
  946. evaluate_this_iter = \
  947. self.config["evaluation_interval"] and \
  948. (self._iteration + 1) % self.config["evaluation_interval"] == 0
  949. step_results = {}
  950. # No evaluation necessary, just run the next training iteration.
  951. if not evaluate_this_iter:
  952. step_results = self._exec_plan_or_training_iteration_fn()
  953. # We have to evaluate in this training iteration.
  954. else:
  955. # No parallelism.
  956. if not self.config["evaluation_parallel_to_training"]:
  957. step_results = self._exec_plan_or_training_iteration_fn()
  958. # Kick off evaluation-loop (and parallel train() call,
  959. # if requested).
  960. # Parallel eval + training.
  961. if self.config["evaluation_parallel_to_training"]:
  962. with concurrent.futures.ThreadPoolExecutor() as executor:
  963. train_future = executor.submit(
  964. lambda: self._exec_plan_or_training_iteration_fn())
  965. # Automatically determine duration of the evaluation.
  966. if self.config["evaluation_duration"] == "auto":
  967. unit = self.config["evaluation_duration_unit"]
  968. step_results.update(
  969. self.evaluate(
  970. duration_fn=functools.partial(
  971. auto_duration_fn, unit, self.config[
  972. "evaluation_num_workers"], self.config[
  973. "evaluation_config"])))
  974. else:
  975. step_results.update(self.evaluate())
  976. # Collect the training results from the future.
  977. step_results.update(train_future.result())
  978. # Sequential: train (already done above), then eval.
  979. else:
  980. step_results.update(self.evaluate())
  981. # Attach latest available evaluation results to train results,
  982. # if necessary.
  983. if (not evaluate_this_iter
  984. and self.config["always_attach_evaluation_results"]):
  985. assert isinstance(self.evaluation_metrics, dict), \
  986. "Trainer.evaluate() needs to return a dict."
  987. step_results.update(self.evaluation_metrics)
  988. # Check `env_task_fn` for possible update of the env's task.
  989. if self.config["env_task_fn"] is not None:
  990. if not callable(self.config["env_task_fn"]):
  991. raise ValueError(
  992. "`env_task_fn` must be None or a callable taking "
  993. "[train_results, env, env_ctx] as args!")
  994. def fn(env, env_context, task_fn):
  995. new_task = task_fn(step_results, env, env_context)
  996. cur_task = env.get_task()
  997. if cur_task != new_task:
  998. env.set_task(new_task)
  999. fn = functools.partial(fn, task_fn=self.config["env_task_fn"])
  1000. self.workers.foreach_env_with_context(fn)
  1001. return step_results
  1002. @PublicAPI
  1003. def evaluate(
  1004. self,
  1005. episodes_left_fn=None, # deprecated
  1006. duration_fn: Optional[Callable[[int], int]] = None,
  1007. ) -> dict:
  1008. """Evaluates current policy under `evaluation_config` settings.
  1009. Note that this default implementation does not do anything beyond
  1010. merging evaluation_config with the normal trainer config.
  1011. Args:
  1012. duration_fn: An optional callable taking the already run
  1013. num episodes as only arg and returning the number of
  1014. episodes left to run. It's used to find out whether
  1015. evaluation should continue.
  1016. """
  1017. if episodes_left_fn is not None:
  1018. deprecation_warning(
  1019. old="Trainer.evaluate(episodes_left_fn)",
  1020. new="Trainer.evaluate(duration_fn)",
  1021. error=False)
  1022. duration_fn = episodes_left_fn
  1023. # In case we are evaluating (in a thread) parallel to training,
  1024. # we may have to re-enable eager mode here (gets disabled in the
  1025. # thread).
  1026. if self.config.get("framework") in ["tf2", "tfe"] and \
  1027. not tf.executing_eagerly():
  1028. tf1.enable_eager_execution()
  1029. # Call the `_before_evaluate` hook.
  1030. self._before_evaluate()
  1031. # Sync weights to the evaluation WorkerSet.
  1032. if self.evaluation_workers is not None:
  1033. self.evaluation_workers.sync_weights(
  1034. from_worker=self.workers.local_worker())
  1035. self._sync_filters_if_needed(self.evaluation_workers)
  1036. if self.config["custom_eval_function"]:
  1037. logger.info("Running custom eval function {}".format(
  1038. self.config["custom_eval_function"]))
  1039. metrics = self.config["custom_eval_function"](
  1040. self, self.evaluation_workers)
  1041. if not metrics or not isinstance(metrics, dict):
  1042. raise ValueError("Custom eval function must return "
  1043. "dict of metrics, got {}.".format(metrics))
  1044. else:
  1045. if self.evaluation_workers is None and \
  1046. self.workers.local_worker().input_reader is None:
  1047. raise ValueError(
  1048. "Cannot evaluate w/o an evaluation worker set in "
  1049. "the Trainer or w/o an env on the local worker!\n"
  1050. "Try one of the following:\n1) Set "
  1051. "`evaluation_interval` >= 0 to force creating a "
  1052. "separate evaluation worker set.\n2) Set "
  1053. "`create_env_on_driver=True` to force the local "
  1054. "(non-eval) worker to have an environment to "
  1055. "evaluate on.")
  1056. # How many episodes/timesteps do we need to run?
  1057. # In "auto" mode (only for parallel eval + training): Run as long
  1058. # as training lasts.
  1059. unit = self.config["evaluation_duration_unit"]
  1060. eval_cfg = self.config["evaluation_config"]
  1061. rollout = eval_cfg["rollout_fragment_length"]
  1062. num_envs = eval_cfg["num_envs_per_worker"]
  1063. duration = self.config["evaluation_duration"] if \
  1064. self.config["evaluation_duration"] != "auto" else \
  1065. (self.config["evaluation_num_workers"] or 1) * \
  1066. (1 if unit == "episodes" else rollout)
  1067. num_ts_run = 0
  1068. # Default done-function returns True, whenever num episodes
  1069. # have been completed.
  1070. if duration_fn is None:
  1071. def duration_fn(num_units_done):
  1072. return duration - num_units_done
  1073. logger.info(f"Evaluating current policy for {duration} {unit}.")
  1074. metrics = None
  1075. # No evaluation worker set ->
  1076. # Do evaluation using the local worker. Expect error due to the
  1077. # local worker not having an env.
  1078. if self.evaluation_workers is None:
  1079. # If unit=episodes -> Run n times `sample()` (each sample
  1080. # produces exactly 1 episode).
  1081. # If unit=ts -> Run 1 `sample()` b/c the
  1082. # `rollout_fragment_length` is exactly the desired ts.
  1083. iters = duration if unit == "episodes" else 1
  1084. for _ in range(iters):
  1085. num_ts_run += len(self.workers.local_worker().sample())
  1086. metrics = collect_metrics(self.workers.local_worker())
  1087. # Evaluation worker set only has local worker.
  1088. elif self.config["evaluation_num_workers"] == 0:
  1089. # If unit=episodes -> Run n times `sample()` (each sample
  1090. # produces exactly 1 episode).
  1091. # If unit=ts -> Run 1 `sample()` b/c the
  1092. # `rollout_fragment_length` is exactly the desired ts.
  1093. iters = duration if unit == "episodes" else 1
  1094. for _ in range(iters):
  1095. num_ts_run += len(
  1096. self.evaluation_workers.local_worker().sample())
  1097. # Evaluation worker set has n remote workers.
  1098. else:
  1099. # How many episodes have we run (across all eval workers)?
  1100. num_units_done = 0
  1101. round_ = 0
  1102. while True:
  1103. units_left_to_do = duration_fn(num_units_done)
  1104. if units_left_to_do <= 0:
  1105. break
  1106. round_ += 1
  1107. batches = ray.get([
  1108. w.sample.remote() for i, w in enumerate(
  1109. self.evaluation_workers.remote_workers())
  1110. if i * (1 if unit == "episodes" else rollout *
  1111. num_envs) < units_left_to_do
  1112. ])
  1113. # 1 episode per returned batch.
  1114. if unit == "episodes":
  1115. num_units_done += len(batches)
  1116. # n timesteps per returned batch.
  1117. else:
  1118. ts = sum(len(b) for b in batches)
  1119. num_ts_run += ts
  1120. num_units_done += ts
  1121. logger.info(f"Ran round {round_} of parallel evaluation "
  1122. f"({num_units_done}/{duration} {unit} done)")
  1123. if metrics is None:
  1124. metrics = collect_metrics(
  1125. self.evaluation_workers.local_worker(),
  1126. self.evaluation_workers.remote_workers())
  1127. metrics["timesteps_this_iter"] = num_ts_run
  1128. # Evaluation does not run for every step.
  1129. # Save evaluation metrics on trainer, so it can be attached to
  1130. # subsequent step results as latest evaluation result.
  1131. self.evaluation_metrics = {"evaluation": metrics}
  1132. # Also return the results here for convenience.
  1133. return self.evaluation_metrics
  1134. @ExperimentalAPI
  1135. def training_iteration(self) -> ResultDict:
  1136. """Default single iteration logic of an algorithm.
  1137. - Collect on-policy samples (SampleBatches) in parallel using the
  1138. Trainer's RolloutWorkers (@ray.remote).
  1139. - Concatenate collected SampleBatches into one train batch.
  1140. - Note that we may have more than one policy in the multi-agent case:
  1141. Call the different policies' `learn_on_batch` (simple optimizer) OR
  1142. `load_batch_into_buffer` + `learn_on_loaded_batch` (multi-GPU
  1143. optimizer) methods to calculate loss and update the model(s).
  1144. - Return all collected metrics for the iteration.
  1145. Returns:
  1146. The results dict from executing the training iteration.
  1147. """
  1148. # Some shortcuts.
  1149. batch_size = self.config["train_batch_size"]
  1150. # Collects SampleBatches in parallel and synchronously
  1151. # from the Trainer's RolloutWorkers until we hit the
  1152. # configured `train_batch_size`.
  1153. sample_batches = []
  1154. num_env_steps = 0
  1155. num_agent_steps = 0
  1156. while (not self._by_agent_steps and num_env_steps < batch_size) or \
  1157. (self._by_agent_steps and num_agent_steps < batch_size):
  1158. new_sample_batches = synchronous_parallel_sample(self.workers)
  1159. sample_batches.extend(new_sample_batches)
  1160. num_env_steps += sum(len(s) for s in new_sample_batches)
  1161. num_agent_steps += sum(
  1162. len(s) if isinstance(s, SampleBatch) else s.agent_steps()
  1163. for s in new_sample_batches)
  1164. self._counters[NUM_ENV_STEPS_SAMPLED] += num_env_steps
  1165. self._counters[NUM_AGENT_STEPS_SAMPLED] += num_agent_steps
  1166. # Combine all batches at once
  1167. train_batch = SampleBatch.concat_samples(sample_batches)
  1168. # Use simple optimizer (only for multi-agent or tf-eager; all other
  1169. # cases should use the multi-GPU optimizer, even if only using 1 GPU).
  1170. # TODO: (sven) rename MultiGPUOptimizer into something more
  1171. # meaningful.
  1172. if self.config.get("simple_optimizer") is True:
  1173. train_results = train_one_step(self, train_batch)
  1174. else:
  1175. train_results = multi_gpu_train_one_step(self, train_batch)
  1176. # Update weights - after learning on the local worker - on all remote
  1177. # workers.
  1178. if self.workers.remote_workers():
  1179. with self._timers[WORKER_UPDATE_TIMER]:
  1180. self.workers.sync_weights()
  1181. return train_results
  1182. @DeveloperAPI
  1183. @staticmethod
  1184. def execution_plan(workers, config, **kwargs):
  1185. # Collects experiences in parallel from multiple RolloutWorker actors.
  1186. rollouts = ParallelRollouts(workers, mode="bulk_sync")
  1187. # Combine experiences batches until we hit `train_batch_size` in size.
  1188. # Then, train the policy on those experiences and update the workers.
  1189. train_op = rollouts.combine(
  1190. ConcatBatches(
  1191. min_batch_size=config["train_batch_size"],
  1192. count_steps_by=config["multiagent"]["count_steps_by"],
  1193. ))
  1194. if config.get("simple_optimizer") is True:
  1195. train_op = train_op.for_each(TrainOneStep(workers))
  1196. else:
  1197. train_op = train_op.for_each(
  1198. MultiGPUTrainOneStep(
  1199. workers=workers,
  1200. sgd_minibatch_size=config.get("sgd_minibatch_size",
  1201. config["train_batch_size"]),
  1202. num_sgd_iter=config.get("num_sgd_iter", 1),
  1203. num_gpus=config["num_gpus"],
  1204. _fake_gpus=config["_fake_gpus"]))
  1205. # Add on the standard episode reward, etc. metrics reporting. This
  1206. # returns a LocalIterator[metrics_dict] representing metrics for each
  1207. # train step.
  1208. return StandardMetricsReporting(train_op, workers, config)
  1209. @PublicAPI
  1210. def compute_single_action(
  1211. self,
  1212. observation: Optional[TensorStructType] = None,
  1213. state: Optional[List[TensorStructType]] = None,
  1214. *,
  1215. prev_action: Optional[TensorStructType] = None,
  1216. prev_reward: Optional[float] = None,
  1217. info: Optional[EnvInfoDict] = None,
  1218. input_dict: Optional[SampleBatch] = None,
  1219. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1220. full_fetch: bool = False,
  1221. explore: Optional[bool] = None,
  1222. timestep: Optional[int] = None,
  1223. episode: Optional[Episode] = None,
  1224. unsquash_action: Optional[bool] = None,
  1225. clip_action: Optional[bool] = None,
  1226. # Deprecated args.
  1227. unsquash_actions=DEPRECATED_VALUE,
  1228. clip_actions=DEPRECATED_VALUE,
  1229. # Kwargs placeholder for future compatibility.
  1230. **kwargs,
  1231. ) -> Union[TensorStructType, Tuple[TensorStructType, List[TensorType],
  1232. Dict[str, TensorType]]]:
  1233. """Computes an action for the specified policy on the local worker.
  1234. Note that you can also access the policy object through
  1235. self.get_policy(policy_id) and call compute_single_action() on it
  1236. directly.
  1237. Args:
  1238. observation: Single (unbatched) observation from the
  1239. environment.
  1240. state: List of all RNN hidden (single, unbatched) state tensors.
  1241. prev_action: Single (unbatched) previous action value.
  1242. prev_reward: Single (unbatched) previous reward value.
  1243. info: Env info dict, if any.
  1244. input_dict: An optional SampleBatch that holds all the values
  1245. for: obs, state, prev_action, and prev_reward, plus maybe
  1246. custom defined views of the current env trajectory. Note
  1247. that only one of `obs` or `input_dict` must be non-None.
  1248. policy_id: Policy to query (only applies to multi-agent).
  1249. Default: "default_policy".
  1250. full_fetch: Whether to return extra action fetch results.
  1251. This is always set to True if `state` is specified.
  1252. explore: Whether to apply exploration to the action.
  1253. Default: None -> use self.config["explore"].
  1254. timestep: The current (sampling) time step.
  1255. episode: This provides access to all of the internal episodes'
  1256. state, which may be useful for model-based or multi-agent
  1257. algorithms.
  1258. unsquash_action: Should actions be unsquashed according to the
  1259. env's/Policy's action space? If None, use the value of
  1260. self.config["normalize_actions"].
  1261. clip_action: Should actions be clipped according to the
  1262. env's/Policy's action space? If None, use the value of
  1263. self.config["clip_actions"].
  1264. Keyword Args:
  1265. kwargs: forward compatibility placeholder
  1266. Returns:
  1267. The computed action if full_fetch=False, or a tuple of a) the
  1268. full output of policy.compute_actions() if full_fetch=True
  1269. or we have an RNN-based Policy.
  1270. Raises:
  1271. KeyError: If the `policy_id` cannot be found in this Trainer's
  1272. local worker.
  1273. """
  1274. if clip_actions != DEPRECATED_VALUE:
  1275. deprecation_warning(
  1276. old="Trainer.compute_single_action(`clip_actions`=...)",
  1277. new="Trainer.compute_single_action(`clip_action`=...)",
  1278. error=False)
  1279. clip_action = clip_actions
  1280. if unsquash_actions != DEPRECATED_VALUE:
  1281. deprecation_warning(
  1282. old="Trainer.compute_single_action(`unsquash_actions`=...)",
  1283. new="Trainer.compute_single_action(`unsquash_action`=...)",
  1284. error=False)
  1285. unsquash_action = unsquash_actions
  1286. # `unsquash_action` is None: Use value of config['normalize_actions'].
  1287. if unsquash_action is None:
  1288. unsquash_action = self.config["normalize_actions"]
  1289. # `clip_action` is None: Use value of config['clip_actions'].
  1290. elif clip_action is None:
  1291. clip_action = self.config["clip_actions"]
  1292. # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state`
  1293. # are all None.
  1294. err_msg = "Provide either `input_dict` OR [`observation`, ...] as " \
  1295. "args to Trainer.compute_single_action!"
  1296. if input_dict is not None:
  1297. assert observation is None and prev_action is None and \
  1298. prev_reward is None and state is None, err_msg
  1299. observation = input_dict[SampleBatch.OBS]
  1300. else:
  1301. assert observation is not None, err_msg
  1302. # Get the policy to compute the action for (in the multi-agent case,
  1303. # Trainer may hold >1 policies).
  1304. policy = self.get_policy(policy_id)
  1305. if policy is None:
  1306. raise KeyError(
  1307. f"PolicyID '{policy_id}' not found in PolicyMap of the "
  1308. f"Trainer's local worker!")
  1309. local_worker = self.workers.local_worker()
  1310. # Check the preprocessor and preprocess, if necessary.
  1311. pp = local_worker.preprocessors[policy_id]
  1312. if pp and type(pp).__name__ != "NoPreprocessor":
  1313. observation = pp.transform(observation)
  1314. observation = local_worker.filters[policy_id](
  1315. observation, update=False)
  1316. # Input-dict.
  1317. if input_dict is not None:
  1318. input_dict[SampleBatch.OBS] = observation
  1319. action, state, extra = policy.compute_single_action(
  1320. input_dict=input_dict,
  1321. explore=explore,
  1322. timestep=timestep,
  1323. episode=episode,
  1324. )
  1325. # Individual args.
  1326. else:
  1327. action, state, extra = policy.compute_single_action(
  1328. obs=observation,
  1329. state=state,
  1330. prev_action=prev_action,
  1331. prev_reward=prev_reward,
  1332. info=info,
  1333. explore=explore,
  1334. timestep=timestep,
  1335. episode=episode,
  1336. )
  1337. # If we work in normalized action space (normalize_actions=True),
  1338. # we re-translate here into the env's action space.
  1339. if unsquash_action:
  1340. action = space_utils.unsquash_action(action,
  1341. policy.action_space_struct)
  1342. # Clip, according to env's action space.
  1343. elif clip_action:
  1344. action = space_utils.clip_action(action,
  1345. policy.action_space_struct)
  1346. # Return 3-Tuple: Action, states, and extra-action fetches.
  1347. if state or full_fetch:
  1348. return action, state, extra
  1349. # Ensure backward compatibility.
  1350. else:
  1351. return action
  1352. @PublicAPI
  1353. def compute_actions(
  1354. self,
  1355. observations: TensorStructType,
  1356. state: Optional[List[TensorStructType]] = None,
  1357. *,
  1358. prev_action: Optional[TensorStructType] = None,
  1359. prev_reward: Optional[TensorStructType] = None,
  1360. info: Optional[EnvInfoDict] = None,
  1361. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1362. full_fetch: bool = False,
  1363. explore: Optional[bool] = None,
  1364. timestep: Optional[int] = None,
  1365. episodes: Optional[List[Episode]] = None,
  1366. unsquash_actions: Optional[bool] = None,
  1367. clip_actions: Optional[bool] = None,
  1368. # Deprecated.
  1369. normalize_actions=None,
  1370. **kwargs,
  1371. ):
  1372. """Computes an action for the specified policy on the local Worker.
  1373. Note that you can also access the policy object through
  1374. self.get_policy(policy_id) and call compute_actions() on it directly.
  1375. Args:
  1376. observation: Observation from the environment.
  1377. state: RNN hidden state, if any. If state is not None,
  1378. then all of compute_single_action(...) is returned
  1379. (computed action, rnn state(s), logits dictionary).
  1380. Otherwise compute_single_action(...)[0] is returned
  1381. (computed action).
  1382. prev_action: Previous action value, if any.
  1383. prev_reward: Previous reward, if any.
  1384. info: Env info dict, if any.
  1385. policy_id: Policy to query (only applies to multi-agent).
  1386. full_fetch: Whether to return extra action fetch results.
  1387. This is always set to True if RNN state is specified.
  1388. explore: Whether to pick an exploitation or exploration
  1389. action (default: None -> use self.config["explore"]).
  1390. timestep: The current (sampling) time step.
  1391. episodes: This provides access to all of the internal episodes'
  1392. state, which may be useful for model-based or multi-agent
  1393. algorithms.
  1394. unsquash_actions: Should actions be unsquashed according
  1395. to the env's/Policy's action space? If None, use
  1396. self.config["normalize_actions"].
  1397. clip_actions: Should actions be clipped according to the
  1398. env's/Policy's action space? If None, use
  1399. self.config["clip_actions"].
  1400. Keyword Args:
  1401. kwargs: forward compatibility placeholder
  1402. Returns:
  1403. The computed action if full_fetch=False, or a tuple consisting of
  1404. the full output of policy.compute_actions_from_input_dict() if
  1405. full_fetch=True or we have an RNN-based Policy.
  1406. """
  1407. if normalize_actions is not None:
  1408. deprecation_warning(
  1409. old="Trainer.compute_actions(`normalize_actions`=...)",
  1410. new="Trainer.compute_actions(`unsquash_actions`=...)",
  1411. error=False)
  1412. unsquash_actions = normalize_actions
  1413. # `unsquash_actions` is None: Use value of config['normalize_actions'].
  1414. if unsquash_actions is None:
  1415. unsquash_actions = self.config["normalize_actions"]
  1416. # `clip_actions` is None: Use value of config['clip_actions'].
  1417. elif clip_actions is None:
  1418. clip_actions = self.config["clip_actions"]
  1419. # Preprocess obs and states.
  1420. state_defined = state is not None
  1421. policy = self.get_policy(policy_id)
  1422. filtered_obs, filtered_state = [], []
  1423. for agent_id, ob in observations.items():
  1424. worker = self.workers.local_worker()
  1425. preprocessed = worker.preprocessors[policy_id].transform(ob)
  1426. filtered = worker.filters[policy_id](preprocessed, update=False)
  1427. filtered_obs.append(filtered)
  1428. if state is None:
  1429. continue
  1430. elif agent_id in state:
  1431. filtered_state.append(state[agent_id])
  1432. else:
  1433. filtered_state.append(policy.get_initial_state())
  1434. # Batch obs and states
  1435. obs_batch = np.stack(filtered_obs)
  1436. if state is None:
  1437. state = []
  1438. else:
  1439. state = list(zip(*filtered_state))
  1440. state = [np.stack(s) for s in state]
  1441. input_dict = {SampleBatch.OBS: obs_batch}
  1442. if prev_action:
  1443. input_dict[SampleBatch.PREV_ACTIONS] = prev_action
  1444. if prev_reward:
  1445. input_dict[SampleBatch.PREV_REWARDS] = prev_reward
  1446. if info:
  1447. input_dict[SampleBatch.INFOS] = info
  1448. for i, s in enumerate(state):
  1449. input_dict[f"state_in_{i}"] = s
  1450. # Batch compute actions
  1451. actions, states, infos = policy.compute_actions_from_input_dict(
  1452. input_dict=input_dict,
  1453. explore=explore,
  1454. timestep=timestep,
  1455. episodes=episodes,
  1456. )
  1457. # Unbatch actions for the environment into a multi-agent dict.
  1458. single_actions = space_utils.unbatch(actions)
  1459. actions = {}
  1460. for key, a in zip(observations, single_actions):
  1461. # If we work in normalized action space (normalize_actions=True),
  1462. # we re-translate here into the env's action space.
  1463. if unsquash_actions:
  1464. a = space_utils.unsquash_action(a, policy.action_space_struct)
  1465. # Clip, according to env's action space.
  1466. elif clip_actions:
  1467. a = space_utils.clip_action(a, policy.action_space_struct)
  1468. actions[key] = a
  1469. # Unbatch states into a multi-agent dict.
  1470. unbatched_states = {}
  1471. for idx, agent_id in enumerate(observations):
  1472. unbatched_states[agent_id] = [s[idx] for s in states]
  1473. # Return only actions or full tuple
  1474. if state_defined or full_fetch:
  1475. return actions, unbatched_states, infos
  1476. else:
  1477. return actions
  1478. @PublicAPI
  1479. def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy:
  1480. """Return policy for the specified id, or None.
  1481. Args:
  1482. policy_id: ID of the policy to return.
  1483. """
  1484. return self.workers.local_worker().get_policy(policy_id)
  1485. @PublicAPI
  1486. def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict:
  1487. """Return a dictionary of policy ids to weights.
  1488. Args:
  1489. policies: Optional list of policies to return weights for,
  1490. or None for all policies.
  1491. """
  1492. return self.workers.local_worker().get_weights(policies)
  1493. @PublicAPI
  1494. def set_weights(self, weights: Dict[PolicyID, dict]):
  1495. """Set policy weights by policy id.
  1496. Args:
  1497. weights: Map of policy ids to weights to set.
  1498. """
  1499. self.workers.local_worker().set_weights(weights)
  1500. @PublicAPI
  1501. def add_policy(
  1502. self,
  1503. policy_id: PolicyID,
  1504. policy_cls: Type[Policy],
  1505. *,
  1506. observation_space: Optional[gym.spaces.Space] = None,
  1507. action_space: Optional[gym.spaces.Space] = None,
  1508. config: Optional[PartialTrainerConfigDict] = None,
  1509. policy_state: Optional[PolicyState] = None,
  1510. policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID],
  1511. PolicyID]] = None,
  1512. policies_to_train: Optional[Container[PolicyID]] = None,
  1513. evaluation_workers: bool = True,
  1514. workers: Optional[List[Union[RolloutWorker, ActorHandle]]] = None,
  1515. ) -> Policy:
  1516. """Adds a new policy to this Trainer.
  1517. Args:
  1518. policy_id: ID of the policy to add.
  1519. policy_cls: The Policy class to use for
  1520. constructing the new Policy.
  1521. observation_space: The observation space of the policy to add.
  1522. If None, try to infer this space from the environment.
  1523. action_space: The action space of the policy to add.
  1524. If None, try to infer this space from the environment.
  1525. config: The config overrides for the policy to add.
  1526. policy_state: Optional state dict to apply to the new
  1527. policy instance, right after its construction.
  1528. policy_mapping_fn: An optional (updated) policy mapping function
  1529. to use from here on. Note that already ongoing episodes will
  1530. not change their mapping but will use the old mapping till
  1531. the end of the episode.
  1532. policies_to_train: An optional list/set of policy IDs to be
  1533. trained. If None, will keep the existing list in place.
  1534. Policies, whose IDs are not in the list will not be updated.
  1535. evaluation_workers: Whether to add the new policy also
  1536. to the evaluation WorkerSet.
  1537. workers: A list of RolloutWorker/ActorHandles (remote
  1538. RolloutWorkers) to add this policy to. If defined, will only
  1539. add the given policy to these workers.
  1540. Returns:
  1541. The newly added policy (the copy that got added to the local
  1542. worker).
  1543. """
  1544. kwargs = dict(
  1545. policy_id=policy_id,
  1546. policy_cls=policy_cls,
  1547. observation_space=observation_space,
  1548. action_space=action_space,
  1549. config=config,
  1550. policy_state=policy_state,
  1551. policy_mapping_fn=policy_mapping_fn,
  1552. policies_to_train=list(policies_to_train),
  1553. )
  1554. def fn(worker: RolloutWorker):
  1555. # `foreach_worker` function: Adds the policy the the worker (and
  1556. # maybe changes its policy_mapping_fn - if provided here).
  1557. worker.add_policy(**kwargs)
  1558. if workers is not None:
  1559. ray_gets = []
  1560. for worker in workers:
  1561. if isinstance(worker, ActorHandle):
  1562. ray_gets.append(worker.add_policy.remote(**kwargs))
  1563. else:
  1564. fn(worker)
  1565. ray.get(ray_gets)
  1566. else:
  1567. # Run foreach_worker fn on all workers.
  1568. self.workers.foreach_worker(fn)
  1569. # Update evaluation workers, if necessary.
  1570. if evaluation_workers and self.evaluation_workers is not None:
  1571. self.evaluation_workers.foreach_worker(fn)
  1572. # Return newly added policy (from the local rollout worker).
  1573. return self.get_policy(policy_id)
  1574. @PublicAPI
  1575. def remove_policy(
  1576. self,
  1577. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1578. *,
  1579. policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
  1580. policies_to_train: Optional[List[PolicyID]] = None,
  1581. evaluation_workers: bool = True,
  1582. ) -> None:
  1583. """Removes a new policy from this Trainer.
  1584. Args:
  1585. policy_id: ID of the policy to be removed.
  1586. policy_mapping_fn: An optional (updated) policy mapping function
  1587. to use from here on. Note that already ongoing episodes will
  1588. not change their mapping but will use the old mapping till
  1589. the end of the episode.
  1590. policies_to_train: An optional list of policy IDs to be trained.
  1591. If None, will keep the existing list in place. Policies,
  1592. whose IDs are not in the list will not be updated.
  1593. evaluation_workers: Whether to also remove the policy from the
  1594. evaluation WorkerSet.
  1595. """
  1596. def fn(worker):
  1597. worker.remove_policy(
  1598. policy_id=policy_id,
  1599. policy_mapping_fn=policy_mapping_fn,
  1600. policies_to_train=policies_to_train,
  1601. )
  1602. self.workers.foreach_worker(fn)
  1603. if evaluation_workers and self.evaluation_workers is not None:
  1604. self.evaluation_workers.foreach_worker(fn)
  1605. @DeveloperAPI
  1606. def export_policy_model(self,
  1607. export_dir: str,
  1608. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1609. onnx: Optional[int] = None) -> None:
  1610. """Exports policy model with given policy_id to a local directory.
  1611. Args:
  1612. export_dir: Writable local directory.
  1613. policy_id: Optional policy id to export.
  1614. onnx: If given, will export model in ONNX format. The
  1615. value of this parameter set the ONNX OpSet version to use.
  1616. If None, the output format will be DL framework specific.
  1617. Example:
  1618. >>> trainer = MyTrainer()
  1619. >>> for _ in range(10):
  1620. >>> trainer.train()
  1621. >>> trainer.export_policy_model("/tmp/dir")
  1622. >>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1)
  1623. """
  1624. self.get_policy(policy_id).export_model(export_dir, onnx)
  1625. @DeveloperAPI
  1626. def export_policy_checkpoint(
  1627. self,
  1628. export_dir: str,
  1629. filename_prefix: str = "model",
  1630. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1631. ) -> None:
  1632. """Exports policy model checkpoint to a local directory.
  1633. Args:
  1634. export_dir: Writable local directory.
  1635. filename_prefix: file name prefix of checkpoint files.
  1636. policy_id: Optional policy id to export.
  1637. Example:
  1638. >>> trainer = MyTrainer()
  1639. >>> for _ in range(10):
  1640. >>> trainer.train()
  1641. >>> trainer.export_policy_checkpoint("/tmp/export_dir")
  1642. """
  1643. self.get_policy(policy_id).export_checkpoint(export_dir,
  1644. filename_prefix)
  1645. @DeveloperAPI
  1646. def import_policy_model_from_h5(
  1647. self,
  1648. import_file: str,
  1649. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1650. ) -> None:
  1651. """Imports a policy's model with given policy_id from a local h5 file.
  1652. Args:
  1653. import_file: The h5 file to import from.
  1654. policy_id: Optional policy id to import into.
  1655. Example:
  1656. >>> trainer = MyTrainer()
  1657. >>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
  1658. >>> for _ in range(10):
  1659. >>> trainer.train()
  1660. """
  1661. self.get_policy(policy_id).import_model_from_h5(import_file)
  1662. # Sync new weights to remote workers.
  1663. self._sync_weights_to_workers(worker_set=self.workers)
  1664. @override(Trainable)
  1665. def save_checkpoint(self, checkpoint_dir: str) -> str:
  1666. checkpoint_path = os.path.join(checkpoint_dir,
  1667. "checkpoint-{}".format(self.iteration))
  1668. pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
  1669. return checkpoint_path
  1670. @override(Trainable)
  1671. def load_checkpoint(self, checkpoint_path: str) -> None:
  1672. extra_data = pickle.load(open(checkpoint_path, "rb"))
  1673. self.__setstate__(extra_data)
  1674. @override(Trainable)
  1675. def log_result(self, result: ResultDict) -> None:
  1676. # Log after the callback is invoked, so that the user has a chance
  1677. # to mutate the result.
  1678. self.callbacks.on_train_result(trainer=self, result=result)
  1679. # Then log according to Trainable's logging logic.
  1680. Trainable.log_result(self, result)
  1681. @override(Trainable)
  1682. def cleanup(self) -> None:
  1683. # Stop all workers.
  1684. if hasattr(self, "workers"):
  1685. self.workers.stop()
  1686. @classmethod
  1687. @override(Trainable)
  1688. def default_resource_request(
  1689. cls, config: PartialTrainerConfigDict) -> \
  1690. Union[Resources, PlacementGroupFactory]:
  1691. # Default logic for RLlib algorithms (Trainers):
  1692. # Create one bundle per individual worker (local or remote).
  1693. # Use `num_cpus_for_driver` and `num_gpus` for the local worker and
  1694. # `num_cpus_per_worker` and `num_gpus_per_worker` for the remote
  1695. # workers to determine their CPU/GPU resource needs.
  1696. # Convenience config handles.
  1697. cf = dict(cls.get_default_config(), **config)
  1698. eval_cf = cf["evaluation_config"]
  1699. # TODO(ekl): add custom resources here once tune supports them
  1700. # Return PlacementGroupFactory containing all needed resources
  1701. # (already properly defined as device bundles).
  1702. return PlacementGroupFactory(
  1703. bundles=[{
  1704. # Local worker.
  1705. "CPU": cf["num_cpus_for_driver"],
  1706. "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
  1707. }] + [
  1708. {
  1709. # RolloutWorkers.
  1710. "CPU": cf["num_cpus_per_worker"],
  1711. "GPU": cf["num_gpus_per_worker"],
  1712. } for _ in range(cf["num_workers"])
  1713. ] + ([
  1714. {
  1715. # Evaluation workers.
  1716. # Note: The local eval worker is located on the driver CPU.
  1717. "CPU": eval_cf.get("num_cpus_per_worker",
  1718. cf["num_cpus_per_worker"]),
  1719. "GPU": eval_cf.get("num_gpus_per_worker",
  1720. cf["num_gpus_per_worker"]),
  1721. } for _ in range(cf["evaluation_num_workers"])
  1722. ] if cf["evaluation_interval"] else []),
  1723. strategy=config.get("placement_strategy", "PACK"))
  1724. @DeveloperAPI
  1725. def _before_evaluate(self):
  1726. """Pre-evaluation callback."""
  1727. pass
  1728. def _get_env_creator_from_env_id(
  1729. self, env_id: Optional[str] = None) -> EnvCreator:
  1730. """Returns an env creator callable, given an `env_id` (e.g. "CartPole-v0").
  1731. Args:
  1732. env_id: An already tune registered env ID, a known gym env name,
  1733. or None (if no env is used).
  1734. Returns:
  1735. """
  1736. if env_id:
  1737. # An already registered env.
  1738. if _global_registry.contains(ENV_CREATOR, env_id):
  1739. return _global_registry.get(ENV_CREATOR, env_id)
  1740. # A class path specifier.
  1741. elif "." in env_id:
  1742. def env_creator_from_classpath(env_context):
  1743. try:
  1744. env_obj = from_config(env_id, env_context)
  1745. except ValueError:
  1746. raise EnvError(
  1747. ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_id))
  1748. return env_obj
  1749. return env_creator_from_classpath
  1750. # Try gym/PyBullet/Vizdoom.
  1751. else:
  1752. return functools.partial(
  1753. gym_env_creator, env_descriptor=env_id)
  1754. # No env -> Env creator always returns None.
  1755. else:
  1756. return lambda env_config: None
  1757. @DeveloperAPI
  1758. def _make_workers(
  1759. self,
  1760. *,
  1761. env_creator: EnvCreator,
  1762. validate_env: Optional[Callable[[EnvType, EnvContext], None]],
  1763. policy_class: Type[Policy],
  1764. config: TrainerConfigDict,
  1765. num_workers: int,
  1766. local_worker: bool = True,
  1767. ) -> WorkerSet:
  1768. """Default factory method for a WorkerSet running under this Trainer.
  1769. Override this method by passing a custom `make_workers` into
  1770. `build_trainer`.
  1771. Args:
  1772. env_creator: A function that return and Env given an env
  1773. config.
  1774. validate_env: Optional callable to validate the generated
  1775. environment. The env to be checked is the one returned from
  1776. the env creator, which may be a (single, not-yet-vectorized)
  1777. gym.Env or your custom RLlib env type (e.g. MultiAgentEnv,
  1778. VectorEnv, BaseEnv, etc..).
  1779. policy_class: The Policy class to use for creating the policies
  1780. of the workers.
  1781. config: The Trainer's config.
  1782. num_workers: Number of remote rollout workers to create.
  1783. 0 for local only.
  1784. local_worker: Whether to create a local (non @ray.remote) worker
  1785. in the returned set as well (default: True). If `num_workers`
  1786. is 0, always create a local worker.
  1787. Returns:
  1788. The created WorkerSet.
  1789. """
  1790. return WorkerSet(
  1791. env_creator=env_creator,
  1792. validate_env=validate_env,
  1793. policy_class=policy_class,
  1794. trainer_config=config,
  1795. num_workers=num_workers,
  1796. local_worker=local_worker,
  1797. logdir=self.logdir,
  1798. )
  1799. def _sync_filters_if_needed(self, workers: WorkerSet):
  1800. if self.config.get("observation_filter", "NoFilter") != "NoFilter":
  1801. FilterManager.synchronize(
  1802. workers.local_worker().filters,
  1803. workers.remote_workers(),
  1804. update_remote=self.config["synchronize_filters"])
  1805. logger.debug("synchronized filters: {}".format(
  1806. workers.local_worker().filters))
  1807. @DeveloperAPI
  1808. def _sync_weights_to_workers(
  1809. self,
  1810. *,
  1811. worker_set: Optional[WorkerSet] = None,
  1812. workers: Optional[List[RolloutWorker]] = None,
  1813. ) -> None:
  1814. """Sync "main" weights to given WorkerSet or list of workers."""
  1815. assert worker_set is not None
  1816. # Broadcast the new policy weights to all evaluation workers.
  1817. logger.info("Synchronizing weights to workers.")
  1818. weights = ray.put(self.workers.local_worker().save())
  1819. worker_set.foreach_worker(lambda w: w.restore(ray.get(weights)))
  1820. def _exec_plan_or_training_iteration_fn(self):
  1821. if self.config["_disable_execution_plan_api"]:
  1822. results = self.training_iteration()
  1823. else:
  1824. results = next(self.train_exec_impl)
  1825. return results
  1826. @classmethod
  1827. @override(Trainable)
  1828. def resource_help(cls, config: TrainerConfigDict) -> str:
  1829. return ("\n\nYou can adjust the resource requests of RLlib agents by "
  1830. "setting `num_workers`, `num_gpus`, and other configs. See "
  1831. "the DEFAULT_CONFIG defined by each agent for more info.\n\n"
  1832. "The config of this agent is: {}".format(config))
  1833. @classmethod
  1834. def merge_trainer_configs(cls,
  1835. config1: TrainerConfigDict,
  1836. config2: PartialTrainerConfigDict,
  1837. _allow_unknown_configs: Optional[bool] = None
  1838. ) -> TrainerConfigDict:
  1839. """Merges a complete Trainer config with a partial override dict.
  1840. Respects nested structures within the config dicts. The values in the
  1841. partial override dict take priority.
  1842. Args:
  1843. config1: The complete Trainer's dict to be merged (overridden)
  1844. with `config2`.
  1845. config2: The partial override config dict to merge on top of
  1846. `config1`.
  1847. _allow_unknown_configs: If True, keys in `config2` that don't exist
  1848. in `config1` are allowed and will be added to the final config.
  1849. Returns:
  1850. The merged full trainer config dict.
  1851. """
  1852. config1 = copy.deepcopy(config1)
  1853. if "callbacks" in config2 and type(config2["callbacks"]) is dict:
  1854. legacy_callbacks_dict = config2["callbacks"]
  1855. def make_callbacks():
  1856. # Deprecation warning will be logged by DefaultCallbacks.
  1857. return DefaultCallbacks(
  1858. legacy_callbacks_dict=legacy_callbacks_dict)
  1859. config2["callbacks"] = make_callbacks
  1860. if _allow_unknown_configs is None:
  1861. _allow_unknown_configs = cls._allow_unknown_configs
  1862. return deep_update(config1, config2, _allow_unknown_configs,
  1863. cls._allow_unknown_subkeys,
  1864. cls._override_all_subkeys_if_type_changes)
  1865. @staticmethod
  1866. def validate_framework(config: PartialTrainerConfigDict) -> None:
  1867. """Validates the config dictionary wrt the framework settings.
  1868. Args:
  1869. config: The config dictionary to be validated.
  1870. """
  1871. _tf1, _tf, _tfv = None, None, None
  1872. _torch = None
  1873. framework = config["framework"]
  1874. tf_valid_frameworks = {"tf", "tf2", "tfe"}
  1875. if framework not in tf_valid_frameworks and framework != "torch":
  1876. return
  1877. elif framework in tf_valid_frameworks:
  1878. _tf1, _tf, _tfv = try_import_tf()
  1879. else:
  1880. _torch, _ = try_import_torch()
  1881. def check_if_correct_nn_framework_installed():
  1882. """Check if tf/torch experiment is running and tf/torch installed.
  1883. """
  1884. if framework in tf_valid_frameworks:
  1885. if not (_tf1 or _tf):
  1886. raise ImportError((
  1887. "TensorFlow was specified as the 'framework' "
  1888. "inside of your config dictionary. However, there was "
  1889. "no installation found. You can install TensorFlow "
  1890. "via `pip install tensorflow`"))
  1891. elif framework == "torch":
  1892. if not _torch:
  1893. raise ImportError(
  1894. ("PyTorch was specified as the 'framework' inside "
  1895. "of your config dictionary. However, there was no "
  1896. "installation found. You can install PyTorch via "
  1897. "`pip install torch`"))
  1898. def resolve_tf_settings():
  1899. """Check and resolve tf settings."""
  1900. if _tf1 and config["framework"] in ["tf2", "tfe"]:
  1901. if config["framework"] == "tf2" and _tfv < 2:
  1902. raise ValueError(
  1903. "You configured `framework`=tf2, but your installed "
  1904. "pip tf-version is < 2.0! Make sure your TensorFlow "
  1905. "version is >= 2.x.")
  1906. if not _tf1.executing_eagerly():
  1907. _tf1.enable_eager_execution()
  1908. # Recommend setting tracing to True for speedups.
  1909. logger.info(
  1910. f"Executing eagerly (framework='{config['framework']}'),"
  1911. f" with eager_tracing={config['eager_tracing']}. For "
  1912. "production workloads, make sure to set eager_tracing=True"
  1913. " in order to match the speed of tf-static-graph "
  1914. "(framework='tf'). For debugging purposes, "
  1915. "`eager_tracing=False` is the best choice.")
  1916. # Tf-static-graph (framework=tf): Recommend upgrading to tf2 and
  1917. # enabling eager tracing for similar speed.
  1918. elif _tf1 and config["framework"] == "tf":
  1919. logger.info(
  1920. "Your framework setting is 'tf', meaning you are using "
  1921. "static-graph mode. Set framework='tf2' to enable eager "
  1922. "execution with tf2.x. You may also then want to set "
  1923. "eager_tracing=True in order to reach similar execution "
  1924. "speed as with static-graph mode.")
  1925. check_if_correct_nn_framework_installed()
  1926. resolve_tf_settings()
  1927. @ExperimentalAPI
  1928. def validate_config(self, config: TrainerConfigDict) -> None:
  1929. """Validates a given config dict for this Trainer.
  1930. Users should override this method to implement custom validation
  1931. behavior. It is recommended to call `super().validate_config()` in
  1932. this override.
  1933. Args:
  1934. config: The given config dict to check.
  1935. Raises:
  1936. ValueError: If there is something wrong with the config.
  1937. """
  1938. model_config = config.get("model")
  1939. if model_config is None:
  1940. config["model"] = model_config = {}
  1941. # Monitor should be replaced by `record_env`.
  1942. if config.get("monitor", DEPRECATED_VALUE) != DEPRECATED_VALUE:
  1943. deprecation_warning("monitor", "record_env", error=False)
  1944. config["record_env"] = config.get("monitor", False)
  1945. # Empty string would fail some if-blocks checking for this setting.
  1946. # Set to True instead, meaning: use default output dir to store
  1947. # the videos.
  1948. if config.get("record_env") == "":
  1949. config["record_env"] = True
  1950. # Use DefaultCallbacks class, if callbacks is None.
  1951. if config["callbacks"] is None:
  1952. config["callbacks"] = DefaultCallbacks
  1953. # Check, whether given `callbacks` is a callable.
  1954. if not callable(config["callbacks"]):
  1955. raise ValueError("`callbacks` must be a callable method that "
  1956. "returns a subclass of DefaultCallbacks, got "
  1957. f"{config['callbacks']}!")
  1958. # Multi-GPU settings.
  1959. simple_optim_setting = config.get("simple_optimizer", DEPRECATED_VALUE)
  1960. if simple_optim_setting != DEPRECATED_VALUE:
  1961. deprecation_warning(old="simple_optimizer", error=False)
  1962. # Validate "multiagent" sub-dict and convert policy 4-tuples to
  1963. # PolicySpec objects.
  1964. policies, is_multi_agent = check_multi_agent(config)
  1965. framework = config.get("framework")
  1966. # Multi-GPU setting: Must use MultiGPUTrainOneStep.
  1967. if config.get("num_gpus", 0) > 1:
  1968. if framework in ["tfe", "tf2"]:
  1969. raise ValueError("`num_gpus` > 1 not supported yet for "
  1970. "framework={}!".format(framework))
  1971. elif simple_optim_setting is True:
  1972. raise ValueError(
  1973. "Cannot use `simple_optimizer` if `num_gpus` > 1! "
  1974. "Consider not setting `simple_optimizer` in your config.")
  1975. config["simple_optimizer"] = False
  1976. # Auto-setting: Use simple-optimizer for tf-eager or multiagent,
  1977. # otherwise: MultiGPUTrainOneStep (if supported by the algo's execution
  1978. # plan).
  1979. elif simple_optim_setting == DEPRECATED_VALUE:
  1980. # tf-eager: Must use simple optimizer.
  1981. if framework not in ["tf", "torch"]:
  1982. config["simple_optimizer"] = True
  1983. # Multi-agent case: Try using MultiGPU optimizer (only
  1984. # if all policies used are DynamicTFPolicies or TorchPolicies).
  1985. elif is_multi_agent:
  1986. from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
  1987. from ray.rllib.policy.torch_policy import TorchPolicy
  1988. default_policy_cls = self.get_default_policy_class(config)
  1989. if any((p[0] or default_policy_cls) is None
  1990. or not issubclass(p[0] or default_policy_cls,
  1991. (DynamicTFPolicy, TorchPolicy))
  1992. for p in config["multiagent"]["policies"].values()):
  1993. config["simple_optimizer"] = True
  1994. else:
  1995. config["simple_optimizer"] = False
  1996. else:
  1997. config["simple_optimizer"] = False
  1998. # User manually set simple-optimizer to False -> Error if tf-eager.
  1999. elif simple_optim_setting is False:
  2000. if framework in ["tfe", "tf2"]:
  2001. raise ValueError("`simple_optimizer=False` not supported for "
  2002. "framework={}!".format(framework))
  2003. # Offline RL settings.
  2004. if isinstance(config["input_evaluation"], tuple):
  2005. config["input_evaluation"] = list(config["input_evaluation"])
  2006. elif not isinstance(config["input_evaluation"], list):
  2007. raise ValueError(
  2008. "`input_evaluation` must be a list of strings, got {}!".format(
  2009. config["input_evaluation"]))
  2010. # Check model config.
  2011. # If no preprocessing, propagate into model's config as well
  2012. # (so model will know, whether inputs are preprocessed or not).
  2013. if config["_disable_preprocessor_api"] is True:
  2014. model_config["_disable_preprocessor_api"] = True
  2015. # If no action flattening, propagate into model's config as well
  2016. # (so model will know, whether action inputs are already flattened or
  2017. # not).
  2018. if config["_disable_action_flattening"] is True:
  2019. model_config["_disable_action_flattening"] = True
  2020. # Prev_a/r settings.
  2021. prev_a_r = model_config.get("lstm_use_prev_action_reward",
  2022. DEPRECATED_VALUE)
  2023. if prev_a_r != DEPRECATED_VALUE:
  2024. deprecation_warning(
  2025. "model.lstm_use_prev_action_reward",
  2026. "model.lstm_use_prev_action and model.lstm_use_prev_reward",
  2027. error=False)
  2028. model_config["lstm_use_prev_action"] = prev_a_r
  2029. model_config["lstm_use_prev_reward"] = prev_a_r
  2030. # Check batching/sample collection settings.
  2031. if config["batch_mode"] not in [
  2032. "truncate_episodes", "complete_episodes"
  2033. ]:
  2034. raise ValueError("`batch_mode` must be one of [truncate_episodes|"
  2035. "complete_episodes]! Got {}".format(
  2036. config["batch_mode"]))
  2037. # Check multi-agent batch count mode.
  2038. if config["multiagent"].get("count_steps_by", "env_steps") not in \
  2039. ["env_steps", "agent_steps"]:
  2040. raise ValueError(
  2041. "`count_steps_by` must be one of [env_steps|agent_steps]! "
  2042. "Got {}".format(config["multiagent"]["count_steps_by"]))
  2043. self._by_agent_steps = self.config["multiagent"].get(
  2044. "count_steps_by") == "agent_steps"
  2045. # Metrics settings.
  2046. if config["metrics_smoothing_episodes"] != DEPRECATED_VALUE:
  2047. deprecation_warning(
  2048. old="metrics_smoothing_episodes",
  2049. new="metrics_num_episodes_for_smoothing",
  2050. error=False,
  2051. )
  2052. config["metrics_num_episodes_for_smoothing"] = \
  2053. config["metrics_smoothing_episodes"]
  2054. if config["min_iter_time_s"] != DEPRECATED_VALUE:
  2055. deprecation_warning(
  2056. old="min_iter_time_s",
  2057. new="min_time_s_per_reporting",
  2058. error=False,
  2059. )
  2060. config["min_time_s_per_reporting"] = config["min_iter_time_s"]
  2061. if config["collect_metrics_timeout"] != DEPRECATED_VALUE:
  2062. # TODO: Warn once all algos use the `training_iteration` method.
  2063. # deprecation_warning(
  2064. # old="collect_metrics_timeout",
  2065. # new="metrics_episode_collection_timeout_s",
  2066. # error=False,
  2067. # )
  2068. config["metrics_episode_collection_timeout_s"] = \
  2069. config["collect_metrics_timeout"]
  2070. if config["timesteps_per_iteration"] != DEPRECATED_VALUE:
  2071. # TODO: Warn once all algos use the `training_iteration` method.
  2072. # deprecation_warning(
  2073. # old="timesteps_per_iteration",
  2074. # new="min_sample_timesteps_per_reporting",
  2075. # error=False,
  2076. # )
  2077. config["min_sample_timesteps_per_reporting"] = \
  2078. config["timesteps_per_iteration"]
  2079. # Metrics settings.
  2080. if config["metrics_smoothing_episodes"] != DEPRECATED_VALUE:
  2081. deprecation_warning(
  2082. old="metrics_smoothing_episodes",
  2083. new="metrics_num_episodes_for_smoothing",
  2084. error=False,
  2085. )
  2086. config["metrics_num_episodes_for_smoothing"] = \
  2087. config["metrics_smoothing_episodes"]
  2088. # Evaluation settings.
  2089. # Deprecated setting: `evaluation_num_episodes`.
  2090. if config["evaluation_num_episodes"] != DEPRECATED_VALUE:
  2091. deprecation_warning(
  2092. old="evaluation_num_episodes",
  2093. new="`evaluation_duration` and `evaluation_duration_unit="
  2094. "episodes`",
  2095. error=False)
  2096. config["evaluation_duration"] = config["evaluation_num_episodes"]
  2097. config["evaluation_duration_unit"] = "episodes"
  2098. config["evaluation_num_episodes"] = DEPRECATED_VALUE
  2099. # If `evaluation_num_workers` > 0, warn if `evaluation_interval` is
  2100. # None (also set `evaluation_interval` to 1).
  2101. if config["evaluation_num_workers"] > 0 and \
  2102. not config["evaluation_interval"]:
  2103. logger.warning(
  2104. f"You have specified {config['evaluation_num_workers']} "
  2105. "evaluation workers, but your `evaluation_interval` is None! "
  2106. "Therefore, evaluation will not occur automatically with each"
  2107. " call to `Trainer.train()`. Instead, you will have to call "
  2108. "`Trainer.evaluate()` manually in order to trigger an "
  2109. "evaluation run.")
  2110. # If `evaluation_num_workers=0` and
  2111. # `evaluation_parallel_to_training=True`, warn that you need
  2112. # at least one remote eval worker for parallel training and
  2113. # evaluation, and set `evaluation_parallel_to_training` to False.
  2114. elif config["evaluation_num_workers"] == 0 and \
  2115. config.get("evaluation_parallel_to_training", False):
  2116. logger.warning(
  2117. "`evaluation_parallel_to_training` can only be done if "
  2118. "`evaluation_num_workers` > 0! Setting "
  2119. "`evaluation_parallel_to_training` to False.")
  2120. config["evaluation_parallel_to_training"] = False
  2121. # If `evaluation_duration=auto`, error if
  2122. # `evaluation_parallel_to_training=False`.
  2123. if config["evaluation_duration"] == "auto":
  2124. if not config["evaluation_parallel_to_training"]:
  2125. raise ValueError(
  2126. "`evaluation_duration=auto` not supported for "
  2127. "`evaluation_parallel_to_training=False`!")
  2128. # Make sure, it's an int otherwise.
  2129. elif not isinstance(config["evaluation_duration"], int) or \
  2130. config["evaluation_duration"] <= 0:
  2131. raise ValueError("`evaluation_duration` ({}) must be an int and "
  2132. ">0!".format(config["evaluation_duration"]))
  2133. @ExperimentalAPI
  2134. @staticmethod
  2135. def validate_env(env: EnvType, env_context: EnvContext) -> None:
  2136. """Env validator function for this Trainer class.
  2137. Override this in child classes to define custom validation
  2138. behavior.
  2139. Args:
  2140. env: The (sub-)environment to validate. This is normally a
  2141. single sub-environment (e.g. a gym.Env) within a vectorized
  2142. setup.
  2143. env_context: The EnvContext to configure the environment.
  2144. Raises:
  2145. Exception in case something is wrong with the given environment.
  2146. """
  2147. pass
  2148. def try_recover_from_step_attempt(self) -> None:
  2149. """Try to identify and remove any unhealthy workers.
  2150. This method is called after an unexpected remote error is encountered
  2151. from a worker during the call to `self.step_attempt()` (within
  2152. `self.step()`). It issues check requests to all current workers and
  2153. removes any that respond with error. If no healthy workers remain,
  2154. an error is raised. Otherwise, tries to re-build the execution plan
  2155. with the remaining (healthy) workers.
  2156. """
  2157. workers = getattr(self, "workers", None)
  2158. if not isinstance(workers, WorkerSet):
  2159. return
  2160. logger.info("Health checking all workers...")
  2161. checks = []
  2162. for ev in workers.remote_workers():
  2163. _, obj_ref = ev.sample_with_count.remote()
  2164. checks.append(obj_ref)
  2165. healthy_workers = []
  2166. for i, obj_ref in enumerate(checks):
  2167. w = workers.remote_workers()[i]
  2168. try:
  2169. ray.get(obj_ref)
  2170. healthy_workers.append(w)
  2171. logger.info("Worker {} looks healthy".format(i + 1))
  2172. except RayError:
  2173. logger.exception("Removing unhealthy worker {}".format(i + 1))
  2174. try:
  2175. w.__ray_terminate__.remote()
  2176. except Exception:
  2177. logger.exception("Error terminating unhealthy worker")
  2178. if len(healthy_workers) < 1:
  2179. raise RuntimeError(
  2180. "Not enough healthy workers remain to continue.")
  2181. logger.warning("Recreating execution plan after failure.")
  2182. workers.reset(healthy_workers)
  2183. if not self.config.get("_disable_execution_plan_api") and \
  2184. callable(self.execution_plan):
  2185. logger.warning("Recreating execution plan after failure")
  2186. self.train_exec_impl = self.execution_plan(
  2187. workers, self.config, **self._kwargs_for_execution_plan())
  2188. @override(Trainable)
  2189. def _export_model(self, export_formats: List[str],
  2190. export_dir: str) -> Dict[str, str]:
  2191. ExportFormat.validate(export_formats)
  2192. exported = {}
  2193. if ExportFormat.CHECKPOINT in export_formats:
  2194. path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
  2195. self.export_policy_checkpoint(path)
  2196. exported[ExportFormat.CHECKPOINT] = path
  2197. if ExportFormat.MODEL in export_formats:
  2198. path = os.path.join(export_dir, ExportFormat.MODEL)
  2199. self.export_policy_model(path)
  2200. exported[ExportFormat.MODEL] = path
  2201. if ExportFormat.ONNX in export_formats:
  2202. path = os.path.join(export_dir, ExportFormat.ONNX)
  2203. self.export_policy_model(
  2204. path, onnx=int(os.getenv("ONNX_OPSET", "11")))
  2205. exported[ExportFormat.ONNX] = path
  2206. return exported
  2207. def import_model(self, import_file: str):
  2208. """Imports a model from import_file.
  2209. Note: Currently, only h5 files are supported.
  2210. Args:
  2211. import_file (str): The file to import the model from.
  2212. Returns:
  2213. A dict that maps ExportFormats to successfully exported models.
  2214. """
  2215. # Check for existence.
  2216. if not os.path.exists(import_file):
  2217. raise FileNotFoundError(
  2218. "`import_file` '{}' does not exist! Can't import Model.".
  2219. format(import_file))
  2220. # Get the format of the given file.
  2221. import_format = "h5" # TODO(sven): Support checkpoint loading.
  2222. ExportFormat.validate([import_format])
  2223. if import_format != ExportFormat.H5:
  2224. raise NotImplementedError
  2225. else:
  2226. return self.import_policy_model_from_h5(import_file)
  2227. def __getstate__(self) -> dict:
  2228. state = {}
  2229. if hasattr(self, "workers"):
  2230. state["worker"] = self.workers.local_worker().save()
  2231. # TODO: Experimental functionality: Store contents of replay buffer
  2232. # to checkpoint, only if user has configured this.
  2233. if self.local_replay_buffer is not None and \
  2234. self.config.get("store_buffer_in_checkpoints"):
  2235. state["local_replay_buffer"] = \
  2236. self.local_replay_buffer.get_state()
  2237. if self.train_exec_impl is not None:
  2238. state["train_exec_impl"] = (
  2239. self.train_exec_impl.shared_metrics.get().save())
  2240. return state
  2241. def __setstate__(self, state: dict):
  2242. if hasattr(self, "workers") and "worker" in state:
  2243. self.workers.local_worker().restore(state["worker"])
  2244. remote_state = ray.put(state["worker"])
  2245. for r in self.workers.remote_workers():
  2246. r.restore.remote(remote_state)
  2247. # If necessary, restore replay data as well.
  2248. if self.local_replay_buffer is not None:
  2249. # TODO: Experimental functionality: Restore contents of replay
  2250. # buffer from checkpoint, only if user has configured this.
  2251. if self.config.get("store_buffer_in_checkpoints"):
  2252. if "local_replay_buffer" in state:
  2253. self.local_replay_buffer.set_state(
  2254. state["local_replay_buffer"])
  2255. else:
  2256. logger.warning(
  2257. "`store_buffer_in_checkpoints` is True, but no replay "
  2258. "data found in state!")
  2259. elif "local_replay_buffer" in state and \
  2260. log_once("no_store_buffer_in_checkpoints_but_data_found"):
  2261. logger.warning(
  2262. "`store_buffer_in_checkpoints` is False, but some replay "
  2263. "data found in state!")
  2264. if self.train_exec_impl is not None:
  2265. self.train_exec_impl.shared_metrics.get().restore(
  2266. state["train_exec_impl"])
  2267. # TODO: Deprecate this method (`build_trainer` should no longer be used).
  2268. @staticmethod
  2269. def with_updates(**overrides) -> Type["Trainer"]:
  2270. raise NotImplementedError(
  2271. "`with_updates` may only be called on Trainer sub-classes "
  2272. "that were generated via the `ray.rllib.agents.trainer_template."
  2273. "build_trainer()` function (which has been deprecated)!")
  2274. @DeveloperAPI
  2275. def _create_local_replay_buffer_if_necessary(
  2276. self, config: PartialTrainerConfigDict
  2277. ) -> Optional[MultiAgentReplayBuffer]:
  2278. """Create a MultiAgentReplayBuffer instance if necessary.
  2279. Args:
  2280. config: Algorithm-specific configuration data.
  2281. Returns:
  2282. MultiAgentReplayBuffer instance based on trainer config.
  2283. None, if local replay buffer is not needed.
  2284. """
  2285. # These are the agents that utilizes a local replay buffer.
  2286. if ("replay_buffer_config" not in config
  2287. or not config["replay_buffer_config"]):
  2288. # Does not need a replay buffer.
  2289. return None
  2290. replay_buffer_config = config["replay_buffer_config"]
  2291. if ("type" not in replay_buffer_config
  2292. or replay_buffer_config["type"] != "MultiAgentReplayBuffer"):
  2293. # DistributedReplayBuffer coming soon.
  2294. return None
  2295. capacity = config.get("buffer_size", DEPRECATED_VALUE)
  2296. if capacity != DEPRECATED_VALUE:
  2297. # Print a deprecation warning.
  2298. deprecation_warning(
  2299. old="config['buffer_size']",
  2300. new="config['replay_buffer_config']['capacity']",
  2301. error=False)
  2302. else:
  2303. # Get capacity out of replay_buffer_config.
  2304. capacity = replay_buffer_config["capacity"]
  2305. # Configure prio. replay parameters.
  2306. if config.get("prioritized_replay"):
  2307. prio_args = {
  2308. "prioritized_replay_alpha": config["prioritized_replay_alpha"],
  2309. "prioritized_replay_beta": config["prioritized_replay_beta"],
  2310. "prioritized_replay_eps": config["prioritized_replay_eps"],
  2311. }
  2312. # Switch off prioritization (alpha=0.0).
  2313. else:
  2314. prio_args = {"prioritized_replay_alpha": 0.0}
  2315. return MultiAgentReplayBuffer(
  2316. num_shards=1,
  2317. learning_starts=config["learning_starts"],
  2318. capacity=capacity,
  2319. replay_batch_size=config["train_batch_size"],
  2320. replay_mode=config["multiagent"]["replay_mode"],
  2321. replay_sequence_length=config.get("replay_sequence_length", 1),
  2322. replay_burn_in=config.get("burn_in", 0),
  2323. replay_zero_init_states=config.get("zero_init_states", True),
  2324. **prio_args)
  2325. @DeveloperAPI
  2326. def _kwargs_for_execution_plan(self):
  2327. kwargs = {}
  2328. if self.local_replay_buffer:
  2329. kwargs["local_replay_buffer"] = self.local_replay_buffer
  2330. return kwargs
  2331. def _register_if_needed(self, env_object: Union[str, EnvType, None],
  2332. config) -> Optional[str]:
  2333. if isinstance(env_object, str):
  2334. return env_object
  2335. elif isinstance(env_object, type):
  2336. name = env_object.__name__
  2337. if config.get("remote_worker_envs"):
  2338. @ray.remote(num_cpus=0)
  2339. class _wrapper(env_object):
  2340. # Add convenience `_get_spaces` and `_is_multi_agent`
  2341. # methods.
  2342. def _get_spaces(self):
  2343. return self.observation_space, self.action_space
  2344. def _is_multi_agent(self):
  2345. return isinstance(self, MultiAgentEnv)
  2346. register_env(name, lambda cfg: _wrapper.remote(cfg))
  2347. else:
  2348. register_env(name, lambda cfg: env_object(cfg))
  2349. return name
  2350. elif env_object is None:
  2351. return None
  2352. raise ValueError(
  2353. "{} is an invalid env specification. ".format(env_object) +
  2354. "You can specify a custom env as either a class "
  2355. "(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").")
  2356. def _step_context(trainer):
  2357. class StepCtx:
  2358. def __enter__(self):
  2359. # First call to stop, `result` is expected to be None ->
  2360. # Start with self.failures=-1 -> set to 0 the very first call
  2361. # to `self.stop()`.
  2362. self.failures = -1
  2363. self.time_start = time.time()
  2364. self.sampled = 0
  2365. self.trained = 0
  2366. self.init_env_steps_sampled = trainer._counters[
  2367. NUM_ENV_STEPS_SAMPLED]
  2368. self.init_env_steps_trained = trainer._counters[
  2369. NUM_ENV_STEPS_TRAINED]
  2370. self.init_agent_steps_sampled = trainer._counters[
  2371. NUM_AGENT_STEPS_SAMPLED]
  2372. self.init_agent_steps_trained = trainer._counters[
  2373. NUM_AGENT_STEPS_TRAINED]
  2374. return self
  2375. def __exit__(self, *args):
  2376. pass
  2377. def should_stop(self, result):
  2378. # First call to stop, `result` is expected to be None ->
  2379. # self.failures=0.
  2380. if result is None:
  2381. # Fail after n retries.
  2382. self.failures += 1
  2383. if self.failures > MAX_WORKER_FAILURE_RETRIES:
  2384. raise RuntimeError(
  2385. "Failed to recover from worker crash.")
  2386. # Stopping criteria: Only when using the `training_iteration`
  2387. # API, b/c for the `exec_plan` API, the logic to stop is
  2388. # already built into the execution plans via the
  2389. # `StandardMetricsReporting` op.
  2390. elif trainer.config["_disable_execution_plan_api"]:
  2391. if trainer._by_agent_steps:
  2392. self.sampled = \
  2393. trainer._counters[NUM_AGENT_STEPS_SAMPLED] - \
  2394. self.init_agent_steps_sampled
  2395. self.trained = \
  2396. trainer._counters[NUM_AGENT_STEPS_TRAINED] - \
  2397. self.init_agent_steps_trained
  2398. else:
  2399. self.sampled = \
  2400. trainer._counters[NUM_ENV_STEPS_SAMPLED] - \
  2401. self.init_env_steps_sampled
  2402. self.trained = \
  2403. trainer._counters[NUM_ENV_STEPS_TRAINED] - \
  2404. self.init_env_steps_trained
  2405. min_t = trainer.config["min_time_s_per_reporting"]
  2406. min_sample_ts = trainer.config[
  2407. "min_sample_timesteps_per_reporting"]
  2408. min_train_ts = trainer.config[
  2409. "min_train_timesteps_per_reporting"]
  2410. # Repeat if not enough time has passed or if not enough
  2411. # env|train timesteps have been processed (or these min
  2412. # values are not provided by the user).
  2413. if result is not None and \
  2414. (not min_t or
  2415. time.time() - self.time_start >= min_t) and \
  2416. (not min_sample_ts or
  2417. self.sampled >= min_sample_ts) and \
  2418. (not min_train_ts or
  2419. self.trained >= min_train_ts):
  2420. return True
  2421. # No errors (we got results) -> Break.
  2422. elif result is not None:
  2423. return True
  2424. return False
  2425. return StepCtx()
  2426. def _compile_step_results(self, *, step_ctx, step_attempt_results=None):
  2427. # Return dict.
  2428. results: ResultDict = {}
  2429. step_attempt_results = step_attempt_results or {}
  2430. # Evaluation results.
  2431. if "evaluation" in step_attempt_results:
  2432. results["evaluation"] = step_attempt_results.pop("evaluation")
  2433. # Custom metrics and episode media.
  2434. results["custom_metrics"] = step_attempt_results.pop(
  2435. "custom_metrics", {})
  2436. results["episode_media"] = step_attempt_results.pop(
  2437. "episode_media", {})
  2438. # Learner info.
  2439. results["info"] = {LEARNER_INFO: step_attempt_results}
  2440. # Collect rollout worker metrics.
  2441. episodes, self._episodes_to_be_collected = collect_episodes(
  2442. self.workers.local_worker(),
  2443. self.workers.remote_workers(),
  2444. self._episodes_to_be_collected,
  2445. timeout_seconds=self.config[
  2446. "metrics_episode_collection_timeout_s"])
  2447. orig_episodes = list(episodes)
  2448. missing = self.config["metrics_num_episodes_for_smoothing"] - \
  2449. len(episodes)
  2450. if missing > 0:
  2451. episodes = self._episode_history[-missing:] + episodes
  2452. assert len(episodes) <= \
  2453. self.config["metrics_num_episodes_for_smoothing"]
  2454. self._episode_history.extend(orig_episodes)
  2455. self._episode_history = \
  2456. self._episode_history[
  2457. -self.config["metrics_num_episodes_for_smoothing"]:]
  2458. results["sampler_results"] = summarize_episodes(
  2459. episodes, orig_episodes)
  2460. # TODO: Don't dump sampler results into top-level.
  2461. results.update(results["sampler_results"])
  2462. results["num_healthy_workers"] = len(self.workers.remote_workers())
  2463. # Train-steps- and env/agent-steps this iteration.
  2464. for c in [
  2465. NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED,
  2466. NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_TRAINED
  2467. ]:
  2468. results[c] = self._counters[c]
  2469. if self._by_agent_steps:
  2470. results[NUM_AGENT_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled
  2471. results[NUM_AGENT_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
  2472. # TODO: For CQL and other algos, count by trained steps.
  2473. results["timesteps_total"] = self._counters[
  2474. NUM_AGENT_STEPS_SAMPLED]
  2475. else:
  2476. results[NUM_ENV_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled
  2477. results[NUM_ENV_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
  2478. # TODO: For CQL and other algos, count by trained steps.
  2479. results["timesteps_total"] = self._counters[NUM_ENV_STEPS_SAMPLED]
  2480. # TODO: Backward compatibility.
  2481. results["agent_timesteps_total"] = self._counters[
  2482. NUM_AGENT_STEPS_SAMPLED]
  2483. # Process timer results.
  2484. timers = {}
  2485. for k, timer in self._timers.items():
  2486. timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
  2487. if timer.has_units_processed():
  2488. timers["{}_throughput".format(k)] = round(
  2489. timer.mean_throughput, 3)
  2490. results["timers"] = timers
  2491. # Process counter results.
  2492. counters = {}
  2493. for k, counter in self._counters.items():
  2494. counters[k] = counter
  2495. results["counters"] = counters
  2496. # TODO: Backward compatibility.
  2497. results["info"].update(counters)
  2498. return results
  2499. def __repr__(self):
  2500. return type(self).__name__
  2501. @Deprecated(new="Trainer.evaluate()", error=True)
  2502. def _evaluate(self) -> dict:
  2503. return self.evaluate()
  2504. @Deprecated(new="Trainer.compute_single_action()", error=False)
  2505. def compute_action(self, *args, **kwargs):
  2506. return self.compute_single_action(*args, **kwargs)
  2507. @Deprecated(new="Trainer.try_recover_from_step_attempt()", error=False)
  2508. def _try_recover(self):
  2509. return self.try_recover_from_step_attempt()
  2510. @staticmethod
  2511. @Deprecated(new="Trainer.validate_config()", error=False)
  2512. def _validate_config(config, trainer_or_none):
  2513. assert trainer_or_none is not None
  2514. return trainer_or_none.validate_config(config)