algorithm_config.py 180 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726
  1. import copy
  2. import logging
  3. import math
  4. import os
  5. import sys
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. Callable,
  10. Container,
  11. Dict,
  12. Mapping,
  13. Optional,
  14. Tuple,
  15. Type,
  16. Union,
  17. )
  18. from packaging import version
  19. import ray
  20. from ray.rllib.algorithms.callbacks import DefaultCallbacks
  21. from ray.rllib.core.learner.learner import LearnerHyperparameters
  22. from ray.rllib.core.learner.learner_group_config import LearnerGroupConfig, ModuleSpec
  23. from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
  24. from ray.rllib.core.rl_module.rl_module import ModuleID, SingleAgentRLModuleSpec
  25. from ray.rllib.env.env_context import EnvContext
  26. from ray.rllib.core.learner.learner import TorchCompileWhatToCompile
  27. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  28. from ray.rllib.env.wrappers.atari_wrappers import is_atari
  29. from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
  30. from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
  31. from ray.rllib.evaluation.episode import Episode
  32. from ray.rllib.models import MODEL_DEFAULTS
  33. from ray.rllib.policy.policy import Policy, PolicySpec
  34. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  35. from ray.rllib.utils import deep_update, merge_dicts
  36. from ray.rllib.utils.annotations import (
  37. ExperimentalAPI,
  38. OverrideToImplementCustomLogic_CallToSuperRecommended,
  39. )
  40. from ray.rllib.utils.deprecation import (
  41. DEPRECATED_VALUE,
  42. Deprecated,
  43. deprecation_warning,
  44. )
  45. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  46. from ray.rllib.utils.from_config import NotProvided, from_config
  47. from ray.rllib.utils.gym import (
  48. convert_old_gym_space_to_gymnasium_space,
  49. try_import_gymnasium_and_gym,
  50. )
  51. from ray.rllib.utils.policy import validate_policy_id
  52. from ray.rllib.utils.schedules.scheduler import Scheduler
  53. from ray.rllib.utils.serialization import (
  54. NOT_SERIALIZABLE,
  55. deserialize_type,
  56. serialize_type,
  57. )
  58. from ray.rllib.utils.torch_utils import TORCH_COMPILE_REQUIRED_VERSION
  59. from ray.rllib.utils.typing import (
  60. AgentID,
  61. AlgorithmConfigDict,
  62. EnvConfigDict,
  63. EnvType,
  64. LearningRateOrSchedule,
  65. MultiAgentPolicyConfigDict,
  66. PartialAlgorithmConfigDict,
  67. PolicyID,
  68. ResultDict,
  69. SampleBatchType,
  70. )
  71. from ray.tune.logger import Logger
  72. from ray.tune.registry import get_trainable_cls
  73. from ray.tune.result import TRIAL_INFO
  74. from ray.tune.tune import _Config
  75. from ray.util import log_once
  76. gym, old_gym = try_import_gymnasium_and_gym()
  77. Space = gym.Space
  78. """TODO(jungong, sven): in "offline_data" we can potentially unify all input types
  79. under input and input_config keys. E.g.
  80. input: sample
  81. input_config {
  82. env: CartPole-v1
  83. }
  84. or:
  85. input: json_reader
  86. input_config {
  87. path: /tmp/
  88. }
  89. or:
  90. input: dataset
  91. input_config {
  92. format: parquet
  93. path: /tmp/
  94. }
  95. """
  96. if TYPE_CHECKING:
  97. from ray.rllib.algorithms.algorithm import Algorithm
  98. from ray.rllib.core.learner import Learner
  99. logger = logging.getLogger(__name__)
  100. def _check_rl_module_spec(module_spec: ModuleSpec) -> None:
  101. if not isinstance(module_spec, (SingleAgentRLModuleSpec, MultiAgentRLModuleSpec)):
  102. raise ValueError(
  103. "rl_module_spec must be an instance of "
  104. "SingleAgentRLModuleSpec or MultiAgentRLModuleSpec."
  105. f"Got {type(module_spec)} instead."
  106. )
  107. class AlgorithmConfig(_Config):
  108. """A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration.
  109. Example:
  110. >>> from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  111. >>> from ray.rllib.algorithms.callbacks import MemoryTrackingCallbacks
  112. >>> # Construct a generic config object, specifying values within different
  113. >>> # sub-categories, e.g. "training".
  114. >>> config = AlgorithmConfig().training(gamma=0.9, lr=0.01) # doctest: +SKIP
  115. ... .environment(env="CartPole-v1")
  116. ... .resources(num_gpus=0)
  117. ... .rollouts(num_rollout_workers=4)
  118. ... .callbacks(MemoryTrackingCallbacks)
  119. >>> # A config object can be used to construct the respective Algorithm.
  120. >>> rllib_algo = config.build() # doctest: +SKIP
  121. Example:
  122. >>> from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  123. >>> from ray import tune
  124. >>> # In combination with a tune.grid_search:
  125. >>> config = AlgorithmConfig()
  126. >>> config.training(lr=tune.grid_search([0.01, 0.001])) # doctest: +SKIP
  127. >>> # Use `to_dict()` method to get the legacy plain python config dict
  128. >>> # for usage with `tune.Tuner().fit()`.
  129. >>> tune.Tuner( # doctest: +SKIP
  130. ... "[registered Algorithm class]", param_space=config.to_dict()
  131. ... ).fit()
  132. """
  133. @staticmethod
  134. def DEFAULT_POLICY_MAPPING_FN(aid, episode, worker, **kwargs):
  135. # The default policy mapping function to use if None provided.
  136. # Map any agent ID to "default_policy".
  137. return DEFAULT_POLICY_ID
  138. @classmethod
  139. def from_dict(cls, config_dict: dict) -> "AlgorithmConfig":
  140. """Creates an AlgorithmConfig from a legacy python config dict.
  141. Examples:
  142. >>> from ray.rllib.algorithms.ppo.ppo import PPOConfig # doctest: +SKIP
  143. >>> ppo_config = PPOConfig.from_dict({...}) # doctest: +SKIP
  144. >>> ppo = ppo_config.build(env="Pendulum-v1") # doctest: +SKIP
  145. Args:
  146. config_dict: The legacy formatted python config dict for some algorithm.
  147. Returns:
  148. A new AlgorithmConfig object that matches the given python config dict.
  149. """
  150. # Create a default config object of this class.
  151. config_obj = cls()
  152. # Remove `_is_frozen` flag from config dict in case the AlgorithmConfig that
  153. # the dict was derived from was already frozen (we don't want to copy the
  154. # frozenness).
  155. config_dict.pop("_is_frozen", None)
  156. config_obj.update_from_dict(config_dict)
  157. return config_obj
  158. @classmethod
  159. def overrides(cls, **kwargs):
  160. """Generates and validates a set of config key/value pairs (passed via kwargs).
  161. Validation whether given config keys are valid is done immediately upon
  162. construction (by comparing against the properties of a default AlgorithmConfig
  163. object of this class).
  164. Allows combination with a full AlgorithmConfig object to yield a new
  165. AlgorithmConfig object.
  166. Used anywhere, we would like to enable the user to only define a few config
  167. settings that would change with respect to some main config, e.g. in multi-agent
  168. setups and evaluation configs.
  169. Examples:
  170. >>> from ray.rllib.algorithms.ppo import PPOConfig
  171. >>> from ray.rllib.policy.policy import PolicySpec
  172. >>> config = (
  173. ... PPOConfig()
  174. ... .multi_agent(
  175. ... policies={
  176. ... "pol0": PolicySpec(config=PPOConfig.overrides(lambda_=0.95))
  177. ... },
  178. ... )
  179. ... )
  180. >>> from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  181. >>> from ray.rllib.algorithms.pg import PGConfig
  182. >>> config = (
  183. ... PGConfig()
  184. ... .evaluation(
  185. ... evaluation_num_workers=1,
  186. ... evaluation_interval=1,
  187. ... evaluation_config=AlgorithmConfig.overrides(explore=False),
  188. ... )
  189. ... )
  190. Returns:
  191. A dict mapping valid config property-names to values.
  192. Raises:
  193. KeyError: In case a non-existing property name (kwargs key) is being
  194. passed in. Valid property names are taken from a default AlgorithmConfig
  195. object of `cls`.
  196. """
  197. default_config = cls()
  198. config_overrides = {}
  199. for key, value in kwargs.items():
  200. if not hasattr(default_config, key):
  201. raise KeyError(
  202. f"Invalid property name {key} for config class {cls.__name__}!"
  203. )
  204. # Allow things like "lambda" as well.
  205. key = cls._translate_special_keys(key, warn_deprecated=True)
  206. config_overrides[key] = value
  207. return config_overrides
  208. def __init__(self, algo_class=None):
  209. # Define all settings and their default values.
  210. # Define the default RLlib Algorithm class that this AlgorithmConfig will be
  211. # applied to.
  212. self.algo_class = algo_class
  213. # `self.python_environment()`
  214. self.extra_python_environs_for_driver = {}
  215. self.extra_python_environs_for_worker = {}
  216. # `self.resources()`
  217. self.num_gpus = 0
  218. self.num_cpus_per_worker = 1
  219. self.num_gpus_per_worker = 0
  220. self._fake_gpus = False
  221. self.num_cpus_for_local_worker = 1
  222. self.num_learner_workers = 0
  223. self.num_gpus_per_learner_worker = 0
  224. self.num_cpus_per_learner_worker = 1
  225. self.local_gpu_idx = 0
  226. self.custom_resources_per_worker = {}
  227. self.placement_strategy = "PACK"
  228. # `self.framework()`
  229. self.framework_str = "torch"
  230. self.eager_tracing = True
  231. self.eager_max_retraces = 20
  232. self.tf_session_args = {
  233. # note: overridden by `local_tf_session_args`
  234. "intra_op_parallelism_threads": 2,
  235. "inter_op_parallelism_threads": 2,
  236. "gpu_options": {
  237. "allow_growth": True,
  238. },
  239. "log_device_placement": False,
  240. "device_count": {"CPU": 1},
  241. # Required by multi-GPU (num_gpus > 1).
  242. "allow_soft_placement": True,
  243. }
  244. self.local_tf_session_args = {
  245. # Allow a higher level of parallelism by default, but not unlimited
  246. # since that can cause crashes with many concurrent drivers.
  247. "intra_op_parallelism_threads": 8,
  248. "inter_op_parallelism_threads": 8,
  249. }
  250. # Torch compile settings
  251. self.torch_compile_learner = False
  252. self.torch_compile_learner_what_to_compile = (
  253. TorchCompileWhatToCompile.FORWARD_TRAIN
  254. )
  255. # AOT Eager is a dummy backend and will not result in speedups
  256. self.torch_compile_learner_dynamo_backend = (
  257. "aot_eager" if sys.platform == "darwin" else "inductor"
  258. )
  259. self.torch_compile_learner_dynamo_mode = None
  260. self.torch_compile_worker = False
  261. # AOT Eager is a dummy backend and will not result in speedups
  262. self.torch_compile_worker_dynamo_backend = (
  263. "aot_eager" if sys.platform == "darwin" else "onnxrt"
  264. )
  265. self.torch_compile_worker_dynamo_mode = None
  266. # `self.environment()`
  267. self.env = None
  268. self.env_config = {}
  269. self.observation_space = None
  270. self.action_space = None
  271. self.env_task_fn = None
  272. self.render_env = False
  273. self.clip_rewards = None
  274. self.normalize_actions = True
  275. self.clip_actions = False
  276. self.disable_env_checking = False
  277. self.auto_wrap_old_gym_envs = True
  278. self.action_mask_key = "action_mask"
  279. # Whether this env is an atari env (for atari-specific preprocessing).
  280. # If not specified, we will try to auto-detect this.
  281. self._is_atari = None
  282. # `self.rollouts()`
  283. self.env_runner_cls = None
  284. self.num_rollout_workers = 0
  285. self.num_envs_per_worker = 1
  286. self.sample_collector = SimpleListCollector
  287. self.create_env_on_local_worker = False
  288. self.sample_async = False
  289. self.enable_connectors = True
  290. self.update_worker_filter_stats = True
  291. self.use_worker_filter_stats = True
  292. self.rollout_fragment_length = 200
  293. self.batch_mode = "truncate_episodes"
  294. self.remote_worker_envs = False
  295. self.remote_env_batch_wait_ms = 0
  296. self.validate_workers_after_construction = True
  297. self.preprocessor_pref = "deepmind"
  298. self.observation_filter = "NoFilter"
  299. self.compress_observations = False
  300. self.enable_tf1_exec_eagerly = False
  301. self.sampler_perf_stats_ema_coef = None
  302. # `self.training()`
  303. self.gamma = 0.99
  304. self.lr = 0.001
  305. self.grad_clip = None
  306. self.grad_clip_by = "global_norm"
  307. self.train_batch_size = 32
  308. # TODO (sven): Unsolved problem with RLModules sometimes requiring settings from
  309. # the main AlgorithmConfig. We should not require the user to provide those
  310. # settings in both, the AlgorithmConfig (as property) AND the model config
  311. # dict. We should generally move to a world, in which there exists an
  312. # AlgorithmConfig that a) has-a user provided model config object and b)
  313. # is given a chance to compile a final model config (dict or object) that is
  314. # then passed into the RLModule/Catalog. This design would then match our
  315. # "compilation" pattern, where we compile automatically those settings that
  316. # should NOT be touched by the user.
  317. # In case, an Algorithm already uses the above described pattern (and has
  318. # `self.model` as a @property, ignore AttributeError (for trying to set this
  319. # property).
  320. try:
  321. self.model = copy.deepcopy(MODEL_DEFAULTS)
  322. except AttributeError:
  323. pass
  324. self.optimizer = {}
  325. self.max_requests_in_flight_per_sampler_worker = 2
  326. self._learner_class = None
  327. self._enable_learner_api = False
  328. # `self.callbacks()`
  329. self.callbacks_class = DefaultCallbacks
  330. # `self.explore()`
  331. self.explore = True
  332. # This is not compatible with RLModules, which have a method
  333. # `forward_exploration` to specify custom exploration behavior.
  334. self.exploration_config = {}
  335. # `self.multi_agent()`
  336. self.policies = {DEFAULT_POLICY_ID: PolicySpec()}
  337. self.algorithm_config_overrides_per_module = {}
  338. self.policy_map_capacity = 100
  339. self.policy_mapping_fn = self.DEFAULT_POLICY_MAPPING_FN
  340. self.policies_to_train = None
  341. self.policy_states_are_swappable = False
  342. self.observation_fn = None
  343. self.count_steps_by = "env_steps"
  344. # `self.offline_data()`
  345. self.input_ = "sampler"
  346. self.input_config = {}
  347. self.actions_in_input_normalized = False
  348. self.postprocess_inputs = False
  349. self.shuffle_buffer_size = 0
  350. self.output = None
  351. self.output_config = {}
  352. self.output_compress_columns = ["obs", "new_obs"]
  353. self.output_max_file_size = 64 * 1024 * 1024
  354. self.offline_sampling = False
  355. # `self.evaluation()`
  356. self.evaluation_interval = None
  357. self.evaluation_duration = 10
  358. self.evaluation_duration_unit = "episodes"
  359. self.evaluation_sample_timeout_s = 180.0
  360. self.evaluation_parallel_to_training = False
  361. self.evaluation_config = None
  362. self.off_policy_estimation_methods = {}
  363. self.ope_split_batch_by_episode = True
  364. self.evaluation_num_workers = 0
  365. self.custom_evaluation_function = None
  366. self.always_attach_evaluation_results = False
  367. self.enable_async_evaluation = False
  368. # TODO: Set this flag still in the config or - much better - in the
  369. # RolloutWorker as a property.
  370. self.in_evaluation = False
  371. self.sync_filters_on_rollout_workers_timeout_s = 60.0
  372. # `self.reporting()`
  373. self.keep_per_episode_custom_metrics = False
  374. self.metrics_episode_collection_timeout_s = 60.0
  375. self.metrics_num_episodes_for_smoothing = 100
  376. self.min_time_s_per_iteration = None
  377. self.min_train_timesteps_per_iteration = 0
  378. self.min_sample_timesteps_per_iteration = 0
  379. # `self.checkpointing()`
  380. self.export_native_model_files = False
  381. self.checkpoint_trainable_policies_only = False
  382. # `self.debugging()`
  383. self.logger_creator = None
  384. self.logger_config = None
  385. self.log_level = "WARN"
  386. self.log_sys_usage = True
  387. self.fake_sampler = False
  388. self.seed = None
  389. # `self.fault_tolerance()`
  390. self.ignore_worker_failures = False
  391. self.recreate_failed_workers = False
  392. # By default restart failed worker a thousand times.
  393. # This should be enough to handle normal transient failures.
  394. # This also prevents infinite number of restarts in case
  395. # the worker or env has a bug.
  396. self.max_num_worker_restarts = 1000
  397. # Small delay between worker restarts. In case rollout or
  398. # evaluation workers have remote dependencies, this delay can be
  399. # adjusted to make sure we don't flood them with re-connection
  400. # requests, and allow them enough time to recover.
  401. # This delay also gives Ray time to stream back error logging
  402. # and exceptions.
  403. self.delay_between_worker_restarts_s = 60.0
  404. self.restart_failed_sub_environments = False
  405. self.num_consecutive_worker_failures_tolerance = 100
  406. self.worker_health_probe_timeout_s = 60
  407. self.worker_restore_timeout_s = 1800
  408. # `self.rl_module()`
  409. self.rl_module_spec = None
  410. self._enable_rl_module_api = False
  411. # Helper to keep track of the original exploration config when dis-/enabling
  412. # rl modules.
  413. self.__prior_exploration_config = None
  414. # `self.experimental()`
  415. self._tf_policy_handles_more_than_one_loss = False
  416. self._disable_preprocessor_api = False
  417. self._disable_action_flattening = False
  418. self._disable_execution_plan_api = True
  419. self._disable_initialize_loss_from_dummy_batch = False
  420. # Has this config object been frozen (cannot alter its attributes anymore).
  421. self._is_frozen = False
  422. # TODO: Remove, once all deprecation_warning calls upon using these keys
  423. # have been removed.
  424. # === Deprecated keys ===
  425. self.simple_optimizer = DEPRECATED_VALUE
  426. self.monitor = DEPRECATED_VALUE
  427. self.evaluation_num_episodes = DEPRECATED_VALUE
  428. self.metrics_smoothing_episodes = DEPRECATED_VALUE
  429. self.timesteps_per_iteration = DEPRECATED_VALUE
  430. self.min_iter_time_s = DEPRECATED_VALUE
  431. self.collect_metrics_timeout = DEPRECATED_VALUE
  432. self.min_time_s_per_reporting = DEPRECATED_VALUE
  433. self.min_train_timesteps_per_reporting = DEPRECATED_VALUE
  434. self.min_sample_timesteps_per_reporting = DEPRECATED_VALUE
  435. self.input_evaluation = DEPRECATED_VALUE
  436. self.policy_map_cache = DEPRECATED_VALUE
  437. self.worker_cls = DEPRECATED_VALUE
  438. self.synchronize_filters = DEPRECATED_VALUE
  439. # The following values have moved because of the new ReplayBuffer API
  440. self.buffer_size = DEPRECATED_VALUE
  441. self.prioritized_replay = DEPRECATED_VALUE
  442. self.learning_starts = DEPRECATED_VALUE
  443. self.replay_batch_size = DEPRECATED_VALUE
  444. # -1 = DEPRECATED_VALUE is a valid value for replay_sequence_length
  445. self.replay_sequence_length = None
  446. self.replay_mode = DEPRECATED_VALUE
  447. self.prioritized_replay_alpha = DEPRECATED_VALUE
  448. self.prioritized_replay_beta = DEPRECATED_VALUE
  449. self.prioritized_replay_eps = DEPRECATED_VALUE
  450. self.min_time_s_per_reporting = DEPRECATED_VALUE
  451. self.min_train_timesteps_per_reporting = DEPRECATED_VALUE
  452. self.min_sample_timesteps_per_reporting = DEPRECATED_VALUE
  453. def to_dict(self) -> AlgorithmConfigDict:
  454. """Converts all settings into a legacy config dict for backward compatibility.
  455. Returns:
  456. A complete AlgorithmConfigDict, usable in backward-compatible Tune/RLlib
  457. use cases, e.g. w/ `tune.Tuner().fit()`.
  458. """
  459. config = copy.deepcopy(vars(self))
  460. config.pop("algo_class")
  461. config.pop("_is_frozen")
  462. # Worst naming convention ever: NEVER EVER use reserved key-words...
  463. if "lambda_" in config:
  464. assert hasattr(self, "lambda_")
  465. config["lambda"] = getattr(self, "lambda_")
  466. config.pop("lambda_")
  467. if "input_" in config:
  468. assert hasattr(self, "input_")
  469. config["input"] = getattr(self, "input_")
  470. config.pop("input_")
  471. # Convert `policies` (PolicySpecs?) into dict.
  472. # Convert policies dict such that each policy ID maps to a old-style.
  473. # 4-tuple: class, obs-, and action space, config.
  474. if "policies" in config and isinstance(config["policies"], dict):
  475. policies_dict = {}
  476. for policy_id, policy_spec in config.pop("policies").items():
  477. if isinstance(policy_spec, PolicySpec):
  478. policies_dict[policy_id] = (
  479. policy_spec.policy_class,
  480. policy_spec.observation_space,
  481. policy_spec.action_space,
  482. policy_spec.config,
  483. )
  484. else:
  485. policies_dict[policy_id] = policy_spec
  486. config["policies"] = policies_dict
  487. # Switch out deprecated vs new config keys.
  488. config["callbacks"] = config.pop("callbacks_class", DefaultCallbacks)
  489. config["create_env_on_driver"] = config.pop("create_env_on_local_worker", 1)
  490. config["custom_eval_function"] = config.pop("custom_evaluation_function", None)
  491. config["framework"] = config.pop("framework_str", None)
  492. config["num_cpus_for_driver"] = config.pop("num_cpus_for_local_worker", 1)
  493. config["num_workers"] = config.pop("num_rollout_workers", 0)
  494. # Simplify: Remove all deprecated keys that have as value `DEPRECATED_VALUE`.
  495. # These would be useless in the returned dict anyways.
  496. for dep_k in [
  497. "monitor",
  498. "evaluation_num_episodes",
  499. "metrics_smoothing_episodes",
  500. "timesteps_per_iteration",
  501. "min_iter_time_s",
  502. "collect_metrics_timeout",
  503. "buffer_size",
  504. "prioritized_replay",
  505. "learning_starts",
  506. "replay_batch_size",
  507. "replay_mode",
  508. "prioritized_replay_alpha",
  509. "prioritized_replay_beta",
  510. "prioritized_replay_eps",
  511. "min_time_s_per_reporting",
  512. "min_train_timesteps_per_reporting",
  513. "min_sample_timesteps_per_reporting",
  514. "input_evaluation",
  515. ]:
  516. if config.get(dep_k) == DEPRECATED_VALUE:
  517. config.pop(dep_k, None)
  518. return config
  519. def update_from_dict(
  520. self,
  521. config_dict: PartialAlgorithmConfigDict,
  522. ) -> "AlgorithmConfig":
  523. """Modifies this AlgorithmConfig via the provided python config dict.
  524. Warns if `config_dict` contains deprecated keys.
  525. Silently sets even properties of `self` that do NOT exist. This way, this method
  526. may be used to configure custom Policies which do not have their own specific
  527. AlgorithmConfig classes, e.g.
  528. `ray.rllib.examples.policy.random_policy::RandomPolicy`.
  529. Args:
  530. config_dict: The old-style python config dict (PartialAlgorithmConfigDict)
  531. to use for overriding some properties defined in there.
  532. Returns:
  533. This updated AlgorithmConfig object.
  534. """
  535. eval_call = {}
  536. # We deal with this special key before all others because it may influence
  537. # stuff like "exploration_config".
  538. # Namely, we want to re-instantiate the exploration config this config had
  539. # inside `self.rl_module()` before potentially overwriting it in the following.
  540. if "_enable_rl_module_api" in config_dict:
  541. self.rl_module(_enable_rl_module_api=config_dict["_enable_rl_module_api"])
  542. # Modify our properties one by one.
  543. for key, value in config_dict.items():
  544. key = self._translate_special_keys(key, warn_deprecated=False)
  545. # Ray Tune saves additional data under this magic keyword.
  546. # This should not get treated as AlgorithmConfig field.
  547. if key == TRIAL_INFO:
  548. continue
  549. if key == "_enable_rl_module_api":
  550. # We've dealt with this above.
  551. continue
  552. # Set our multi-agent settings.
  553. elif key == "multiagent":
  554. kwargs = {
  555. k: value[k]
  556. for k in [
  557. "policies",
  558. "policy_map_capacity",
  559. "policy_mapping_fn",
  560. "policies_to_train",
  561. "policy_states_are_swappable",
  562. "observation_fn",
  563. "count_steps_by",
  564. ]
  565. if k in value
  566. }
  567. self.multi_agent(**kwargs)
  568. # Some keys specify config sub-dicts and therefore should go through the
  569. # correct methods to properly `.update()` those from given config dict
  570. # (to not lose any sub-keys).
  571. elif key == "callbacks_class" and value != NOT_SERIALIZABLE:
  572. # For backward compatibility reasons, only resolve possible
  573. # classpath if value is a str type.
  574. if isinstance(value, str):
  575. value = deserialize_type(value, error=True)
  576. self.callbacks(callbacks_class=value)
  577. elif key == "env_config":
  578. self.environment(env_config=value)
  579. elif key.startswith("evaluation_"):
  580. eval_call[key] = value
  581. elif key == "exploration_config":
  582. if config_dict.get("_enable_rl_module_api", False):
  583. self.exploration_config = value
  584. continue
  585. if isinstance(value, dict) and "type" in value:
  586. value["type"] = deserialize_type(value["type"])
  587. self.exploration(exploration_config=value)
  588. elif key == "model":
  589. # Resolve possible classpath.
  590. if isinstance(value, dict) and value.get("custom_model"):
  591. value["custom_model"] = deserialize_type(value["custom_model"])
  592. self.training(**{key: value})
  593. elif key == "optimizer":
  594. self.training(**{key: value})
  595. elif key == "replay_buffer_config":
  596. if isinstance(value, dict) and "type" in value:
  597. value["type"] = deserialize_type(value["type"])
  598. self.training(**{key: value})
  599. elif key == "sample_collector":
  600. # Resolve possible classpath.
  601. value = deserialize_type(value)
  602. self.rollouts(sample_collector=value)
  603. # If config key matches a property, just set it, otherwise, warn and set.
  604. else:
  605. if not hasattr(self, key) and log_once(
  606. "unknown_property_in_algo_config"
  607. ):
  608. logger.warning(
  609. f"Cannot create {type(self).__name__} from given "
  610. f"`config_dict`! Property {key} not supported."
  611. )
  612. setattr(self, key, value)
  613. self.evaluation(**eval_call)
  614. return self
  615. # TODO(sven): We might want to have a `deserialize` method as well. Right now,
  616. # simply using the from_dict() API works in this same (deserializing) manner,
  617. # whether the dict used is actually code-free (already serialized) or not
  618. # (i.e. a classic RLlib config dict with e.g. "callbacks" key still pointing to
  619. # a class).
  620. def serialize(self) -> Mapping[str, Any]:
  621. """Returns a mapping from str to JSON'able values representing this config.
  622. The resulting values will not have any code in them.
  623. Classes (such as `callbacks_class`) will be converted to their full
  624. classpath, e.g. `ray.rllib.algorithms.callbacks.DefaultCallbacks`.
  625. Actual code such as lambda functions will be written as their source
  626. code (str) plus any closure information for properly restoring the
  627. code inside the AlgorithmConfig object made from the returned dict data.
  628. Dataclass objects get converted to dicts.
  629. Returns:
  630. A mapping from str to JSON'able values.
  631. """
  632. config = self.to_dict()
  633. return self._serialize_dict(config)
  634. def copy(self, copy_frozen: Optional[bool] = None) -> "AlgorithmConfig":
  635. """Creates a deep copy of this config and (un)freezes if necessary.
  636. Args:
  637. copy_frozen: Whether the created deep copy will be frozen or not. If None,
  638. keep the same frozen status that `self` currently has.
  639. Returns:
  640. A deep copy of `self` that is (un)frozen.
  641. """
  642. cp = copy.deepcopy(self)
  643. if copy_frozen is True:
  644. cp.freeze()
  645. elif copy_frozen is False:
  646. cp._is_frozen = False
  647. if isinstance(cp.evaluation_config, AlgorithmConfig):
  648. cp.evaluation_config._is_frozen = False
  649. return cp
  650. def freeze(self) -> None:
  651. """Freezes this config object, such that no attributes can be set anymore.
  652. Algorithms should use this method to make sure that their config objects
  653. remain read-only after this.
  654. """
  655. if self._is_frozen:
  656. return
  657. self._is_frozen = True
  658. # Also freeze underlying eval config, if applicable.
  659. if isinstance(self.evaluation_config, AlgorithmConfig):
  660. self.evaluation_config.freeze()
  661. # TODO: Flip out all set/dict/list values into frozen versions
  662. # of themselves? This way, users won't even be able to alter those values
  663. # directly anymore.
  664. @OverrideToImplementCustomLogic_CallToSuperRecommended
  665. def validate(self) -> None:
  666. """Validates all values in this config."""
  667. # Validate rollout settings.
  668. if not (
  669. (
  670. isinstance(self.rollout_fragment_length, int)
  671. and self.rollout_fragment_length > 0
  672. )
  673. or self.rollout_fragment_length == "auto"
  674. ):
  675. raise ValueError("`rollout_fragment_length` must be int >0 or 'auto'!")
  676. if self.batch_mode not in ["truncate_episodes", "complete_episodes"]:
  677. raise ValueError(
  678. "`config.batch_mode` must be one of [truncate_episodes|"
  679. "complete_episodes]! Got {}".format(self.batch_mode)
  680. )
  681. if self.preprocessor_pref not in ["rllib", "deepmind", None]:
  682. raise ValueError(
  683. "`config.preprocessor_pref` must be either 'rllib', 'deepmind' or None!"
  684. )
  685. if self.num_envs_per_worker <= 0:
  686. raise ValueError(
  687. f"`num_envs_per_worker` ({self.num_envs_per_worker}) must be "
  688. f"larger than 0!"
  689. )
  690. # Check correct framework settings, and whether configured framework is
  691. # installed.
  692. _tf1, _tf, _tfv = None, None, None
  693. _torch = None
  694. if self.framework_str not in {"tf", "tf2"} and self.framework_str != "torch":
  695. return
  696. elif self.framework_str in {"tf", "tf2"}:
  697. _tf1, _tf, _tfv = try_import_tf()
  698. else:
  699. _torch, _ = try_import_torch()
  700. # Check if torch framework supports torch.compile.
  701. if (
  702. _torch is not None
  703. and self.framework_str == "torch"
  704. and version.parse(_torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
  705. and (self.torch_compile_learner or self.torch_compile_worker)
  706. ):
  707. raise ValueError("torch.compile is only supported from torch 2.0.0")
  708. self._check_if_correct_nn_framework_installed(_tf1, _tf, _torch)
  709. self._resolve_tf_settings(_tf1, _tfv)
  710. # Check `policies_to_train` for invalid entries.
  711. if isinstance(self.policies_to_train, (list, set, tuple)):
  712. for pid in self.policies_to_train:
  713. if pid not in self.policies:
  714. raise ValueError(
  715. "`config.multi_agent(policies_to_train=..)` contains "
  716. f"policy ID ({pid}) that was not defined in "
  717. f"`config.multi_agent(policies=..)`!"
  718. )
  719. # If `evaluation_num_workers` > 0, warn if `evaluation_interval` is
  720. # None.
  721. if self.evaluation_num_workers > 0 and not self.evaluation_interval:
  722. logger.warning(
  723. f"You have specified {self.evaluation_num_workers} "
  724. "evaluation workers, but your `evaluation_interval` is None! "
  725. "Therefore, evaluation will not occur automatically with each"
  726. " call to `Algorithm.train()`. Instead, you will have to call "
  727. "`Algorithm.evaluate()` manually in order to trigger an "
  728. "evaluation run."
  729. )
  730. # If `evaluation_num_workers=0` and
  731. # `evaluation_parallel_to_training=True`, warn that you need
  732. # at least one remote eval worker for parallel training and
  733. # evaluation, and set `evaluation_parallel_to_training` to False.
  734. elif self.evaluation_num_workers == 0 and self.evaluation_parallel_to_training:
  735. raise ValueError(
  736. "`evaluation_parallel_to_training` can only be done if "
  737. "`evaluation_num_workers` > 0! Try setting "
  738. "`config.evaluation_parallel_to_training` to False."
  739. )
  740. # If `evaluation_duration=auto`, error if
  741. # `evaluation_parallel_to_training=False`.
  742. if self.evaluation_duration == "auto":
  743. if not self.evaluation_parallel_to_training:
  744. raise ValueError(
  745. "`evaluation_duration=auto` not supported for "
  746. "`evaluation_parallel_to_training=False`!"
  747. )
  748. # Make sure, it's an int otherwise.
  749. elif (
  750. not isinstance(self.evaluation_duration, int)
  751. or self.evaluation_duration <= 0
  752. ):
  753. raise ValueError(
  754. f"`evaluation_duration` ({self.evaluation_duration}) must be an "
  755. f"int and >0!"
  756. )
  757. # Check model config.
  758. # If no preprocessing, propagate into model's config as well
  759. # (so model will know, whether inputs are preprocessed or not).
  760. if self._disable_preprocessor_api is True:
  761. self.model["_disable_preprocessor_api"] = True
  762. # If no action flattening, propagate into model's config as well
  763. # (so model will know, whether action inputs are already flattened or
  764. # not).
  765. if self._disable_action_flattening is True:
  766. self.model["_disable_action_flattening"] = True
  767. if self.model.get("custom_preprocessor"):
  768. deprecation_warning(
  769. old="AlgorithmConfig.training(model={'custom_preprocessor': ...})",
  770. help="Custom preprocessors are deprecated, "
  771. "since they sometimes conflict with the built-in "
  772. "preprocessors for handling complex observation spaces. "
  773. "Please use wrapper classes around your environment "
  774. "instead.",
  775. error=True,
  776. )
  777. # RLModule API only works with connectors and with Learner API.
  778. if not self.enable_connectors and self._enable_rl_module_api:
  779. raise ValueError(
  780. "RLModule API only works with connectors. "
  781. "Please enable connectors via "
  782. "`config.rollouts(enable_connectors=True)`."
  783. )
  784. # Learner API requires RLModule API.
  785. if self._enable_learner_api is not self._enable_rl_module_api:
  786. raise ValueError(
  787. "Learner API requires RLModule API and vice-versa! "
  788. "Enable RLModule API via "
  789. "`config.rl_module(_enable_rl_module_api=True)` and the Learner API "
  790. "via `config.training(_enable_learner_api=True)` (or set both to "
  791. "False)."
  792. )
  793. # TODO @Avnishn: This is a short-term work around due to
  794. # https://github.com/ray-project/ray/issues/35409
  795. # Remove this once we are able to specify placement group bundle index in RLlib
  796. if (
  797. self.num_cpus_per_learner_worker > 1
  798. and self.num_gpus_per_learner_worker > 0
  799. ):
  800. raise ValueError(
  801. "Cannot set both `num_cpus_per_learner_worker` and "
  802. " `num_gpus_per_learner_worker` > 0! Users must set one"
  803. " or the other due to issues with placement group"
  804. " fragmentation. See "
  805. "https://github.com/ray-project/ray/issues/35409 for more details."
  806. )
  807. if bool(os.environ.get("RLLIB_ENABLE_RL_MODULE", False)):
  808. # Enable RLModule API and connectors if env variable is set
  809. # (to be used in unittesting)
  810. self.rl_module(_enable_rl_module_api=True)
  811. self.training(_enable_learner_api=True)
  812. self.enable_connectors = True
  813. # LR-schedule checking.
  814. if self._enable_learner_api:
  815. Scheduler.validate(
  816. fixed_value_or_schedule=self.lr,
  817. setting_name="lr",
  818. description="learning rate",
  819. )
  820. # Validate grad clipping settings.
  821. if self.grad_clip_by not in ["value", "norm", "global_norm"]:
  822. raise ValueError(
  823. f"`grad_clip_by` ({self.grad_clip_by}) must be one of: 'value', "
  824. "'norm', or 'global_norm'!"
  825. )
  826. # TODO: Deprecate self.simple_optimizer!
  827. # Multi-GPU settings.
  828. if self.simple_optimizer is True:
  829. pass
  830. # Multi-GPU setting: Must use MultiGPUTrainOneStep.
  831. elif not self._enable_learner_api and self.num_gpus > 1:
  832. # TODO: AlphaStar uses >1 GPUs differently (1 per policy actor), so this is
  833. # ok for tf2 here.
  834. # Remove this hacky check, once we have fully moved to the Learner API.
  835. if self.framework_str == "tf2" and type(self).__name__ != "AlphaStar":
  836. raise ValueError(
  837. "`num_gpus` > 1 not supported yet for "
  838. f"framework={self.framework_str}!"
  839. )
  840. elif self.simple_optimizer is True:
  841. raise ValueError(
  842. "Cannot use `simple_optimizer` if `num_gpus` > 1! "
  843. "Consider not setting `simple_optimizer` in your config."
  844. )
  845. self.simple_optimizer = False
  846. # Auto-setting: Use simple-optimizer for tf-eager or multiagent,
  847. # otherwise: MultiGPUTrainOneStep (if supported by the algo's execution
  848. # plan).
  849. elif self.simple_optimizer == DEPRECATED_VALUE:
  850. # tf-eager: Must use simple optimizer.
  851. if self.framework_str not in ["tf", "torch"]:
  852. self.simple_optimizer = True
  853. # Multi-agent case: Try using MultiGPU optimizer (only
  854. # if all policies used are DynamicTFPolicies or TorchPolicies).
  855. elif self.is_multi_agent():
  856. from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
  857. from ray.rllib.policy.torch_policy import TorchPolicy
  858. default_policy_cls = None
  859. if self.algo_class:
  860. default_policy_cls = self.algo_class.get_default_policy_class(self)
  861. policies = self.policies
  862. policy_specs = (
  863. [
  864. PolicySpec(*spec) if isinstance(spec, (tuple, list)) else spec
  865. for spec in policies.values()
  866. ]
  867. if isinstance(policies, dict)
  868. else [PolicySpec() for _ in policies]
  869. )
  870. if any(
  871. (spec.policy_class or default_policy_cls) is None
  872. or not issubclass(
  873. spec.policy_class or default_policy_cls,
  874. (DynamicTFPolicy, TorchPolicy),
  875. )
  876. for spec in policy_specs
  877. ):
  878. self.simple_optimizer = True
  879. else:
  880. self.simple_optimizer = False
  881. else:
  882. self.simple_optimizer = False
  883. # User manually set simple-optimizer to False -> Error if tf-eager.
  884. elif self.simple_optimizer is False:
  885. if self.framework_str == "tf2":
  886. raise ValueError(
  887. "`simple_optimizer=False` not supported for "
  888. f"config.framework({self.framework_str})!"
  889. )
  890. if self.input_ == "sampler" and self.off_policy_estimation_methods:
  891. raise ValueError(
  892. "Off-policy estimation methods can only be used if the input is a "
  893. "dataset. We currently do not support applying off_policy_esitmation "
  894. "method on a sampler input."
  895. )
  896. if self.input_ == "dataset":
  897. # if we need to read a ray dataset set the parallelism and
  898. # num_cpus_per_read_task from rollout worker settings
  899. self.input_config["num_cpus_per_read_task"] = self.num_cpus_per_worker
  900. if self.in_evaluation:
  901. # If using dataset for evaluation, the parallelism gets set to
  902. # evaluation_num_workers for backward compatibility and num_cpus gets
  903. # set to num_cpus_per_worker from rollout worker. User only needs to
  904. # set evaluation_num_workers.
  905. self.input_config["parallelism"] = self.evaluation_num_workers or 1
  906. else:
  907. # If using dataset for training, the parallelism and num_cpus gets set
  908. # based on rollout worker parameters. This is for backwards
  909. # compatibility for now. User only needs to set num_rollout_workers.
  910. self.input_config["parallelism"] = self.num_rollout_workers or 1
  911. if self._enable_rl_module_api:
  912. default_rl_module_spec = self.get_default_rl_module_spec()
  913. _check_rl_module_spec(default_rl_module_spec)
  914. if self.rl_module_spec is not None:
  915. # Merge provided RL Module spec class with defaults
  916. _check_rl_module_spec(self.rl_module_spec)
  917. # We can only merge if we have SingleAgentRLModuleSpecs.
  918. # TODO(Artur): Support merging for MultiAgentRLModuleSpecs.
  919. if isinstance(self.rl_module_spec, SingleAgentRLModuleSpec):
  920. if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec):
  921. default_rl_module_spec.update(self.rl_module_spec)
  922. self.rl_module_spec = default_rl_module_spec
  923. elif isinstance(default_rl_module_spec, MultiAgentRLModuleSpec):
  924. raise ValueError(
  925. "Cannot merge MultiAgentRLModuleSpec with "
  926. "SingleAgentRLModuleSpec!"
  927. )
  928. else:
  929. self.rl_module_spec = default_rl_module_spec
  930. not_compatible_w_rlm_msg = (
  931. "Cannot use `{}` option with RLModule API. `{"
  932. "}` is part of the ModelV2 API and Policy API,"
  933. " which are not compatible with the RLModule "
  934. "API. You can either deactivate the RLModule "
  935. "API by setting `config.rl_module( "
  936. "_enable_rl_module_api=False)` and "
  937. "`config.training(_enable_learner_api=False)` ,"
  938. "or use the RLModule API and implement your "
  939. "custom model as an RLModule."
  940. )
  941. if self.model["custom_model"] is not None:
  942. raise ValueError(
  943. not_compatible_w_rlm_msg.format("custom_model", "custom_model")
  944. )
  945. if self.model["custom_model_config"] != {}:
  946. raise ValueError(
  947. not_compatible_w_rlm_msg.format(
  948. "custom_model_config", "custom_model_config"
  949. )
  950. )
  951. if self.exploration_config:
  952. # This is not compatible with RLModules, which have a method
  953. # `forward_exploration` to specify custom exploration behavior.
  954. raise ValueError(
  955. "When RLModule API are enabled, exploration_config can not be "
  956. "set. If you want to implement custom exploration behaviour, "
  957. "please modify the `forward_exploration` method of the "
  958. "RLModule at hand. On configs that have a default exploration "
  959. "config, this must be done with "
  960. "`config.exploration_config={}`."
  961. )
  962. # make sure the resource requirements for learner_group is valid
  963. if self.num_learner_workers == 0 and self.num_gpus_per_worker > 1:
  964. raise ValueError(
  965. "num_gpus_per_worker must be 0 (cpu) or 1 (gpu) when using local mode "
  966. "(i.e. num_learner_workers = 0)"
  967. )
  968. def build(
  969. self,
  970. env: Optional[Union[str, EnvType]] = None,
  971. logger_creator: Optional[Callable[[], Logger]] = None,
  972. use_copy: bool = True,
  973. ) -> "Algorithm":
  974. """Builds an Algorithm from this AlgorithmConfig (or a copy thereof).
  975. Args:
  976. env: Name of the environment to use (e.g. a gym-registered str),
  977. a full class path (e.g.
  978. "ray.rllib.examples.env.random_env.RandomEnv"), or an Env
  979. class directly. Note that this arg can also be specified via
  980. the "env" key in `config`.
  981. logger_creator: Callable that creates a ray.tune.Logger
  982. object. If unspecified, a default logger is created.
  983. use_copy: Whether to deepcopy `self` and pass the copy to the Algorithm
  984. (instead of `self`) as config. This is useful in case you would like to
  985. recycle the same AlgorithmConfig over and over, e.g. in a test case, in
  986. which we loop over different DL-frameworks.
  987. Returns:
  988. A ray.rllib.algorithms.algorithm.Algorithm object.
  989. """
  990. if env is not None:
  991. self.env = env
  992. if self.evaluation_config is not None:
  993. self.evaluation_config["env"] = env
  994. if logger_creator is not None:
  995. self.logger_creator = logger_creator
  996. algo_class = self.algo_class
  997. if isinstance(self.algo_class, str):
  998. algo_class = get_trainable_cls(self.algo_class)
  999. return algo_class(
  1000. config=self if not use_copy else copy.deepcopy(self),
  1001. logger_creator=self.logger_creator,
  1002. )
  1003. def python_environment(
  1004. self,
  1005. *,
  1006. extra_python_environs_for_driver: Optional[dict] = NotProvided,
  1007. extra_python_environs_for_worker: Optional[dict] = NotProvided,
  1008. ) -> "AlgorithmConfig":
  1009. """Sets the config's python environment settings.
  1010. Args:
  1011. extra_python_environs_for_driver: Any extra python env vars to set in the
  1012. algorithm's process, e.g., {"OMP_NUM_THREADS": "16"}.
  1013. extra_python_environs_for_worker: The extra python environments need to set
  1014. for worker processes.
  1015. Returns:
  1016. This updated AlgorithmConfig object.
  1017. """
  1018. if extra_python_environs_for_driver is not NotProvided:
  1019. self.extra_python_environs_for_driver = extra_python_environs_for_driver
  1020. if extra_python_environs_for_worker is not NotProvided:
  1021. self.extra_python_environs_for_worker = extra_python_environs_for_worker
  1022. return self
  1023. def resources(
  1024. self,
  1025. *,
  1026. num_gpus: Optional[Union[float, int]] = NotProvided,
  1027. _fake_gpus: Optional[bool] = NotProvided,
  1028. num_cpus_per_worker: Optional[Union[float, int]] = NotProvided,
  1029. num_gpus_per_worker: Optional[Union[float, int]] = NotProvided,
  1030. num_cpus_for_local_worker: Optional[int] = NotProvided,
  1031. num_learner_workers: Optional[int] = NotProvided,
  1032. num_cpus_per_learner_worker: Optional[Union[float, int]] = NotProvided,
  1033. num_gpus_per_learner_worker: Optional[Union[float, int]] = NotProvided,
  1034. local_gpu_idx: Optional[int] = NotProvided,
  1035. custom_resources_per_worker: Optional[dict] = NotProvided,
  1036. placement_strategy: Optional[str] = NotProvided,
  1037. ) -> "AlgorithmConfig":
  1038. """Specifies resources allocated for an Algorithm and its ray actors/workers.
  1039. Args:
  1040. num_gpus: Number of GPUs to allocate to the algorithm process.
  1041. Note that not all algorithms can take advantage of GPUs.
  1042. Support for multi-GPU is currently only available for
  1043. tf-[PPO/IMPALA/DQN/PG]. This can be fractional (e.g., 0.3 GPUs).
  1044. _fake_gpus: Set to True for debugging (multi-)?GPU funcitonality on a
  1045. CPU machine. GPU towers will be simulated by graphs located on
  1046. CPUs in this case. Use `num_gpus` to test for different numbers of
  1047. fake GPUs.
  1048. num_cpus_per_worker: Number of CPUs to allocate per worker.
  1049. num_gpus_per_worker: Number of GPUs to allocate per worker. This can be
  1050. fractional. This is usually needed only if your env itself requires a
  1051. GPU (i.e., it is a GPU-intensive video game), or model inference is
  1052. unusually expensive.
  1053. num_learner_workers: Number of workers used for training. A value of 0
  1054. means training will take place on a local worker on head node CPUs or 1
  1055. GPU (determined by `num_gpus_per_learner_worker`). For multi-gpu
  1056. training, set number of workers greater than 1 and set
  1057. `num_gpus_per_learner_worker` accordingly (e.g. 4 GPUs total, and model
  1058. needs 2 GPUs: `num_learner_workers = 2` and
  1059. `num_gpus_per_learner_worker = 2`)
  1060. num_cpus_per_learner_worker: Number of CPUs allocated per Learner worker.
  1061. Only necessary for custom processing pipeline inside each Learner
  1062. requiring multiple CPU cores. Ignored if `num_learner_workers = 0`.
  1063. num_gpus_per_learner_worker: Number of GPUs allocated per worker. If
  1064. `num_learner_workers = 0`, any value greater than 0 will run the
  1065. training on a single GPU on the head node, while a value of 0 will run
  1066. the training on head node CPU cores. If num_gpus_per_learner_worker is
  1067. set, then num_cpus_per_learner_worker cannot be set.
  1068. local_gpu_idx: if num_gpus_per_worker > 0, and num_workers<2, then this gpu
  1069. index will be used for training. This is an index into the available
  1070. cuda devices. For example if os.environ["CUDA_VISIBLE_DEVICES"] = "1"
  1071. then a local_gpu_idx of 0 will use the gpu with id 1 on the node.
  1072. custom_resources_per_worker: Any custom Ray resources to allocate per
  1073. worker.
  1074. num_cpus_for_local_worker: Number of CPUs to allocate for the algorithm.
  1075. Note: this only takes effect when running in Tune. Otherwise,
  1076. the algorithm runs in the main program (driver).
  1077. custom_resources_per_worker: Any custom Ray resources to allocate per
  1078. worker.
  1079. placement_strategy: The strategy for the placement group factory returned by
  1080. `Algorithm.default_resource_request()`. A PlacementGroup defines, which
  1081. devices (resources) should always be co-located on the same node.
  1082. For example, an Algorithm with 2 rollout workers, running with
  1083. num_gpus=1 will request a placement group with the bundles:
  1084. [{"gpu": 1, "cpu": 1}, {"cpu": 1}, {"cpu": 1}], where the first bundle
  1085. is for the driver and the other 2 bundles are for the two workers.
  1086. These bundles can now be "placed" on the same or different
  1087. nodes depending on the value of `placement_strategy`:
  1088. "PACK": Packs bundles into as few nodes as possible.
  1089. "SPREAD": Places bundles across distinct nodes as even as possible.
  1090. "STRICT_PACK": Packs bundles into one node. The group is not allowed
  1091. to span multiple nodes.
  1092. "STRICT_SPREAD": Packs bundles across distinct nodes.
  1093. Returns:
  1094. This updated AlgorithmConfig object.
  1095. """
  1096. if num_gpus is not NotProvided:
  1097. self.num_gpus = num_gpus
  1098. if _fake_gpus is not NotProvided:
  1099. self._fake_gpus = _fake_gpus
  1100. if num_cpus_per_worker is not NotProvided:
  1101. self.num_cpus_per_worker = num_cpus_per_worker
  1102. if num_gpus_per_worker is not NotProvided:
  1103. self.num_gpus_per_worker = num_gpus_per_worker
  1104. if num_cpus_for_local_worker is not NotProvided:
  1105. self.num_cpus_for_local_worker = num_cpus_for_local_worker
  1106. if custom_resources_per_worker is not NotProvided:
  1107. self.custom_resources_per_worker = custom_resources_per_worker
  1108. if placement_strategy is not NotProvided:
  1109. self.placement_strategy = placement_strategy
  1110. if num_learner_workers is not NotProvided:
  1111. self.num_learner_workers = num_learner_workers
  1112. if num_cpus_per_learner_worker is not NotProvided:
  1113. self.num_cpus_per_learner_worker = num_cpus_per_learner_worker
  1114. if num_gpus_per_learner_worker is not NotProvided:
  1115. self.num_gpus_per_learner_worker = num_gpus_per_learner_worker
  1116. if local_gpu_idx is not NotProvided:
  1117. self.local_gpu_idx = local_gpu_idx
  1118. return self
  1119. def framework(
  1120. self,
  1121. framework: Optional[str] = NotProvided,
  1122. *,
  1123. eager_tracing: Optional[bool] = NotProvided,
  1124. eager_max_retraces: Optional[int] = NotProvided,
  1125. tf_session_args: Optional[Dict[str, Any]] = NotProvided,
  1126. local_tf_session_args: Optional[Dict[str, Any]] = NotProvided,
  1127. torch_compile_learner: Optional[bool] = NotProvided,
  1128. torch_compile_learner_what_to_compile: Optional[str] = NotProvided,
  1129. torch_compile_learner_dynamo_mode: Optional[str] = NotProvided,
  1130. torch_compile_learner_dynamo_backend: Optional[str] = NotProvided,
  1131. torch_compile_worker: Optional[bool] = NotProvided,
  1132. torch_compile_worker_dynamo_backend: Optional[str] = NotProvided,
  1133. torch_compile_worker_dynamo_mode: Optional[str] = NotProvided,
  1134. ) -> "AlgorithmConfig":
  1135. """Sets the config's DL framework settings.
  1136. Args:
  1137. framework: torch: PyTorch; tf2: TensorFlow 2.x (eager execution or traced
  1138. if eager_tracing=True); tf: TensorFlow (static-graph);
  1139. eager_tracing: Enable tracing in eager mode. This greatly improves
  1140. performance (speedup ~2x), but makes it slightly harder to debug
  1141. since Python code won't be evaluated after the initial eager pass.
  1142. Only possible if framework=tf2.
  1143. eager_max_retraces: Maximum number of tf.function re-traces before a
  1144. runtime error is raised. This is to prevent unnoticed retraces of
  1145. methods inside the `..._eager_traced` Policy, which could slow down
  1146. execution by a factor of 4, without the user noticing what the root
  1147. cause for this slowdown could be.
  1148. Only necessary for framework=tf2.
  1149. Set to None to ignore the re-trace count and never throw an error.
  1150. tf_session_args: Configures TF for single-process operation by default.
  1151. local_tf_session_args: Override the following tf session args on the local
  1152. worker
  1153. torch_compile_learner: If True, forward_train methods on TorchRLModules
  1154. on the learner are compiled. If not specified, the default is to compile
  1155. forward train on the learner.
  1156. torch_compile_learner_what_to_compile: A TorchCompileWhatToCompile
  1157. mode specifying what to compile on the learner side if
  1158. torch_compile_learner is True. See TorchCompileWhatToCompile for
  1159. details and advice on its usage.
  1160. torch_compile_learner_dynamo_backend: The torch dynamo backend to use on
  1161. the learner.
  1162. torch_compile_learner_dynamo_mode: The torch dynamo mode to use on the
  1163. learner.
  1164. torch_compile_worker: If True, forward exploration and inference methods on
  1165. TorchRLModules on the workers are compiled. If not specified,
  1166. the default is to not compile forward methods on the workers because
  1167. retracing can be expensive.
  1168. torch_compile_worker_dynamo_backend: The torch dynamo backend to use on
  1169. the workers.
  1170. torch_compile_worker_dynamo_mode: The torch dynamo mode to use on the
  1171. workers.
  1172. Returns:
  1173. This updated AlgorithmConfig object.
  1174. """
  1175. if framework is not NotProvided:
  1176. if framework == "tfe":
  1177. deprecation_warning(
  1178. old="AlgorithmConfig.framework('tfe')",
  1179. new="AlgorithmConfig.framework('tf2')",
  1180. error=True,
  1181. )
  1182. self.framework_str = framework
  1183. if eager_tracing is not NotProvided:
  1184. self.eager_tracing = eager_tracing
  1185. if eager_max_retraces is not NotProvided:
  1186. self.eager_max_retraces = eager_max_retraces
  1187. if tf_session_args is not NotProvided:
  1188. self.tf_session_args = tf_session_args
  1189. if local_tf_session_args is not NotProvided:
  1190. self.local_tf_session_args = local_tf_session_args
  1191. if torch_compile_learner is not NotProvided:
  1192. self.torch_compile_learner = torch_compile_learner
  1193. if torch_compile_learner_dynamo_backend is not NotProvided:
  1194. self.torch_compile_learner_dynamo_backend = (
  1195. torch_compile_learner_dynamo_backend
  1196. )
  1197. if torch_compile_learner_dynamo_mode is not NotProvided:
  1198. self.torch_compile_learner_dynamo_mode = torch_compile_learner_dynamo_mode
  1199. if torch_compile_learner_what_to_compile is not NotProvided:
  1200. self.torch_compile_learner_what_to_compile = (
  1201. torch_compile_learner_what_to_compile
  1202. )
  1203. if torch_compile_worker is not NotProvided:
  1204. self.torch_compile_worker = torch_compile_worker
  1205. if torch_compile_worker_dynamo_backend is not NotProvided:
  1206. self.torch_compile_worker_dynamo_backend = (
  1207. torch_compile_worker_dynamo_backend
  1208. )
  1209. if torch_compile_worker_dynamo_mode is not NotProvided:
  1210. self.torch_compile_worker_dynamo_mode = torch_compile_worker_dynamo_mode
  1211. return self
  1212. def environment(
  1213. self,
  1214. env: Optional[Union[str, EnvType]] = NotProvided,
  1215. *,
  1216. env_config: Optional[EnvConfigDict] = NotProvided,
  1217. observation_space: Optional[gym.spaces.Space] = NotProvided,
  1218. action_space: Optional[gym.spaces.Space] = NotProvided,
  1219. env_task_fn: Optional[
  1220. Callable[[ResultDict, EnvType, EnvContext], Any]
  1221. ] = NotProvided,
  1222. render_env: Optional[bool] = NotProvided,
  1223. clip_rewards: Optional[Union[bool, float]] = NotProvided,
  1224. normalize_actions: Optional[bool] = NotProvided,
  1225. clip_actions: Optional[bool] = NotProvided,
  1226. disable_env_checking: Optional[bool] = NotProvided,
  1227. is_atari: Optional[bool] = NotProvided,
  1228. auto_wrap_old_gym_envs: Optional[bool] = NotProvided,
  1229. action_mask_key: Optional[str] = NotProvided,
  1230. ) -> "AlgorithmConfig":
  1231. """Sets the config's RL-environment settings.
  1232. Args:
  1233. env: The environment specifier. This can either be a tune-registered env,
  1234. via `tune.register_env([name], lambda env_ctx: [env object])`,
  1235. or a string specifier of an RLlib supported type. In the latter case,
  1236. RLlib will try to interpret the specifier as either an Farama-Foundation
  1237. gymnasium env, a PyBullet env, a ViZDoomGym env, or a fully qualified
  1238. classpath to an Env class, e.g.
  1239. "ray.rllib.examples.env.random_env.RandomEnv".
  1240. env_config: Arguments dict passed to the env creator as an EnvContext
  1241. object (which is a dict plus the properties: num_rollout_workers,
  1242. worker_index, vector_index, and remote).
  1243. observation_space: The observation space for the Policies of this Algorithm.
  1244. action_space: The action space for the Policies of this Algorithm.
  1245. env_task_fn: A callable taking the last train results, the base env and the
  1246. env context as args and returning a new task to set the env to.
  1247. The env must be a `TaskSettableEnv` sub-class for this to work.
  1248. See `examples/curriculum_learning.py` for an example.
  1249. render_env: If True, try to render the environment on the local worker or on
  1250. worker 1 (if num_rollout_workers > 0). For vectorized envs, this usually
  1251. means that only the first sub-environment will be rendered.
  1252. In order for this to work, your env will have to implement the
  1253. `render()` method which either:
  1254. a) handles window generation and rendering itself (returning True) or
  1255. b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].
  1256. clip_rewards: Whether to clip rewards during Policy's postprocessing.
  1257. None (default): Clip for Atari only (r=sign(r)).
  1258. True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0.
  1259. False: Never clip.
  1260. [float value]: Clip at -value and + value.
  1261. Tuple[value1, value2]: Clip at value1 and value2.
  1262. normalize_actions: If True, RLlib will learn entirely inside a normalized
  1263. action space (0.0 centered with small stddev; only affecting Box
  1264. components). We will unsquash actions (and clip, just in case) to the
  1265. bounds of the env's action space before sending actions back to the env.
  1266. clip_actions: If True, RLlib will clip actions according to the env's bounds
  1267. before sending them back to the env.
  1268. TODO: (sven) This option should be deprecated and always be False.
  1269. disable_env_checking: If True, disable the environment pre-checking module.
  1270. is_atari: This config can be used to explicitly specify whether the env is
  1271. an Atari env or not. If not specified, RLlib will try to auto-detect
  1272. this.
  1273. auto_wrap_old_gym_envs: Whether to auto-wrap old gym environments (using
  1274. the pre 0.24 gym APIs, e.g. reset() returning single obs and no info
  1275. dict). If True, RLlib will automatically wrap the given gym env class
  1276. with the gym-provided compatibility wrapper
  1277. (gym.wrappers.EnvCompatibility). If False, RLlib will produce a
  1278. descriptive error on which steps to perform to upgrade to gymnasium
  1279. (or to switch this flag to True).
  1280. action_mask_key: If observation is a dictionary, expect the value by
  1281. the key `action_mask_key` to contain a valid actions mask (`numpy.int8`
  1282. array of zeros and ones). Defaults to "action_mask".
  1283. Returns:
  1284. This updated AlgorithmConfig object.
  1285. """
  1286. if env is not NotProvided:
  1287. self.env = env
  1288. if env_config is not NotProvided:
  1289. deep_update(
  1290. self.env_config,
  1291. env_config,
  1292. True,
  1293. )
  1294. if observation_space is not NotProvided:
  1295. self.observation_space = observation_space
  1296. if action_space is not NotProvided:
  1297. self.action_space = action_space
  1298. if env_task_fn is not NotProvided:
  1299. self.env_task_fn = env_task_fn
  1300. if render_env is not NotProvided:
  1301. self.render_env = render_env
  1302. if clip_rewards is not NotProvided:
  1303. self.clip_rewards = clip_rewards
  1304. if normalize_actions is not NotProvided:
  1305. self.normalize_actions = normalize_actions
  1306. if clip_actions is not NotProvided:
  1307. self.clip_actions = clip_actions
  1308. if disable_env_checking is not NotProvided:
  1309. self.disable_env_checking = disable_env_checking
  1310. if is_atari is not NotProvided:
  1311. self._is_atari = is_atari
  1312. if auto_wrap_old_gym_envs is not NotProvided:
  1313. self.auto_wrap_old_gym_envs = auto_wrap_old_gym_envs
  1314. if action_mask_key is not NotProvided:
  1315. self.action_mask_key = action_mask_key
  1316. return self
  1317. def rollouts(
  1318. self,
  1319. *,
  1320. env_runner_cls: Optional[type] = NotProvided,
  1321. num_rollout_workers: Optional[int] = NotProvided,
  1322. num_envs_per_worker: Optional[int] = NotProvided,
  1323. create_env_on_local_worker: Optional[bool] = NotProvided,
  1324. sample_collector: Optional[Type[SampleCollector]] = NotProvided,
  1325. sample_async: Optional[bool] = NotProvided,
  1326. enable_connectors: Optional[bool] = NotProvided,
  1327. use_worker_filter_stats: Optional[bool] = NotProvided,
  1328. update_worker_filter_stats: Optional[bool] = NotProvided,
  1329. rollout_fragment_length: Optional[Union[int, str]] = NotProvided,
  1330. batch_mode: Optional[str] = NotProvided,
  1331. remote_worker_envs: Optional[bool] = NotProvided,
  1332. remote_env_batch_wait_ms: Optional[float] = NotProvided,
  1333. validate_workers_after_construction: Optional[bool] = NotProvided,
  1334. preprocessor_pref: Optional[str] = NotProvided,
  1335. observation_filter: Optional[str] = NotProvided,
  1336. compress_observations: Optional[bool] = NotProvided,
  1337. enable_tf1_exec_eagerly: Optional[bool] = NotProvided,
  1338. sampler_perf_stats_ema_coef: Optional[float] = NotProvided,
  1339. ignore_worker_failures=DEPRECATED_VALUE,
  1340. recreate_failed_workers=DEPRECATED_VALUE,
  1341. restart_failed_sub_environments=DEPRECATED_VALUE,
  1342. num_consecutive_worker_failures_tolerance=DEPRECATED_VALUE,
  1343. worker_health_probe_timeout_s=DEPRECATED_VALUE,
  1344. worker_restore_timeout_s=DEPRECATED_VALUE,
  1345. synchronize_filter=DEPRECATED_VALUE,
  1346. ) -> "AlgorithmConfig":
  1347. """Sets the rollout worker configuration.
  1348. Args:
  1349. env_runner_cls: The EnvRunner class to use for environment rollouts (data
  1350. collection).
  1351. num_rollout_workers: Number of rollout worker actors to create for
  1352. parallel sampling. Setting this to 0 will force rollouts to be done in
  1353. the local worker (driver process or the Algorithm's actor when using
  1354. Tune).
  1355. num_envs_per_worker: Number of environments to evaluate vector-wise per
  1356. worker. This enables model inference batching, which can improve
  1357. performance for inference bottlenecked workloads.
  1358. sample_collector: The SampleCollector class to be used to collect and
  1359. retrieve environment-, model-, and sampler data. Override the
  1360. SampleCollector base class to implement your own
  1361. collection/buffering/retrieval logic.
  1362. create_env_on_local_worker: When `num_rollout_workers` > 0, the driver
  1363. (local_worker; worker-idx=0) does not need an environment. This is
  1364. because it doesn't have to sample (done by remote_workers;
  1365. worker_indices > 0) nor evaluate (done by evaluation workers;
  1366. see below).
  1367. sample_async: Use a background thread for sampling (slightly off-policy,
  1368. usually not advisable to turn on unless your env specifically requires
  1369. it).
  1370. enable_connectors: Use connector based environment runner, so that all
  1371. preprocessing of obs and postprocessing of actions are done in agent
  1372. and action connectors.
  1373. use_worker_filter_stats: Whether to use the workers in the WorkerSet to
  1374. update the central filters (held by the local worker). If False, stats
  1375. from the workers will not be used and discarded.
  1376. update_worker_filter_stats: Whether to push filter updates from the central
  1377. filters (held by the local worker) to the remote workers' filters.
  1378. Setting this to True might be useful within the evaluation config in
  1379. order to disable the usage of evaluation trajectories for synching
  1380. the central filter (used for training).
  1381. rollout_fragment_length: Divide episodes into fragments of this many steps
  1382. each during rollouts. Trajectories of this size are collected from
  1383. rollout workers and combined into a larger batch of `train_batch_size`
  1384. for learning.
  1385. For example, given rollout_fragment_length=100 and
  1386. train_batch_size=1000:
  1387. 1. RLlib collects 10 fragments of 100 steps each from rollout workers.
  1388. 2. These fragments are concatenated and we perform an epoch of SGD.
  1389. When using multiple envs per worker, the fragment size is multiplied by
  1390. `num_envs_per_worker`. This is since we are collecting steps from
  1391. multiple envs in parallel. For example, if num_envs_per_worker=5, then
  1392. rollout workers will return experiences in chunks of 5*100 = 500 steps.
  1393. The dataflow here can vary per algorithm. For example, PPO further
  1394. divides the train batch into minibatches for multi-epoch SGD.
  1395. Set to "auto" to have RLlib compute an exact `rollout_fragment_length`
  1396. to match the given batch size.
  1397. batch_mode: How to build individual batches with the EnvRunner(s). Batches
  1398. coming from distributed EnvRunners are usually concat'd to form the
  1399. train batch. Note that "steps" below can mean different things (either
  1400. env- or agent-steps) and depends on the `count_steps_by` setting,
  1401. adjustable via `AlgorithmConfig.multi_agent(count_steps_by=..)`:
  1402. 1) "truncate_episodes": Each call to `EnvRunner.sample()` will return a
  1403. batch of at most `rollout_fragment_length * num_envs_per_worker` in
  1404. size. The batch will be exactly `rollout_fragment_length * num_envs`
  1405. in size if postprocessing does not change batch sizes. Episodes
  1406. may be truncated in order to meet this size requirement.
  1407. This mode guarantees evenly sized batches, but increases
  1408. variance as the future return must now be estimated at truncation
  1409. boundaries.
  1410. 2) "complete_episodes": Each call to `EnvRunner.sample()` will return a
  1411. batch of at least `rollout_fragment_length * num_envs_per_worker` in
  1412. size. Episodes will not be truncated, but multiple episodes
  1413. may be packed within one batch to meet the (minimum) batch size.
  1414. Note that when `num_envs_per_worker > 1`, episode steps will be buffered
  1415. until the episode completes, and hence batches may contain
  1416. significant amounts of off-policy data.
  1417. remote_worker_envs: If using num_envs_per_worker > 1, whether to create
  1418. those new envs in remote processes instead of in the same worker.
  1419. This adds overheads, but can make sense if your envs can take much
  1420. time to step / reset (e.g., for StarCraft). Use this cautiously;
  1421. overheads are significant.
  1422. remote_env_batch_wait_ms: Timeout that remote workers are waiting when
  1423. polling environments. 0 (continue when at least one env is ready) is
  1424. a reasonable default, but optimal value could be obtained by measuring
  1425. your environment step / reset and model inference perf.
  1426. validate_workers_after_construction: Whether to validate that each created
  1427. remote worker is healthy after its construction process.
  1428. preprocessor_pref: Whether to use "rllib" or "deepmind" preprocessors by
  1429. default. Set to None for using no preprocessor. In this case, the
  1430. model will have to handle possibly complex observations from the
  1431. environment.
  1432. observation_filter: Element-wise observation filter, either "NoFilter"
  1433. or "MeanStdFilter".
  1434. compress_observations: Whether to LZ4 compress individual observations
  1435. in the SampleBatches collected during rollouts.
  1436. enable_tf1_exec_eagerly: Explicitly tells the rollout worker to enable
  1437. TF eager execution. This is useful for example when framework is
  1438. "torch", but a TF2 policy needs to be restored for evaluation or
  1439. league-based purposes.
  1440. sampler_perf_stats_ema_coef: If specified, perf stats are in EMAs. This
  1441. is the coeff of how much new data points contribute to the averages.
  1442. Default is None, which uses simple global average instead.
  1443. The EMA update rule is: updated = (1 - ema_coef) * old + ema_coef * new
  1444. Returns:
  1445. This updated AlgorithmConfig object.
  1446. """
  1447. if env_runner_cls is not NotProvided:
  1448. self.env_runner_cls = env_runner_cls
  1449. if num_rollout_workers is not NotProvided:
  1450. self.num_rollout_workers = num_rollout_workers
  1451. if num_envs_per_worker is not NotProvided:
  1452. self.num_envs_per_worker = num_envs_per_worker
  1453. if sample_collector is not NotProvided:
  1454. self.sample_collector = sample_collector
  1455. if create_env_on_local_worker is not NotProvided:
  1456. self.create_env_on_local_worker = create_env_on_local_worker
  1457. if sample_async is not NotProvided:
  1458. self.sample_async = sample_async
  1459. if enable_connectors is not NotProvided:
  1460. self.enable_connectors = enable_connectors
  1461. if use_worker_filter_stats is not NotProvided:
  1462. self.use_worker_filter_stats = use_worker_filter_stats
  1463. if update_worker_filter_stats is not NotProvided:
  1464. self.update_worker_filter_stats = update_worker_filter_stats
  1465. if rollout_fragment_length is not NotProvided:
  1466. self.rollout_fragment_length = rollout_fragment_length
  1467. if batch_mode is not NotProvided:
  1468. self.batch_mode = batch_mode
  1469. if remote_worker_envs is not NotProvided:
  1470. self.remote_worker_envs = remote_worker_envs
  1471. if remote_env_batch_wait_ms is not NotProvided:
  1472. self.remote_env_batch_wait_ms = remote_env_batch_wait_ms
  1473. if validate_workers_after_construction is not NotProvided:
  1474. self.validate_workers_after_construction = (
  1475. validate_workers_after_construction
  1476. )
  1477. if preprocessor_pref is not NotProvided:
  1478. self.preprocessor_pref = preprocessor_pref
  1479. if observation_filter is not NotProvided:
  1480. self.observation_filter = observation_filter
  1481. if synchronize_filter is not NotProvided:
  1482. self.synchronize_filters = synchronize_filter
  1483. if compress_observations is not NotProvided:
  1484. self.compress_observations = compress_observations
  1485. if enable_tf1_exec_eagerly is not NotProvided:
  1486. self.enable_tf1_exec_eagerly = enable_tf1_exec_eagerly
  1487. if sampler_perf_stats_ema_coef is not NotProvided:
  1488. self.sampler_perf_stats_ema_coef = sampler_perf_stats_ema_coef
  1489. # Deprecated settings.
  1490. if synchronize_filter != DEPRECATED_VALUE:
  1491. deprecation_warning(
  1492. old="AlgorithmConfig.rollouts(synchronize_filter=..)",
  1493. new="AlgorithmConfig.rollouts(update_worker_filter_stats=..)",
  1494. error=False,
  1495. )
  1496. self.update_worker_filter_stats = synchronize_filter
  1497. if ignore_worker_failures != DEPRECATED_VALUE:
  1498. deprecation_warning(
  1499. old="ignore_worker_failures is deprecated, and will soon be a no-op",
  1500. error=False,
  1501. )
  1502. self.ignore_worker_failures = ignore_worker_failures
  1503. if recreate_failed_workers != DEPRECATED_VALUE:
  1504. deprecation_warning(
  1505. old="AlgorithmConfig.rollouts(recreate_failed_workers=..)",
  1506. new="AlgorithmConfig.fault_tolerance(recreate_failed_workers=..)",
  1507. error=False,
  1508. )
  1509. self.recreate_failed_workers = recreate_failed_workers
  1510. if restart_failed_sub_environments != DEPRECATED_VALUE:
  1511. deprecation_warning(
  1512. old="AlgorithmConfig.rollouts(restart_failed_sub_environments=..)",
  1513. new=(
  1514. "AlgorithmConfig.fault_tolerance("
  1515. "restart_failed_sub_environments=..)"
  1516. ),
  1517. error=False,
  1518. )
  1519. self.restart_failed_sub_environments = restart_failed_sub_environments
  1520. if num_consecutive_worker_failures_tolerance != DEPRECATED_VALUE:
  1521. deprecation_warning(
  1522. old=(
  1523. "AlgorithmConfig.rollouts("
  1524. "num_consecutive_worker_failures_tolerance=..)"
  1525. ),
  1526. new=(
  1527. "AlgorithmConfig.fault_tolerance("
  1528. "num_consecutive_worker_failures_tolerance=..)"
  1529. ),
  1530. error=False,
  1531. )
  1532. self.num_consecutive_worker_failures_tolerance = (
  1533. num_consecutive_worker_failures_tolerance
  1534. )
  1535. if worker_health_probe_timeout_s != DEPRECATED_VALUE:
  1536. deprecation_warning(
  1537. old="AlgorithmConfig.rollouts(worker_health_probe_timeout_s=..)",
  1538. new="AlgorithmConfig.fault_tolerance(worker_health_probe_timeout_s=..)",
  1539. error=False,
  1540. )
  1541. self.worker_health_probe_timeout_s = worker_health_probe_timeout_s
  1542. if worker_restore_timeout_s != DEPRECATED_VALUE:
  1543. deprecation_warning(
  1544. old="AlgorithmConfig.rollouts(worker_restore_timeout_s=..)",
  1545. new="AlgorithmConfig.fault_tolerance(worker_restore_timeout_s=..)",
  1546. error=False,
  1547. )
  1548. self.worker_restore_timeout_s = worker_restore_timeout_s
  1549. return self
  1550. def training(
  1551. self,
  1552. *,
  1553. gamma: Optional[float] = NotProvided,
  1554. lr: Optional[LearningRateOrSchedule] = NotProvided,
  1555. grad_clip: Optional[float] = NotProvided,
  1556. grad_clip_by: Optional[str] = NotProvided,
  1557. train_batch_size: Optional[int] = NotProvided,
  1558. model: Optional[dict] = NotProvided,
  1559. optimizer: Optional[dict] = NotProvided,
  1560. max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided,
  1561. _enable_learner_api: Optional[bool] = NotProvided,
  1562. learner_class: Optional[Type["Learner"]] = NotProvided,
  1563. ) -> "AlgorithmConfig":
  1564. """Sets the training related configuration.
  1565. Args:
  1566. gamma: Float specifying the discount factor of the Markov Decision process.
  1567. lr: The learning rate (float) or learning rate schedule in the format of
  1568. [[timestep, lr-value], [timestep, lr-value], ...]
  1569. In case of a schedule, intermediary timesteps will be assigned to
  1570. linearly interpolated learning rate values. A schedule config's first
  1571. entry must start with timestep 0, i.e.: [[0, initial_value], [...]].
  1572. Note: If you require a) more than one optimizer (per RLModule),
  1573. b) optimizer types that are not Adam, c) a learning rate schedule that
  1574. is not a linearly interpolated, piecewise schedule as described above,
  1575. or d) specifying c'tor arguments of the optimizer that are not the
  1576. learning rate (e.g. Adam's epsilon), then you must override your
  1577. Learner's `configure_optimizer_for_module()` method and handle
  1578. lr-scheduling yourself.
  1579. grad_clip: If None, no gradient clipping will be applied. Otherwise,
  1580. depending on the setting of `grad_clip_by`, the (float) value of
  1581. `grad_clip` will have the following effect:
  1582. If `grad_clip_by=value`: Will clip all computed gradients individually
  1583. inside the interval [-`grad_clip`, +`grad_clip`].
  1584. If `grad_clip_by=norm`, will compute the L2-norm of each weight/bias
  1585. gradient tensor individually and then clip all gradients such that these
  1586. L2-norms do not exceed `grad_clip`. The L2-norm of a tensor is computed
  1587. via: `sqrt(SUM(w0^2, w1^2, ..., wn^2))` where w[i] are the elements of
  1588. the tensor (no matter what the shape of this tensor is).
  1589. If `grad_clip_by=global_norm`, will compute the square of the L2-norm of
  1590. each weight/bias gradient tensor individually, sum up all these squared
  1591. L2-norms across all given gradient tensors (e.g. the entire module to
  1592. be updated), square root that overall sum, and then clip all gradients
  1593. such that this global L2-norm does not exceed the given value.
  1594. The global L2-norm over a list of tensors (e.g. W and V) is computed
  1595. via:
  1596. `sqrt[SUM(w0^2, w1^2, ..., wn^2) + SUM(v0^2, v1^2, ..., vm^2)]`, where
  1597. w[i] and v[j] are the elements of the tensors W and V (no matter what
  1598. the shapes of these tensors are).
  1599. grad_clip_by: See `grad_clip` for the effect of this setting on gradient
  1600. clipping. Allowed values are `value`, `norm`, and `global_norm`.
  1601. train_batch_size: Training batch size, if applicable.
  1602. model: Arguments passed into the policy model. See models/catalog.py for a
  1603. full list of the available model options.
  1604. TODO: Provide ModelConfig objects instead of dicts.
  1605. optimizer: Arguments to pass to the policy optimizer. This setting is not
  1606. used when `_enable_learner_api=True`.
  1607. max_requests_in_flight_per_sampler_worker: Max number of inflight requests
  1608. to each sampling worker. See the FaultTolerantActorManager class for
  1609. more details.
  1610. Tuning these values is important when running experimens with
  1611. large sample batches, where there is the risk that the object store may
  1612. fill up, causing spilling of objects to disk. This can cause any
  1613. asynchronous requests to become very slow, making your experiment run
  1614. slow as well. You can inspect the object store during your experiment
  1615. via a call to ray memory on your headnode, and by using the ray
  1616. dashboard. If you're seeing that the object store is filling up,
  1617. turn down the number of remote requests in flight, or enable compression
  1618. in your experiment of timesteps.
  1619. _enable_learner_api: Whether to enable the LearnerGroup and Learner
  1620. for training. This API uses ray.train to run the training loop which
  1621. allows for a more flexible distributed training.
  1622. Returns:
  1623. This updated AlgorithmConfig object.
  1624. """
  1625. if gamma is not NotProvided:
  1626. self.gamma = gamma
  1627. if lr is not NotProvided:
  1628. self.lr = lr
  1629. if grad_clip is not NotProvided:
  1630. self.grad_clip = grad_clip
  1631. if grad_clip_by is not NotProvided:
  1632. self.grad_clip_by = grad_clip_by
  1633. if train_batch_size is not NotProvided:
  1634. self.train_batch_size = train_batch_size
  1635. if model is not NotProvided:
  1636. # Validate prev_a/r settings.
  1637. prev_a_r = model.get("lstm_use_prev_action_reward", DEPRECATED_VALUE)
  1638. if prev_a_r != DEPRECATED_VALUE:
  1639. deprecation_warning(
  1640. "model.lstm_use_prev_action_reward",
  1641. "model.lstm_use_prev_action and model.lstm_use_prev_reward",
  1642. error=True,
  1643. )
  1644. self.model.update(model)
  1645. if (
  1646. model.get("_use_default_native_models", DEPRECATED_VALUE)
  1647. != DEPRECATED_VALUE
  1648. ):
  1649. deprecation_warning(
  1650. old="AlgorithmConfig.training(_use_default_native_models=True)",
  1651. help="_use_default_native_models is not supported "
  1652. "anymore. To get rid of this error, set `rl_module("
  1653. "_enable_rl_module_api` to True. Native models will "
  1654. "be better supported by the upcoming RLModule API.",
  1655. # Error out if user tries to enable this
  1656. error=model["_use_default_native_models"],
  1657. )
  1658. if optimizer is not NotProvided:
  1659. self.optimizer = merge_dicts(self.optimizer, optimizer)
  1660. if max_requests_in_flight_per_sampler_worker is not NotProvided:
  1661. self.max_requests_in_flight_per_sampler_worker = (
  1662. max_requests_in_flight_per_sampler_worker
  1663. )
  1664. if _enable_learner_api is not NotProvided:
  1665. self._enable_learner_api = _enable_learner_api
  1666. if learner_class is not NotProvided:
  1667. self._learner_class = learner_class
  1668. return self
  1669. def callbacks(self, callbacks_class) -> "AlgorithmConfig":
  1670. """Sets the callbacks configuration.
  1671. Args:
  1672. callbacks_class: Callbacks class, whose methods will be run during
  1673. various phases of training and environment sample collection.
  1674. See the `DefaultCallbacks` class and
  1675. `examples/custom_metrics_and_callbacks.py` for more usage information.
  1676. Returns:
  1677. This updated AlgorithmConfig object.
  1678. """
  1679. if callbacks_class is None:
  1680. callbacks_class = DefaultCallbacks
  1681. # Check, whether given `callbacks` is a callable.
  1682. if not callable(callbacks_class):
  1683. raise ValueError(
  1684. "`config.callbacks_class` must be a callable method that "
  1685. "returns a subclass of DefaultCallbacks, got "
  1686. f"{callbacks_class}!"
  1687. )
  1688. self.callbacks_class = callbacks_class
  1689. return self
  1690. def exploration(
  1691. self,
  1692. *,
  1693. explore: Optional[bool] = NotProvided,
  1694. exploration_config: Optional[dict] = NotProvided,
  1695. ) -> "AlgorithmConfig":
  1696. """Sets the config's exploration settings.
  1697. Args:
  1698. explore: Default exploration behavior, iff `explore=None` is passed into
  1699. compute_action(s). Set to False for no exploration behavior (e.g.,
  1700. for evaluation).
  1701. exploration_config: A dict specifying the Exploration object's config.
  1702. Returns:
  1703. This updated AlgorithmConfig object.
  1704. """
  1705. if explore is not NotProvided:
  1706. self.explore = explore
  1707. if exploration_config is not NotProvided:
  1708. # Override entire `exploration_config` if `type` key changes.
  1709. # Update, if `type` key remains the same or is not specified.
  1710. new_exploration_config = deep_update(
  1711. {"exploration_config": self.exploration_config},
  1712. {"exploration_config": exploration_config},
  1713. False,
  1714. ["exploration_config"],
  1715. ["exploration_config"],
  1716. )
  1717. self.exploration_config = new_exploration_config["exploration_config"]
  1718. return self
  1719. def evaluation(
  1720. self,
  1721. *,
  1722. evaluation_interval: Optional[int] = NotProvided,
  1723. evaluation_duration: Optional[Union[int, str]] = NotProvided,
  1724. evaluation_duration_unit: Optional[str] = NotProvided,
  1725. evaluation_sample_timeout_s: Optional[float] = NotProvided,
  1726. evaluation_parallel_to_training: Optional[bool] = NotProvided,
  1727. evaluation_config: Optional[
  1728. Union["AlgorithmConfig", PartialAlgorithmConfigDict]
  1729. ] = NotProvided,
  1730. off_policy_estimation_methods: Optional[Dict] = NotProvided,
  1731. ope_split_batch_by_episode: Optional[bool] = NotProvided,
  1732. evaluation_num_workers: Optional[int] = NotProvided,
  1733. custom_evaluation_function: Optional[Callable] = NotProvided,
  1734. always_attach_evaluation_results: Optional[bool] = NotProvided,
  1735. enable_async_evaluation: Optional[bool] = NotProvided,
  1736. # Deprecated args.
  1737. evaluation_num_episodes=DEPRECATED_VALUE,
  1738. ) -> "AlgorithmConfig":
  1739. """Sets the config's evaluation settings.
  1740. Args:
  1741. evaluation_interval: Evaluate with every `evaluation_interval` training
  1742. iterations. The evaluation stats will be reported under the "evaluation"
  1743. metric key. Note that for Ape-X metrics are already only reported for
  1744. the lowest epsilon workers (least random workers).
  1745. Set to None (or 0) for no evaluation.
  1746. evaluation_duration: Duration for which to run evaluation each
  1747. `evaluation_interval`. The unit for the duration can be set via
  1748. `evaluation_duration_unit` to either "episodes" (default) or
  1749. "timesteps". If using multiple evaluation workers
  1750. (evaluation_num_workers > 1), the load to run will be split amongst
  1751. these.
  1752. If the value is "auto":
  1753. - For `evaluation_parallel_to_training=True`: Will run as many
  1754. episodes/timesteps that fit into the (parallel) training step.
  1755. - For `evaluation_parallel_to_training=False`: Error.
  1756. evaluation_duration_unit: The unit, with which to count the evaluation
  1757. duration. Either "episodes" (default) or "timesteps".
  1758. evaluation_sample_timeout_s: The timeout (in seconds) for the ray.get call
  1759. to the remote evaluation worker(s) `sample()` method. After this time,
  1760. the user will receive a warning and instructions on how to fix the
  1761. issue. This could be either to make sure the episode ends, increasing
  1762. the timeout, or switching to `evaluation_duration_unit=timesteps`.
  1763. evaluation_parallel_to_training: Whether to run evaluation in parallel to
  1764. a Algorithm.train() call using threading. Default=False.
  1765. E.g. evaluation_interval=2 -> For every other training iteration,
  1766. the Algorithm.train() and Algorithm.evaluate() calls run in parallel.
  1767. Note: This is experimental. Possible pitfalls could be race conditions
  1768. for weight synching at the beginning of the evaluation loop.
  1769. evaluation_config: Typical usage is to pass extra args to evaluation env
  1770. creator and to disable exploration by computing deterministic actions.
  1771. IMPORTANT NOTE: Policy gradient algorithms are able to find the optimal
  1772. policy, even if this is a stochastic one. Setting "explore=False" here
  1773. will result in the evaluation workers not using this optimal policy!
  1774. off_policy_estimation_methods: Specify how to evaluate the current policy,
  1775. along with any optional config parameters. This only has an effect when
  1776. reading offline experiences ("input" is not "sampler").
  1777. Available keys:
  1778. {ope_method_name: {"type": ope_type, ...}} where `ope_method_name`
  1779. is a user-defined string to save the OPE results under, and
  1780. `ope_type` can be any subclass of OffPolicyEstimator, e.g.
  1781. ray.rllib.offline.estimators.is::ImportanceSampling
  1782. or your own custom subclass, or the full class path to the subclass.
  1783. You can also add additional config arguments to be passed to the
  1784. OffPolicyEstimator in the dict, e.g.
  1785. {"qreg_dr": {"type": DoublyRobust, "q_model_type": "qreg", "k": 5}}
  1786. ope_split_batch_by_episode: Whether to use SampleBatch.split_by_episode() to
  1787. split the input batch to episodes before estimating the ope metrics. In
  1788. case of bandits you should make this False to see improvements in ope
  1789. evaluation speed. In case of bandits, it is ok to not split by episode,
  1790. since each record is one timestep already. The default is True.
  1791. evaluation_num_workers: Number of parallel workers to use for evaluation.
  1792. Note that this is set to zero by default, which means evaluation will
  1793. be run in the algorithm process (only if evaluation_interval is not
  1794. None). If you increase this, it will increase the Ray resource usage of
  1795. the algorithm since evaluation workers are created separately from
  1796. rollout workers (used to sample data for training).
  1797. custom_evaluation_function: Customize the evaluation method. This must be a
  1798. function of signature (algo: Algorithm, eval_workers: WorkerSet) ->
  1799. metrics: dict. See the Algorithm.evaluate() method to see the default
  1800. implementation. The Algorithm guarantees all eval workers have the
  1801. latest policy state before this function is called.
  1802. always_attach_evaluation_results: Make sure the latest available evaluation
  1803. results are always attached to a step result dict. This may be useful
  1804. if Tune or some other meta controller needs access to evaluation metrics
  1805. all the time.
  1806. enable_async_evaluation: If True, use an AsyncRequestsManager for
  1807. the evaluation workers and use this manager to send `sample()` requests
  1808. to the evaluation workers. This way, the Algorithm becomes more robust
  1809. against long running episodes and/or failing (and restarting) workers.
  1810. Returns:
  1811. This updated AlgorithmConfig object.
  1812. """
  1813. if evaluation_num_episodes != DEPRECATED_VALUE:
  1814. deprecation_warning(
  1815. old="AlgorithmConfig.evaluation(evaluation_num_episodes=..)",
  1816. new="AlgorithmConfig.evaluation(evaluation_duration=.., "
  1817. "evaluation_duration_unit='episodes')",
  1818. error=False,
  1819. )
  1820. evaluation_duration = evaluation_num_episodes
  1821. if evaluation_interval is not NotProvided:
  1822. self.evaluation_interval = evaluation_interval
  1823. if evaluation_duration is not NotProvided:
  1824. self.evaluation_duration = evaluation_duration
  1825. if evaluation_duration_unit is not NotProvided:
  1826. self.evaluation_duration_unit = evaluation_duration_unit
  1827. if evaluation_sample_timeout_s is not NotProvided:
  1828. self.evaluation_sample_timeout_s = evaluation_sample_timeout_s
  1829. if evaluation_parallel_to_training is not NotProvided:
  1830. self.evaluation_parallel_to_training = evaluation_parallel_to_training
  1831. if evaluation_config is not NotProvided:
  1832. # If user really wants to set this to None, we should allow this here,
  1833. # instead of creating an empty dict.
  1834. if evaluation_config is None:
  1835. self.evaluation_config = None
  1836. # Update (don't replace) the existing overrides with the provided ones.
  1837. else:
  1838. from ray.rllib.algorithms.algorithm import Algorithm
  1839. self.evaluation_config = deep_update(
  1840. self.evaluation_config or {},
  1841. evaluation_config,
  1842. True,
  1843. Algorithm._allow_unknown_subkeys,
  1844. Algorithm._override_all_subkeys_if_type_changes,
  1845. Algorithm._override_all_key_list,
  1846. )
  1847. if off_policy_estimation_methods is not NotProvided:
  1848. self.off_policy_estimation_methods = off_policy_estimation_methods
  1849. if evaluation_num_workers is not NotProvided:
  1850. self.evaluation_num_workers = evaluation_num_workers
  1851. if custom_evaluation_function is not NotProvided:
  1852. self.custom_evaluation_function = custom_evaluation_function
  1853. if always_attach_evaluation_results is not NotProvided:
  1854. self.always_attach_evaluation_results = always_attach_evaluation_results
  1855. if enable_async_evaluation is not NotProvided:
  1856. self.enable_async_evaluation = enable_async_evaluation
  1857. if ope_split_batch_by_episode is not NotProvided:
  1858. self.ope_split_batch_by_episode = ope_split_batch_by_episode
  1859. return self
  1860. def offline_data(
  1861. self,
  1862. *,
  1863. input_=NotProvided,
  1864. input_config=NotProvided,
  1865. actions_in_input_normalized=NotProvided,
  1866. input_evaluation=NotProvided,
  1867. postprocess_inputs=NotProvided,
  1868. shuffle_buffer_size=NotProvided,
  1869. output=NotProvided,
  1870. output_config=NotProvided,
  1871. output_compress_columns=NotProvided,
  1872. output_max_file_size=NotProvided,
  1873. offline_sampling=NotProvided,
  1874. ) -> "AlgorithmConfig":
  1875. """Sets the config's offline data settings.
  1876. Args:
  1877. input_: Specify how to generate experiences:
  1878. - "sampler": Generate experiences via online (env) simulation (default).
  1879. - A local directory or file glob expression (e.g., "/tmp/*.json").
  1880. - A list of individual file paths/URIs (e.g., ["/tmp/1.json",
  1881. "s3://bucket/2.json"]).
  1882. - A dict with string keys and sampling probabilities as values (e.g.,
  1883. {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
  1884. - A callable that takes an `IOContext` object as only arg and returns a
  1885. ray.rllib.offline.InputReader.
  1886. - A string key that indexes a callable with tune.registry.register_input
  1887. input_config: Arguments that describe the settings for reading the input.
  1888. If input is `sample`, this will be environment configuation, e.g.
  1889. `env_name` and `env_config`, etc. See `EnvContext` for more info.
  1890. If the input is `dataset`, this will be e.g. `format`, `path`.
  1891. actions_in_input_normalized: True, if the actions in a given offline "input"
  1892. are already normalized (between -1.0 and 1.0). This is usually the case
  1893. when the offline file has been generated by another RLlib algorithm
  1894. (e.g. PPO or SAC), while "normalize_actions" was set to True.
  1895. postprocess_inputs: Whether to run postprocess_trajectory() on the
  1896. trajectory fragments from offline inputs. Note that postprocessing will
  1897. be done using the *current* policy, not the *behavior* policy, which
  1898. is typically undesirable for on-policy algorithms.
  1899. shuffle_buffer_size: If positive, input batches will be shuffled via a
  1900. sliding window buffer of this number of batches. Use this if the input
  1901. data is not in random enough order. Input is delayed until the shuffle
  1902. buffer is filled.
  1903. output: Specify where experiences should be saved:
  1904. - None: don't save any experiences
  1905. - "logdir" to save to the agent log dir
  1906. - a path/URI to save to a custom output directory (e.g., "s3://bckt/")
  1907. - a function that returns a rllib.offline.OutputWriter
  1908. output_config: Arguments accessible from the IOContext for configuring
  1909. custom output.
  1910. output_compress_columns: What sample batch columns to LZ4 compress in the
  1911. output data.
  1912. output_max_file_size: Max output file size (in bytes) before rolling over
  1913. to a new file.
  1914. offline_sampling: Whether sampling for the Algorithm happens via
  1915. reading from offline data. If True, RolloutWorkers will NOT limit the
  1916. number of collected batches within the same `sample()` call based on
  1917. the number of sub-environments within the worker (no sub-environments
  1918. present).
  1919. Returns:
  1920. This updated AlgorithmConfig object.
  1921. """
  1922. if input_ is not NotProvided:
  1923. self.input_ = input_
  1924. if input_config is not NotProvided:
  1925. if not isinstance(input_config, dict):
  1926. raise ValueError(
  1927. f"input_config must be a dict, got {type(input_config)}."
  1928. )
  1929. # TODO (Kourosh) Once we use a complete sepration between rollout worker
  1930. # and input dataset reader we can remove this.
  1931. # For now Error out if user attempts to set these parameters.
  1932. msg = "{} should not be set in the input_config. RLlib will use {} instead."
  1933. if input_config.get("num_cpus_per_read_task") is not None:
  1934. raise ValueError(
  1935. msg.format(
  1936. "num_cpus_per_read_task",
  1937. "config.resources(num_cpus_per_worker=..)",
  1938. )
  1939. )
  1940. if input_config.get("parallelism") is not None:
  1941. if self.in_evaluation:
  1942. raise ValueError(
  1943. msg.format(
  1944. "parallelism",
  1945. "config.evaluation(evaluation_num_workers=..)",
  1946. )
  1947. )
  1948. else:
  1949. raise ValueError(
  1950. msg.format(
  1951. "parallelism", "config.rollouts(num_rollout_workers=..)"
  1952. )
  1953. )
  1954. self.input_config = input_config
  1955. if actions_in_input_normalized is not NotProvided:
  1956. self.actions_in_input_normalized = actions_in_input_normalized
  1957. if input_evaluation is not NotProvided:
  1958. deprecation_warning(
  1959. old="offline_data(input_evaluation={})".format(input_evaluation),
  1960. new="evaluation(off_policy_estimation_methods={})".format(
  1961. input_evaluation
  1962. ),
  1963. error=True,
  1964. help="Running OPE during training is not recommended.",
  1965. )
  1966. if postprocess_inputs is not NotProvided:
  1967. self.postprocess_inputs = postprocess_inputs
  1968. if shuffle_buffer_size is not NotProvided:
  1969. self.shuffle_buffer_size = shuffle_buffer_size
  1970. if output is not NotProvided:
  1971. self.output = output
  1972. if output_config is not NotProvided:
  1973. self.output_config = output_config
  1974. if output_compress_columns is not NotProvided:
  1975. self.output_compress_columns = output_compress_columns
  1976. if output_max_file_size is not NotProvided:
  1977. self.output_max_file_size = output_max_file_size
  1978. if offline_sampling is not NotProvided:
  1979. self.offline_sampling = offline_sampling
  1980. return self
  1981. def multi_agent(
  1982. self,
  1983. *,
  1984. policies=NotProvided,
  1985. algorithm_config_overrides_per_module: Optional[
  1986. Dict[ModuleID, PartialAlgorithmConfigDict]
  1987. ] = NotProvided,
  1988. policy_map_capacity: Optional[int] = NotProvided,
  1989. policy_mapping_fn: Optional[
  1990. Callable[[AgentID, "Episode"], PolicyID]
  1991. ] = NotProvided,
  1992. policies_to_train: Optional[
  1993. Union[Container[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
  1994. ] = NotProvided,
  1995. policy_states_are_swappable: Optional[bool] = NotProvided,
  1996. observation_fn: Optional[Callable] = NotProvided,
  1997. count_steps_by: Optional[str] = NotProvided,
  1998. # Deprecated args:
  1999. replay_mode=DEPRECATED_VALUE,
  2000. # Now done via Ray object store, which has its own cloud-supported
  2001. # spillover mechanism.
  2002. policy_map_cache=DEPRECATED_VALUE,
  2003. ) -> "AlgorithmConfig":
  2004. """Sets the config's multi-agent settings.
  2005. Validates the new multi-agent settings and translates everything into
  2006. a unified multi-agent setup format. For example a `policies` list or set
  2007. of IDs is properly converted into a dict mapping these IDs to PolicySpecs.
  2008. Args:
  2009. policies: Map of type MultiAgentPolicyConfigDict from policy ids to either
  2010. 4-tuples of (policy_cls, obs_space, act_space, config) or PolicySpecs.
  2011. These tuples or PolicySpecs define the class of the policy, the
  2012. observation- and action spaces of the policies, and any extra config.
  2013. algorithm_config_overrides_per_module: Only used if both
  2014. `_enable_learner_api` and `_enable_rl_module_api` are True.
  2015. A mapping from ModuleIDs to
  2016. per-module AlgorithmConfig override dicts, which apply certain settings,
  2017. e.g. the learning rate, from the main AlgorithmConfig only to this
  2018. particular module (within a MultiAgentRLModule).
  2019. You can create override dicts by using the `AlgorithmConfig.overrides`
  2020. utility. For example, to override your learning rate and (PPO) lambda
  2021. setting just for a single RLModule with your MultiAgentRLModule, do:
  2022. config.multi_agent(algorithm_config_overrides_per_module={
  2023. "module_1": PPOConfig.overrides(lr=0.0002, lambda_=0.75),
  2024. })
  2025. policy_map_capacity: Keep this many policies in the "policy_map" (before
  2026. writing least-recently used ones to disk/S3).
  2027. policy_mapping_fn: Function mapping agent ids to policy ids. The signature
  2028. is: `(agent_id, episode, worker, **kwargs) -> PolicyID`.
  2029. policies_to_train: Determines those policies that should be updated.
  2030. Options are:
  2031. - None, for training all policies.
  2032. - An iterable of PolicyIDs that should be trained.
  2033. - A callable, taking a PolicyID and a SampleBatch or MultiAgentBatch
  2034. and returning a bool (indicating whether the given policy is trainable
  2035. or not, given the particular batch). This allows you to have a policy
  2036. trained only on certain data (e.g. when playing against a certain
  2037. opponent).
  2038. policy_states_are_swappable: Whether all Policy objects in this map can be
  2039. "swapped out" via a simple `state = A.get_state(); B.set_state(state)`,
  2040. where `A` and `B` are policy instances in this map. You should set
  2041. this to True for significantly speeding up the PolicyMap's cache lookup
  2042. times, iff your policies all share the same neural network
  2043. architecture and optimizer types. If True, the PolicyMap will not
  2044. have to garbage collect old, least recently used policies, but instead
  2045. keep them in memory and simply override their state with the state of
  2046. the most recently accessed one.
  2047. For example, in a league-based training setup, you might have 100s of
  2048. the same policies in your map (playing against each other in various
  2049. combinations), but all of them share the same state structure
  2050. (are "swappable").
  2051. observation_fn: Optional function that can be used to enhance the local
  2052. agent observations to include more state. See
  2053. rllib/evaluation/observation_function.py for more info.
  2054. count_steps_by: Which metric to use as the "batch size" when building a
  2055. MultiAgentBatch. The two supported values are:
  2056. "env_steps": Count each time the env is "stepped" (no matter how many
  2057. multi-agent actions are passed/how many multi-agent observations
  2058. have been returned in the previous step).
  2059. "agent_steps": Count each individual agent step as one step.
  2060. Returns:
  2061. This updated AlgorithmConfig object.
  2062. """
  2063. if policies is not NotProvided:
  2064. # Make sure our Policy IDs are ok (this should work whether `policies`
  2065. # is a dict or just any Sequence).
  2066. for pid in policies:
  2067. validate_policy_id(pid, error=True)
  2068. # Validate each policy spec in a given dict.
  2069. if isinstance(policies, dict):
  2070. for pid, spec in policies.items():
  2071. # If not a PolicySpec object, values must be lists/tuples of len 4.
  2072. if not isinstance(spec, PolicySpec):
  2073. if not isinstance(spec, (list, tuple)) or len(spec) != 4:
  2074. raise ValueError(
  2075. "Policy specs must be tuples/lists of "
  2076. "(cls or None, obs_space, action_space, config), "
  2077. f"got {spec} for PolicyID={pid}"
  2078. )
  2079. # TODO: Switch from dict to AlgorithmConfigOverride, once available.
  2080. # Config not a dict.
  2081. elif (
  2082. not isinstance(spec.config, (AlgorithmConfig, dict))
  2083. and spec.config is not None
  2084. ):
  2085. raise ValueError(
  2086. f"Multi-agent policy config for {pid} must be a dict or "
  2087. f"AlgorithmConfig object, but got {type(spec.config)}!"
  2088. )
  2089. self.policies = policies
  2090. if algorithm_config_overrides_per_module is not NotProvided:
  2091. self.algorithm_config_overrides_per_module = (
  2092. algorithm_config_overrides_per_module
  2093. )
  2094. if policy_map_capacity is not NotProvided:
  2095. self.policy_map_capacity = policy_map_capacity
  2096. if policy_mapping_fn is not NotProvided:
  2097. # Create `policy_mapping_fn` from a config dict.
  2098. # Helpful is users would like to specify custom callable classes in
  2099. # yaml files.
  2100. if isinstance(policy_mapping_fn, dict):
  2101. policy_mapping_fn = from_config(policy_mapping_fn)
  2102. self.policy_mapping_fn = policy_mapping_fn
  2103. if observation_fn is not NotProvided:
  2104. self.observation_fn = observation_fn
  2105. if policy_map_cache != DEPRECATED_VALUE:
  2106. deprecation_warning(
  2107. old="AlgorithmConfig.multi_agent(policy_map_cache=..)",
  2108. error=True,
  2109. )
  2110. if replay_mode != DEPRECATED_VALUE:
  2111. deprecation_warning(
  2112. old="AlgorithmConfig.multi_agent(replay_mode=..)",
  2113. new="AlgorithmConfig.training("
  2114. "replay_buffer_config={'replay_mode': ..})",
  2115. error=True,
  2116. )
  2117. if count_steps_by is not NotProvided:
  2118. if count_steps_by not in ["env_steps", "agent_steps"]:
  2119. raise ValueError(
  2120. "config.multi_agent(count_steps_by=..) must be one of "
  2121. f"[env_steps|agent_steps], not {count_steps_by}!"
  2122. )
  2123. self.count_steps_by = count_steps_by
  2124. if policies_to_train is not NotProvided:
  2125. assert (
  2126. isinstance(policies_to_train, (list, set, tuple))
  2127. or callable(policies_to_train)
  2128. or policies_to_train is None
  2129. ), (
  2130. "ERROR: `policies_to_train` must be a [list|set|tuple] or a "
  2131. "callable taking PolicyID and SampleBatch and returning "
  2132. "True|False (trainable or not?) or None (for always training all "
  2133. "policies)."
  2134. )
  2135. # Check `policies_to_train` for invalid entries.
  2136. if isinstance(policies_to_train, (list, set, tuple)):
  2137. if len(policies_to_train) == 0:
  2138. logger.warning(
  2139. "`config.multi_agent(policies_to_train=..)` is empty! "
  2140. "Make sure - if you would like to learn at least one policy - "
  2141. "to add its ID to that list."
  2142. )
  2143. self.policies_to_train = policies_to_train
  2144. if policy_states_are_swappable is not NotProvided:
  2145. self.policy_states_are_swappable = policy_states_are_swappable
  2146. return self
  2147. def is_multi_agent(self) -> bool:
  2148. """Returns whether this config specifies a multi-agent setup.
  2149. Returns:
  2150. True, if a) >1 policies defined OR b) 1 policy defined, but its ID is NOT
  2151. DEFAULT_POLICY_ID.
  2152. """
  2153. return len(self.policies) > 1 or DEFAULT_POLICY_ID not in self.policies
  2154. def reporting(
  2155. self,
  2156. *,
  2157. keep_per_episode_custom_metrics: Optional[bool] = NotProvided,
  2158. metrics_episode_collection_timeout_s: Optional[float] = NotProvided,
  2159. metrics_num_episodes_for_smoothing: Optional[int] = NotProvided,
  2160. min_time_s_per_iteration: Optional[int] = NotProvided,
  2161. min_train_timesteps_per_iteration: Optional[int] = NotProvided,
  2162. min_sample_timesteps_per_iteration: Optional[int] = NotProvided,
  2163. ) -> "AlgorithmConfig":
  2164. """Sets the config's reporting settings.
  2165. Args:
  2166. keep_per_episode_custom_metrics: Store raw custom metrics without
  2167. calculating max, min, mean
  2168. metrics_episode_collection_timeout_s: Wait for metric batches for at most
  2169. this many seconds. Those that have not returned in time will be
  2170. collected in the next train iteration.
  2171. metrics_num_episodes_for_smoothing: Smooth rollout metrics over this many
  2172. episodes, if possible.
  2173. In case rollouts (sample collection) just started, there may be fewer
  2174. than this many episodes in the buffer and we'll compute metrics
  2175. over this smaller number of available episodes.
  2176. In case there are more than this many episodes collected in a single
  2177. training iteration, use all of these episodes for metrics computation,
  2178. meaning don't ever cut any "excess" episodes.
  2179. Set this to 1 to disable smoothing and to always report only the most
  2180. recently collected episode's return.
  2181. min_time_s_per_iteration: Minimum time to accumulate within a single
  2182. `train()` call. This value does not affect learning,
  2183. only the number of times `Algorithm.training_step()` is called by
  2184. `Algorithm.train()`. If - after one such step attempt, the time taken
  2185. has not reached `min_time_s_per_iteration`, will perform n more
  2186. `training_step()` calls until the minimum time has been
  2187. consumed. Set to 0 or None for no minimum time.
  2188. min_train_timesteps_per_iteration: Minimum training timesteps to accumulate
  2189. within a single `train()` call. This value does not affect learning,
  2190. only the number of times `Algorithm.training_step()` is called by
  2191. `Algorithm.train()`. If - after one such step attempt, the training
  2192. timestep count has not been reached, will perform n more
  2193. `training_step()` calls until the minimum timesteps have been
  2194. executed. Set to 0 or None for no minimum timesteps.
  2195. min_sample_timesteps_per_iteration: Minimum env sampling timesteps to
  2196. accumulate within a single `train()` call. This value does not affect
  2197. learning, only the number of times `Algorithm.training_step()` is
  2198. called by `Algorithm.train()`. If - after one such step attempt, the env
  2199. sampling timestep count has not been reached, will perform n more
  2200. `training_step()` calls until the minimum timesteps have been
  2201. executed. Set to 0 or None for no minimum timesteps.
  2202. Returns:
  2203. This updated AlgorithmConfig object.
  2204. """
  2205. if keep_per_episode_custom_metrics is not NotProvided:
  2206. self.keep_per_episode_custom_metrics = keep_per_episode_custom_metrics
  2207. if metrics_episode_collection_timeout_s is not NotProvided:
  2208. self.metrics_episode_collection_timeout_s = (
  2209. metrics_episode_collection_timeout_s
  2210. )
  2211. if metrics_num_episodes_for_smoothing is not NotProvided:
  2212. self.metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing
  2213. if min_time_s_per_iteration is not NotProvided:
  2214. self.min_time_s_per_iteration = min_time_s_per_iteration
  2215. if min_train_timesteps_per_iteration is not NotProvided:
  2216. self.min_train_timesteps_per_iteration = min_train_timesteps_per_iteration
  2217. if min_sample_timesteps_per_iteration is not NotProvided:
  2218. self.min_sample_timesteps_per_iteration = min_sample_timesteps_per_iteration
  2219. return self
  2220. def checkpointing(
  2221. self,
  2222. export_native_model_files: Optional[bool] = NotProvided,
  2223. checkpoint_trainable_policies_only: Optional[bool] = NotProvided,
  2224. ) -> "AlgorithmConfig":
  2225. """Sets the config's checkpointing settings.
  2226. Args:
  2227. export_native_model_files: Whether an individual Policy-
  2228. or the Algorithm's checkpoints also contain (tf or torch) native
  2229. model files. These could be used to restore just the NN models
  2230. from these files w/o requiring RLlib. These files are generated
  2231. by calling the tf- or torch- built-in saving utility methods on
  2232. the actual models.
  2233. checkpoint_trainable_policies_only: Whether to only add Policies to the
  2234. Algorithm checkpoint (in sub-directory "policies/") that are trainable
  2235. according to the `is_trainable_policy` callable of the local worker.
  2236. Returns:
  2237. This updated AlgorithmConfig object.
  2238. """
  2239. if export_native_model_files is not NotProvided:
  2240. self.export_native_model_files = export_native_model_files
  2241. if checkpoint_trainable_policies_only is not NotProvided:
  2242. self.checkpoint_trainable_policies_only = checkpoint_trainable_policies_only
  2243. return self
  2244. def debugging(
  2245. self,
  2246. *,
  2247. logger_creator: Optional[Callable[[], Logger]] = NotProvided,
  2248. logger_config: Optional[dict] = NotProvided,
  2249. log_level: Optional[str] = NotProvided,
  2250. log_sys_usage: Optional[bool] = NotProvided,
  2251. fake_sampler: Optional[bool] = NotProvided,
  2252. seed: Optional[int] = NotProvided,
  2253. # deprecated
  2254. worker_cls=None,
  2255. ) -> "AlgorithmConfig":
  2256. """Sets the config's debugging settings.
  2257. Args:
  2258. logger_creator: Callable that creates a ray.tune.Logger
  2259. object. If unspecified, a default logger is created.
  2260. logger_config: Define logger-specific configuration to be used inside Logger
  2261. Default value None allows overwriting with nested dicts.
  2262. log_level: Set the ray.rllib.* log level for the agent process and its
  2263. workers. Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level
  2264. will also periodically print out summaries of relevant internal dataflow
  2265. (this is also printed out once at startup at the INFO level). When using
  2266. the `rllib train` command, you can also use the `-v` and `-vv` flags as
  2267. shorthand for INFO and DEBUG.
  2268. log_sys_usage: Log system resource metrics to results. This requires
  2269. `psutil` to be installed for sys stats, and `gputil` for GPU metrics.
  2270. fake_sampler: Use fake (infinite speed) sampler. For testing only.
  2271. seed: This argument, in conjunction with worker_index, sets the random
  2272. seed of each worker, so that identically configured trials will have
  2273. identical results. This makes experiments reproducible.
  2274. Returns:
  2275. This updated AlgorithmConfig object.
  2276. """
  2277. if worker_cls is not None:
  2278. deprecation_warning(
  2279. old="AlgorithmConfig.debugging(worker_cls=..)",
  2280. new="AlgorithmConfig.rollouts(env_runner_cls=...)",
  2281. error=True,
  2282. )
  2283. if logger_creator is not NotProvided:
  2284. self.logger_creator = logger_creator
  2285. if logger_config is not NotProvided:
  2286. self.logger_config = logger_config
  2287. if log_level is not NotProvided:
  2288. self.log_level = log_level
  2289. if log_sys_usage is not NotProvided:
  2290. self.log_sys_usage = log_sys_usage
  2291. if fake_sampler is not NotProvided:
  2292. self.fake_sampler = fake_sampler
  2293. if seed is not NotProvided:
  2294. self.seed = seed
  2295. return self
  2296. def fault_tolerance(
  2297. self,
  2298. recreate_failed_workers: Optional[bool] = NotProvided,
  2299. max_num_worker_restarts: Optional[int] = NotProvided,
  2300. delay_between_worker_restarts_s: Optional[float] = NotProvided,
  2301. restart_failed_sub_environments: Optional[bool] = NotProvided,
  2302. num_consecutive_worker_failures_tolerance: Optional[int] = NotProvided,
  2303. worker_health_probe_timeout_s: int = NotProvided,
  2304. worker_restore_timeout_s: int = NotProvided,
  2305. ):
  2306. """Sets the config's fault tolerance settings.
  2307. Args:
  2308. recreate_failed_workers: Whether - upon a worker failure - RLlib will try to
  2309. recreate the lost worker as an identical copy of the failed one. The new
  2310. worker will only differ from the failed one in its
  2311. `self.recreated_worker=True` property value. It will have the same
  2312. `worker_index` as the original one. If True, the
  2313. `ignore_worker_failures` setting will be ignored.
  2314. max_num_worker_restarts: The maximum number of times a worker is allowed to
  2315. be restarted (if `recreate_failed_workers` is True).
  2316. delay_between_worker_restarts_s: The delay (in seconds) between two
  2317. consecutive worker restarts (if `recreate_failed_workers` is True).
  2318. restart_failed_sub_environments: If True and any sub-environment (within
  2319. a vectorized env) throws any error during env stepping, the
  2320. Sampler will try to restart the faulty sub-environment. This is done
  2321. without disturbing the other (still intact) sub-environment and without
  2322. the RolloutWorker crashing.
  2323. num_consecutive_worker_failures_tolerance: The number of consecutive times
  2324. a rollout worker (or evaluation worker) failure is tolerated before
  2325. finally crashing the Algorithm. Only useful if either
  2326. `ignore_worker_failures` or `recreate_failed_workers` is True.
  2327. Note that for `restart_failed_sub_environments` and sub-environment
  2328. failures, the worker itself is NOT affected and won't throw any errors
  2329. as the flawed sub-environment is silently restarted under the hood.
  2330. worker_health_probe_timeout_s: Max amount of time we should spend waiting
  2331. for health probe calls to finish. Health pings are very cheap, so the
  2332. default is 1 minute.
  2333. worker_restore_timeout_s: Max amount of time we should wait to restore
  2334. states on recovered worker actors. Default is 30 mins.
  2335. Returns:
  2336. This updated AlgorithmConfig object.
  2337. """
  2338. if recreate_failed_workers is not NotProvided:
  2339. self.recreate_failed_workers = recreate_failed_workers
  2340. if max_num_worker_restarts is not NotProvided:
  2341. self.max_num_worker_restarts = max_num_worker_restarts
  2342. if delay_between_worker_restarts_s is not NotProvided:
  2343. self.delay_between_worker_restarts_s = delay_between_worker_restarts_s
  2344. if restart_failed_sub_environments is not NotProvided:
  2345. self.restart_failed_sub_environments = restart_failed_sub_environments
  2346. if num_consecutive_worker_failures_tolerance is not NotProvided:
  2347. self.num_consecutive_worker_failures_tolerance = (
  2348. num_consecutive_worker_failures_tolerance
  2349. )
  2350. if worker_health_probe_timeout_s is not NotProvided:
  2351. self.worker_health_probe_timeout_s = worker_health_probe_timeout_s
  2352. if worker_restore_timeout_s is not NotProvided:
  2353. self.worker_restore_timeout_s = worker_restore_timeout_s
  2354. return self
  2355. @ExperimentalAPI
  2356. def rl_module(
  2357. self,
  2358. *,
  2359. rl_module_spec: Optional[ModuleSpec] = NotProvided,
  2360. _enable_rl_module_api: Optional[bool] = NotProvided,
  2361. ) -> "AlgorithmConfig":
  2362. """Sets the config's RLModule settings.
  2363. Args:
  2364. rl_module_spec: The RLModule spec to use for this config. It can be either
  2365. a SingleAgentRLModuleSpec or a MultiAgentRLModuleSpec. If the
  2366. observation_space, action_space, catalog_class, or the model config is
  2367. not specified it will be inferred from the env and other parts of the
  2368. algorithm config object.
  2369. _enable_rl_module_api: Whether to enable the RLModule API for this config.
  2370. By default if you call `config.rl_module(...)`, the
  2371. RLModule API will NOT be enabled. If you want to enable it, you can call
  2372. `config.rl_module(_enable_rl_module_api=True)`.
  2373. Returns:
  2374. This updated AlgorithmConfig object.
  2375. """
  2376. if rl_module_spec is not NotProvided:
  2377. self.rl_module_spec = rl_module_spec
  2378. if _enable_rl_module_api is not NotProvided:
  2379. self._enable_rl_module_api = _enable_rl_module_api
  2380. if _enable_rl_module_api is True and self.exploration_config:
  2381. logger.warning(
  2382. "Setting `exploration_config={}` because you set "
  2383. "`_enable_rl_module_api=True`. When RLModule API are "
  2384. "enabled, exploration_config can not be "
  2385. "set. If you want to implement custom exploration behaviour, "
  2386. "please modify the `forward_exploration` method of the "
  2387. "RLModule at hand. On configs that have a default exploration "
  2388. "config, this must be done with "
  2389. "`config.exploration_config={}`."
  2390. )
  2391. self.__prior_exploration_config = self.exploration_config
  2392. self.exploration_config = {}
  2393. elif _enable_rl_module_api is False and not self.exploration_config:
  2394. if self.__prior_exploration_config is not None:
  2395. logger.warning(
  2396. "Setting `exploration_config="
  2397. f"{self.__prior_exploration_config}` because you set "
  2398. "`_enable_rl_module_api=False`. This exploration config was "
  2399. "restored from a prior exploration config that was overriden "
  2400. "when setting `_enable_rl_module_api=True`. This occurs "
  2401. "because when RLModule API are enabled, exploration_config "
  2402. "can not be set."
  2403. )
  2404. self.exploration_config = self.__prior_exploration_config
  2405. self.__prior_exploration_config = None
  2406. else:
  2407. logger.warning(
  2408. "config._enable_rl_module_api was set to False, but no prior "
  2409. "exploration config was found to be restored."
  2410. )
  2411. else:
  2412. # throw a warning if the user has used this API but not enabled it.
  2413. logger.warning(
  2414. "You have called `config.rl_module(...)` but "
  2415. "have not enabled the RLModule API. To enable it, call "
  2416. "`config.rl_module(_enable_rl_module_api=True)`."
  2417. )
  2418. return self
  2419. def experimental(
  2420. self,
  2421. *,
  2422. _tf_policy_handles_more_than_one_loss: Optional[bool] = NotProvided,
  2423. _disable_preprocessor_api: Optional[bool] = NotProvided,
  2424. _disable_action_flattening: Optional[bool] = NotProvided,
  2425. _disable_execution_plan_api: Optional[bool] = NotProvided,
  2426. _disable_initialize_loss_from_dummy_batch: Optional[bool] = NotProvided,
  2427. ) -> "AlgorithmConfig":
  2428. """Sets the config's experimental settings.
  2429. Args:
  2430. _tf_policy_handles_more_than_one_loss: Experimental flag.
  2431. If True, TFPolicy will handle more than one loss/optimizer.
  2432. Set this to True, if you would like to return more than
  2433. one loss term from your `loss_fn` and an equal number of optimizers
  2434. from your `optimizer_fn`. In the future, the default for this will be
  2435. True.
  2436. _disable_preprocessor_api: Experimental flag.
  2437. If True, no (observation) preprocessor will be created and
  2438. observations will arrive in model as they are returned by the env.
  2439. In the future, the default for this will be True.
  2440. _disable_action_flattening: Experimental flag.
  2441. If True, RLlib will no longer flatten the policy-computed actions into
  2442. a single tensor (for storage in SampleCollectors/output files/etc..),
  2443. but leave (possibly nested) actions as-is. Disabling flattening affects:
  2444. - SampleCollectors: Have to store possibly nested action structs.
  2445. - Models that have the previous action(s) as part of their input.
  2446. - Algorithms reading from offline files (incl. action information).
  2447. _disable_execution_plan_api: Experimental flag.
  2448. If True, the execution plan API will not be used. Instead,
  2449. a Algorithm's `training_iteration` method will be called as-is each
  2450. training iteration.
  2451. Returns:
  2452. This updated AlgorithmConfig object.
  2453. """
  2454. if _tf_policy_handles_more_than_one_loss is not NotProvided:
  2455. self._tf_policy_handles_more_than_one_loss = (
  2456. _tf_policy_handles_more_than_one_loss
  2457. )
  2458. if _disable_preprocessor_api is not NotProvided:
  2459. self._disable_preprocessor_api = _disable_preprocessor_api
  2460. if _disable_action_flattening is not NotProvided:
  2461. self._disable_action_flattening = _disable_action_flattening
  2462. if _disable_execution_plan_api is not NotProvided:
  2463. self._disable_execution_plan_api = _disable_execution_plan_api
  2464. if _disable_initialize_loss_from_dummy_batch is not NotProvided:
  2465. self._disable_initialize_loss_from_dummy_batch = (
  2466. _disable_initialize_loss_from_dummy_batch
  2467. )
  2468. return self
  2469. @property
  2470. def learner_class(self) -> Type["Learner"]:
  2471. """Returns the Learner sub-class to use by this Algorithm.
  2472. Either
  2473. a) User sets a specific learner class via calling `.training(learner_class=...)`
  2474. b) User leaves learner class unset (None) and the AlgorithmConfig itself
  2475. figures out the actual learner class by calling its own
  2476. `.get_default_learner_class()` method.
  2477. """
  2478. return self._learner_class or self.get_default_learner_class()
  2479. @property
  2480. def is_atari(self) -> bool:
  2481. """True if if specified env is an Atari env."""
  2482. # Not yet determined, try to figure this out.
  2483. if self._is_atari is None:
  2484. # Atari envs are usually specified via a string like "PongNoFrameskip-v4"
  2485. # or "ALE/Breakout-v5".
  2486. # We do NOT attempt to auto-detect Atari env for other specified types like
  2487. # a callable, to avoid running heavy logics in validate().
  2488. # For these cases, users can explicitly set `environment(atari=True)`.
  2489. if type(self.env) is not str:
  2490. return False
  2491. try:
  2492. env = gym.make(self.env)
  2493. # Any gymnasium error -> Cannot be an Atari env.
  2494. except gym.error.Error:
  2495. return False
  2496. self._is_atari = is_atari(env)
  2497. # Clean up env's resources, if any.
  2498. env.close()
  2499. return self._is_atari
  2500. # TODO: Make rollout_fragment_length as read-only property and replace the current
  2501. # self.rollout_fragment_length a private variable.
  2502. def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
  2503. """Automatically infers a proper rollout_fragment_length setting if "auto".
  2504. Uses the simple formula:
  2505. `rollout_fragment_length` = `train_batch_size` /
  2506. (`num_envs_per_worker` * `num_rollout_workers`)
  2507. If result is not a fraction AND `worker_index` is provided, will make
  2508. those workers add another timestep, such that the overall batch size (across
  2509. the workers) will add up to exactly the `train_batch_size`.
  2510. Returns:
  2511. The user-provided `rollout_fragment_length` or a computed one (if user
  2512. value is "auto").
  2513. """
  2514. if self.rollout_fragment_length == "auto":
  2515. # Example:
  2516. # 2 workers, 2 envs per worker, 2000 train batch size:
  2517. # -> 2000 / 4 -> 500
  2518. # 4 workers, 3 envs per worker, 2500 train batch size:
  2519. # -> 2500 / 12 -> 208.333 -> diff=4 (208 * 12 = 2496)
  2520. # -> worker 1: 209, workers 2-4: 208
  2521. rollout_fragment_length = self.train_batch_size / (
  2522. self.num_envs_per_worker * (self.num_rollout_workers or 1)
  2523. )
  2524. if int(rollout_fragment_length) != rollout_fragment_length:
  2525. diff = self.train_batch_size - int(
  2526. rollout_fragment_length
  2527. ) * self.num_envs_per_worker * (self.num_rollout_workers or 1)
  2528. if (worker_index * self.num_envs_per_worker) <= diff:
  2529. return int(rollout_fragment_length) + 1
  2530. return int(rollout_fragment_length)
  2531. else:
  2532. return self.rollout_fragment_length
  2533. # TODO: Make evaluation_config as read-only property and replace the current
  2534. # self.evaluation_config a private variable.
  2535. def get_evaluation_config_object(
  2536. self,
  2537. ) -> Optional["AlgorithmConfig"]:
  2538. """Creates a full AlgorithmConfig object from `self.evaluation_config`.
  2539. Returns:
  2540. A fully valid AlgorithmConfig object that can be used for the evaluation
  2541. WorkerSet. If `self` is already an evaluation config object, return None.
  2542. """
  2543. if self.in_evaluation:
  2544. assert self.evaluation_config is None
  2545. return None
  2546. evaluation_config = self.evaluation_config
  2547. # Already an AlgorithmConfig -> copy and use as-is.
  2548. if isinstance(evaluation_config, AlgorithmConfig):
  2549. eval_config_obj = evaluation_config.copy(copy_frozen=False)
  2550. # Create unfrozen copy of self to be used as the to-be-returned eval
  2551. # AlgorithmConfig.
  2552. else:
  2553. eval_config_obj = self.copy(copy_frozen=False)
  2554. # Update with evaluation override settings:
  2555. eval_config_obj.update_from_dict(evaluation_config or {})
  2556. # Switch on the `in_evaluation` flag and remove `evaluation_config`
  2557. # (set to None).
  2558. eval_config_obj.in_evaluation = True
  2559. eval_config_obj.evaluation_config = None
  2560. # Evaluation duration unit: episodes.
  2561. # Switch on `complete_episode` rollouts. Also, make sure
  2562. # rollout fragments are short so we never have more than one
  2563. # episode in one rollout.
  2564. if self.evaluation_duration_unit == "episodes":
  2565. eval_config_obj.batch_mode = "complete_episodes"
  2566. eval_config_obj.rollout_fragment_length = 1
  2567. # Evaluation duration unit: timesteps.
  2568. # - Set `batch_mode=truncate_episodes` so we don't perform rollouts
  2569. # strictly along episode borders.
  2570. # Set `rollout_fragment_length` such that desired steps are divided
  2571. # equally amongst workers or - in "auto" duration mode - set it
  2572. # to a reasonably small number (10), such that a single `sample()`
  2573. # call doesn't take too much time and we can stop evaluation as soon
  2574. # as possible after the train step is completed.
  2575. else:
  2576. eval_config_obj.batch_mode = "truncate_episodes"
  2577. eval_config_obj.rollout_fragment_length = (
  2578. 10
  2579. if self.evaluation_duration == "auto"
  2580. else int(
  2581. math.ceil(
  2582. self.evaluation_duration / (self.evaluation_num_workers or 1)
  2583. )
  2584. )
  2585. )
  2586. return eval_config_obj
  2587. def get_multi_agent_setup(
  2588. self,
  2589. *,
  2590. policies: Optional[MultiAgentPolicyConfigDict] = None,
  2591. env: Optional[EnvType] = None,
  2592. spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
  2593. default_policy_class: Optional[Type[Policy]] = None,
  2594. ) -> Tuple[MultiAgentPolicyConfigDict, Callable[[PolicyID, SampleBatchType], bool]]:
  2595. r"""Compiles complete multi-agent config (dict) from the information in `self`.
  2596. Infers the observation- and action spaces, the policy classes, and the policy's
  2597. configs. The returned `MultiAgentPolicyConfigDict` is fully unified and strictly
  2598. maps PolicyIDs to complete PolicySpec objects (with all their fields not-None).
  2599. Examples:
  2600. .. testcode::
  2601. import gymnasium as gym
  2602. from ray.rllib.algorithms.ppo import PPOConfig
  2603. config = (
  2604. PPOConfig()
  2605. .environment("CartPole-v1")
  2606. .framework("torch")
  2607. .multi_agent(policies={"pol1", "pol2"}, policies_to_train=["pol1"])
  2608. )
  2609. policy_dict, is_policy_to_train = config.get_multi_agent_setup(
  2610. env=gym.make("CartPole-v1"))
  2611. is_policy_to_train("pol1")
  2612. is_policy_to_train("pol2")
  2613. Args:
  2614. policies: An optional multi-agent `policies` dict, mapping policy IDs
  2615. to PolicySpec objects. If not provided, will use `self.policies`
  2616. instead. Note that the `policy_class`, `observation_space`, and
  2617. `action_space` properties in these PolicySpecs may be None and must
  2618. therefore be inferred here.
  2619. env: An optional env instance, from which to infer the different spaces for
  2620. the different policies. If not provided, will try to infer from
  2621. `spaces`. Otherwise from `self.observation_space` and
  2622. `self.action_space`. If no information on spaces can be infered, will
  2623. raise an error.
  2624. spaces: Optional dict mapping policy IDs to tuples of 1) observation space
  2625. and 2) action space that should be used for the respective policy.
  2626. These spaces were usually provided by an already instantiated remote
  2627. EnvRunner (usually a RolloutWorker). If not provided, will try to infer
  2628. from `env`. Otherwise from `self.observation_space` and
  2629. `self.action_space`. If no information on spaces can be inferred, will
  2630. raise an error.
  2631. default_policy_class: The Policy class to use should a PolicySpec have its
  2632. policy_class property set to None.
  2633. Returns:
  2634. A tuple consisting of 1) a MultiAgentPolicyConfigDict and 2) a
  2635. `is_policy_to_train(PolicyID, SampleBatchType) -> bool` callable.
  2636. Raises:
  2637. ValueError: In case, no spaces can be infered for the policy/ies.
  2638. ValueError: In case, two agents in the env map to the same PolicyID
  2639. (according to `self.policy_mapping_fn`), but have different action- or
  2640. observation spaces according to the infered space information.
  2641. """
  2642. policies = copy.deepcopy(policies or self.policies)
  2643. # Policies given as set/list/tuple (of PolicyIDs) -> Setup each policy
  2644. # automatically via empty PolicySpec (will make RLlib infer observation- and
  2645. # action spaces as well as the Policy's class).
  2646. if isinstance(policies, (set, list, tuple)):
  2647. policies = {pid: PolicySpec() for pid in policies}
  2648. # Try extracting spaces from env or from given spaces dict.
  2649. env_obs_space = None
  2650. env_act_space = None
  2651. # Env is a ray.remote: Get spaces via its (automatically added)
  2652. # `_get_spaces()` method.
  2653. if isinstance(env, ray.actor.ActorHandle):
  2654. env_obs_space, env_act_space = ray.get(env._get_spaces.remote())
  2655. # Normal env (gym.Env or MultiAgentEnv): These should have the
  2656. # `observation_space` and `action_space` properties.
  2657. elif env is not None:
  2658. # `env` is a gymnasium.vector.Env.
  2659. if hasattr(env, "single_observation_space") and isinstance(
  2660. env.single_observation_space, gym.Space
  2661. ):
  2662. env_obs_space = env.single_observation_space
  2663. # `env` is a gymnasium.Env.
  2664. elif hasattr(env, "observation_space") and isinstance(
  2665. env.observation_space, gym.Space
  2666. ):
  2667. env_obs_space = env.observation_space
  2668. # `env` is a gymnasium.vector.Env.
  2669. if hasattr(env, "single_action_space") and isinstance(
  2670. env.single_action_space, gym.Space
  2671. ):
  2672. env_act_space = env.single_action_space
  2673. # `env` is a gymnasium.Env.
  2674. elif hasattr(env, "action_space") and isinstance(
  2675. env.action_space, gym.Space
  2676. ):
  2677. env_act_space = env.action_space
  2678. # Last resort: Try getting the env's spaces from the spaces
  2679. # dict's special __env__ key.
  2680. if spaces is not None:
  2681. if env_obs_space is None:
  2682. env_obs_space = spaces.get("__env__", [None])[0]
  2683. if env_act_space is None:
  2684. env_act_space = spaces.get("__env__", [None, None])[1]
  2685. # Check each defined policy ID and unify its spec.
  2686. for pid, policy_spec in policies.copy().items():
  2687. # Convert to PolicySpec if plain list/tuple.
  2688. if not isinstance(policy_spec, PolicySpec):
  2689. policies[pid] = policy_spec = PolicySpec(*policy_spec)
  2690. # Infer policy classes for policies dict, if not provided (None).
  2691. if policy_spec.policy_class is None and default_policy_class is not None:
  2692. policies[pid].policy_class = default_policy_class
  2693. # In case - somehow - an old gym Space made it to here, convert it
  2694. # to the corresponding gymnasium space.
  2695. if old_gym and isinstance(policy_spec.observation_space, old_gym.Space):
  2696. policies[
  2697. pid
  2698. ].observation_space = convert_old_gym_space_to_gymnasium_space(
  2699. policy_spec.observation_space
  2700. )
  2701. # Infer observation space.
  2702. elif policy_spec.observation_space is None:
  2703. if spaces is not None and pid in spaces:
  2704. obs_space = spaces[pid][0]
  2705. elif env_obs_space is not None:
  2706. # Multi-agent case AND different agents have different spaces:
  2707. # Need to reverse map spaces (for the different agents) to certain
  2708. # policy IDs.
  2709. if (
  2710. isinstance(env, MultiAgentEnv)
  2711. and hasattr(env, "_obs_space_in_preferred_format")
  2712. and env._obs_space_in_preferred_format
  2713. ):
  2714. obs_space = None
  2715. mapping_fn = self.policy_mapping_fn
  2716. one_obs_space = next(iter(env_obs_space.values()))
  2717. # If all obs spaces are the same anyways, just use the first
  2718. # single-agent space.
  2719. if all(s == one_obs_space for s in env_obs_space.values()):
  2720. obs_space = one_obs_space
  2721. # Otherwise, we have to match the policy ID with all possible
  2722. # agent IDs and find the agent ID that matches.
  2723. elif mapping_fn:
  2724. for aid in env.get_agent_ids():
  2725. # Match: Assign spaces for this agentID to the PolicyID.
  2726. if mapping_fn(aid, None, worker=None) == pid:
  2727. # Make sure, different agents that map to the same
  2728. # policy don't have different spaces.
  2729. if (
  2730. obs_space is not None
  2731. and env_obs_space[aid] != obs_space
  2732. ):
  2733. raise ValueError(
  2734. "Two agents in your environment map to the "
  2735. "same policyID (as per your `policy_mapping"
  2736. "_fn`), however, these agents also have "
  2737. "different observation spaces!"
  2738. )
  2739. obs_space = env_obs_space[aid]
  2740. # Otherwise, just use env's obs space as-is.
  2741. else:
  2742. obs_space = env_obs_space
  2743. # Space given directly in config.
  2744. elif self.observation_space:
  2745. obs_space = self.observation_space
  2746. else:
  2747. raise ValueError(
  2748. "`observation_space` not provided in PolicySpec for "
  2749. f"{pid} and env does not have an observation space OR "
  2750. "no spaces received from other workers' env(s) OR no "
  2751. "`observation_space` specified in config!"
  2752. )
  2753. policies[pid].observation_space = obs_space
  2754. # In case - somehow - an old gym Space made it to here, convert it
  2755. # to the corresponding gymnasium space.
  2756. if old_gym and isinstance(policy_spec.action_space, old_gym.Space):
  2757. policies[pid].action_space = convert_old_gym_space_to_gymnasium_space(
  2758. policy_spec.action_space
  2759. )
  2760. # Infer action space.
  2761. elif policy_spec.action_space is None:
  2762. if spaces is not None and pid in spaces:
  2763. act_space = spaces[pid][1]
  2764. elif env_act_space is not None:
  2765. # Multi-agent case AND different agents have different spaces:
  2766. # Need to reverse map spaces (for the different agents) to certain
  2767. # policy IDs.
  2768. if (
  2769. isinstance(env, MultiAgentEnv)
  2770. and hasattr(env, "_action_space_in_preferred_format")
  2771. and env._action_space_in_preferred_format
  2772. ):
  2773. act_space = None
  2774. mapping_fn = self.policy_mapping_fn
  2775. one_act_space = next(iter(env_act_space.values()))
  2776. # If all action spaces are the same anyways, just use the first
  2777. # single-agent space.
  2778. if all(s == one_act_space for s in env_act_space.values()):
  2779. act_space = one_act_space
  2780. # Otherwise, we have to match the policy ID with all possible
  2781. # agent IDs and find the agent ID that matches.
  2782. elif mapping_fn:
  2783. for aid in env.get_agent_ids():
  2784. # Match: Assign spaces for this AgentID to the PolicyID.
  2785. if mapping_fn(aid, None, worker=None) == pid:
  2786. # Make sure, different agents that map to the same
  2787. # policy don't have different spaces.
  2788. if (
  2789. act_space is not None
  2790. and env_act_space[aid] != act_space
  2791. ):
  2792. raise ValueError(
  2793. "Two agents in your environment map to the "
  2794. "same policyID (as per your `policy_mapping"
  2795. "_fn`), however, these agents also have "
  2796. "different action spaces!"
  2797. )
  2798. act_space = env_act_space[aid]
  2799. # Otherwise, just use env's action space as-is.
  2800. else:
  2801. act_space = env_act_space
  2802. elif self.action_space:
  2803. act_space = self.action_space
  2804. else:
  2805. raise ValueError(
  2806. "`action_space` not provided in PolicySpec for "
  2807. f"{pid} and env does not have an action space OR "
  2808. "no spaces received from other workers' env(s) OR no "
  2809. "`action_space` specified in config!"
  2810. )
  2811. policies[pid].action_space = act_space
  2812. # Create entire AlgorithmConfig object from the provided override.
  2813. # If None, use {} as override.
  2814. if not isinstance(policies[pid].config, AlgorithmConfig):
  2815. assert policies[pid].config is None or isinstance(
  2816. policies[pid].config, dict
  2817. )
  2818. policies[pid].config = self.copy(copy_frozen=False).update_from_dict(
  2819. policies[pid].config or {}
  2820. )
  2821. # If container given, construct a simple default callable returning True
  2822. # if the PolicyID is found in the list/set of IDs.
  2823. if self.policies_to_train is not None and not callable(self.policies_to_train):
  2824. pols = set(self.policies_to_train)
  2825. def is_policy_to_train(pid, batch=None):
  2826. return pid in pols
  2827. else:
  2828. is_policy_to_train = self.policies_to_train
  2829. return policies, is_policy_to_train
  2830. # TODO: Move this to those algorithms that really need this, which is currently
  2831. # only A2C and PG.
  2832. def validate_train_batch_size_vs_rollout_fragment_length(self) -> None:
  2833. """Detects mismatches for `train_batch_size` vs `rollout_fragment_length`.
  2834. Only applicable for algorithms, whose train_batch_size should be directly
  2835. dependent on rollout_fragment_length (synchronous sampling, on-policy PG algos).
  2836. If rollout_fragment_length != "auto", makes sure that the product of
  2837. `rollout_fragment_length` x `num_rollout_workers` x `num_envs_per_worker`
  2838. roughly (10%) matches the provided `train_batch_size`. Otherwise, errors with
  2839. asking the user to set rollout_fragment_length to `auto` or to a matching
  2840. value.
  2841. Also, only checks this if `train_batch_size` > 0 (DDPPO sets this
  2842. to -1 to auto-calculate the actual batch size later).
  2843. Raises:
  2844. ValueError: If there is a mismatch between user provided
  2845. `rollout_fragment_length` and `train_batch_size`.
  2846. """
  2847. if (
  2848. self.rollout_fragment_length != "auto"
  2849. and not self.in_evaluation
  2850. and self.train_batch_size > 0
  2851. ):
  2852. min_batch_size = (
  2853. max(self.num_rollout_workers, 1)
  2854. * self.num_envs_per_worker
  2855. * self.rollout_fragment_length
  2856. )
  2857. batch_size = min_batch_size
  2858. while batch_size < self.train_batch_size:
  2859. batch_size += min_batch_size
  2860. if (
  2861. batch_size - self.train_batch_size > 0.1 * self.train_batch_size
  2862. or batch_size - min_batch_size - self.train_batch_size
  2863. > (0.1 * self.train_batch_size)
  2864. ):
  2865. suggested_rollout_fragment_length = self.train_batch_size // (
  2866. self.num_envs_per_worker * (self.num_rollout_workers or 1)
  2867. )
  2868. raise ValueError(
  2869. f"Your desired `train_batch_size` ({self.train_batch_size}) or a "
  2870. "value 10% off of that cannot be achieved with your other "
  2871. f"settings (num_rollout_workers={self.num_rollout_workers}; "
  2872. f"num_envs_per_worker={self.num_envs_per_worker}; "
  2873. f"rollout_fragment_length={self.rollout_fragment_length})! "
  2874. "Try setting `rollout_fragment_length` to 'auto' OR "
  2875. f"{suggested_rollout_fragment_length}."
  2876. )
  2877. def get_torch_compile_learner_config(self):
  2878. """Returns the TorchCompileConfig to use on learners."""
  2879. from ray.rllib.core.rl_module.torch.torch_compile_config import (
  2880. TorchCompileConfig,
  2881. )
  2882. return TorchCompileConfig(
  2883. torch_dynamo_backend=self.torch_compile_learner_dynamo_backend,
  2884. torch_dynamo_mode=self.torch_compile_learner_dynamo_mode,
  2885. )
  2886. def get_torch_compile_worker_config(self):
  2887. """Returns the TorchCompileConfig to use on workers."""
  2888. from ray.rllib.core.rl_module.torch.torch_compile_config import (
  2889. TorchCompileConfig,
  2890. )
  2891. return TorchCompileConfig(
  2892. torch_dynamo_backend=self.torch_compile_worker_dynamo_backend,
  2893. torch_dynamo_mode=self.torch_compile_worker_dynamo_mode,
  2894. )
  2895. def get_default_rl_module_spec(self) -> ModuleSpec:
  2896. """Returns the RLModule spec to use for this algorithm.
  2897. Override this method in the sub-class to return the RLModule spec given
  2898. the input framework.
  2899. Returns:
  2900. The ModuleSpec (SingleAgentRLModuleSpec or MultiAgentRLModuleSpec) to use
  2901. for this algorithm's RLModule.
  2902. """
  2903. raise NotImplementedError
  2904. def get_default_learner_class(self) -> Union[Type["Learner"], str]:
  2905. """Returns the Learner class to use for this algorithm.
  2906. Override this method in the sub-class to return the Learner class type given
  2907. the input framework.
  2908. Returns:
  2909. The Learner class to use for this algorithm either as a class type or as
  2910. a string (e.g. ray.rllib.core.learner.testing.torch.BC).
  2911. """
  2912. raise NotImplementedError
  2913. def get_marl_module_spec(
  2914. self,
  2915. *,
  2916. policy_dict: Dict[str, PolicySpec],
  2917. single_agent_rl_module_spec: Optional[SingleAgentRLModuleSpec] = None,
  2918. ) -> MultiAgentRLModuleSpec:
  2919. """Returns the MultiAgentRLModule spec based on the given policy spec dict.
  2920. policy_dict could be a partial dict of the policies that we need to turn into
  2921. an equivalent multi-agent RLModule spec.
  2922. Args:
  2923. policy_dict: The policy spec dict. Using this dict, we can determine the
  2924. inferred values for observation_space, action_space, and config for
  2925. each policy. If the module spec does not have these values specified,
  2926. they will get auto-filled with these values obtrained from the policy
  2927. spec dict. Here we are relying on the policy's logic for infering these
  2928. values from other sources of information (e.g. environement)
  2929. single_agent_rl_module_spec: The SingleAgentRLModuleSpec to use for
  2930. constructing a MultiAgentRLModuleSpec. If None, the already
  2931. configured spec (`self.rl_module_spec`) or the default ModuleSpec for
  2932. this algorithm (`self.get_default_rl_module_spec()`) will be used.
  2933. """
  2934. # TODO (Kourosh): When we replace policy entirely there will be no need for
  2935. # this function to map policy_dict to marl_module_specs anymore. The module
  2936. # spec will be directly given by the user or inferred from env and spaces.
  2937. # TODO (Kourosh): Raise an error if the config is not frozen (validated)
  2938. # If the module is single-agent convert it to multi-agent spec
  2939. # The default ModuleSpec (might be multi-agent or single-agent).
  2940. default_rl_module_spec = self.get_default_rl_module_spec()
  2941. # The currently configured ModuleSpec (might be multi-agent or single-agent).
  2942. # If None, use the default one.
  2943. current_rl_module_spec = self.rl_module_spec or default_rl_module_spec
  2944. # Algorithm is currently setup as a single-agent one.
  2945. if isinstance(current_rl_module_spec, SingleAgentRLModuleSpec):
  2946. # Use either the provided `single_agent_rl_module_spec` (a
  2947. # SingleAgentRLModuleSpec), the currently configured one of this
  2948. # AlgorithmConfig object, or the default one.
  2949. single_agent_rl_module_spec = (
  2950. single_agent_rl_module_spec or current_rl_module_spec
  2951. )
  2952. # Now construct the proper MultiAgentRLModuleSpec.
  2953. marl_module_spec = MultiAgentRLModuleSpec(
  2954. module_specs={
  2955. k: copy.deepcopy(single_agent_rl_module_spec)
  2956. for k in policy_dict.keys()
  2957. },
  2958. )
  2959. # Algorithm is currently setup as a multi-agent one.
  2960. else:
  2961. # The user currently has a MultiAgentSpec setup (either via
  2962. # self.rl_module_spec or the default spec of this AlgorithmConfig).
  2963. assert isinstance(current_rl_module_spec, MultiAgentRLModuleSpec)
  2964. # Default is single-agent but the user has provided a multi-agent spec
  2965. # so the use-case is multi-agent.
  2966. if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec):
  2967. # The individual (single-agent) module specs are defined by the user
  2968. # in the currently setup MultiAgentRLModuleSpec -> Use that
  2969. # SingleAgentRLModuleSpec.
  2970. if isinstance(
  2971. current_rl_module_spec.module_specs, SingleAgentRLModuleSpec
  2972. ):
  2973. single_agent_spec = single_agent_rl_module_spec or (
  2974. current_rl_module_spec.module_specs
  2975. )
  2976. module_specs = {
  2977. k: copy.deepcopy(single_agent_spec) for k in policy_dict.keys()
  2978. }
  2979. # The individual (single-agent) module specs have not been configured
  2980. # via this AlgorithmConfig object -> Use provided single-agent spec or
  2981. # the the default spec (which is also a SingleAgentRLModuleSpec in this
  2982. # case).
  2983. else:
  2984. single_agent_spec = (
  2985. single_agent_rl_module_spec or default_rl_module_spec
  2986. )
  2987. module_specs = {
  2988. k: copy.deepcopy(
  2989. current_rl_module_spec.module_specs.get(
  2990. k, single_agent_spec
  2991. )
  2992. )
  2993. for k in policy_dict.keys()
  2994. }
  2995. # Now construct the proper MultiAgentRLModuleSpec.
  2996. # We need to infer the multi-agent class from `current_rl_module_spec`
  2997. # and fill in the module_specs dict.
  2998. marl_module_spec = current_rl_module_spec.__class__(
  2999. marl_module_class=current_rl_module_spec.marl_module_class,
  3000. module_specs=module_specs,
  3001. modules_to_load=current_rl_module_spec.modules_to_load,
  3002. load_state_path=current_rl_module_spec.load_state_path,
  3003. )
  3004. # Default is multi-agent and user wants to override it -> Don't use the
  3005. # default.
  3006. else:
  3007. # Use has given an override SingleAgentRLModuleSpec -> Use this to
  3008. # construct the individual RLModules within the MultiAgentRLModuleSpec.
  3009. if single_agent_rl_module_spec is not None:
  3010. pass
  3011. # User has NOT provided an override SingleAgentRLModuleSpec.
  3012. else:
  3013. # But the currently setup multi-agent spec has a SingleAgentRLModule
  3014. # spec defined -> Use that to construct the individual RLModules
  3015. # within the MultiAgentRLModuleSpec.
  3016. if isinstance(
  3017. current_rl_module_spec.module_specs, SingleAgentRLModuleSpec
  3018. ):
  3019. # The individual module specs are not given, it is given as one
  3020. # SingleAgentRLModuleSpec to be re-used for all
  3021. single_agent_rl_module_spec = (
  3022. current_rl_module_spec.module_specs
  3023. )
  3024. # The currently setup multi-agent spec has NO
  3025. # SingleAgentRLModuleSpec in it -> Error (there is no way we can
  3026. # infer this information from anywhere at this point).
  3027. else:
  3028. raise ValueError(
  3029. "We have a MultiAgentRLModuleSpec "
  3030. f"({current_rl_module_spec}), but no "
  3031. "`SingleAgentRLModuleSpec`s to compile the individual "
  3032. "RLModules' specs! Use "
  3033. "`AlgorithmConfig.get_marl_module_spec("
  3034. "policy_dict=.., single_agent_rl_module_spec=..)`."
  3035. )
  3036. # Now construct the proper MultiAgentRLModuleSpec.
  3037. marl_module_spec = current_rl_module_spec.__class__(
  3038. marl_module_class=current_rl_module_spec.marl_module_class,
  3039. module_specs={
  3040. k: copy.deepcopy(single_agent_rl_module_spec)
  3041. for k in policy_dict.keys()
  3042. },
  3043. modules_to_load=current_rl_module_spec.modules_to_load,
  3044. load_state_path=current_rl_module_spec.load_state_path,
  3045. )
  3046. # Make sure that policy_dict and marl_module_spec have similar keys
  3047. if set(policy_dict.keys()) != set(marl_module_spec.module_specs.keys()):
  3048. raise ValueError(
  3049. "Policy dict and module spec have different keys! \n"
  3050. f"policy_dict keys: {list(policy_dict.keys())} \n"
  3051. f"module_spec keys: {list(marl_module_spec.module_specs.keys())}"
  3052. )
  3053. # Fill in the missing values from the specs that we already have. By combining
  3054. # PolicySpecs and the default RLModuleSpec.
  3055. for module_id in policy_dict:
  3056. policy_spec = policy_dict[module_id]
  3057. module_spec = marl_module_spec.module_specs[module_id]
  3058. if module_spec.module_class is None:
  3059. if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec):
  3060. module_spec.module_class = default_rl_module_spec.module_class
  3061. elif isinstance(
  3062. default_rl_module_spec.module_specs, SingleAgentRLModuleSpec
  3063. ):
  3064. module_class = default_rl_module_spec.module_specs.module_class
  3065. # This should be already checked in validate() but we check it
  3066. # again here just in case
  3067. if module_class is None:
  3068. raise ValueError(
  3069. "The default rl_module spec cannot have an empty "
  3070. "module_class under its SingleAgentRLModuleSpec."
  3071. )
  3072. module_spec.module_class = module_class
  3073. elif module_id in default_rl_module_spec.module_specs:
  3074. module_spec.module_class = default_rl_module_spec.module_specs[
  3075. module_id
  3076. ].module_class
  3077. else:
  3078. raise ValueError(
  3079. f"Module class for module {module_id} cannot be inferred. "
  3080. f"It is neither provided in the rl_module_spec that "
  3081. "is passed in nor in the default module spec used in "
  3082. "the algorithm."
  3083. )
  3084. if module_spec.catalog_class is None:
  3085. if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec):
  3086. module_spec.catalog_class = default_rl_module_spec.catalog_class
  3087. elif isinstance(
  3088. default_rl_module_spec.module_specs, SingleAgentRLModuleSpec
  3089. ):
  3090. catalog_class = default_rl_module_spec.module_specs.catalog_class
  3091. module_spec.catalog_class = catalog_class
  3092. elif module_id in default_rl_module_spec.module_specs:
  3093. module_spec.catalog_class = default_rl_module_spec.module_specs[
  3094. module_id
  3095. ].catalog_class
  3096. else:
  3097. raise ValueError(
  3098. f"Catalog class for module {module_id} cannot be inferred. "
  3099. f"It is neither provided in the rl_module_spec that "
  3100. "is passed in nor in the default module spec used in "
  3101. "the algorithm."
  3102. )
  3103. if module_spec.observation_space is None:
  3104. module_spec.observation_space = policy_spec.observation_space
  3105. if module_spec.action_space is None:
  3106. module_spec.action_space = policy_spec.action_space
  3107. if module_spec.model_config_dict is None:
  3108. module_spec.model_config_dict = policy_spec.config.get("model", {})
  3109. return marl_module_spec
  3110. def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfig:
  3111. if not self._is_frozen:
  3112. raise ValueError(
  3113. "Cannot call `get_learner_group_config()` on an unfrozen "
  3114. "AlgorithmConfig! Please call `AlgorithmConfig.freeze()` first."
  3115. )
  3116. config = (
  3117. LearnerGroupConfig()
  3118. .module(module_spec)
  3119. .learner(
  3120. learner_class=self.learner_class,
  3121. learner_hyperparameters=self.get_learner_hyperparameters(),
  3122. )
  3123. .resources(
  3124. num_learner_workers=self.num_learner_workers,
  3125. num_cpus_per_learner_worker=(
  3126. self.num_cpus_per_learner_worker
  3127. if not self.num_gpus_per_learner_worker
  3128. else 0
  3129. ),
  3130. num_gpus_per_learner_worker=self.num_gpus_per_learner_worker,
  3131. local_gpu_idx=self.local_gpu_idx,
  3132. )
  3133. )
  3134. if self.framework_str == "torch":
  3135. config.framework(
  3136. torch_compile=self.torch_compile_learner,
  3137. torch_compile_cfg=self.get_torch_compile_learner_config(),
  3138. torch_compile_what_to_compile=self.torch_compile_learner_what_to_compile, # noqa: E501
  3139. )
  3140. elif self.framework_str == "tf2":
  3141. config.framework(eager_tracing=self.eager_tracing)
  3142. return config
  3143. def get_learner_hyperparameters(self) -> LearnerHyperparameters:
  3144. """Returns a new LearnerHyperparameters instance for the respective Learner.
  3145. The LearnerHyperparameters is a dataclass containing only those config settings
  3146. from AlgorithmConfig that are used by the algorithm's specific Learner
  3147. sub-class. They allow distributing only those settings relevant for learning
  3148. across a set of learner workers (instead of having to distribute the entire
  3149. AlgorithmConfig object).
  3150. Note that LearnerHyperparameters should always be derived directly from a
  3151. AlgorithmConfig object's own settings and considered frozen/read-only.
  3152. Returns:
  3153. A LearnerHyperparameters instance for the respective Learner.
  3154. """
  3155. # Compile the per-module learner hyperparameter instances (if applicable).
  3156. per_module_learner_hp_overrides = {}
  3157. if self.algorithm_config_overrides_per_module:
  3158. for (
  3159. module_id,
  3160. overrides,
  3161. ) in self.algorithm_config_overrides_per_module.items():
  3162. # Copy this AlgorithmConfig object (unfreeze copy), update copy from
  3163. # the provided override dict for this module_id, then
  3164. # create a new LearnerHyperparameter object from this altered
  3165. # AlgorithmConfig.
  3166. config_for_module = self.copy(copy_frozen=False).update_from_dict(
  3167. overrides
  3168. )
  3169. config_for_module.algorithm_config_overrides_per_module = None
  3170. per_module_learner_hp_overrides[
  3171. module_id
  3172. ] = config_for_module.get_learner_hyperparameters()
  3173. return LearnerHyperparameters(
  3174. learning_rate=self.lr,
  3175. grad_clip=self.grad_clip,
  3176. grad_clip_by=self.grad_clip_by,
  3177. _per_module_overrides=per_module_learner_hp_overrides,
  3178. seed=self.seed,
  3179. )
  3180. def __setattr__(self, key, value):
  3181. """Gatekeeper in case we are in frozen state and need to error."""
  3182. # If we are frozen, do not allow to set any attributes anymore.
  3183. if hasattr(self, "_is_frozen") and self._is_frozen:
  3184. # TODO: Remove `simple_optimizer` entirely.
  3185. # Remove need to set `worker_index` in RolloutWorker's c'tor.
  3186. if key not in ["simple_optimizer", "worker_index", "_is_frozen"]:
  3187. raise AttributeError(
  3188. f"Cannot set attribute ({key}) of an already frozen "
  3189. "AlgorithmConfig!"
  3190. )
  3191. super().__setattr__(key, value)
  3192. def __getitem__(self, item):
  3193. """Shim method to still support accessing properties by key lookup.
  3194. This way, an AlgorithmConfig object can still be used as if a dict, e.g.
  3195. by Ray Tune.
  3196. Examples:
  3197. >>> from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  3198. >>> config = AlgorithmConfig()
  3199. >>> print(config["lr"])
  3200. ... 0.001
  3201. """
  3202. # TODO: Uncomment this once all algorithms use AlgorithmConfigs under the
  3203. # hood (as well as Ray Tune).
  3204. # if log_once("algo_config_getitem"):
  3205. # logger.warning(
  3206. # "AlgorithmConfig objects should NOT be used as dict! "
  3207. # f"Try accessing `{item}` directly as a property."
  3208. # )
  3209. # In case user accesses "old" keys, e.g. "num_workers", which need to
  3210. # be translated to their correct property names.
  3211. item = self._translate_special_keys(item)
  3212. return getattr(self, item)
  3213. def __setitem__(self, key, value):
  3214. # TODO: Remove comments once all methods/functions only support
  3215. # AlgorithmConfigs and there is no more ambiguity anywhere in the code
  3216. # on whether an AlgorithmConfig is used or an old python config dict.
  3217. # raise AttributeError(
  3218. # "AlgorithmConfig objects should not have their values set like dicts"
  3219. # f"(`config['{key}'] = {value}`), "
  3220. # f"but via setting their properties directly (config.{prop} = {value})."
  3221. # )
  3222. if key == "multiagent":
  3223. raise AttributeError(
  3224. "Cannot set `multiagent` key in an AlgorithmConfig!\nTry setting "
  3225. "the multi-agent components of your AlgorithmConfig object via the "
  3226. "`multi_agent()` method and its arguments.\nE.g. `config.multi_agent("
  3227. "policies=.., policy_mapping_fn.., policies_to_train=..)`."
  3228. )
  3229. super().__setattr__(key, value)
  3230. def __contains__(self, item) -> bool:
  3231. """Shim method to help pretend we are a dict."""
  3232. prop = self._translate_special_keys(item, warn_deprecated=False)
  3233. return hasattr(self, prop)
  3234. def get(self, key, default=None):
  3235. """Shim method to help pretend we are a dict."""
  3236. prop = self._translate_special_keys(key, warn_deprecated=False)
  3237. return getattr(self, prop, default)
  3238. def pop(self, key, default=None):
  3239. """Shim method to help pretend we are a dict."""
  3240. return self.get(key, default)
  3241. def keys(self):
  3242. """Shim method to help pretend we are a dict."""
  3243. return self.to_dict().keys()
  3244. def values(self):
  3245. """Shim method to help pretend we are a dict."""
  3246. return self.to_dict().values()
  3247. def items(self):
  3248. """Shim method to help pretend we are a dict."""
  3249. return self.to_dict().items()
  3250. @staticmethod
  3251. def _serialize_dict(config):
  3252. # Serialize classes to classpaths:
  3253. config["callbacks"] = serialize_type(config["callbacks"])
  3254. config["sample_collector"] = serialize_type(config["sample_collector"])
  3255. if isinstance(config["env"], type):
  3256. config["env"] = serialize_type(config["env"])
  3257. if "replay_buffer_config" in config and (
  3258. isinstance(config["replay_buffer_config"].get("type"), type)
  3259. ):
  3260. config["replay_buffer_config"]["type"] = serialize_type(
  3261. config["replay_buffer_config"]["type"]
  3262. )
  3263. if isinstance(config["exploration_config"].get("type"), type):
  3264. config["exploration_config"]["type"] = serialize_type(
  3265. config["exploration_config"]["type"]
  3266. )
  3267. if isinstance(config["model"].get("custom_model"), type):
  3268. config["model"]["custom_model"] = serialize_type(
  3269. config["model"]["custom_model"]
  3270. )
  3271. # List'ify `policies`, iff a set or tuple (these types are not JSON'able).
  3272. ma_config = config.get("multiagent")
  3273. if ma_config is not None:
  3274. if isinstance(ma_config.get("policies"), (set, tuple)):
  3275. ma_config["policies"] = list(ma_config["policies"])
  3276. # Do NOT serialize functions/lambdas.
  3277. if ma_config.get("policy_mapping_fn"):
  3278. ma_config["policy_mapping_fn"] = NOT_SERIALIZABLE
  3279. if ma_config.get("policies_to_train"):
  3280. ma_config["policies_to_train"] = NOT_SERIALIZABLE
  3281. # However, if these "multiagent" settings have been provided directly
  3282. # on the top-level (as they should), we override the settings under
  3283. # "multiagent". Note that the "multiagent" key should no longer be used anyways.
  3284. if isinstance(config.get("policies"), (set, tuple)):
  3285. config["policies"] = list(config["policies"])
  3286. # Do NOT serialize functions/lambdas.
  3287. if config.get("policy_mapping_fn"):
  3288. config["policy_mapping_fn"] = NOT_SERIALIZABLE
  3289. if config.get("policies_to_train"):
  3290. config["policies_to_train"] = NOT_SERIALIZABLE
  3291. return config
  3292. @staticmethod
  3293. def _translate_special_keys(key: str, warn_deprecated: bool = True) -> str:
  3294. # Handle special key (str) -> `AlgorithmConfig.[some_property]` cases.
  3295. if key == "callbacks":
  3296. key = "callbacks_class"
  3297. elif key == "create_env_on_driver":
  3298. key = "create_env_on_local_worker"
  3299. elif key == "custom_eval_function":
  3300. key = "custom_evaluation_function"
  3301. elif key == "framework":
  3302. key = "framework_str"
  3303. elif key == "input":
  3304. key = "input_"
  3305. elif key == "lambda":
  3306. key = "lambda_"
  3307. elif key == "num_cpus_for_driver":
  3308. key = "num_cpus_for_local_worker"
  3309. elif key == "num_workers":
  3310. key = "num_rollout_workers"
  3311. # Deprecated keys.
  3312. if warn_deprecated:
  3313. if key == "collect_metrics_timeout":
  3314. deprecation_warning(
  3315. old="collect_metrics_timeout",
  3316. new="metrics_episode_collection_timeout_s",
  3317. error=True,
  3318. )
  3319. elif key == "metrics_smoothing_episodes":
  3320. deprecation_warning(
  3321. old="config.metrics_smoothing_episodes",
  3322. new="config.metrics_num_episodes_for_smoothing",
  3323. error=True,
  3324. )
  3325. elif key == "min_iter_time_s":
  3326. deprecation_warning(
  3327. old="config.min_iter_time_s",
  3328. new="config.min_time_s_per_iteration",
  3329. error=True,
  3330. )
  3331. elif key == "min_time_s_per_reporting":
  3332. deprecation_warning(
  3333. old="config.min_time_s_per_reporting",
  3334. new="config.min_time_s_per_iteration",
  3335. error=True,
  3336. )
  3337. elif key == "min_sample_timesteps_per_reporting":
  3338. deprecation_warning(
  3339. old="config.min_sample_timesteps_per_reporting",
  3340. new="config.min_sample_timesteps_per_iteration",
  3341. error=True,
  3342. )
  3343. elif key == "min_train_timesteps_per_reporting":
  3344. deprecation_warning(
  3345. old="config.min_train_timesteps_per_reporting",
  3346. new="config.min_train_timesteps_per_iteration",
  3347. error=True,
  3348. )
  3349. elif key == "timesteps_per_iteration":
  3350. deprecation_warning(
  3351. old="config.timesteps_per_iteration",
  3352. new="`config.min_sample_timesteps_per_iteration` OR "
  3353. "`config.min_train_timesteps_per_iteration`",
  3354. error=True,
  3355. )
  3356. elif key == "evaluation_num_episodes":
  3357. deprecation_warning(
  3358. old="config.evaluation_num_episodes",
  3359. new="`config.evaluation_duration` and "
  3360. "`config.evaluation_duration_unit=episodes`",
  3361. error=True,
  3362. )
  3363. return key
  3364. def _check_if_correct_nn_framework_installed(self, _tf1, _tf, _torch):
  3365. """Check if tf/torch experiment is running and tf/torch installed."""
  3366. if self.framework_str in {"tf", "tf2"}:
  3367. if not (_tf1 or _tf):
  3368. raise ImportError(
  3369. (
  3370. "TensorFlow was specified as the framework to use (via `config."
  3371. "framework([tf|tf2])`)! However, no installation was "
  3372. "found. You can install TensorFlow via `pip install tensorflow`"
  3373. )
  3374. )
  3375. elif self.framework_str == "torch":
  3376. if not _torch:
  3377. raise ImportError(
  3378. (
  3379. "PyTorch was specified as the framework to use (via `config."
  3380. "framework('torch')`)! However, no installation was found. You "
  3381. "can install PyTorch via `pip install torch`."
  3382. )
  3383. )
  3384. def _resolve_tf_settings(self, _tf1, _tfv):
  3385. """Check and resolve tf settings."""
  3386. if _tf1 and self.framework_str == "tf2":
  3387. if self.framework_str == "tf2" and _tfv < 2:
  3388. raise ValueError(
  3389. "You configured `framework`=tf2, but your installed "
  3390. "pip tf-version is < 2.0! Make sure your TensorFlow "
  3391. "version is >= 2.x."
  3392. )
  3393. if not _tf1.executing_eagerly():
  3394. _tf1.enable_eager_execution()
  3395. # Recommend setting tracing to True for speedups.
  3396. logger.info(
  3397. f"Executing eagerly (framework='{self.framework_str}'),"
  3398. f" with eager_tracing={self.eager_tracing}. For "
  3399. "production workloads, make sure to set eager_tracing=True"
  3400. " in order to match the speed of tf-static-graph "
  3401. "(framework='tf'). For debugging purposes, "
  3402. "`eager_tracing=False` is the best choice."
  3403. )
  3404. # Tf-static-graph (framework=tf): Recommend upgrading to tf2 and
  3405. # enabling eager tracing for similar speed.
  3406. elif _tf1 and self.framework_str == "tf":
  3407. logger.info(
  3408. "Your framework setting is 'tf', meaning you are using "
  3409. "static-graph mode. Set framework='tf2' to enable eager "
  3410. "execution with tf2.x. You may also then want to set "
  3411. "eager_tracing=True in order to reach similar execution "
  3412. "speed as with static-graph mode."
  3413. )
  3414. @property
  3415. @Deprecated(
  3416. old="AlgorithmConfig.multiagent['[some key]']",
  3417. new="AlgorithmConfig.[some key]",
  3418. error=True,
  3419. )
  3420. def multiagent(self):
  3421. pass
  3422. @property
  3423. @Deprecated(new="AlgorithmConfig.rollouts(num_rollout_workers=..)", error=True)
  3424. def num_workers(self):
  3425. pass