rollout_worker.py 84 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054
  1. import copy
  2. import importlib.util
  3. import logging
  4. import os
  5. import platform
  6. import threading
  7. from collections import defaultdict
  8. from types import FunctionType
  9. from typing import (
  10. TYPE_CHECKING,
  11. Any,
  12. Callable,
  13. Container,
  14. Dict,
  15. List,
  16. Optional,
  17. Set,
  18. Tuple,
  19. Type,
  20. Union,
  21. )
  22. import numpy as np
  23. import tree # pip install dm_tree
  24. from gymnasium.spaces import Discrete, MultiDiscrete, Space
  25. import ray
  26. from ray import ObjectRef
  27. from ray import cloudpickle as pickle
  28. from ray.rllib.connectors.util import (
  29. create_connectors_for_policy,
  30. maybe_get_filters_for_syncing,
  31. )
  32. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  33. from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
  34. from ray.rllib.env.env_context import EnvContext
  35. from ray.rllib.env.env_runner import EnvRunner
  36. from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
  37. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  38. from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind
  39. from ray.rllib.evaluation.metrics import RolloutMetrics
  40. from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
  41. from ray.rllib.models import ModelCatalog
  42. from ray.rllib.models.preprocessors import Preprocessor
  43. from ray.rllib.offline import (
  44. D4RLReader,
  45. DatasetReader,
  46. DatasetWriter,
  47. InputReader,
  48. IOContext,
  49. JsonReader,
  50. JsonWriter,
  51. MixedInput,
  52. NoopOutput,
  53. OutputWriter,
  54. ShuffledInput,
  55. )
  56. from ray.rllib.policy.policy import Policy, PolicySpec
  57. from ray.rllib.policy.policy_map import PolicyMap
  58. from ray.rllib.policy.sample_batch import (
  59. DEFAULT_POLICY_ID,
  60. MultiAgentBatch,
  61. concat_samples,
  62. convert_ma_batch_to_sample_batch,
  63. )
  64. from ray.rllib.policy.torch_policy import TorchPolicy
  65. from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
  66. from ray.rllib.utils import check_env, force_list
  67. from ray.rllib.utils.annotations import DeveloperAPI, override
  68. from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
  69. from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
  70. from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
  71. from ray.rllib.utils.filter import Filter, NoFilter, get_filter
  72. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  73. from ray.rllib.utils.from_config import from_config
  74. from ray.rllib.utils.policy import create_policy_for_framework, validate_policy_id
  75. from ray.rllib.utils.sgd import do_minibatch_sgd
  76. from ray.rllib.utils.tf_run_builder import _TFRunBuilder
  77. from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
  78. from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
  79. from ray.rllib.utils.typing import (
  80. AgentID,
  81. EnvCreator,
  82. EnvType,
  83. ModelGradients,
  84. ModelWeights,
  85. MultiAgentPolicyConfigDict,
  86. PartialAlgorithmConfigDict,
  87. PolicyID,
  88. PolicyState,
  89. SampleBatchType,
  90. T,
  91. )
  92. from ray.tune.registry import registry_contains_input, registry_get_input
  93. from ray.util.annotations import PublicAPI
  94. from ray.util.debug import disable_log_once_globally, enable_periodic_logging, log_once
  95. from ray.util.iter import ParallelIteratorWorker
  96. if TYPE_CHECKING:
  97. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  98. from ray.rllib.algorithms.callbacks import DefaultCallbacks # noqa
  99. from ray.rllib.evaluation.episode import Episode
  100. tf1, tf, tfv = try_import_tf()
  101. torch, _ = try_import_torch()
  102. logger = logging.getLogger(__name__)
  103. # Handle to the current rollout worker, which will be set to the most recently
  104. # created RolloutWorker in this process. This can be helpful to access in
  105. # custom env or policy classes for debugging or advanced use cases.
  106. _global_worker: Optional["RolloutWorker"] = None
  107. @DeveloperAPI
  108. def get_global_worker() -> "RolloutWorker":
  109. """Returns a handle to the active rollout worker in this process."""
  110. global _global_worker
  111. return _global_worker
  112. def _update_env_seed_if_necessary(
  113. env: EnvType, seed: int, worker_idx: int, vector_idx: int
  114. ):
  115. """Set a deterministic random seed on environment.
  116. NOTE: this may not work with remote environments (issue #18154).
  117. """
  118. if seed is None:
  119. return
  120. # A single RL job is unlikely to have more than 10K
  121. # rollout workers.
  122. max_num_envs_per_workers: int = 1000
  123. assert (
  124. worker_idx < max_num_envs_per_workers
  125. ), "Too many envs per worker. Random seeds may collide."
  126. computed_seed: int = worker_idx * max_num_envs_per_workers + vector_idx + seed
  127. # Gymnasium.env.
  128. # This will silently fail for most Farama-foundation gymnasium environments.
  129. # (they do nothing and return None per default)
  130. if not hasattr(env, "reset"):
  131. if log_once("env_has_no_reset_method"):
  132. logger.info(f"Env {env} doesn't have a `reset()` method. Cannot seed.")
  133. else:
  134. try:
  135. env.reset(seed=computed_seed)
  136. except Exception:
  137. logger.info(
  138. f"Env {env} doesn't support setting a seed via its `reset()` "
  139. "method! Implement this method as `reset(self, *, seed=None, "
  140. "options=None)` for it to abide to the correct API. Cannot seed."
  141. )
  142. @DeveloperAPI
  143. class RolloutWorker(ParallelIteratorWorker, EnvRunner):
  144. """Common experience collection class.
  145. This class wraps a policy instance and an environment class to
  146. collect experiences from the environment. You can create many replicas of
  147. this class as Ray actors to scale RL training.
  148. This class supports vectorized and multi-agent policy evaluation (e.g.,
  149. VectorEnv, MultiAgentEnv, etc.)
  150. Examples:
  151. >>> # Create a rollout worker and using it to collect experiences.
  152. >>> import gymnasium as gym
  153. >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
  154. >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
  155. >>> worker = RolloutWorker( # doctest: +SKIP
  156. ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
  157. ... default_policy_class=PGTF1Policy) # doctest: +SKIP
  158. >>> print(worker.sample()) # doctest: +SKIP
  159. SampleBatch({
  160. "obs": [[...]], "actions": [[...]], "rewards": [[...]],
  161. "terminateds": [[...]], "truncateds": [[...]], "new_obs": [[...]]})
  162. >>> # Creating a multi-agent rollout worker
  163. >>> from gymnasium.spaces import Discrete, Box
  164. >>> import random
  165. >>> MultiAgentTrafficGrid = ... # doctest: +SKIP
  166. >>> worker = RolloutWorker( # doctest: +SKIP
  167. ... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
  168. ... config=AlgorithmConfig().multi_agent(
  169. ... policies={ # doctest: +SKIP
  170. ... # Use an ensemble of two policies for car agents
  171. ... "car_policy1": # doctest: +SKIP
  172. ... (PGTFPolicy, Box(...), Discrete(...),
  173. ... AlgorithmConfig.overrides(gamma=0.99)),
  174. ... "car_policy2": # doctest: +SKIP
  175. ... (PGTFPolicy, Box(...), Discrete(...),
  176. ... AlgorithmConfig.overrides(gamma=0.95)),
  177. ... # Use a single shared policy for all traffic lights
  178. ... "traffic_light_policy":
  179. ... (PGTFPolicy, Box(...), Discrete(...), {}),
  180. ... },
  181. ... policy_mapping_fn=(
  182. ... lambda agent_id, episode, **kwargs:
  183. ... random.choice(["car_policy1", "car_policy2"])
  184. ... if agent_id.startswith("car_") else "traffic_light_policy"),
  185. ... ),
  186. .. )
  187. >>> print(worker.sample()) # doctest: +SKIP
  188. MultiAgentBatch({
  189. "car_policy1": SampleBatch(...),
  190. "car_policy2": SampleBatch(...),
  191. "traffic_light_policy": SampleBatch(...)})
  192. """
  193. def __init__(
  194. self,
  195. *,
  196. env_creator: EnvCreator,
  197. validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
  198. config: Optional["AlgorithmConfig"] = None,
  199. worker_index: int = 0,
  200. num_workers: Optional[int] = None,
  201. recreated_worker: bool = False,
  202. log_dir: Optional[str] = None,
  203. spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
  204. default_policy_class: Optional[Type[Policy]] = None,
  205. dataset_shards: Optional[List[ray.data.Dataset]] = None,
  206. # Deprecated: This is all specified in `config` anyways.
  207. tf_session_creator=DEPRECATED_VALUE, # Use config.tf_session_options instead.
  208. ):
  209. """Initializes a RolloutWorker instance.
  210. Args:
  211. env_creator: Function that returns a gym.Env given an EnvContext
  212. wrapped configuration.
  213. validate_env: Optional callable to validate the generated
  214. environment (only on worker=0).
  215. worker_index: For remote workers, this should be set to a
  216. non-zero and unique value. This index is passed to created envs
  217. through EnvContext so that envs can be configured per worker.
  218. recreated_worker: Whether this worker is a recreated one. Workers are
  219. recreated by an Algorithm (via WorkerSet) in case
  220. `recreate_failed_workers=True` and one of the original workers (or an
  221. already recreated one) has failed. They don't differ from original
  222. workers other than the value of this flag (`self.recreated_worker`).
  223. log_dir: Directory where logs can be placed.
  224. spaces: An optional space dict mapping policy IDs
  225. to (obs_space, action_space)-tuples. This is used in case no
  226. Env is created on this RolloutWorker.
  227. """
  228. # Deprecated args.
  229. if tf_session_creator != DEPRECATED_VALUE:
  230. deprecation_warning(
  231. old="RolloutWorker(.., tf_session_creator=.., ..)",
  232. new="config.framework(tf_session_args={..}); "
  233. "RolloutWorker(config=config, ..)",
  234. error=True,
  235. )
  236. self._original_kwargs: dict = locals().copy()
  237. del self._original_kwargs["self"]
  238. global _global_worker
  239. _global_worker = self
  240. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  241. # Default config needed?
  242. if config is None or isinstance(config, dict):
  243. config = AlgorithmConfig().update_from_dict(config or {})
  244. # Freeze config, so no one else can alter it from here on.
  245. config.freeze()
  246. # Set extra python env variables before calling super constructor.
  247. if config.extra_python_environs_for_driver and worker_index == 0:
  248. for key, value in config.extra_python_environs_for_driver.items():
  249. os.environ[key] = str(value)
  250. elif config.extra_python_environs_for_worker and worker_index > 0:
  251. for key, value in config.extra_python_environs_for_worker.items():
  252. os.environ[key] = str(value)
  253. def gen_rollouts():
  254. while True:
  255. yield self.sample()
  256. ParallelIteratorWorker.__init__(self, gen_rollouts, False)
  257. EnvRunner.__init__(self, config=config)
  258. self.num_workers = (
  259. num_workers if num_workers is not None else self.config.num_rollout_workers
  260. )
  261. # In case we are reading from distributed datasets, store the shards here
  262. # and pick our shard by our worker-index.
  263. self._ds_shards = dataset_shards
  264. self.worker_index: int = worker_index
  265. # Lock to be able to lock this entire worker
  266. # (via `self.lock()` and `self.unlock()`).
  267. # This might be crucial to prevent a race condition in case
  268. # `config.policy_states_are_swappable=True` and you are using an Algorithm
  269. # with a learner thread. In this case, the thread might update a policy
  270. # that is being swapped (during the update) by the Algorithm's
  271. # training_step's `RolloutWorker.get_weights()` call (to sync back the
  272. # new weights to all remote workers).
  273. self._lock = threading.Lock()
  274. if (
  275. tf1
  276. and (config.framework_str == "tf2" or config.enable_tf1_exec_eagerly)
  277. # This eager check is necessary for certain all-framework tests
  278. # that use tf's eager_mode() context generator.
  279. and not tf1.executing_eagerly()
  280. ):
  281. tf1.enable_eager_execution()
  282. if self.config.log_level:
  283. logging.getLogger("ray.rllib").setLevel(self.config.log_level)
  284. if self.worker_index > 1:
  285. disable_log_once_globally() # only need 1 worker to log
  286. elif self.config.log_level == "DEBUG":
  287. enable_periodic_logging()
  288. env_context = EnvContext(
  289. self.config.env_config,
  290. worker_index=self.worker_index,
  291. vector_index=0,
  292. num_workers=self.num_workers,
  293. remote=self.config.remote_worker_envs,
  294. recreated_worker=recreated_worker,
  295. )
  296. self.env_context = env_context
  297. self.config: AlgorithmConfig = config
  298. self.callbacks: DefaultCallbacks = self.config.callbacks_class()
  299. self.recreated_worker: bool = recreated_worker
  300. # Setup current policy_mapping_fn. Start with the one from the config, which
  301. # might be None in older checkpoints (nowadays AlgorithmConfig has a proper
  302. # default for this); Need to cover this situation via the backup lambda here.
  303. self.policy_mapping_fn = (
  304. lambda agent_id, episode, worker, **kw: DEFAULT_POLICY_ID
  305. )
  306. self.set_policy_mapping_fn(self.config.policy_mapping_fn)
  307. self.env_creator: EnvCreator = env_creator
  308. # Resolve possible auto-fragment length.
  309. configured_rollout_fragment_length = self.config.get_rollout_fragment_length(
  310. worker_index=self.worker_index
  311. )
  312. self.total_rollout_fragment_length: int = (
  313. configured_rollout_fragment_length * self.config.num_envs_per_worker
  314. )
  315. self.preprocessing_enabled: bool = not config._disable_preprocessor_api
  316. self.last_batch: Optional[SampleBatchType] = None
  317. self.global_vars: dict = {
  318. # TODO(sven): Make this per-policy!
  319. "timestep": 0,
  320. # Counter for performed gradient updates per policy in `self.policy_map`.
  321. # Allows for compiling metrics on the off-policy'ness of an update given
  322. # that the number of gradient updates of the sampling policies are known
  323. # to the learner (and can be compared to the learner version of the same
  324. # policy).
  325. "num_grad_updates_per_policy": defaultdict(int),
  326. }
  327. # If seed is provided, add worker index to it and 10k iff evaluation worker.
  328. self.seed = (
  329. None
  330. if self.config.seed is None
  331. else self.config.seed
  332. + self.worker_index
  333. + self.config.in_evaluation * 10000
  334. )
  335. # Update the global seed for numpy/random/tf-eager/torch if we are not
  336. # the local worker, otherwise, this was already done in the Algorithm
  337. # object itself.
  338. if self.worker_index > 0:
  339. update_global_seed_if_necessary(self.config.framework_str, self.seed)
  340. # A single environment provided by the user (via config.env). This may
  341. # also remain None.
  342. # 1) Create the env using the user provided env_creator. This may
  343. # return a gym.Env (incl. MultiAgentEnv), an already vectorized
  344. # VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
  345. # 2) Wrap - if applicable - with Atari/rendering wrappers.
  346. # 3) Seed the env, if necessary.
  347. # 4) Vectorize the existing single env by creating more clones of
  348. # this env and wrapping it with the RLlib BaseEnv class.
  349. self.env = self.make_sub_env_fn = None
  350. # Create a (single) env for this worker.
  351. if not (
  352. self.worker_index == 0
  353. and self.num_workers > 0
  354. and not self.config.create_env_on_local_worker
  355. ):
  356. # Run the `env_creator` function passing the EnvContext.
  357. self.env = env_creator(copy.deepcopy(self.env_context))
  358. clip_rewards = self.config.clip_rewards
  359. if self.env is not None:
  360. # Validate environment (general validation function).
  361. if not self.config.disable_env_checking:
  362. check_env(self.env, self.config)
  363. # Custom validation function given, typically a function attribute of the
  364. # Algorithm.
  365. if validate_env is not None:
  366. validate_env(self.env, self.env_context)
  367. # We can't auto-wrap a BaseEnv.
  368. if isinstance(self.env, (BaseEnv, ray.actor.ActorHandle)):
  369. def wrap(env):
  370. return env
  371. # Atari type env and "deepmind" preprocessor pref.
  372. elif is_atari(self.env) and self.config.preprocessor_pref == "deepmind":
  373. # Deepmind wrappers already handle all preprocessing.
  374. self.preprocessing_enabled = False
  375. # If clip_rewards not explicitly set to False, switch it
  376. # on here (clip between -1.0 and 1.0).
  377. if self.config.clip_rewards is None:
  378. clip_rewards = True
  379. # Framestacking is used.
  380. use_framestack = self.config.model.get("framestack") is True
  381. def wrap(env):
  382. env = wrap_deepmind(
  383. env,
  384. dim=self.config.model.get("dim"),
  385. framestack=use_framestack,
  386. noframeskip=self.config.env_config.get("frameskip", 0) == 1,
  387. )
  388. return env
  389. elif self.config.preprocessor_pref is None:
  390. # Only turn off preprocessing
  391. self.preprocessing_enabled = False
  392. def wrap(env):
  393. return env
  394. else:
  395. def wrap(env):
  396. return env
  397. # Wrap env through the correct wrapper.
  398. self.env: EnvType = wrap(self.env)
  399. # Ideally, we would use the same make_sub_env() function below
  400. # to create self.env, but wrap(env) and self.env has a cyclic
  401. # dependency on each other right now, so we would settle on
  402. # duplicating the random seed setting logic for now.
  403. _update_env_seed_if_necessary(self.env, self.seed, self.worker_index, 0)
  404. # Call custom callback function `on_sub_environment_created`.
  405. self.callbacks.on_sub_environment_created(
  406. worker=self,
  407. sub_environment=self.env,
  408. env_context=self.env_context,
  409. )
  410. self.make_sub_env_fn = self._get_make_sub_env_fn(
  411. env_creator, env_context, validate_env, wrap, self.seed
  412. )
  413. self.spaces = spaces
  414. self.default_policy_class = default_policy_class
  415. self.policy_dict, self.is_policy_to_train = self.config.get_multi_agent_setup(
  416. env=self.env,
  417. spaces=self.spaces,
  418. default_policy_class=self.default_policy_class,
  419. )
  420. self.policy_map: Optional[PolicyMap] = None
  421. # TODO(jungong) : clean up after non-connector env_runner is fully deprecated.
  422. self.preprocessors: Dict[PolicyID, Preprocessor] = None
  423. # Check available number of GPUs.
  424. num_gpus = (
  425. self.config.num_gpus
  426. if self.worker_index == 0
  427. else self.config.num_gpus_per_worker
  428. )
  429. # This is only for the old API where local_worker was responsible for learning
  430. if not self.config._enable_learner_api:
  431. # Error if we don't find enough GPUs.
  432. if (
  433. ray.is_initialized()
  434. and ray._private.worker._mode() != ray._private.worker.LOCAL_MODE
  435. and not config._fake_gpus
  436. ):
  437. devices = []
  438. if self.config.framework_str in ["tf2", "tf"]:
  439. devices = get_tf_gpu_devices()
  440. elif self.config.framework_str == "torch":
  441. devices = list(range(torch.cuda.device_count()))
  442. if len(devices) < num_gpus:
  443. raise RuntimeError(
  444. ERR_MSG_NO_GPUS.format(len(devices), devices)
  445. + HOWTO_CHANGE_CONFIG
  446. )
  447. # Warn, if running in local-mode and actual GPUs (not faked) are
  448. # requested.
  449. elif (
  450. ray.is_initialized()
  451. and ray._private.worker._mode() == ray._private.worker.LOCAL_MODE
  452. and num_gpus > 0
  453. and not self.config._fake_gpus
  454. ):
  455. logger.warning(
  456. "You are running ray with `local_mode=True`, but have "
  457. f"configured {num_gpus} GPUs to be used! In local mode, "
  458. f"Policies are placed on the CPU and the `num_gpus` setting "
  459. f"is ignored."
  460. )
  461. self.filters: Dict[PolicyID, Filter] = defaultdict(NoFilter)
  462. # if RLModule API is enabled, marl_module_spec holds the specs of the RLModules
  463. self.marl_module_spec = None
  464. self._update_policy_map(policy_dict=self.policy_dict)
  465. # Update Policy's view requirements from Model, only if Policy directly
  466. # inherited from base `Policy` class. At this point here, the Policy
  467. # must have it's Model (if any) defined and ready to output an initial
  468. # state.
  469. for pol in self.policy_map.values():
  470. if not pol._model_init_state_automatically_added and not pol.config.get(
  471. "_enable_rl_module_api", False
  472. ):
  473. pol._update_model_view_requirements_from_init_state()
  474. self.multiagent: bool = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
  475. if self.multiagent and self.env is not None:
  476. if not isinstance(
  477. self.env,
  478. (BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv, ray.actor.ActorHandle),
  479. ):
  480. raise ValueError(
  481. f"Have multiple policies {self.policy_map}, but the "
  482. f"env {self.env} is not a subclass of BaseEnv, "
  483. f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!"
  484. )
  485. if self.worker_index == 0:
  486. logger.info("Built filter map: {}".format(self.filters))
  487. # This RolloutWorker has no env.
  488. if self.env is None:
  489. self.async_env = None
  490. # Use a custom env-vectorizer and call it providing self.env.
  491. elif "custom_vector_env" in self.config:
  492. self.async_env = self.config.custom_vector_env(self.env)
  493. # Default: Vectorize self.env via the make_sub_env function. This adds
  494. # further clones of self.env and creates a RLlib BaseEnv (which is
  495. # vectorized under the hood).
  496. else:
  497. # Always use vector env for consistency even if num_envs_per_worker=1.
  498. self.async_env: BaseEnv = convert_to_base_env(
  499. self.env,
  500. make_env=self.make_sub_env_fn,
  501. num_envs=self.config.num_envs_per_worker,
  502. remote_envs=self.config.remote_worker_envs,
  503. remote_env_batch_wait_ms=self.config.remote_env_batch_wait_ms,
  504. worker=self,
  505. restart_failed_sub_environments=(
  506. self.config.restart_failed_sub_environments
  507. ),
  508. )
  509. # `truncate_episodes`: Allow a batch to contain more than one episode
  510. # (fragments) and always make the batch `rollout_fragment_length`
  511. # long.
  512. rollout_fragment_length_for_sampler = configured_rollout_fragment_length
  513. if self.config.batch_mode == "truncate_episodes":
  514. pack = True
  515. # `complete_episodes`: Never cut episodes and sampler will return
  516. # exactly one (complete) episode per poll.
  517. else:
  518. assert self.config.batch_mode == "complete_episodes"
  519. rollout_fragment_length_for_sampler = float("inf")
  520. pack = False
  521. # Create the IOContext for this worker.
  522. self.io_context: IOContext = IOContext(
  523. log_dir, self.config, self.worker_index, self
  524. )
  525. render = False
  526. if self.config.render_env is True and (
  527. self.num_workers == 0 or self.worker_index == 1
  528. ):
  529. render = True
  530. if self.env is None:
  531. self.sampler = None
  532. elif self.config.sample_async:
  533. self.sampler = AsyncSampler(
  534. worker=self,
  535. env=self.async_env,
  536. clip_rewards=clip_rewards,
  537. rollout_fragment_length=rollout_fragment_length_for_sampler,
  538. count_steps_by=self.config.count_steps_by,
  539. callbacks=self.callbacks,
  540. multiple_episodes_in_batch=pack,
  541. normalize_actions=self.config.normalize_actions,
  542. clip_actions=self.config.clip_actions,
  543. observation_fn=self.config.observation_fn,
  544. sample_collector_class=self.config.sample_collector,
  545. render=render,
  546. )
  547. # Start the Sampler thread.
  548. self.sampler.start()
  549. else:
  550. self.sampler = SyncSampler(
  551. worker=self,
  552. env=self.async_env,
  553. clip_rewards=clip_rewards,
  554. rollout_fragment_length=rollout_fragment_length_for_sampler,
  555. count_steps_by=self.config.count_steps_by,
  556. callbacks=self.callbacks,
  557. multiple_episodes_in_batch=pack,
  558. normalize_actions=self.config.normalize_actions,
  559. clip_actions=self.config.clip_actions,
  560. observation_fn=self.config.observation_fn,
  561. sample_collector_class=self.config.sample_collector,
  562. render=render,
  563. )
  564. self.input_reader: InputReader = self._get_input_creator_from_config()(
  565. self.io_context
  566. )
  567. self.output_writer: OutputWriter = self._get_output_creator_from_config()(
  568. self.io_context
  569. )
  570. # The current weights sequence number (version). May remain None for when
  571. # not tracking weights versions.
  572. self.weights_seq_no: Optional[int] = None
  573. logger.debug(
  574. "Created rollout worker with env {} ({}), policies {}".format(
  575. self.async_env, self.env, self.policy_map
  576. )
  577. )
  578. @override(EnvRunner)
  579. def assert_healthy(self):
  580. is_healthy = self.policy_map and self.input_reader and self.output_writer
  581. assert is_healthy, (
  582. f"RolloutWorker {self} (idx={self.worker_index}; "
  583. f"num_workers={self.num_workers}) not healthy!"
  584. )
  585. @override(EnvRunner)
  586. def sample(self, **kwargs) -> SampleBatchType:
  587. """Returns a batch of experience sampled from this worker.
  588. This method must be implemented by subclasses.
  589. Returns:
  590. A columnar batch of experiences (e.g., tensors) or a MultiAgentBatch.
  591. Examples:
  592. >>> import gymnasium as gym
  593. >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
  594. >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
  595. >>> worker = RolloutWorker( # doctest: +SKIP
  596. ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
  597. ... default_policy_class=PGTF1Policy, # doctest: +SKIP
  598. ... config=AlgorithmConfig(), # doctest: +SKIP
  599. ... )
  600. >>> print(worker.sample()) # doctest: +SKIP
  601. SampleBatch({"obs": [...], "action": [...], ...})
  602. """
  603. if self.config.fake_sampler and self.last_batch is not None:
  604. return self.last_batch
  605. elif self.input_reader is None:
  606. raise ValueError(
  607. "RolloutWorker has no `input_reader` object! "
  608. "Cannot call `sample()`. You can try setting "
  609. "`create_env_on_driver` to True."
  610. )
  611. if log_once("sample_start"):
  612. logger.info(
  613. "Generating sample batch of size {}".format(
  614. self.total_rollout_fragment_length
  615. )
  616. )
  617. batches = [self.input_reader.next()]
  618. steps_so_far = (
  619. batches[0].count
  620. if self.config.count_steps_by == "env_steps"
  621. else batches[0].agent_steps()
  622. )
  623. # In truncate_episodes mode, never pull more than 1 batch per env.
  624. # This avoids over-running the target batch size.
  625. if (
  626. self.config.batch_mode == "truncate_episodes"
  627. and not self.config.offline_sampling
  628. ):
  629. max_batches = self.config.num_envs_per_worker
  630. else:
  631. max_batches = float("inf")
  632. while steps_so_far < self.total_rollout_fragment_length and (
  633. len(batches) < max_batches
  634. ):
  635. batch = self.input_reader.next()
  636. steps_so_far += (
  637. batch.count
  638. if self.config.count_steps_by == "env_steps"
  639. else batch.agent_steps()
  640. )
  641. batches.append(batch)
  642. batch = concat_samples(batches)
  643. self.callbacks.on_sample_end(worker=self, samples=batch)
  644. # Always do writes prior to compression for consistency and to allow
  645. # for better compression inside the writer.
  646. self.output_writer.write(batch)
  647. if log_once("sample_end"):
  648. logger.info("Completed sample batch:\n\n{}\n".format(summarize(batch)))
  649. if self.config.compress_observations:
  650. batch.compress(bulk=self.config.compress_observations == "bulk")
  651. if self.config.fake_sampler:
  652. self.last_batch = batch
  653. return batch
  654. @ray.method(num_returns=2)
  655. def sample_with_count(self) -> Tuple[SampleBatchType, int]:
  656. """Same as sample() but returns the count as a separate value.
  657. Returns:
  658. A columnar batch of experiences (e.g., tensors) and the
  659. size of the collected batch.
  660. Examples:
  661. >>> import gymnasium as gym
  662. >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
  663. >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
  664. >>> worker = RolloutWorker( # doctest: +SKIP
  665. ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
  666. ... default_policy_class=PGTFPolicy) # doctest: +SKIP
  667. >>> print(worker.sample_with_count()) # doctest: +SKIP
  668. (SampleBatch({"obs": [...], "action": [...], ...}), 3)
  669. """
  670. batch = self.sample()
  671. return batch, batch.count
  672. def learn_on_batch(self, samples: SampleBatchType) -> Dict:
  673. """Update policies based on the given batch.
  674. This is the equivalent to apply_gradients(compute_gradients(samples)),
  675. but can be optimized to avoid pulling gradients into CPU memory.
  676. Args:
  677. samples: The SampleBatch or MultiAgentBatch to learn on.
  678. Returns:
  679. Dictionary of extra metadata from compute_gradients().
  680. Examples:
  681. >>> import gymnasium as gym
  682. >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
  683. >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
  684. >>> worker = RolloutWorker( # doctest: +SKIP
  685. ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
  686. ... default_policy_class=PGTF1Policy) # doctest: +SKIP
  687. >>> batch = worker.sample() # doctest: +SKIP
  688. >>> info = worker.learn_on_batch(samples) # doctest: +SKIP
  689. """
  690. if log_once("learn_on_batch"):
  691. logger.info(
  692. "Training on concatenated sample batches:\n\n{}\n".format(
  693. summarize(samples)
  694. )
  695. )
  696. info_out = {}
  697. if isinstance(samples, MultiAgentBatch):
  698. builders = {}
  699. to_fetch = {}
  700. for pid, batch in samples.policy_batches.items():
  701. if self.is_policy_to_train is not None and not self.is_policy_to_train(
  702. pid, samples
  703. ):
  704. continue
  705. # Decompress SampleBatch, in case some columns are compressed.
  706. batch.decompress_if_needed()
  707. policy = self.policy_map[pid]
  708. tf_session = policy.get_session()
  709. if tf_session and hasattr(policy, "_build_learn_on_batch"):
  710. builders[pid] = _TFRunBuilder(tf_session, "learn_on_batch")
  711. to_fetch[pid] = policy._build_learn_on_batch(builders[pid], batch)
  712. else:
  713. info_out[pid] = policy.learn_on_batch(batch)
  714. info_out.update({pid: builders[pid].get(v) for pid, v in to_fetch.items()})
  715. else:
  716. if self.is_policy_to_train is None or self.is_policy_to_train(
  717. DEFAULT_POLICY_ID, samples
  718. ):
  719. info_out.update(
  720. {
  721. DEFAULT_POLICY_ID: self.policy_map[
  722. DEFAULT_POLICY_ID
  723. ].learn_on_batch(samples)
  724. }
  725. )
  726. if log_once("learn_out"):
  727. logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
  728. return info_out
  729. def sample_and_learn(
  730. self,
  731. expected_batch_size: int,
  732. num_sgd_iter: int,
  733. sgd_minibatch_size: str,
  734. standardize_fields: List[str],
  735. ) -> Tuple[dict, int]:
  736. """Sample and batch and learn on it.
  737. This is typically used in combination with distributed allreduce.
  738. Args:
  739. expected_batch_size: Expected number of samples to learn on.
  740. num_sgd_iter: Number of SGD iterations.
  741. sgd_minibatch_size: SGD minibatch size.
  742. standardize_fields: List of sample fields to normalize.
  743. Returns:
  744. A tuple consisting of a dictionary of extra metadata returned from
  745. the policies' `learn_on_batch()` and the number of samples
  746. learned on.
  747. """
  748. batch = self.sample()
  749. assert batch.count == expected_batch_size, (
  750. "Batch size possibly out of sync between workers, expected:",
  751. expected_batch_size,
  752. "got:",
  753. batch.count,
  754. )
  755. logger.info(
  756. "Executing distributed minibatch SGD "
  757. "with epoch size {}, minibatch size {}".format(
  758. batch.count, sgd_minibatch_size
  759. )
  760. )
  761. info = do_minibatch_sgd(
  762. batch,
  763. self.policy_map,
  764. self,
  765. num_sgd_iter,
  766. sgd_minibatch_size,
  767. standardize_fields,
  768. )
  769. return info, batch.count
  770. def compute_gradients(
  771. self,
  772. samples: SampleBatchType,
  773. single_agent: bool = None,
  774. ) -> Tuple[ModelGradients, dict]:
  775. """Returns a gradient computed w.r.t the specified samples.
  776. Uses the Policy's/ies' compute_gradients method(s) to perform the
  777. calculations. Skips policies that are not trainable as per
  778. `self.is_policy_to_train()`.
  779. Args:
  780. samples: The SampleBatch or MultiAgentBatch to compute gradients
  781. for using this worker's trainable policies.
  782. Returns:
  783. In the single-agent case, a tuple consisting of ModelGradients and
  784. info dict of the worker's policy.
  785. In the multi-agent case, a tuple consisting of a dict mapping
  786. PolicyID to ModelGradients and a dict mapping PolicyID to extra
  787. metadata info.
  788. Note that the first return value (grads) can be applied as is to a
  789. compatible worker using the worker's `apply_gradients()` method.
  790. Examples:
  791. >>> import gymnasium as gym
  792. >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
  793. >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
  794. >>> worker = RolloutWorker( # doctest: +SKIP
  795. ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
  796. ... default_policy_class=PGTF1Policy) # doctest: +SKIP
  797. >>> batch = worker.sample() # doctest: +SKIP
  798. >>> grads, info = worker.compute_gradients(samples) # doctest: +SKIP
  799. """
  800. if log_once("compute_gradients"):
  801. logger.info("Compute gradients on:\n\n{}\n".format(summarize(samples)))
  802. if single_agent is True:
  803. samples = convert_ma_batch_to_sample_batch(samples)
  804. grad_out, info_out = self.policy_map[DEFAULT_POLICY_ID].compute_gradients(
  805. samples
  806. )
  807. info_out["batch_count"] = samples.count
  808. return grad_out, info_out
  809. # Treat everything as is multi-agent.
  810. samples = samples.as_multi_agent()
  811. # Calculate gradients for all policies.
  812. grad_out, info_out = {}, {}
  813. if self.config.framework_str == "tf":
  814. for pid, batch in samples.policy_batches.items():
  815. if self.is_policy_to_train is not None and not self.is_policy_to_train(
  816. pid, samples
  817. ):
  818. continue
  819. policy = self.policy_map[pid]
  820. builder = _TFRunBuilder(policy.get_session(), "compute_gradients")
  821. grad_out[pid], info_out[pid] = policy._build_compute_gradients(
  822. builder, batch
  823. )
  824. grad_out = {k: builder.get(v) for k, v in grad_out.items()}
  825. info_out = {k: builder.get(v) for k, v in info_out.items()}
  826. else:
  827. for pid, batch in samples.policy_batches.items():
  828. if self.is_policy_to_train is not None and not self.is_policy_to_train(
  829. pid, samples
  830. ):
  831. continue
  832. grad_out[pid], info_out[pid] = self.policy_map[pid].compute_gradients(
  833. batch
  834. )
  835. info_out["batch_count"] = samples.count
  836. if log_once("grad_out"):
  837. logger.info("Compute grad info:\n\n{}\n".format(summarize(info_out)))
  838. return grad_out, info_out
  839. def apply_gradients(
  840. self,
  841. grads: Union[ModelGradients, Dict[PolicyID, ModelGradients]],
  842. ) -> None:
  843. """Applies the given gradients to this worker's models.
  844. Uses the Policy's/ies' apply_gradients method(s) to perform the
  845. operations.
  846. Args:
  847. grads: Single ModelGradients (single-agent case) or a dict
  848. mapping PolicyIDs to the respective model gradients
  849. structs.
  850. Examples:
  851. >>> import gymnasium as gym
  852. >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
  853. >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
  854. >>> worker = RolloutWorker( # doctest: +SKIP
  855. ... env_creator=lambda _: gym.make("CartPole-v1"), # doctest: +SKIP
  856. ... default_policy_class=PGTF1Policy) # doctest: +SKIP
  857. >>> samples = worker.sample() # doctest: +SKIP
  858. >>> grads, info = worker.compute_gradients(samples) # doctest: +SKIP
  859. >>> worker.apply_gradients(grads) # doctest: +SKIP
  860. """
  861. if log_once("apply_gradients"):
  862. logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
  863. # Grads is a dict (mapping PolicyIDs to ModelGradients).
  864. # Multi-agent case.
  865. if isinstance(grads, dict):
  866. for pid, g in grads.items():
  867. if self.is_policy_to_train is None or self.is_policy_to_train(
  868. pid, None
  869. ):
  870. self.policy_map[pid].apply_gradients(g)
  871. # Grads is a ModelGradients type. Single-agent case.
  872. elif self.is_policy_to_train is None or self.is_policy_to_train(
  873. DEFAULT_POLICY_ID, None
  874. ):
  875. self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
  876. def get_metrics(self) -> List[RolloutMetrics]:
  877. """Returns the thus-far collected metrics from this worker's rollouts.
  878. Returns:
  879. List of RolloutMetrics collected thus-far.
  880. """
  881. # Get metrics from sampler (if any).
  882. if self.sampler is not None:
  883. out = self.sampler.get_metrics()
  884. else:
  885. out = []
  886. return out
  887. def foreach_env(self, func: Callable[[EnvType], T]) -> List[T]:
  888. """Calls the given function with each sub-environment as arg.
  889. Args:
  890. func: The function to call for each underlying
  891. sub-environment (as only arg).
  892. Returns:
  893. The list of return values of all calls to `func([env])`.
  894. """
  895. if self.async_env is None:
  896. return []
  897. envs = self.async_env.get_sub_environments()
  898. # Empty list (not implemented): Call function directly on the
  899. # BaseEnv.
  900. if not envs:
  901. return [func(self.async_env)]
  902. # Call function on all underlying (vectorized) sub environments.
  903. else:
  904. return [func(e) for e in envs]
  905. def foreach_env_with_context(
  906. self, func: Callable[[EnvType, EnvContext], T]
  907. ) -> List[T]:
  908. """Calls given function with each sub-env plus env_ctx as args.
  909. Args:
  910. func: The function to call for each underlying
  911. sub-environment and its EnvContext (as the args).
  912. Returns:
  913. The list of return values of all calls to `func([env, ctx])`.
  914. """
  915. if self.async_env is None:
  916. return []
  917. envs = self.async_env.get_sub_environments()
  918. # Empty list (not implemented): Call function directly on the
  919. # BaseEnv.
  920. if not envs:
  921. return [func(self.async_env, self.env_context)]
  922. # Call function on all underlying (vectorized) sub environments.
  923. else:
  924. ret = []
  925. for i, e in enumerate(envs):
  926. ctx = self.env_context.copy_with_overrides(vector_index=i)
  927. ret.append(func(e, ctx))
  928. return ret
  929. def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Optional[Policy]:
  930. """Return policy for the specified id, or None.
  931. Args:
  932. policy_id: ID of the policy to return. None for DEFAULT_POLICY_ID
  933. (in the single agent case).
  934. Returns:
  935. The policy under the given ID (or None if not found).
  936. """
  937. return self.policy_map.get(policy_id)
  938. def add_policy(
  939. self,
  940. policy_id: PolicyID,
  941. policy_cls: Optional[Type[Policy]] = None,
  942. policy: Optional[Policy] = None,
  943. *,
  944. observation_space: Optional[Space] = None,
  945. action_space: Optional[Space] = None,
  946. config: Optional[PartialAlgorithmConfigDict] = None,
  947. policy_state: Optional[PolicyState] = None,
  948. policy_mapping_fn: Optional[Callable[[AgentID, "Episode"], PolicyID]] = None,
  949. policies_to_train: Optional[
  950. Union[Container[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
  951. ] = None,
  952. module_spec: Optional[SingleAgentRLModuleSpec] = None,
  953. ) -> Policy:
  954. """Adds a new policy to this RolloutWorker.
  955. Args:
  956. policy_id: ID of the policy to add.
  957. policy_cls: The Policy class to use for constructing the new Policy.
  958. Note: Only one of `policy_cls` or `policy` must be provided.
  959. policy: The Policy instance to add to this algorithm.
  960. Note: Only one of `policy_cls` or `policy` must be provided.
  961. observation_space: The observation space of the policy to add.
  962. action_space: The action space of the policy to add.
  963. config: The config overrides for the policy to add.
  964. policy_state: Optional state dict to apply to the new
  965. policy instance, right after its construction.
  966. policy_mapping_fn: An optional (updated) policy mapping function
  967. to use from here on. Note that already ongoing episodes will
  968. not change their mapping but will use the old mapping till
  969. the end of the episode.
  970. policies_to_train: An optional container of policy IDs to be
  971. trained or a callable taking PolicyID and - optionally -
  972. SampleBatchType and returning a bool (trainable or not?).
  973. If None, will keep the existing setup in place.
  974. Policies, whose IDs are not in the list (or for which the
  975. callable returns False) will not be updated.
  976. module_spec: In the new RLModule API we need to pass in the module_spec for
  977. the new module that is supposed to be added. Knowing the policy spec is
  978. not sufficient.
  979. Returns:
  980. The newly added policy.
  981. Raises:
  982. ValueError: If both `policy_cls` AND `policy` are provided.
  983. KeyError: If the given `policy_id` already exists in this worker's
  984. PolicyMap.
  985. """
  986. validate_policy_id(policy_id, error=False)
  987. if module_spec is not None and not self.config._enable_rl_module_api:
  988. raise ValueError(
  989. "If you pass in module_spec to the policy, the RLModule API needs "
  990. "to be enabled."
  991. )
  992. if policy_id in self.policy_map:
  993. raise KeyError(
  994. f"Policy ID '{policy_id}' already exists in policy map! "
  995. "Make sure you use a Policy ID that has not been taken yet."
  996. " Policy IDs that are already in your policy map: "
  997. f"{list(self.policy_map.keys())}"
  998. )
  999. if (policy_cls is None) == (policy is None):
  1000. raise ValueError(
  1001. "Only one of `policy_cls` or `policy` must be provided to "
  1002. "RolloutWorker.add_policy()!"
  1003. )
  1004. if policy is None:
  1005. policy_dict_to_add, _ = self.config.get_multi_agent_setup(
  1006. policies={
  1007. policy_id: PolicySpec(
  1008. policy_cls, observation_space, action_space, config
  1009. )
  1010. },
  1011. env=self.env,
  1012. spaces=self.spaces,
  1013. default_policy_class=self.default_policy_class,
  1014. )
  1015. else:
  1016. policy_dict_to_add = {
  1017. policy_id: PolicySpec(
  1018. type(policy),
  1019. policy.observation_space,
  1020. policy.action_space,
  1021. policy.config,
  1022. )
  1023. }
  1024. self.policy_dict.update(policy_dict_to_add)
  1025. self._update_policy_map(
  1026. policy_dict=policy_dict_to_add,
  1027. policy=policy,
  1028. policy_states={policy_id: policy_state},
  1029. single_agent_rl_module_spec=module_spec,
  1030. )
  1031. self.set_policy_mapping_fn(policy_mapping_fn)
  1032. if policies_to_train is not None:
  1033. self.set_is_policy_to_train(policies_to_train)
  1034. return self.policy_map[policy_id]
  1035. def remove_policy(
  1036. self,
  1037. *,
  1038. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1039. policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
  1040. policies_to_train: Optional[
  1041. Union[Container[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
  1042. ] = None,
  1043. ) -> None:
  1044. """Removes a policy from this RolloutWorker.
  1045. Args:
  1046. policy_id: ID of the policy to be removed. None for
  1047. DEFAULT_POLICY_ID.
  1048. policy_mapping_fn: An optional (updated) policy mapping function
  1049. to use from here on. Note that already ongoing episodes will
  1050. not change their mapping but will use the old mapping till
  1051. the end of the episode.
  1052. policies_to_train: An optional container of policy IDs to be
  1053. trained or a callable taking PolicyID and - optionally -
  1054. SampleBatchType and returning a bool (trainable or not?).
  1055. If None, will keep the existing setup in place.
  1056. Policies, whose IDs are not in the list (or for which the
  1057. callable returns False) will not be updated.
  1058. """
  1059. if policy_id not in self.policy_map:
  1060. raise ValueError(f"Policy ID '{policy_id}' not in policy map!")
  1061. del self.policy_map[policy_id]
  1062. del self.preprocessors[policy_id]
  1063. self.set_policy_mapping_fn(policy_mapping_fn)
  1064. if policies_to_train is not None:
  1065. self.set_is_policy_to_train(policies_to_train)
  1066. def set_policy_mapping_fn(
  1067. self,
  1068. policy_mapping_fn: Optional[Callable[[AgentID, "Episode"], PolicyID]] = None,
  1069. ) -> None:
  1070. """Sets `self.policy_mapping_fn` to a new callable (if provided).
  1071. Args:
  1072. policy_mapping_fn: The new mapping function to use. If None,
  1073. will keep the existing mapping function in place.
  1074. """
  1075. if policy_mapping_fn is not None:
  1076. self.policy_mapping_fn = policy_mapping_fn
  1077. if not callable(self.policy_mapping_fn):
  1078. raise ValueError("`policy_mapping_fn` must be a callable!")
  1079. def set_is_policy_to_train(
  1080. self,
  1081. is_policy_to_train: Union[
  1082. Container[PolicyID], Callable[[PolicyID, Optional[SampleBatchType]], bool]
  1083. ],
  1084. ) -> None:
  1085. """Sets `self.is_policy_to_train()` to a new callable.
  1086. Args:
  1087. is_policy_to_train: A container of policy IDs to be
  1088. trained or a callable taking PolicyID and - optionally -
  1089. SampleBatchType and returning a bool (trainable or not?).
  1090. If None, will keep the existing setup in place.
  1091. Policies, whose IDs are not in the list (or for which the
  1092. callable returns False) will not be updated.
  1093. """
  1094. # If container given, construct a simple default callable returning True
  1095. # if the PolicyID is found in the list/set of IDs.
  1096. if not callable(is_policy_to_train):
  1097. assert isinstance(is_policy_to_train, (list, set, tuple)), (
  1098. "ERROR: `is_policy_to_train`must be a [list|set|tuple] or a "
  1099. "callable taking PolicyID and SampleBatch and returning "
  1100. "True|False (trainable or not?)."
  1101. )
  1102. pols = set(is_policy_to_train)
  1103. def is_policy_to_train(pid, batch=None):
  1104. return pid in pols
  1105. self.is_policy_to_train = is_policy_to_train
  1106. @PublicAPI(stability="alpha")
  1107. def get_policies_to_train(
  1108. self, batch: Optional[SampleBatchType] = None
  1109. ) -> Set[PolicyID]:
  1110. """Returns all policies-to-train, given an optional batch.
  1111. Loops through all policies currently in `self.policy_map` and checks
  1112. the return value of `self.is_policy_to_train(pid, batch)`.
  1113. Args:
  1114. batch: An optional SampleBatchType for the
  1115. `self.is_policy_to_train(pid, [batch]?)` check.
  1116. Returns:
  1117. The set of currently trainable policy IDs, given the optional
  1118. `batch`.
  1119. """
  1120. return {
  1121. pid
  1122. for pid in self.policy_map.keys()
  1123. if self.is_policy_to_train is None or self.is_policy_to_train(pid, batch)
  1124. }
  1125. def for_policy(
  1126. self,
  1127. func: Callable[[Policy, Optional[Any]], T],
  1128. policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID,
  1129. **kwargs,
  1130. ) -> T:
  1131. """Calls the given function with the specified policy as first arg.
  1132. Args:
  1133. func: The function to call with the policy as first arg.
  1134. policy_id: The PolicyID of the policy to call the function with.
  1135. Keyword Args:
  1136. kwargs: Additional kwargs to be passed to the call.
  1137. Returns:
  1138. The return value of the function call.
  1139. """
  1140. return func(self.policy_map[policy_id], **kwargs)
  1141. def foreach_policy(
  1142. self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs
  1143. ) -> List[T]:
  1144. """Calls the given function with each (policy, policy_id) tuple.
  1145. Args:
  1146. func: The function to call with each (policy, policy ID) tuple.
  1147. Keyword Args:
  1148. kwargs: Additional kwargs to be passed to the call.
  1149. Returns:
  1150. The list of return values of all calls to
  1151. `func([policy, pid, **kwargs])`.
  1152. """
  1153. return [func(policy, pid, **kwargs) for pid, policy in self.policy_map.items()]
  1154. def foreach_policy_to_train(
  1155. self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs
  1156. ) -> List[T]:
  1157. """
  1158. Calls the given function with each (policy, policy_id) tuple.
  1159. Only those policies/IDs will be called on, for which
  1160. `self.is_policy_to_train()` returns True.
  1161. Args:
  1162. func: The function to call with each (policy, policy ID) tuple,
  1163. for only those policies that `self.is_policy_to_train`
  1164. returns True.
  1165. Keyword Args:
  1166. kwargs: Additional kwargs to be passed to the call.
  1167. Returns:
  1168. The list of return values of all calls to
  1169. `func([policy, pid, **kwargs])`.
  1170. """
  1171. return [
  1172. # Make sure to only iterate over keys() and not items(). Iterating over
  1173. # items will access policy_map elements even for pids that we do not need,
  1174. # i.e. those that are not in policy_to_train. Access to policy_map elements
  1175. # can cause disk access for policies that were offloaded to disk. Since
  1176. # these policies will be skipped in the for-loop accessing them is
  1177. # unnecessary, making subsequent disk access unnecessary.
  1178. func(self.policy_map[pid], pid, **kwargs)
  1179. for pid in self.policy_map.keys()
  1180. if self.is_policy_to_train is None or self.is_policy_to_train(pid, None)
  1181. ]
  1182. def sync_filters(self, new_filters: dict) -> None:
  1183. """Changes self's filter to given and rebases any accumulated delta.
  1184. Args:
  1185. new_filters: Filters with new state to update local copy.
  1186. """
  1187. assert all(k in new_filters for k in self.filters)
  1188. for k in self.filters:
  1189. self.filters[k].sync(new_filters[k])
  1190. def get_filters(self, flush_after: bool = False) -> Dict:
  1191. """Returns a snapshot of filters.
  1192. Args:
  1193. flush_after: Clears the filter buffer state.
  1194. Returns:
  1195. Dict for serializable filters
  1196. """
  1197. return_filters = {}
  1198. for k, f in self.filters.items():
  1199. return_filters[k] = f.as_serializable()
  1200. if flush_after:
  1201. f.reset_buffer()
  1202. return return_filters
  1203. @override(EnvRunner)
  1204. def get_state(self) -> dict:
  1205. filters = self.get_filters(flush_after=True)
  1206. policy_states = {}
  1207. for pid in self.policy_map.keys():
  1208. # If required by the user, only capture policies that are actually
  1209. # trainable. Otherwise, capture all policies (for saving to disk).
  1210. if (
  1211. not self.config.checkpoint_trainable_policies_only
  1212. or self.is_policy_to_train is None
  1213. or self.is_policy_to_train(pid)
  1214. ):
  1215. policy_states[pid] = self.policy_map[pid].get_state()
  1216. return {
  1217. # List all known policy IDs here for convenience. When an Algorithm gets
  1218. # restored from a checkpoint, it will not have access to the list of
  1219. # possible IDs as each policy is stored in its own sub-dir
  1220. # (see "policy_states").
  1221. "policy_ids": list(self.policy_map.keys()),
  1222. # Note that this field will not be stored in the algorithm checkpoint's
  1223. # state file, but each policy will get its own state file generated in
  1224. # a sub-dir within the algo's checkpoint dir.
  1225. "policy_states": policy_states,
  1226. # Also store current mapping fn and which policies to train.
  1227. "policy_mapping_fn": self.policy_mapping_fn,
  1228. "is_policy_to_train": self.is_policy_to_train,
  1229. # TODO: Filters will be replaced by connectors.
  1230. "filters": filters,
  1231. }
  1232. @override(EnvRunner)
  1233. def set_state(self, state: dict) -> None:
  1234. # Backward compatibility (old checkpoints' states would have the local
  1235. # worker state as a bytes object, not a dict).
  1236. if isinstance(state, bytes):
  1237. state = pickle.loads(state)
  1238. # TODO: Once filters are handled by connectors, get rid of the "filters"
  1239. # key in `state` entirely (will be part of the policies then).
  1240. self.sync_filters(state["filters"])
  1241. connector_enabled = self.config.enable_connectors
  1242. # Support older checkpoint versions (< 1.0), in which the policy_map
  1243. # was stored under the "state" key, not "policy_states".
  1244. policy_states = (
  1245. state["policy_states"] if "policy_states" in state else state["state"]
  1246. )
  1247. for pid, policy_state in policy_states.items():
  1248. # If - for some reason - we have an invalid PolicyID in the state,
  1249. # this might be from an older checkpoint (pre v1.0). Just warn here.
  1250. validate_policy_id(pid, error=False)
  1251. if pid not in self.policy_map:
  1252. spec = policy_state.get("policy_spec", None)
  1253. if spec is None:
  1254. logger.warning(
  1255. f"PolicyID '{pid}' was probably added on-the-fly (not"
  1256. " part of the static `multagent.policies` config) and"
  1257. " no PolicySpec objects found in the pickled policy "
  1258. f"state. Will not add `{pid}`, but ignore it for now."
  1259. )
  1260. else:
  1261. policy_spec = (
  1262. PolicySpec.deserialize(spec)
  1263. if connector_enabled or isinstance(spec, dict)
  1264. else spec
  1265. )
  1266. self.add_policy(
  1267. policy_id=pid,
  1268. policy_cls=policy_spec.policy_class,
  1269. observation_space=policy_spec.observation_space,
  1270. action_space=policy_spec.action_space,
  1271. config=policy_spec.config,
  1272. )
  1273. if pid in self.policy_map:
  1274. self.policy_map[pid].set_state(policy_state)
  1275. # Also restore mapping fn and which policies to train.
  1276. if "policy_mapping_fn" in state:
  1277. self.set_policy_mapping_fn(state["policy_mapping_fn"])
  1278. if state.get("is_policy_to_train") is not None:
  1279. self.set_is_policy_to_train(state["is_policy_to_train"])
  1280. def get_weights(
  1281. self,
  1282. policies: Optional[Container[PolicyID]] = None,
  1283. ) -> Dict[PolicyID, ModelWeights]:
  1284. """Returns each policies' model weights of this worker.
  1285. Args:
  1286. policies: List of PolicyIDs to get the weights from.
  1287. Use None for all policies.
  1288. Returns:
  1289. Dict mapping PolicyIDs to ModelWeights.
  1290. Examples:
  1291. >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
  1292. >>> # Create a RolloutWorker.
  1293. >>> worker = ... # doctest: +SKIP
  1294. >>> weights = worker.get_weights() # doctest: +SKIP
  1295. >>> print(weights) # doctest: +SKIP
  1296. {"default_policy": {"layer1": array(...), "layer2": ...}}
  1297. """
  1298. if policies is None:
  1299. policies = list(self.policy_map.keys())
  1300. policies = force_list(policies)
  1301. return {
  1302. # Make sure to only iterate over keys() and not items(). Iterating over
  1303. # items will access policy_map elements even for pids that we do not need,
  1304. # i.e. those that are not in policies. Access to policy_map elements can
  1305. # cause disk access for policies that were offloaded to disk. Since these
  1306. # policies will be skipped in the for-loop accessing them is unnecessary,
  1307. # making subsequent disk access unnecessary.
  1308. pid: self.policy_map[pid].get_weights()
  1309. for pid in self.policy_map.keys()
  1310. if pid in policies
  1311. }
  1312. def set_weights(
  1313. self,
  1314. weights: Dict[PolicyID, ModelWeights],
  1315. global_vars: Optional[Dict] = None,
  1316. weights_seq_no: Optional[int] = None,
  1317. ) -> None:
  1318. """Sets each policies' model weights of this worker.
  1319. Args:
  1320. weights: Dict mapping PolicyIDs to the new weights to be used.
  1321. global_vars: An optional global vars dict to set this
  1322. worker to. If None, do not update the global_vars.
  1323. weights_seq_no: If needed, a sequence number for the weights version
  1324. can be passed into this method. If not None, will store this seq no
  1325. (in self.weights_seq_no) and in future calls - if the seq no did not
  1326. change wrt. the last call - will ignore the call to save on performance.
  1327. Examples:
  1328. >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
  1329. >>> # Create a RolloutWorker.
  1330. >>> worker = ... # doctest: +SKIP
  1331. >>> weights = worker.get_weights() # doctest: +SKIP
  1332. >>> # Set `global_vars` (timestep) as well.
  1333. >>> worker.set_weights(weights, {"timestep": 42}) # doctest: +SKIP
  1334. """
  1335. # Only update our weights, if no seq no given OR given seq no is different
  1336. # from ours.
  1337. if weights_seq_no is None or weights_seq_no != self.weights_seq_no:
  1338. # If per-policy weights are object refs, `ray.get()` them first.
  1339. if weights and isinstance(next(iter(weights.values())), ObjectRef):
  1340. actual_weights = ray.get(list(weights.values()))
  1341. weights = {
  1342. pid: actual_weights[i] for i, pid in enumerate(weights.keys())
  1343. }
  1344. for pid, w in weights.items():
  1345. self.policy_map[pid].set_weights(w)
  1346. self.weights_seq_no = weights_seq_no
  1347. if global_vars:
  1348. self.set_global_vars(global_vars)
  1349. def get_global_vars(self) -> dict:
  1350. """Returns the current `self.global_vars` dict of this RolloutWorker.
  1351. Returns:
  1352. The current `self.global_vars` dict of this RolloutWorker.
  1353. Examples:
  1354. >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker
  1355. >>> # Create a RolloutWorker.
  1356. >>> worker = ... # doctest: +SKIP
  1357. >>> global_vars = worker.get_global_vars() # doctest: +SKIP
  1358. >>> print(global_vars) # doctest: +SKIP
  1359. {"timestep": 424242}
  1360. """
  1361. return self.global_vars
  1362. def set_global_vars(
  1363. self,
  1364. global_vars: dict,
  1365. policy_ids: Optional[List[PolicyID]] = None,
  1366. ) -> None:
  1367. """Updates this worker's and all its policies' global vars.
  1368. Updates are done using the dict's update method.
  1369. Args:
  1370. global_vars: The global_vars dict to update the `self.global_vars` dict
  1371. from.
  1372. policy_ids: Optional list of Policy IDs to update. If None, will update all
  1373. policies on the to-be-updated workers.
  1374. Examples:
  1375. >>> worker = ... # doctest: +SKIP
  1376. >>> global_vars = worker.set_global_vars( # doctest: +SKIP
  1377. ... {"timestep": 4242})
  1378. """
  1379. # Handle per-policy values.
  1380. global_vars_copy = global_vars.copy()
  1381. gradient_updates_per_policy = global_vars_copy.pop(
  1382. "num_grad_updates_per_policy", {}
  1383. )
  1384. self.global_vars["num_grad_updates_per_policy"].update(
  1385. gradient_updates_per_policy
  1386. )
  1387. # Only update explicitly provided policies or those that that are being
  1388. # trained, in order to avoid superfluous access of policies, which might have
  1389. # been offloaded to the object store.
  1390. # Important b/c global vars are constantly being updated.
  1391. for pid in policy_ids if policy_ids is not None else self.policy_map.keys():
  1392. if self.is_policy_to_train is None or self.is_policy_to_train(pid, None):
  1393. self.policy_map[pid].on_global_var_update(
  1394. dict(
  1395. global_vars_copy,
  1396. # If count is None, Policy won't update the counter.
  1397. **{"num_grad_updates": gradient_updates_per_policy.get(pid)},
  1398. )
  1399. )
  1400. # Update all other global vars.
  1401. self.global_vars.update(global_vars_copy)
  1402. @override(EnvRunner)
  1403. def stop(self) -> None:
  1404. """Releases all resources used by this RolloutWorker."""
  1405. # If we have an env -> Release its resources.
  1406. if self.env is not None:
  1407. self.async_env.stop()
  1408. # In case we have-an AsyncSampler, kill its sampling thread.
  1409. if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
  1410. self.sampler.shutdown = True
  1411. # Close all policies' sessions (if tf static graph).
  1412. for policy in self.policy_map.cache.values():
  1413. sess = policy.get_session()
  1414. # Closes the tf session, if any.
  1415. if sess is not None:
  1416. sess.close()
  1417. def lock(self) -> None:
  1418. """Locks this RolloutWorker via its own threading.Lock."""
  1419. self._lock.acquire()
  1420. def unlock(self) -> None:
  1421. """Unlocks this RolloutWorker via its own threading.Lock."""
  1422. self._lock.release()
  1423. def setup_torch_data_parallel(
  1424. self, url: str, world_rank: int, world_size: int, backend: str
  1425. ) -> None:
  1426. """Join a torch process group for distributed SGD."""
  1427. logger.info(
  1428. "Joining process group, url={}, world_rank={}, "
  1429. "world_size={}, backend={}".format(url, world_rank, world_size, backend)
  1430. )
  1431. torch.distributed.init_process_group(
  1432. backend=backend, init_method=url, rank=world_rank, world_size=world_size
  1433. )
  1434. for pid, policy in self.policy_map.items():
  1435. if not isinstance(policy, (TorchPolicy, TorchPolicyV2)):
  1436. raise ValueError(
  1437. "This policy does not support torch distributed", policy
  1438. )
  1439. policy.distributed_world_size = world_size
  1440. def creation_args(self) -> dict:
  1441. """Returns the kwargs dict used to create this worker."""
  1442. return self._original_kwargs
  1443. @DeveloperAPI
  1444. def get_host(self) -> str:
  1445. """Returns the hostname of the process running this evaluator."""
  1446. return platform.node()
  1447. @DeveloperAPI
  1448. def get_node_ip(self) -> str:
  1449. """Returns the IP address of the node that this worker runs on."""
  1450. return ray.util.get_node_ip_address()
  1451. @DeveloperAPI
  1452. def find_free_port(self) -> int:
  1453. """Finds a free port on the node that this worker runs on."""
  1454. from ray.air._internal.util import find_free_port
  1455. return find_free_port()
  1456. def _update_policy_map(
  1457. self,
  1458. *,
  1459. policy_dict: MultiAgentPolicyConfigDict,
  1460. policy: Optional[Policy] = None,
  1461. policy_states: Optional[Dict[PolicyID, PolicyState]] = None,
  1462. single_agent_rl_module_spec: Optional[SingleAgentRLModuleSpec] = None,
  1463. ) -> None:
  1464. """Updates the policy map (and other stuff) on this worker.
  1465. It performs the following:
  1466. 1. It updates the observation preprocessors and updates the policy_specs
  1467. with the postprocessed observation_spaces.
  1468. 2. It updates the policy_specs with the complete algorithm_config (merged
  1469. with the policy_spec's config).
  1470. 3. If needed it will update the self.marl_module_spec on this worker
  1471. 3. It updates the policy map with the new policies
  1472. 4. It updates the filter dict
  1473. 5. It calls the on_create_policy() hook of the callbacks on the newly added
  1474. policies.
  1475. Args:
  1476. policy_dict: The policy dict to update the policy map with.
  1477. policy: The policy to update the policy map with.
  1478. policy_states: The policy states to update the policy map with.
  1479. single_agent_rl_module_spec: The SingleAgentRLModuleSpec to add to the
  1480. MultiAgentRLModuleSpec. If None, the config's
  1481. `get_default_rl_module_spec` method's output will be used to create
  1482. the policy with.
  1483. """
  1484. # Update the input policy dict with the postprocessed observation spaces and
  1485. # merge configs. Also updates the preprocessor dict.
  1486. updated_policy_dict = self._get_complete_policy_specs_dict(policy_dict)
  1487. # Use the updated policy dict to create the marl_module_spec if necessary
  1488. if self.config._enable_rl_module_api:
  1489. spec = self.config.get_marl_module_spec(
  1490. policy_dict=updated_policy_dict,
  1491. single_agent_rl_module_spec=single_agent_rl_module_spec,
  1492. )
  1493. if self.marl_module_spec is None:
  1494. # this is the first time, so we should create the marl_module_spec
  1495. self.marl_module_spec = spec
  1496. else:
  1497. # This is adding a new policy, so we need call add_modules on the
  1498. # module_specs of returned spec.
  1499. self.marl_module_spec.add_modules(spec.module_specs)
  1500. # Add __marl_module_spec key into the config so that the policy can access
  1501. # it.
  1502. updated_policy_dict = self._update_policy_dict_with_marl_module(
  1503. updated_policy_dict
  1504. )
  1505. # Builds the self.policy_map dict
  1506. self._build_policy_map(
  1507. policy_dict=updated_policy_dict,
  1508. policy=policy,
  1509. policy_states=policy_states,
  1510. )
  1511. # Initialize the filter dict
  1512. self._update_filter_dict(updated_policy_dict)
  1513. # Call callback policy init hooks (only if the added policy did not exist
  1514. # before).
  1515. if policy is None:
  1516. self._call_callbacks_on_create_policy()
  1517. if self.worker_index == 0:
  1518. logger.info(f"Built policy map: {self.policy_map}")
  1519. logger.info(f"Built preprocessor map: {self.preprocessors}")
  1520. def _get_complete_policy_specs_dict(
  1521. self, policy_dict: MultiAgentPolicyConfigDict
  1522. ) -> MultiAgentPolicyConfigDict:
  1523. """Processes the policy dict and creates a new copy with the processed attrs.
  1524. This processes the observation_space and prepares them for passing to rl module
  1525. construction. It also merges the policy configs with the algorithm config.
  1526. During this processing, we will also construct the preprocessors dict.
  1527. """
  1528. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  1529. updated_policy_dict = copy.deepcopy(policy_dict)
  1530. # If our preprocessors dict does not exist yet, create it here.
  1531. self.preprocessors = self.preprocessors or {}
  1532. # Loop through given policy-dict and add each entry to our map.
  1533. for name, policy_spec in sorted(updated_policy_dict.items()):
  1534. logger.debug("Creating policy for {}".format(name))
  1535. # Policy brings its own complete AlgorithmConfig -> Use it for this policy.
  1536. if isinstance(policy_spec.config, AlgorithmConfig):
  1537. merged_conf = policy_spec.config
  1538. else:
  1539. # Update the general config with the specific config
  1540. # for this particular policy.
  1541. merged_conf: "AlgorithmConfig" = self.config.copy(copy_frozen=False)
  1542. merged_conf.update_from_dict(policy_spec.config or {})
  1543. # Update num_workers and worker_index.
  1544. merged_conf.worker_index = self.worker_index
  1545. # Preprocessors.
  1546. obs_space = policy_spec.observation_space
  1547. # Initialize preprocessor for this policy to None.
  1548. self.preprocessors[name] = None
  1549. if self.preprocessing_enabled:
  1550. # Policies should deal with preprocessed (automatically flattened)
  1551. # observations if preprocessing is enabled.
  1552. preprocessor = ModelCatalog.get_preprocessor_for_space(
  1553. obs_space,
  1554. merged_conf.model,
  1555. include_multi_binary=self.config.get(
  1556. "_enable_rl_module_api", False
  1557. ),
  1558. )
  1559. # Original observation space should be accessible at
  1560. # obs_space.original_space after this step.
  1561. if preprocessor is not None:
  1562. obs_space = preprocessor.observation_space
  1563. if not merged_conf.enable_connectors:
  1564. # If connectors are not enabled, rollout worker will handle
  1565. # the running of these preprocessors.
  1566. self.preprocessors[name] = preprocessor
  1567. policy_spec.config = merged_conf
  1568. policy_spec.observation_space = obs_space
  1569. return updated_policy_dict
  1570. def _update_policy_dict_with_marl_module(
  1571. self, policy_dict: MultiAgentPolicyConfigDict
  1572. ) -> MultiAgentPolicyConfigDict:
  1573. for name, policy_spec in policy_dict.items():
  1574. policy_spec.config["__marl_module_spec"] = self.marl_module_spec
  1575. return policy_dict
  1576. def _build_policy_map(
  1577. self,
  1578. *,
  1579. policy_dict: MultiAgentPolicyConfigDict,
  1580. policy: Optional[Policy] = None,
  1581. policy_states: Optional[Dict[PolicyID, PolicyState]] = None,
  1582. ) -> None:
  1583. """Adds the given policy_dict to `self.policy_map`.
  1584. Args:
  1585. policy_dict: The MultiAgentPolicyConfigDict to be added to this
  1586. worker's PolicyMap.
  1587. policy: If the policy to add already exists, user can provide it here.
  1588. policy_states: Optional dict from PolicyIDs to PolicyStates to
  1589. restore the states of the policies being built.
  1590. """
  1591. # If our policy_map does not exist yet, create it here.
  1592. self.policy_map = self.policy_map or PolicyMap(
  1593. capacity=self.config.policy_map_capacity,
  1594. policy_states_are_swappable=self.config.policy_states_are_swappable,
  1595. )
  1596. # Loop through given policy-dict and add each entry to our map.
  1597. for name, policy_spec in sorted(policy_dict.items()):
  1598. # Create the actual policy object.
  1599. if policy is None:
  1600. new_policy = create_policy_for_framework(
  1601. policy_id=name,
  1602. policy_class=get_tf_eager_cls_if_necessary(
  1603. policy_spec.policy_class, policy_spec.config
  1604. ),
  1605. merged_config=policy_spec.config,
  1606. observation_space=policy_spec.observation_space,
  1607. action_space=policy_spec.action_space,
  1608. worker_index=self.worker_index,
  1609. seed=self.seed,
  1610. )
  1611. else:
  1612. new_policy = policy
  1613. # Maybe torch compile an RLModule.
  1614. if self.config.get("_enable_rl_module_api", False) and self.config.get(
  1615. "torch_compile_worker"
  1616. ):
  1617. if self.config.framework_str != "torch":
  1618. raise ValueError("Attempting to compile a non-torch RLModule.")
  1619. rl_module = getattr(new_policy, "model", None)
  1620. if rl_module is not None:
  1621. compile_config = self.config.get_torch_compile_worker_config()
  1622. rl_module.compile(compile_config)
  1623. self.policy_map[name] = new_policy
  1624. restore_states = (policy_states or {}).get(name, None)
  1625. # Set the state of the newly created policy before syncing filters, etc.
  1626. if restore_states:
  1627. new_policy.set_state(restore_states)
  1628. def _update_filter_dict(self, policy_dict: MultiAgentPolicyConfigDict) -> None:
  1629. """Updates the filter dict for the given policy_dict."""
  1630. for name, policy_spec in sorted(policy_dict.items()):
  1631. new_policy = self.policy_map[name]
  1632. if policy_spec.config.enable_connectors:
  1633. # Note(jungong) : We should only create new connectors for the
  1634. # policy iff we are creating a new policy from scratch. i.e,
  1635. # we should NOT create new connectors when we already have the
  1636. # policy object created before this function call or have the
  1637. # restoring states from the caller.
  1638. # Also note that we cannot just check the existence of connectors
  1639. # to decide whether we should create connectors because we may be
  1640. # restoring a policy that has 0 connectors configured.
  1641. if (
  1642. new_policy.agent_connectors is None
  1643. or new_policy.action_connectors is None
  1644. ):
  1645. # TODO(jungong) : revisit this. It will be nicer to create
  1646. # connectors as the last step of Policy.__init__().
  1647. create_connectors_for_policy(new_policy, policy_spec.config)
  1648. maybe_get_filters_for_syncing(self, name)
  1649. else:
  1650. filter_shape = tree.map_structure(
  1651. lambda s: (
  1652. None
  1653. if isinstance(s, (Discrete, MultiDiscrete)) # noqa
  1654. else np.array(s.shape)
  1655. ),
  1656. new_policy.observation_space_struct,
  1657. )
  1658. self.filters[name] = get_filter(
  1659. policy_spec.config.observation_filter,
  1660. filter_shape,
  1661. )
  1662. def _call_callbacks_on_create_policy(self):
  1663. """Calls the on_create_policy callback for each policy in the policy map."""
  1664. for name, policy in self.policy_map.items():
  1665. self.callbacks.on_create_policy(policy_id=name, policy=policy)
  1666. def _get_input_creator_from_config(self):
  1667. def valid_module(class_path):
  1668. if (
  1669. isinstance(class_path, str)
  1670. and not os.path.isfile(class_path)
  1671. and "." in class_path
  1672. ):
  1673. module_path, class_name = class_path.rsplit(".", 1)
  1674. try:
  1675. spec = importlib.util.find_spec(module_path)
  1676. if spec is not None:
  1677. return True
  1678. except (ModuleNotFoundError, ValueError):
  1679. print(
  1680. f"module {module_path} not found while trying to get "
  1681. f"input {class_path}"
  1682. )
  1683. return False
  1684. # A callable returning an InputReader object to use.
  1685. if isinstance(self.config.input_, FunctionType):
  1686. return self.config.input_
  1687. # Use RLlib's Sampler classes (SyncSampler or AsynchSampler, depending
  1688. # on `config.sample_async` setting).
  1689. elif self.config.input_ == "sampler":
  1690. return lambda ioctx: ioctx.default_sampler_input()
  1691. # Ray Dataset input -> Use `config.input_config` to construct DatasetReader.
  1692. elif self.config.input_ == "dataset":
  1693. assert self._ds_shards is not None
  1694. # Input dataset shards should have already been prepared.
  1695. # We just need to take the proper shard here.
  1696. return lambda ioctx: DatasetReader(
  1697. self._ds_shards[self.worker_index], ioctx
  1698. )
  1699. # Dict: Mix of different input methods with different ratios.
  1700. elif isinstance(self.config.input_, dict):
  1701. return lambda ioctx: ShuffledInput(
  1702. MixedInput(self.config.input_, ioctx), self.config.shuffle_buffer_size
  1703. )
  1704. # A pre-registered input descriptor (str).
  1705. elif isinstance(self.config.input_, str) and registry_contains_input(
  1706. self.config.input_
  1707. ):
  1708. return registry_get_input(self.config.input_)
  1709. # D4RL input.
  1710. elif "d4rl" in self.config.input_:
  1711. env_name = self.config.input_.split(".")[-1]
  1712. return lambda ioctx: D4RLReader(env_name, ioctx)
  1713. # Valid python module (class path) -> Create using `from_config`.
  1714. elif valid_module(self.config.input_):
  1715. return lambda ioctx: ShuffledInput(
  1716. from_config(self.config.input_, ioctx=ioctx)
  1717. )
  1718. # JSON file or list of JSON files -> Use JsonReader (shuffled).
  1719. else:
  1720. return lambda ioctx: ShuffledInput(
  1721. JsonReader(self.config.input_, ioctx), self.config.shuffle_buffer_size
  1722. )
  1723. def _get_output_creator_from_config(self):
  1724. if isinstance(self.config.output, FunctionType):
  1725. return self.config.output
  1726. elif self.config.output is None:
  1727. return lambda ioctx: NoopOutput()
  1728. elif self.config.output == "dataset":
  1729. return lambda ioctx: DatasetWriter(
  1730. ioctx, compress_columns=self.config.output_compress_columns
  1731. )
  1732. elif self.config.output == "logdir":
  1733. return lambda ioctx: JsonWriter(
  1734. ioctx.log_dir,
  1735. ioctx,
  1736. max_file_size=self.config.output_max_file_size,
  1737. compress_columns=self.config.output_compress_columns,
  1738. )
  1739. else:
  1740. return lambda ioctx: JsonWriter(
  1741. self.config.output,
  1742. ioctx,
  1743. max_file_size=self.config.output_max_file_size,
  1744. compress_columns=self.config.output_compress_columns,
  1745. )
  1746. def _get_make_sub_env_fn(
  1747. self, env_creator, env_context, validate_env, env_wrapper, seed
  1748. ):
  1749. config = self.config
  1750. def _make_sub_env_local(vector_index):
  1751. # Used to created additional environments during environment
  1752. # vectorization.
  1753. # Create the env context (config dict + meta-data) for
  1754. # this particular sub-env within the vectorized one.
  1755. env_ctx = env_context.copy_with_overrides(vector_index=vector_index)
  1756. # Create the sub-env.
  1757. env = env_creator(env_ctx)
  1758. # Validate first.
  1759. if not config.disable_env_checking:
  1760. try:
  1761. check_env(env, config)
  1762. except Exception as e:
  1763. logger.warning(
  1764. "We've added a module for checking environments that "
  1765. "are used in experiments. Your env may not be set up"
  1766. "correctly. You can disable env checking for now by setting "
  1767. "`disable_env_checking` to True in your experiment config "
  1768. "dictionary. You can run the environment checking module "
  1769. "standalone by calling ray.rllib.utils.check_env(env)."
  1770. )
  1771. raise e
  1772. # Custom validation function given by user.
  1773. if validate_env is not None:
  1774. validate_env(env, env_ctx)
  1775. # Use our wrapper, defined above.
  1776. env = env_wrapper(env)
  1777. # Make sure a deterministic random seed is set on
  1778. # all the sub-environments if specified.
  1779. _update_env_seed_if_necessary(
  1780. env, seed, env_context.worker_index, vector_index
  1781. )
  1782. return env
  1783. if not env_context.remote:
  1784. def _make_sub_env_remote(vector_index):
  1785. sub_env = _make_sub_env_local(vector_index)
  1786. self.callbacks.on_sub_environment_created(
  1787. worker=self,
  1788. sub_environment=sub_env,
  1789. env_context=env_context.copy_with_overrides(
  1790. worker_index=env_context.worker_index,
  1791. vector_index=vector_index,
  1792. remote=False,
  1793. ),
  1794. )
  1795. return sub_env
  1796. return _make_sub_env_remote
  1797. else:
  1798. return _make_sub_env_local