rollout_worker.py 72 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729
  1. import copy
  2. import gym
  3. from gym.spaces import Box, Discrete, MultiDiscrete, Space
  4. import logging
  5. import numpy as np
  6. import platform
  7. import os
  8. import tree # pip install dm_tree
  9. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \
  10. TYPE_CHECKING, Union
  11. import ray
  12. from ray import ObjectRef
  13. from ray import cloudpickle as pickle
  14. from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
  15. from ray.rllib.env.env_context import EnvContext
  16. from ray.rllib.env.external_env import ExternalEnv
  17. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  18. from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
  19. from ray.rllib.env.utils import record_env_wrapper
  20. from ray.rllib.env.vector_env import VectorEnv
  21. from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
  22. from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
  23. from ray.rllib.evaluation.metrics import RolloutMetrics
  24. from ray.rllib.models import ModelCatalog
  25. from ray.rllib.models.preprocessors import Preprocessor
  26. from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
  27. from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
  28. OffPolicyEstimate
  29. from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
  30. from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
  31. from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
  32. from ray.rllib.policy.policy import Policy, PolicySpec
  33. from ray.rllib.policy.policy_map import PolicyMap
  34. from ray.rllib.policy.torch_policy import TorchPolicy
  35. from ray.rllib.utils import force_list, merge_dicts
  36. from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
  37. from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
  38. from ray.rllib.utils.deprecation import deprecation_warning
  39. from ray.rllib.utils.error import EnvError, ERR_MSG_NO_GPUS, \
  40. HOWTO_CHANGE_CONFIG
  41. from ray.rllib.utils.filter import get_filter, Filter
  42. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  43. from ray.rllib.utils.sgd import do_minibatch_sgd
  44. from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
  45. from ray.rllib.utils.tf_run_builder import TFRunBuilder
  46. from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvCreator, \
  47. EnvType, ModelConfigDict, ModelGradients, ModelWeights, \
  48. MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \
  49. PolicyState, SampleBatchType, T
  50. from ray.util.debug import log_once, disable_log_once_globally, \
  51. enable_periodic_logging
  52. from ray.util.iter import ParallelIteratorWorker
  53. if TYPE_CHECKING:
  54. from ray.rllib.evaluation.episode import Episode
  55. from ray.rllib.evaluation.observation_function import ObservationFunction
  56. from ray.rllib.agents.callbacks import DefaultCallbacks # noqa
  57. tf1, tf, tfv = try_import_tf()
  58. torch, _ = try_import_torch()
  59. logger = logging.getLogger(__name__)
  60. # Handle to the current rollout worker, which will be set to the most recently
  61. # created RolloutWorker in this process. This can be helpful to access in
  62. # custom env or policy classes for debugging or advanced use cases.
  63. _global_worker: "RolloutWorker" = None
  64. @DeveloperAPI
  65. def get_global_worker() -> "RolloutWorker":
  66. """Returns a handle to the active rollout worker in this process."""
  67. global _global_worker
  68. return _global_worker
  69. def _update_env_seed_if_necessary(env: EnvType, seed: int, worker_idx: int,
  70. vector_idx: int):
  71. """Set a deterministic random seed on environment.
  72. NOTE: this may not work with remote environments (issue #18154).
  73. """
  74. if not seed:
  75. return
  76. # A single RL job is unlikely to have more than 10K
  77. # rollout workers.
  78. max_num_envs_per_workers: int = 1000
  79. assert worker_idx < max_num_envs_per_workers, \
  80. "Too many envs per worker. Random seeds may collide."
  81. computed_seed: int = (
  82. worker_idx * max_num_envs_per_workers + vector_idx + seed)
  83. # Gym.env.
  84. # This will silently fail for most OpenAI gyms
  85. # (they do nothing and return None per default)
  86. if not hasattr(env, "seed"):
  87. logger.info("Env doesn't support env.seed(): {}".format(env))
  88. else:
  89. env.seed(computed_seed)
  90. @DeveloperAPI
  91. class RolloutWorker(ParallelIteratorWorker):
  92. """Common experience collection class.
  93. This class wraps a policy instance and an environment class to
  94. collect experiences from the environment. You can create many replicas of
  95. this class as Ray actors to scale RL training.
  96. This class supports vectorized and multi-agent policy evaluation (e.g.,
  97. VectorEnv, MultiAgentEnv, etc.)
  98. Examples:
  99. >>> # Create a rollout worker and using it to collect experiences.
  100. >>> worker = RolloutWorker(
  101. ... env_creator=lambda _: gym.make("CartPole-v0"),
  102. ... policy_spec=PGTFPolicy)
  103. >>> print(worker.sample())
  104. SampleBatch({
  105. "obs": [[...]], "actions": [[...]], "rewards": [[...]],
  106. "dones": [[...]], "new_obs": [[...]]})
  107. >>> # Creating a multi-agent rollout worker
  108. >>> worker = RolloutWorker(
  109. ... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
  110. ... policy_spec={
  111. ... # Use an ensemble of two policies for car agents
  112. ... "car_policy1":
  113. ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
  114. ... "car_policy2":
  115. ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
  116. ... # Use a single shared policy for all traffic lights
  117. ... "traffic_light_policy":
  118. ... (PGTFPolicy, Box(...), Discrete(...), {}),
  119. ... },
  120. ... policy_mapping_fn=lambda agent_id, episode, **kwargs:
  121. ... random.choice(["car_policy1", "car_policy2"])
  122. ... if agent_id.startswith("car_") else "traffic_light_policy")
  123. >>> print(worker.sample())
  124. MultiAgentBatch({
  125. "car_policy1": SampleBatch(...),
  126. "car_policy2": SampleBatch(...),
  127. "traffic_light_policy": SampleBatch(...)})
  128. """
  129. @DeveloperAPI
  130. @classmethod
  131. def as_remote(cls,
  132. num_cpus: Optional[int] = None,
  133. num_gpus: Optional[Union[int, float]] = None,
  134. memory: Optional[int] = None,
  135. object_store_memory: Optional[int] = None,
  136. resources: Optional[dict] = None) -> type:
  137. """Returns RolloutWorker class as a `@ray.remote using given options`.
  138. The returned class can then be used to instantiate ray actors.
  139. Args:
  140. num_cpus: The number of CPUs to allocate for the remote actor.
  141. num_gpus: The number of GPUs to allocate for the remote actor.
  142. This could be a fraction as well.
  143. memory: The heap memory request for the remote actor.
  144. object_store_memory: The object store memory for the remote actor.
  145. resources: The default custom resources to allocate for the remote
  146. actor.
  147. Returns:
  148. The `@ray.remote` decorated RolloutWorker class.
  149. """
  150. return ray.remote(
  151. num_cpus=num_cpus,
  152. num_gpus=num_gpus,
  153. memory=memory,
  154. object_store_memory=object_store_memory,
  155. resources=resources)(cls)
  156. @DeveloperAPI
  157. def __init__(
  158. self,
  159. *,
  160. env_creator: EnvCreator,
  161. validate_env: Optional[Callable[[EnvType, EnvContext],
  162. None]] = None,
  163. policy_spec: Optional[Union[type, Dict[PolicyID,
  164. PolicySpec]]] = None,
  165. policy_mapping_fn: Optional[Callable[[AgentID, "Episode"],
  166. PolicyID]] = None,
  167. policies_to_train: Optional[List[PolicyID]] = None,
  168. tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
  169. rollout_fragment_length: int = 100,
  170. count_steps_by: str = "env_steps",
  171. batch_mode: str = "truncate_episodes",
  172. episode_horizon: Optional[int] = None,
  173. preprocessor_pref: str = "deepmind",
  174. sample_async: bool = False,
  175. compress_observations: bool = False,
  176. num_envs: int = 1,
  177. observation_fn: Optional["ObservationFunction"] = None,
  178. observation_filter: str = "NoFilter",
  179. clip_rewards: Optional[Union[bool, float]] = None,
  180. normalize_actions: bool = True,
  181. clip_actions: bool = False,
  182. env_config: Optional[EnvConfigDict] = None,
  183. model_config: Optional[ModelConfigDict] = None,
  184. policy_config: Optional[PartialTrainerConfigDict] = None,
  185. worker_index: int = 0,
  186. num_workers: int = 0,
  187. record_env: Union[bool, str] = False,
  188. log_dir: Optional[str] = None,
  189. log_level: Optional[str] = None,
  190. callbacks: Type["DefaultCallbacks"] = None,
  191. input_creator: Callable[[
  192. IOContext
  193. ], InputReader] = lambda ioctx: ioctx.default_sampler_input(),
  194. input_evaluation: List[str] = frozenset([]),
  195. output_creator: Callable[
  196. [IOContext], OutputWriter] = lambda ioctx: NoopOutput(),
  197. remote_worker_envs: bool = False,
  198. remote_env_batch_wait_ms: int = 0,
  199. soft_horizon: bool = False,
  200. no_done_at_end: bool = False,
  201. seed: int = None,
  202. extra_python_environs: Optional[dict] = None,
  203. fake_sampler: bool = False,
  204. spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
  205. policy=None,
  206. monitor_path=None,
  207. ):
  208. """Initializes a RolloutWorker instance.
  209. Args:
  210. env_creator: Function that returns a gym.Env given an EnvContext
  211. wrapped configuration.
  212. validate_env: Optional callable to validate the generated
  213. environment (only on worker=0).
  214. policy_spec: The MultiAgentPolicyConfigDict mapping policy IDs
  215. (str) to PolicySpec's or a single policy class to use.
  216. If a dict is specified, then we are in multi-agent mode and a
  217. policy_mapping_fn can also be set (if not, will map all agents
  218. to DEFAULT_POLICY_ID).
  219. policy_mapping_fn: A callable that maps agent ids to policy ids in
  220. multi-agent mode. This function will be called each time a new
  221. agent appears in an episode, to bind that agent to a policy
  222. for the duration of the episode. If not provided, will map all
  223. agents to DEFAULT_POLICY_ID.
  224. policies_to_train: Optional list of policies to train, or None
  225. for all policies.
  226. tf_session_creator: A function that returns a TF session.
  227. This is optional and only useful with TFPolicy.
  228. rollout_fragment_length: The target number of steps
  229. (maesured in `count_steps_by`) to include in each sample
  230. batch returned from this worker.
  231. count_steps_by: The unit in which to count fragment
  232. lengths. One of env_steps or agent_steps.
  233. batch_mode: One of the following batch modes:
  234. - "truncate_episodes": Each call to sample() will return a
  235. batch of at most `rollout_fragment_length * num_envs` in size.
  236. The batch will be exactly `rollout_fragment_length * num_envs`
  237. in size if postprocessing does not change batch sizes. Episodes
  238. may be truncated in order to meet this size requirement.
  239. - "complete_episodes": Each call to sample() will return a
  240. batch of at least `rollout_fragment_length * num_envs` in
  241. size. Episodes will not be truncated, but multiple episodes
  242. may be packed within one batch to meet the batch size. Note
  243. that when `num_envs > 1`, episode steps will be buffered
  244. until the episode completes, and hence batches may contain
  245. significant amounts of off-policy data.
  246. episode_horizon: Horizon at which to stop episodes (even if the
  247. environment itself has not retured a "done" signal).
  248. preprocessor_pref: Whether to use RLlib preprocessors
  249. ("rllib") or deepmind ("deepmind"), when applicable.
  250. sample_async: Whether to compute samples asynchronously in
  251. the background, which improves throughput but can cause samples
  252. to be slightly off-policy.
  253. compress_observations: If true, compress the observations.
  254. They can be decompressed with rllib/utils/compression.
  255. num_envs: If more than one, will create multiple envs
  256. and vectorize the computation of actions. This has no effect if
  257. if the env already implements VectorEnv.
  258. observation_fn: Optional multi-agent observation function.
  259. observation_filter: Name of observation filter to use.
  260. clip_rewards: True for clipping rewards to [-1.0, 1.0] prior
  261. to experience postprocessing. None: Clip for Atari only.
  262. float: Clip to [-clip_rewards; +clip_rewards].
  263. normalize_actions: Whether to normalize actions to the
  264. action space's bounds.
  265. clip_actions: Whether to clip action values to the range
  266. specified by the policy action space.
  267. env_config: Config to pass to the env creator.
  268. model_config: Config to use when creating the policy model.
  269. policy_config: Config to pass to the
  270. policy. In the multi-agent case, this config will be merged
  271. with the per-policy configs specified by `policy_spec`.
  272. worker_index: For remote workers, this should be set to a
  273. non-zero and unique value. This index is passed to created envs
  274. through EnvContext so that envs can be configured per worker.
  275. num_workers: For remote workers, how many workers altogether
  276. have been created?
  277. record_env: Write out episode stats and videos
  278. using gym.wrappers.Monitor to this directory if specified. If
  279. True, use the default output dir in ~/ray_results/.... If
  280. False, do not record anything.
  281. log_dir: Directory where logs can be placed.
  282. log_level: Set the root log level on creation.
  283. callbacks: Custom sub-class of
  284. DefaultCallbacks for training/policy/rollout-worker callbacks.
  285. input_creator: Function that returns an InputReader object for
  286. loading previous generated experiences.
  287. input_evaluation: How to evaluate the policy
  288. performance. This only makes sense to set when the input is
  289. reading offline data. The possible values include:
  290. - "is": the step-wise importance sampling estimator.
  291. - "wis": the weighted step-wise is estimator.
  292. - "simulation": run the environment in the background, but
  293. use this data for evaluation only and never for learning.
  294. output_creator: Function that returns an OutputWriter object for
  295. saving generated experiences.
  296. remote_worker_envs: If using num_envs_per_worker > 1,
  297. whether to create those new envs in remote processes instead of
  298. in the current process. This adds overheads, but can make sense
  299. if your envs are expensive to step/reset (e.g., for StarCraft).
  300. Use this cautiously, overheads are significant!
  301. remote_env_batch_wait_ms: Timeout that remote workers
  302. are waiting when polling environments. 0 (continue when at
  303. least one env is ready) is a reasonable default, but optimal
  304. value could be obtained by measuring your environment
  305. step / reset and model inference perf.
  306. soft_horizon: Calculate rewards but don't reset the
  307. environment when the horizon is hit.
  308. no_done_at_end: Ignore the done=True at the end of the
  309. episode and instead record done=False.
  310. seed: Set the seed of both np and tf to this value to
  311. to ensure each remote worker has unique exploration behavior.
  312. extra_python_environs: Extra python environments need to be set.
  313. fake_sampler: Use a fake (inf speed) sampler for testing.
  314. spaces: An optional space dict mapping policy IDs
  315. to (obs_space, action_space)-tuples. This is used in case no
  316. Env is created on this RolloutWorker.
  317. policy: Obsoleted arg. Use `policy_spec` instead.
  318. monitor_path: Obsoleted arg. Use `record_env` instead.
  319. """
  320. # Deprecated args.
  321. if policy is not None:
  322. deprecation_warning("policy", "policy_spec", error=False)
  323. policy_spec = policy
  324. assert policy_spec is not None, \
  325. "Must provide `policy_spec` when creating RolloutWorker!"
  326. # Do quick translation into MultiAgentPolicyConfigDict.
  327. if not isinstance(policy_spec, dict):
  328. policy_spec = {
  329. DEFAULT_POLICY_ID: PolicySpec(policy_class=policy_spec)
  330. }
  331. policy_spec = {
  332. pid: spec if isinstance(spec, PolicySpec) else PolicySpec(*spec)
  333. for pid, spec in policy_spec.copy().items()
  334. }
  335. if monitor_path is not None:
  336. deprecation_warning("monitor_path", "record_env", error=False)
  337. record_env = monitor_path
  338. self._original_kwargs: dict = locals().copy()
  339. del self._original_kwargs["self"]
  340. global _global_worker
  341. _global_worker = self
  342. # set extra environs first
  343. if extra_python_environs:
  344. for key, value in extra_python_environs.items():
  345. os.environ[key] = str(value)
  346. def gen_rollouts():
  347. while True:
  348. yield self.sample()
  349. ParallelIteratorWorker.__init__(self, gen_rollouts, False)
  350. policy_config = policy_config or {}
  351. if (tf1 and policy_config.get("framework") in ["tf2", "tfe"]
  352. # This eager check is necessary for certain all-framework tests
  353. # that use tf's eager_mode() context generator.
  354. and not tf1.executing_eagerly()):
  355. tf1.enable_eager_execution()
  356. if log_level:
  357. logging.getLogger("ray.rllib").setLevel(log_level)
  358. if worker_index > 1:
  359. disable_log_once_globally() # only need 1 worker to log
  360. elif log_level == "DEBUG":
  361. enable_periodic_logging()
  362. env_context = EnvContext(
  363. env_config or {},
  364. worker_index=worker_index,
  365. vector_index=0,
  366. num_workers=num_workers,
  367. )
  368. self.env_context = env_context
  369. self.policy_config: PartialTrainerConfigDict = policy_config
  370. if callbacks:
  371. self.callbacks: "DefaultCallbacks" = callbacks()
  372. else:
  373. from ray.rllib.agents.callbacks import DefaultCallbacks # noqa
  374. self.callbacks: DefaultCallbacks = DefaultCallbacks()
  375. self.worker_index: int = worker_index
  376. self.num_workers: int = num_workers
  377. model_config: ModelConfigDict = \
  378. model_config or self.policy_config.get("model") or {}
  379. # Default policy mapping fn is to always return DEFAULT_POLICY_ID,
  380. # independent on the agent ID and the episode passed in.
  381. self.policy_mapping_fn = \
  382. lambda agent_id, episode, worker, **kwargs: DEFAULT_POLICY_ID
  383. # If provided, set it here.
  384. self.set_policy_mapping_fn(policy_mapping_fn)
  385. self.env_creator: EnvCreator = env_creator
  386. self.rollout_fragment_length: int = rollout_fragment_length * num_envs
  387. self.count_steps_by: str = count_steps_by
  388. self.batch_mode: str = batch_mode
  389. self.compress_observations: bool = compress_observations
  390. self.preprocessing_enabled: bool = False \
  391. if policy_config.get("_disable_preprocessor_api") else True
  392. self.observation_filter = observation_filter
  393. self.last_batch: Optional[SampleBatchType] = None
  394. self.global_vars: Optional[dict] = None
  395. self.fake_sampler: bool = fake_sampler
  396. # Update the global seed for numpy/random/tf-eager/torch if we are not
  397. # the local worker, otherwise, this was already done in the Trainer
  398. # object itself.
  399. if self.worker_index > 0:
  400. update_global_seed_if_necessary(
  401. policy_config.get("framework"), seed)
  402. # A single environment provided by the user (via config.env). This may
  403. # also remain None.
  404. # 1) Create the env using the user provided env_creator. This may
  405. # return a gym.Env (incl. MultiAgentEnv), an already vectorized
  406. # VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
  407. # 2) Wrap - if applicable - with Atari/recording/rendering wrappers.
  408. # 3) Seed the env, if necessary.
  409. # 4) Vectorize the existing single env by creating more clones of
  410. # this env and wrapping it with the RLlib BaseEnv class.
  411. self.env = None
  412. # Create a (single) env for this worker.
  413. if not (worker_index == 0 and num_workers > 0
  414. and not policy_config.get("create_env_on_driver")):
  415. # Run the `env_creator` function passing the EnvContext.
  416. self.env = env_creator(copy.deepcopy(self.env_context))
  417. if self.env is not None:
  418. # Validate environment (general validation function).
  419. _validate_env(self.env, env_context=self.env_context)
  420. # Custom validation function given.
  421. if validate_env is not None:
  422. validate_env(self.env, self.env_context)
  423. # We can't auto-wrap a BaseEnv.
  424. if isinstance(self.env, (BaseEnv, ray.actor.ActorHandle)):
  425. def wrap(env):
  426. return env
  427. # Atari type env and "deepmind" preprocessor pref.
  428. elif is_atari(self.env) and \
  429. not model_config.get("custom_preprocessor") and \
  430. preprocessor_pref == "deepmind":
  431. # Deepmind wrappers already handle all preprocessing.
  432. self.preprocessing_enabled = False
  433. # If clip_rewards not explicitly set to False, switch it
  434. # on here (clip between -1.0 and 1.0).
  435. if clip_rewards is None:
  436. clip_rewards = True
  437. # Framestacking is used.
  438. use_framestack = model_config.get("framestack") is True
  439. def wrap(env):
  440. env = wrap_deepmind(
  441. env,
  442. dim=model_config.get("dim"),
  443. framestack=use_framestack)
  444. env = record_env_wrapper(env, record_env, log_dir,
  445. policy_config)
  446. return env
  447. # gym.Env -> Wrap with gym Monitor.
  448. else:
  449. def wrap(env):
  450. return record_env_wrapper(env, record_env, log_dir,
  451. policy_config)
  452. # Wrap env through the correct wrapper.
  453. self.env: EnvType = wrap(self.env)
  454. # Ideally, we would use the same make_sub_env() function below
  455. # to create self.env, but wrap(env) and self.env has a cyclic
  456. # dependency on each other right now, so we would settle on
  457. # duplicating the random seed setting logic for now.
  458. _update_env_seed_if_necessary(self.env, seed, worker_index, 0)
  459. def make_sub_env(vector_index):
  460. # Used to created additional environments during environment
  461. # vectorization.
  462. # Create the env context (config dict + meta-data) for
  463. # this particular sub-env within the vectorized one.
  464. env_ctx = env_context.copy_with_overrides(
  465. worker_index=worker_index,
  466. vector_index=vector_index,
  467. remote=remote_worker_envs)
  468. # Create the sub-env.
  469. env = env_creator(env_ctx)
  470. # Validate first.
  471. _validate_env(env, env_context=env_ctx)
  472. # Custom validation function given by user.
  473. if validate_env is not None:
  474. validate_env(env, env_ctx)
  475. # Use our wrapper, defined above.
  476. env = wrap(env)
  477. # Make sure a deterministic random seed is set on
  478. # all the sub-environments if specified.
  479. _update_env_seed_if_necessary(env, seed, worker_index,
  480. vector_index)
  481. return env
  482. self.make_sub_env_fn = make_sub_env
  483. self.spaces = spaces
  484. self.policy_dict = _determine_spaces_for_multi_agent_dict(
  485. policy_spec,
  486. self.env,
  487. spaces=self.spaces,
  488. policy_config=policy_config)
  489. # List of IDs of those policies, which should be trained.
  490. # By default, these are all policies found in `self.policy_dict`.
  491. self.policies_to_train: List[PolicyID] = policies_to_train or list(
  492. self.policy_dict.keys())
  493. self.set_policies_to_train(self.policies_to_train)
  494. self.policy_map: PolicyMap = None
  495. self.preprocessors: Dict[PolicyID, Preprocessor] = None
  496. # Check available number of GPUs.
  497. num_gpus = policy_config.get("num_gpus", 0) if \
  498. self.worker_index == 0 else \
  499. policy_config.get("num_gpus_per_worker", 0)
  500. # Error if we don't find enough GPUs.
  501. if ray.is_initialized() and \
  502. ray.worker._mode() != ray.worker.LOCAL_MODE and \
  503. not policy_config.get("_fake_gpus"):
  504. devices = []
  505. if policy_config.get("framework") in ["tf2", "tf", "tfe"]:
  506. devices = get_tf_gpu_devices()
  507. elif policy_config.get("framework") == "torch":
  508. devices = list(range(torch.cuda.device_count()))
  509. if len(devices) < num_gpus:
  510. raise RuntimeError(
  511. ERR_MSG_NO_GPUS.format(len(devices), devices) +
  512. HOWTO_CHANGE_CONFIG)
  513. # Warn, if running in local-mode and actual GPUs (not faked) are
  514. # requested.
  515. elif ray.is_initialized() and \
  516. ray.worker._mode() == ray.worker.LOCAL_MODE and \
  517. num_gpus > 0 and not policy_config.get("_fake_gpus"):
  518. logger.warning(
  519. "You are running ray with `local_mode=True`, but have "
  520. f"configured {num_gpus} GPUs to be used! In local mode, "
  521. f"Policies are placed on the CPU and the `num_gpus` setting "
  522. f"is ignored.")
  523. self._build_policy_map(
  524. self.policy_dict,
  525. policy_config,
  526. session_creator=tf_session_creator,
  527. seed=seed)
  528. # Update Policy's view requirements from Model, only if Policy directly
  529. # inherited from base `Policy` class. At this point here, the Policy
  530. # must have it's Model (if any) defined and ready to output an initial
  531. # state.
  532. for pol in self.policy_map.values():
  533. if not pol._model_init_state_automatically_added:
  534. pol._update_model_view_requirements_from_init_state()
  535. self.multiagent: bool = set(
  536. self.policy_map.keys()) != {DEFAULT_POLICY_ID}
  537. if self.multiagent and self.env is not None:
  538. if not isinstance(self.env,
  539. (BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv,
  540. ray.actor.ActorHandle)):
  541. raise ValueError(
  542. f"Have multiple policies {self.policy_map}, but the "
  543. f"env {self.env} is not a subclass of BaseEnv, "
  544. f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!")
  545. self.filters: Dict[PolicyID, Filter] = {}
  546. for (policy_id, policy) in self.policy_map.items():
  547. filter_shape = tree.map_structure(
  548. lambda s: (None if isinstance( # noqa
  549. s, (Discrete, MultiDiscrete)) else np.array(s.shape)),
  550. policy.observation_space_struct)
  551. self.filters[policy_id] = get_filter(self.observation_filter,
  552. filter_shape)
  553. if self.worker_index == 0:
  554. logger.info("Built filter map: {}".format(self.filters))
  555. # Vectorize environment, if any.
  556. self.num_envs: int = num_envs
  557. # This RolloutWorker has no env.
  558. if self.env is None:
  559. self.async_env = None
  560. # Use a custom env-vectorizer and call it providing self.env.
  561. elif "custom_vector_env" in policy_config:
  562. self.async_env = policy_config["custom_vector_env"](self.env)
  563. # Default: Vectorize self.env via the make_sub_env function. This adds
  564. # further clones of self.env and creates a RLlib BaseEnv (which is
  565. # vectorized under the hood).
  566. else:
  567. # Always use vector env for consistency even if num_envs = 1.
  568. self.async_env: BaseEnv = convert_to_base_env(
  569. self.env,
  570. make_env=self.make_sub_env_fn,
  571. num_envs=num_envs,
  572. remote_envs=remote_worker_envs,
  573. remote_env_batch_wait_ms=remote_env_batch_wait_ms,
  574. )
  575. # `truncate_episodes`: Allow a batch to contain more than one episode
  576. # (fragments) and always make the batch `rollout_fragment_length`
  577. # long.
  578. if self.batch_mode == "truncate_episodes":
  579. pack = True
  580. # `complete_episodes`: Never cut episodes and sampler will return
  581. # exactly one (complete) episode per poll.
  582. elif self.batch_mode == "complete_episodes":
  583. rollout_fragment_length = float("inf")
  584. pack = False
  585. else:
  586. raise ValueError("Unsupported batch mode: {}".format(
  587. self.batch_mode))
  588. # Create the IOContext for this worker.
  589. self.io_context: IOContext = IOContext(log_dir, policy_config,
  590. worker_index, self)
  591. self.reward_estimators: List[OffPolicyEstimator] = []
  592. for method in input_evaluation:
  593. if method == "simulation":
  594. logger.warning(
  595. "Requested 'simulation' input evaluation method: "
  596. "will discard all sampler outputs and keep only metrics.")
  597. sample_async = True
  598. elif method == "is":
  599. ise = ImportanceSamplingEstimator.\
  600. create_from_io_context(self.io_context)
  601. self.reward_estimators.append(ise)
  602. elif method == "wis":
  603. wise = WeightedImportanceSamplingEstimator.\
  604. create_from_io_context(self.io_context)
  605. self.reward_estimators.append(wise)
  606. else:
  607. raise ValueError(
  608. "Unknown evaluation method: {}".format(method))
  609. render = False
  610. if policy_config.get("render_env") is True and \
  611. (num_workers == 0 or worker_index == 1):
  612. render = True
  613. if self.env is None:
  614. self.sampler = None
  615. elif sample_async:
  616. self.sampler = AsyncSampler(
  617. worker=self,
  618. env=self.async_env,
  619. clip_rewards=clip_rewards,
  620. rollout_fragment_length=rollout_fragment_length,
  621. count_steps_by=count_steps_by,
  622. callbacks=self.callbacks,
  623. horizon=episode_horizon,
  624. multiple_episodes_in_batch=pack,
  625. normalize_actions=normalize_actions,
  626. clip_actions=clip_actions,
  627. blackhole_outputs="simulation" in input_evaluation,
  628. soft_horizon=soft_horizon,
  629. no_done_at_end=no_done_at_end,
  630. observation_fn=observation_fn,
  631. sample_collector_class=policy_config.get("sample_collector"),
  632. render=render,
  633. )
  634. # Start the Sampler thread.
  635. self.sampler.start()
  636. else:
  637. self.sampler = SyncSampler(
  638. worker=self,
  639. env=self.async_env,
  640. clip_rewards=clip_rewards,
  641. rollout_fragment_length=rollout_fragment_length,
  642. count_steps_by=count_steps_by,
  643. callbacks=self.callbacks,
  644. horizon=episode_horizon,
  645. multiple_episodes_in_batch=pack,
  646. normalize_actions=normalize_actions,
  647. clip_actions=clip_actions,
  648. soft_horizon=soft_horizon,
  649. no_done_at_end=no_done_at_end,
  650. observation_fn=observation_fn,
  651. sample_collector_class=policy_config.get("sample_collector"),
  652. render=render,
  653. )
  654. self.input_reader: InputReader = input_creator(self.io_context)
  655. self.output_writer: OutputWriter = output_creator(self.io_context)
  656. logger.debug(
  657. "Created rollout worker with env {} ({}), policies {}".format(
  658. self.async_env, self.env, self.policy_map))
  659. @DeveloperAPI
  660. def sample(self) -> SampleBatchType:
  661. """Returns a batch of experience sampled from this worker.
  662. This method must be implemented by subclasses.
  663. Returns:
  664. A columnar batch of experiences (e.g., tensors).
  665. Examples:
  666. >>> print(worker.sample())
  667. SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
  668. """
  669. if self.fake_sampler and self.last_batch is not None:
  670. return self.last_batch
  671. elif self.input_reader is None:
  672. raise ValueError("RolloutWorker has no `input_reader` object! "
  673. "Cannot call `sample()`. You can try setting "
  674. "`create_env_on_driver` to True.")
  675. if log_once("sample_start"):
  676. logger.info("Generating sample batch of size {}".format(
  677. self.rollout_fragment_length))
  678. batches = [self.input_reader.next()]
  679. steps_so_far = batches[0].count if \
  680. self.count_steps_by == "env_steps" else \
  681. batches[0].agent_steps()
  682. # In truncate_episodes mode, never pull more than 1 batch per env.
  683. # This avoids over-running the target batch size.
  684. if self.batch_mode == "truncate_episodes":
  685. max_batches = self.num_envs
  686. else:
  687. max_batches = float("inf")
  688. while (steps_so_far < self.rollout_fragment_length
  689. and len(batches) < max_batches):
  690. batch = self.input_reader.next()
  691. steps_so_far += batch.count if \
  692. self.count_steps_by == "env_steps" else \
  693. batch.agent_steps()
  694. batches.append(batch)
  695. batch = batches[0].concat_samples(batches) if len(batches) > 1 else \
  696. batches[0]
  697. self.callbacks.on_sample_end(worker=self, samples=batch)
  698. # Always do writes prior to compression for consistency and to allow
  699. # for better compression inside the writer.
  700. self.output_writer.write(batch)
  701. # Do off-policy estimation, if needed.
  702. if self.reward_estimators:
  703. for sub_batch in batch.split_by_episode():
  704. for estimator in self.reward_estimators:
  705. estimator.process(sub_batch)
  706. if log_once("sample_end"):
  707. logger.info("Completed sample batch:\n\n{}\n".format(
  708. summarize(batch)))
  709. if self.compress_observations:
  710. batch.compress(bulk=self.compress_observations == "bulk")
  711. if self.fake_sampler:
  712. self.last_batch = batch
  713. return batch
  714. @DeveloperAPI
  715. @ray.method(num_returns=2)
  716. def sample_with_count(self) -> Tuple[SampleBatchType, int]:
  717. """Same as sample() but returns the count as a separate value.
  718. Returns:
  719. A columnar batch of experiences (e.g., tensors) and the
  720. size of the collected batch.
  721. Examples:
  722. >>> print(worker.sample_with_count())
  723. (SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...}), 3)
  724. """
  725. batch = self.sample()
  726. return batch, batch.count
  727. @DeveloperAPI
  728. def learn_on_batch(self, samples: SampleBatchType) -> Dict:
  729. """Update policies based on the given batch.
  730. This is the equivalent to apply_gradients(compute_gradients(samples)),
  731. but can be optimized to avoid pulling gradients into CPU memory.
  732. Args:
  733. samples: The SampleBatch or MultiAgentBatch to learn on.
  734. Returns:
  735. Dictionary of extra metadata from compute_gradients().
  736. Examples:
  737. >>> batch = worker.sample()
  738. >>> info = worker.learn_on_batch(samples)
  739. """
  740. if log_once("learn_on_batch"):
  741. logger.info(
  742. "Training on concatenated sample batches:\n\n{}\n".format(
  743. summarize(samples)))
  744. if isinstance(samples, MultiAgentBatch):
  745. info_out = {}
  746. builders = {}
  747. to_fetch = {}
  748. for pid, batch in samples.policy_batches.items():
  749. if pid not in self.policies_to_train:
  750. continue
  751. # Decompress SampleBatch, in case some columns are compressed.
  752. batch.decompress_if_needed()
  753. policy = self.policy_map[pid]
  754. tf_session = policy.get_session()
  755. if tf_session and hasattr(policy, "_build_learn_on_batch"):
  756. builders[pid] = TFRunBuilder(tf_session, "learn_on_batch")
  757. to_fetch[pid] = policy._build_learn_on_batch(
  758. builders[pid], batch)
  759. else:
  760. info_out[pid] = policy.learn_on_batch(batch)
  761. info_out.update(
  762. {pid: builders[pid].get(v)
  763. for pid, v in to_fetch.items()})
  764. else:
  765. info_out = {
  766. DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
  767. .learn_on_batch(samples)
  768. }
  769. if log_once("learn_out"):
  770. logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
  771. return info_out
  772. def sample_and_learn(self, expected_batch_size: int, num_sgd_iter: int,
  773. sgd_minibatch_size: str,
  774. standardize_fields: List[str]) -> Tuple[dict, int]:
  775. """Sample and batch and learn on it.
  776. This is typically used in combination with distributed allreduce.
  777. Args:
  778. expected_batch_size: Expected number of samples to learn on.
  779. num_sgd_iter: Number of SGD iterations.
  780. sgd_minibatch_size: SGD minibatch size.
  781. standardize_fields: List of sample fields to normalize.
  782. Returns:
  783. A tuple consisting of a dictionary of extra metadata returned from
  784. the policies' `learn_on_batch()` and the number of samples
  785. learned on.
  786. """
  787. batch = self.sample()
  788. assert batch.count == expected_batch_size, \
  789. ("Batch size possibly out of sync between workers, expected:",
  790. expected_batch_size, "got:", batch.count)
  791. logger.info("Executing distributed minibatch SGD "
  792. "with epoch size {}, minibatch size {}".format(
  793. batch.count, sgd_minibatch_size))
  794. info = do_minibatch_sgd(batch, self.policy_map, self, num_sgd_iter,
  795. sgd_minibatch_size, standardize_fields)
  796. return info, batch.count
  797. @DeveloperAPI
  798. def compute_gradients(
  799. self, samples: SampleBatchType) -> Tuple[ModelGradients, dict]:
  800. """Returns a gradient computed w.r.t the specified samples.
  801. Uses the Policy's/ies' compute_gradients method(s) to perform the
  802. calculations.
  803. Args:
  804. samples: The SampleBatch or MultiAgentBatch to compute gradients
  805. for using this worker's policies.
  806. Returns:
  807. In the single-agent case, a tuple consisting of ModelGradients and
  808. info dict of the worker's policy.
  809. In the multi-agent case, a tuple consisting of a dict mapping
  810. PolicyID to ModelGradients and a dict mapping PolicyID to extra
  811. metadata info.
  812. Note that the first return value (grads) can be applied as is to a
  813. compatible worker using the worker's `apply_gradients()` method.
  814. Examples:
  815. >>> batch = worker.sample()
  816. >>> grads, info = worker.compute_gradients(samples)
  817. """
  818. if log_once("compute_gradients"):
  819. logger.info("Compute gradients on:\n\n{}\n".format(
  820. summarize(samples)))
  821. # MultiAgentBatch -> Calculate gradients for all policies.
  822. if isinstance(samples, MultiAgentBatch):
  823. grad_out, info_out = {}, {}
  824. if self.policy_config.get("framework") == "tf":
  825. for pid, batch in samples.policy_batches.items():
  826. if pid not in self.policies_to_train:
  827. continue
  828. policy = self.policy_map[pid]
  829. builder = TFRunBuilder(policy.get_session(),
  830. "compute_gradients")
  831. grad_out[pid], info_out[pid] = (
  832. policy._build_compute_gradients(builder, batch))
  833. grad_out = {k: builder.get(v) for k, v in grad_out.items()}
  834. info_out = {k: builder.get(v) for k, v in info_out.items()}
  835. else:
  836. for pid, batch in samples.policy_batches.items():
  837. if pid not in self.policies_to_train:
  838. continue
  839. grad_out[pid], info_out[pid] = (
  840. self.policy_map[pid].compute_gradients(batch))
  841. # SampleBatch -> Calculate gradients for the default policy.
  842. else:
  843. grad_out, info_out = (
  844. self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
  845. info_out["batch_count"] = samples.count
  846. if log_once("grad_out"):
  847. logger.info("Compute grad info:\n\n{}\n".format(
  848. summarize(info_out)))
  849. return grad_out, info_out
  850. @DeveloperAPI
  851. def apply_gradients(
  852. self,
  853. grads: Union[ModelGradients, Dict[PolicyID, ModelGradients]],
  854. ) -> None:
  855. """Applies the given gradients to this worker's models.
  856. Uses the Policy's/ies' apply_gradients method(s) to perform the
  857. operations.
  858. Args:
  859. grads: Single ModelGradients (single-agent case) or a dict
  860. mapping PolicyIDs to the respective model gradients
  861. structs.
  862. Examples:
  863. >>> samples = worker.sample()
  864. >>> grads, info = worker.compute_gradients(samples)
  865. >>> worker.apply_gradients(grads)
  866. """
  867. if log_once("apply_gradients"):
  868. logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
  869. # Grads is a dict (mapping PolicyIDs to ModelGradients).
  870. # Multi-agent case.
  871. if isinstance(grads, dict):
  872. for pid, g in grads.items():
  873. if pid in self.policies_to_train:
  874. self.policy_map[pid].apply_gradients(g)
  875. # Grads is a ModelGradients type. Single-agent case.
  876. elif DEFAULT_POLICY_ID in self.policies_to_train:
  877. self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
  878. @DeveloperAPI
  879. def get_metrics(self) -> List[Union[RolloutMetrics, OffPolicyEstimate]]:
  880. """Returns the thus-far collected metrics from this worker's rollouts.
  881. Returns:
  882. List of RolloutMetrics and/or OffPolicyEstimate objects
  883. collected thus-far.
  884. """
  885. # Get metrics from sampler (if any).
  886. if self.sampler is not None:
  887. out = self.sampler.get_metrics()
  888. else:
  889. out = []
  890. # Get metrics from our reward-estimators (if any).
  891. for m in self.reward_estimators:
  892. out.extend(m.get_metrics())
  893. return out
  894. @DeveloperAPI
  895. def foreach_env(self, func: Callable[[EnvType], T]) -> List[T]:
  896. """Calls the given function with each sub-environment as arg.
  897. Args:
  898. func: The function to call for each underlying
  899. sub-environment (as only arg).
  900. Returns:
  901. The list of return values of all calls to `func([env])`.
  902. """
  903. if self.async_env is None:
  904. return []
  905. envs = self.async_env.get_sub_environments()
  906. # Empty list (not implemented): Call function directly on the
  907. # BaseEnv.
  908. if not envs:
  909. return [func(self.async_env)]
  910. # Call function on all underlying (vectorized) sub environments.
  911. else:
  912. return [func(e) for e in envs]
  913. @DeveloperAPI
  914. def foreach_env_with_context(
  915. self, func: Callable[[EnvType, EnvContext], T]) -> List[T]:
  916. """Calls given function with each sub-env plus env_ctx as args.
  917. Args:
  918. func: The function to call for each underlying
  919. sub-environment and its EnvContext (as the args).
  920. Returns:
  921. The list of return values of all calls to `func([env, ctx])`.
  922. """
  923. if self.async_env is None:
  924. return []
  925. envs = self.async_env.get_sub_environments()
  926. # Empty list (not implemented): Call function directly on the
  927. # BaseEnv.
  928. if not envs:
  929. return [func(self.async_env, self.env_context)]
  930. # Call function on all underlying (vectorized) sub environments.
  931. else:
  932. ret = []
  933. for i, e in enumerate(envs):
  934. ctx = self.env_context.copy_with_overrides(vector_index=i)
  935. ret.append(func(e, ctx))
  936. return ret
  937. @DeveloperAPI
  938. def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \
  939. Optional[Policy]:
  940. """Return policy for the specified id, or None.
  941. Args:
  942. policy_id: ID of the policy to return. None for DEFAULT_POLICY_ID
  943. (in the single agent case).
  944. Returns:
  945. The policy under the given ID (or None if not found).
  946. """
  947. return self.policy_map.get(policy_id)
  948. @DeveloperAPI
  949. def add_policy(
  950. self,
  951. *,
  952. policy_id: PolicyID,
  953. policy_cls: Type[Policy],
  954. observation_space: Optional[Space] = None,
  955. action_space: Optional[Space] = None,
  956. config: Optional[PartialTrainerConfigDict] = None,
  957. policy_state: Optional[PolicyState] = None,
  958. policy_mapping_fn: Optional[Callable[[AgentID, "Episode"],
  959. PolicyID]] = None,
  960. policies_to_train: Optional[List[PolicyID]] = None,
  961. ) -> Policy:
  962. """Adds a new policy to this RolloutWorker.
  963. Args:
  964. policy_id: ID of the policy to add.
  965. policy_cls: The Policy class to use for constructing the new
  966. Policy.
  967. observation_space: The observation space of the policy to add.
  968. action_space: The action space of the policy to add.
  969. config: The config overrides for the policy to add.
  970. policy_state: Optional state dict to apply to the new
  971. policy instance, right after its construction.
  972. policy_mapping_fn: An optional (updated) policy mapping function
  973. to use from here on. Note that already ongoing episodes will
  974. not change their mapping but will use the old mapping till
  975. the end of the episode.
  976. policies_to_train: An optional list of policy IDs to be trained.
  977. If None, will keep the existing list in place. Policies,
  978. whose IDs are not in the list will not be updated.
  979. Returns:
  980. The newly added policy.
  981. Raises:
  982. KeyError: If the given `policy_id` already exists in this worker's
  983. PolicyMap.
  984. """
  985. if policy_id in self.policy_map:
  986. raise KeyError(f"Policy ID '{policy_id}' already in policy map!")
  987. policy_dict_to_add = _determine_spaces_for_multi_agent_dict(
  988. {
  989. policy_id: PolicySpec(policy_cls, observation_space,
  990. action_space, config or {})
  991. },
  992. self.env,
  993. spaces=self.spaces,
  994. policy_config=self.policy_config,
  995. )
  996. self.policy_dict.update(policy_dict_to_add)
  997. self._build_policy_map(
  998. policy_dict_to_add,
  999. self.policy_config,
  1000. seed=self.policy_config.get("seed"))
  1001. new_policy = self.policy_map[policy_id]
  1002. # Set the state of the newly created policy.
  1003. if policy_state:
  1004. new_policy.set_state(policy_state)
  1005. self.filters[policy_id] = get_filter(
  1006. self.observation_filter, new_policy.observation_space.shape)
  1007. self.set_policy_mapping_fn(policy_mapping_fn)
  1008. self.set_policies_to_train(policies_to_train)
  1009. return new_policy
  1010. @DeveloperAPI
  1011. def remove_policy(
  1012. self,
  1013. *,
  1014. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1015. policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
  1016. policies_to_train: Optional[List[PolicyID]] = None,
  1017. ) -> None:
  1018. """Removes a policy from this RolloutWorker.
  1019. Args:
  1020. policy_id: ID of the policy to be removed. None for
  1021. DEFAULT_POLICY_ID.
  1022. policy_mapping_fn: An optional (updated) policy mapping function
  1023. to use from here on. Note that already ongoing episodes will
  1024. not change their mapping but will use the old mapping till
  1025. the end of the episode.
  1026. policies_to_train: An optional list of policy IDs to be trained.
  1027. If None, will keep the existing list in place. Policies,
  1028. whose IDs are not in the list will not be updated.
  1029. """
  1030. if policy_id not in self.policy_map:
  1031. raise ValueError(f"Policy ID '{policy_id}' not in policy map!")
  1032. del self.policy_map[policy_id]
  1033. del self.preprocessors[policy_id]
  1034. self.set_policy_mapping_fn(policy_mapping_fn)
  1035. self.set_policies_to_train(policies_to_train)
  1036. @DeveloperAPI
  1037. def set_policy_mapping_fn(
  1038. self,
  1039. policy_mapping_fn: Optional[Callable[[AgentID, "Episode"],
  1040. PolicyID]] = None,
  1041. ) -> None:
  1042. """Sets `self.policy_mapping_fn` to a new callable (if provided).
  1043. Args:
  1044. policy_mapping_fn: The new mapping function to use. If None,
  1045. will keep the existing mapping function in place.
  1046. """
  1047. if policy_mapping_fn is not None:
  1048. self.policy_mapping_fn = policy_mapping_fn
  1049. if not callable(self.policy_mapping_fn):
  1050. raise ValueError("`policy_mapping_fn` must be a callable!")
  1051. @DeveloperAPI
  1052. def set_policies_to_train(
  1053. self, policies_to_train: Optional[List[PolicyID]] = None) -> None:
  1054. """Sets `self.policies_to_train` to a new list of PolicyIDs.
  1055. Args:
  1056. policies_to_train: The new list of policy IDs to train with.
  1057. If None, will keep the existing list in place.
  1058. """
  1059. if policies_to_train is not None:
  1060. self.policies_to_train = policies_to_train
  1061. @DeveloperAPI
  1062. def for_policy(self,
  1063. func: Callable[[Policy, Optional[Any]], T],
  1064. policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID,
  1065. **kwargs) -> T:
  1066. """Calls the given function with the specified policy as first arg.
  1067. Args:
  1068. func: The function to call with the policy as first arg.
  1069. policy_id: The PolicyID of the policy to call the function with.
  1070. Keyword Args:
  1071. kwargs: Additional kwargs to be passed to the call.
  1072. Returns:
  1073. The return value of the function call.
  1074. """
  1075. return func(self.policy_map[policy_id], **kwargs)
  1076. @DeveloperAPI
  1077. def foreach_policy(self,
  1078. func: Callable[[Policy, PolicyID, Optional[Any]], T],
  1079. **kwargs) -> List[T]:
  1080. """Calls the given function with each (policy, policy_id) tuple.
  1081. Args:
  1082. func: The function to call with each (policy, policy ID) tuple.
  1083. Keyword Args:
  1084. kwargs: Additional kwargs to be passed to the call.
  1085. Returns:
  1086. The list of return values of all calls to
  1087. `func([policy, pid, **kwargs])`.
  1088. """
  1089. return [
  1090. func(policy, pid, **kwargs)
  1091. for pid, policy in self.policy_map.items()
  1092. ]
  1093. @DeveloperAPI
  1094. def foreach_trainable_policy(
  1095. self, func: Callable[[Policy, PolicyID, Optional[Any]], T],
  1096. **kwargs) -> List[T]:
  1097. """
  1098. Calls the given function with each (policy, policy_id) tuple.
  1099. Only those policies/IDs will be called on, which can be found in
  1100. `self.policies_to_train`.
  1101. Args:
  1102. func: The function to call with each (policy, policy ID) tuple,
  1103. for only those policies that are in `self.policies_to_train`.
  1104. Keyword Args:
  1105. kwargs: Additional kwargs to be passed to the call.
  1106. Returns:
  1107. The list of return values of all calls to
  1108. `func([policy, pid, **kwargs])`.
  1109. """
  1110. return [
  1111. func(policy, pid, **kwargs)
  1112. for pid, policy in self.policy_map.items()
  1113. if pid in self.policies_to_train
  1114. ]
  1115. @DeveloperAPI
  1116. def sync_filters(self, new_filters: dict) -> None:
  1117. """Changes self's filter to given and rebases any accumulated delta.
  1118. Args:
  1119. new_filters: Filters with new state to update local copy.
  1120. """
  1121. assert all(k in new_filters for k in self.filters)
  1122. for k in self.filters:
  1123. self.filters[k].sync(new_filters[k])
  1124. @DeveloperAPI
  1125. def get_filters(self, flush_after: bool = False) -> Dict:
  1126. """Returns a snapshot of filters.
  1127. Args:
  1128. flush_after: Clears the filter buffer state.
  1129. Returns:
  1130. Dict for serializable filters
  1131. """
  1132. return_filters = {}
  1133. for k, f in self.filters.items():
  1134. return_filters[k] = f.as_serializable()
  1135. if flush_after:
  1136. f.clear_buffer()
  1137. return return_filters
  1138. @DeveloperAPI
  1139. def save(self) -> bytes:
  1140. """Serializes this RolloutWorker's current state and returns it.
  1141. Returns:
  1142. The current state of this RolloutWorker as a serialized, pickled
  1143. byte sequence.
  1144. """
  1145. filters = self.get_filters(flush_after=True)
  1146. state = {}
  1147. policy_specs = {}
  1148. for pid in self.policy_map:
  1149. state[pid] = self.policy_map[pid].get_state()
  1150. policy_specs[pid] = self.policy_map.policy_specs[pid]
  1151. return pickle.dumps({
  1152. "filters": filters,
  1153. "state": state,
  1154. "policy_specs": policy_specs,
  1155. })
  1156. @DeveloperAPI
  1157. def restore(self, objs: bytes) -> None:
  1158. """Restores this RolloutWorker's state from a sequence of bytes.
  1159. Args:
  1160. objs: The byte sequence to restore this worker's state from.
  1161. Examples:
  1162. >>> state = worker.save()
  1163. >>> new_worker = RolloutWorker(...)
  1164. >>> new_worker.restore(state)
  1165. """
  1166. objs = pickle.loads(objs)
  1167. self.sync_filters(objs["filters"])
  1168. for pid, state in objs["state"].items():
  1169. if pid not in self.policy_map:
  1170. pol_spec = objs.get("policy_specs", {}).get(pid)
  1171. if not pol_spec:
  1172. logger.warning(
  1173. f"PolicyID '{pid}' was probably added on-the-fly (not"
  1174. " part of the static `multagent.policies` config) and"
  1175. " no PolicySpec objects found in the pickled policy "
  1176. "state. Will not add `{pid}`, but ignore it for now.")
  1177. else:
  1178. self.add_policy(
  1179. policy_id=pid,
  1180. policy_cls=pol_spec.policy_class,
  1181. observation_space=pol_spec.observation_space,
  1182. action_space=pol_spec.action_space,
  1183. config=pol_spec.config,
  1184. )
  1185. else:
  1186. self.policy_map[pid].set_state(state)
  1187. @DeveloperAPI
  1188. def get_weights(
  1189. self,
  1190. policies: Optional[List[PolicyID]] = None,
  1191. ) -> Dict[PolicyID, ModelWeights]:
  1192. """Returns each policies' model weights of this worker.
  1193. Args:
  1194. policies: List of PolicyIDs to get the weights from.
  1195. Use None for all policies.
  1196. Returns:
  1197. Dict mapping PolicyIDs to ModelWeights.
  1198. Examples:
  1199. >>> weights = worker.get_weights()
  1200. >>> print(weights)
  1201. {"default_policy": {"layer1": array(...), "layer2": ...}}
  1202. """
  1203. if policies is None:
  1204. policies = list(self.policy_map.keys())
  1205. policies = force_list(policies)
  1206. return {
  1207. pid: policy.get_weights()
  1208. for pid, policy in self.policy_map.items() if pid in policies
  1209. }
  1210. @DeveloperAPI
  1211. def set_weights(self,
  1212. weights: Dict[PolicyID, ModelWeights],
  1213. global_vars: Optional[Dict] = None) -> None:
  1214. """Sets each policies' model weights of this worker.
  1215. Args:
  1216. weights: Dict mapping PolicyIDs to the new weights to be used.
  1217. global_vars: An optional global vars dict to set this
  1218. worker to. If None, do not update the global_vars.
  1219. Examples:
  1220. >>> weights = worker.get_weights()
  1221. >>> # Set `global_vars` (timestep) as well.
  1222. >>> worker.set_weights(weights, {"timestep": 42})
  1223. """
  1224. # If per-policy weights are object refs, `ray.get()` them first.
  1225. if weights and isinstance(next(iter(weights.values())), ObjectRef):
  1226. actual_weights = ray.get(list(weights.values()))
  1227. weights = {
  1228. pid: actual_weights[i]
  1229. for i, pid in enumerate(weights.keys())
  1230. }
  1231. for pid, w in weights.items():
  1232. self.policy_map[pid].set_weights(w)
  1233. if global_vars:
  1234. self.set_global_vars(global_vars)
  1235. @DeveloperAPI
  1236. def get_global_vars(self) -> dict:
  1237. """Returns the current global_vars dict of this worker.
  1238. Returns:
  1239. The current global_vars dict of this worker.
  1240. Examples:
  1241. >>> global_vars = worker.get_global_vars()
  1242. >>> print(global_vars)
  1243. {"timestep": 424242}
  1244. """
  1245. return self.global_vars
  1246. @DeveloperAPI
  1247. def set_global_vars(self, global_vars: dict) -> None:
  1248. """Updates this worker's and all its policies' global vars.
  1249. Args:
  1250. global_vars: The new global_vars dict.
  1251. Examples:
  1252. >>> global_vars = worker.set_global_vars({"timestep": 4242})
  1253. """
  1254. self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
  1255. self.global_vars = global_vars
  1256. @DeveloperAPI
  1257. def stop(self) -> None:
  1258. """Releases all resources used by this RolloutWorker."""
  1259. # If we have an env -> Release its resources.
  1260. if self.env is not None:
  1261. self.async_env.stop()
  1262. # Close all policies' sessions (if tf static graph).
  1263. for policy in self.policy_map.values():
  1264. sess = policy.get_session()
  1265. # Closes the tf session, if any.
  1266. if sess is not None:
  1267. sess.close()
  1268. @DeveloperAPI
  1269. def apply(
  1270. self,
  1271. func: Callable[["RolloutWorker", Optional[Any], Optional[Any]], T],
  1272. *args, **kwargs) -> T:
  1273. """Calls the given function with this rollout worker instance.
  1274. Useful for when the RolloutWorker class has been converted into a
  1275. ActorHandle and the user needs to execute some functionality (e.g.
  1276. add a property) on the underlying policy object.
  1277. Args:
  1278. func: The function to call, with this RolloutWorker as first
  1279. argument, followed by args, and kwargs.
  1280. args: Optional additional args to pass to the function call.
  1281. kwargs: Optional additional kwargs to pass to the function call.
  1282. Returns:
  1283. The return value of the function call.
  1284. """
  1285. return func(self, *args, **kwargs)
  1286. def setup_torch_data_parallel(self, url: str, world_rank: int,
  1287. world_size: int, backend: str) -> None:
  1288. """Join a torch process group for distributed SGD."""
  1289. logger.info("Joining process group, url={}, world_rank={}, "
  1290. "world_size={}, backend={}".format(url, world_rank,
  1291. world_size, backend))
  1292. torch.distributed.init_process_group(
  1293. backend=backend,
  1294. init_method=url,
  1295. rank=world_rank,
  1296. world_size=world_size)
  1297. for pid, policy in self.policy_map.items():
  1298. if not isinstance(policy, TorchPolicy):
  1299. raise ValueError(
  1300. "This policy does not support torch distributed", policy)
  1301. policy.distributed_world_size = world_size
  1302. @DeveloperAPI
  1303. def creation_args(self) -> dict:
  1304. """Returns the kwargs dict used to create this worker."""
  1305. return self._original_kwargs
  1306. @DeveloperAPI
  1307. def get_host(self) -> str:
  1308. """Returns the hostname of the process running this evaluator."""
  1309. return platform.node()
  1310. @DeveloperAPI
  1311. def get_node_ip(self) -> str:
  1312. """Returns the IP address of the node that this worker runs on."""
  1313. return ray.util.get_node_ip_address()
  1314. @DeveloperAPI
  1315. def find_free_port(self) -> int:
  1316. """Finds a free port on the node that this worker runs on."""
  1317. from ray.util.ml_utils.util import find_free_port
  1318. return find_free_port()
  1319. def __del__(self):
  1320. """If this worker is deleted, clears all resources used by it."""
  1321. # In case we have-an AsyncSampler, kill its sampling thread.
  1322. if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
  1323. self.sampler.shutdown = True
  1324. def _build_policy_map(
  1325. self,
  1326. policy_dict: MultiAgentPolicyConfigDict,
  1327. policy_config: PartialTrainerConfigDict,
  1328. session_creator: Optional[Callable[[], "tf1.Session"]] = None,
  1329. seed: Optional[int] = None,
  1330. ) -> None:
  1331. """Adds the given policy_dict to `self.policy_map`.
  1332. Args:
  1333. policy_dict: The MultiAgentPolicyConfigDict to be added to this
  1334. worker's PolicyMap.
  1335. policy_config: The general policy config to use. May be updated
  1336. by individual policy condig overrides in the given
  1337. multi-agent `policy_dict`.
  1338. session_creator: A callable that creates a tf session
  1339. (if applicable).
  1340. seed: An optional random seed to pass to PolicyMap's
  1341. constructor.
  1342. """
  1343. ma_config = policy_config.get("multiagent", {})
  1344. # If our policy_map does not exist yet, create it here.
  1345. self.policy_map = self.policy_map or PolicyMap(
  1346. worker_index=self.worker_index,
  1347. num_workers=self.num_workers,
  1348. capacity=ma_config.get("policy_map_capacity"),
  1349. path=ma_config.get("policy_map_cache"),
  1350. policy_config=policy_config,
  1351. session_creator=session_creator,
  1352. seed=seed,
  1353. )
  1354. # If our preprocessors dict does not exist yet, create it here.
  1355. self.preprocessors = self.preprocessors or {}
  1356. # Loop through given policy-dict and add each entry to our map.
  1357. for name, (orig_cls, obs_space, act_space,
  1358. conf) in sorted(policy_dict.items()):
  1359. logger.debug("Creating policy for {}".format(name))
  1360. # Update the general policy_config with the specific config
  1361. # for this particular policy.
  1362. merged_conf = merge_dicts(policy_config, conf or {})
  1363. # Update num_workers and worker_index.
  1364. merged_conf["num_workers"] = self.num_workers
  1365. merged_conf["worker_index"] = self.worker_index
  1366. # Preprocessors.
  1367. if self.preprocessing_enabled:
  1368. preprocessor = ModelCatalog.get_preprocessor_for_space(
  1369. obs_space, merged_conf.get("model"))
  1370. self.preprocessors[name] = preprocessor
  1371. if preprocessor is not None:
  1372. obs_space = preprocessor.observation_space
  1373. else:
  1374. self.preprocessors[name] = None
  1375. # Create the actual policy object.
  1376. self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
  1377. conf, merged_conf)
  1378. if self.worker_index == 0:
  1379. logger.info(f"Built policy map: {self.policy_map}")
  1380. logger.info(f"Built preprocessor map: {self.preprocessors}")
  1381. @Deprecated(
  1382. new="Trainer.get_policy().export_model([export_dir], [onnx]?)",
  1383. error=False)
  1384. def export_policy_model(self,
  1385. export_dir: str,
  1386. policy_id: PolicyID = DEFAULT_POLICY_ID,
  1387. onnx: Optional[int] = None):
  1388. self.policy_map[policy_id].export_model(export_dir, onnx=onnx)
  1389. @Deprecated(
  1390. new="Trainer.get_policy().import_model_from_h5([import_file])",
  1391. error=False)
  1392. def import_policy_model_from_h5(self,
  1393. import_file: str,
  1394. policy_id: PolicyID = DEFAULT_POLICY_ID):
  1395. self.policy_map[policy_id].import_model_from_h5(import_file)
  1396. @Deprecated(
  1397. new="Trainer.get_policy().export_checkpoint([export_dir], "
  1398. "[filename]?)",
  1399. error=False)
  1400. def export_policy_checkpoint(self,
  1401. export_dir: str,
  1402. filename_prefix: str = "model",
  1403. policy_id: PolicyID = DEFAULT_POLICY_ID):
  1404. self.policy_map[policy_id].export_checkpoint(export_dir,
  1405. filename_prefix)
  1406. def _determine_spaces_for_multi_agent_dict(
  1407. multi_agent_dict: MultiAgentPolicyConfigDict,
  1408. env: Optional[EnvType] = None,
  1409. spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
  1410. policy_config: Optional[PartialTrainerConfigDict] = None,
  1411. ) -> MultiAgentPolicyConfigDict:
  1412. policy_config = policy_config or {}
  1413. # Try extracting spaces from env or from given spaces dict.
  1414. env_obs_space = None
  1415. env_act_space = None
  1416. # Env is a ray.remote: Get spaces via its (automatically added)
  1417. # `_get_spaces()` method.
  1418. if isinstance(env, ray.actor.ActorHandle):
  1419. env_obs_space, env_act_space = ray.get(env._get_spaces.remote())
  1420. # Normal env (gym.Env or MultiAgentEnv): These should have the
  1421. # `observation_space` and `action_space` properties.
  1422. elif env is not None:
  1423. if hasattr(env, "observation_space") and isinstance(
  1424. env.observation_space, gym.Space):
  1425. env_obs_space = env.observation_space
  1426. if hasattr(env, "action_space") and isinstance(env.action_space,
  1427. gym.Space):
  1428. env_act_space = env.action_space
  1429. # Last resort: Try getting the env's spaces from the spaces
  1430. # dict's special __env__ key.
  1431. if spaces is not None:
  1432. if env_obs_space is None:
  1433. env_obs_space = spaces.get("__env__", [None])[0]
  1434. if env_act_space is None:
  1435. env_act_space = spaces.get("__env__", [None, None])[1]
  1436. for pid, policy_spec in multi_agent_dict.copy().items():
  1437. if policy_spec.observation_space is None:
  1438. if spaces is not None and pid in spaces:
  1439. obs_space = spaces[pid][0]
  1440. elif env_obs_space is not None:
  1441. obs_space = env_obs_space
  1442. elif policy_config.get("observation_space"):
  1443. obs_space = policy_config["observation_space"]
  1444. else:
  1445. raise ValueError(
  1446. "`observation_space` not provided in PolicySpec for "
  1447. f"{pid} and env does not have an observation space OR "
  1448. "no spaces received from other workers' env(s) OR no "
  1449. "`observation_space` specified in config!")
  1450. multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
  1451. observation_space=obs_space)
  1452. if policy_spec.action_space is None:
  1453. if spaces is not None and pid in spaces:
  1454. act_space = spaces[pid][1]
  1455. elif env_act_space is not None:
  1456. act_space = env_act_space
  1457. elif policy_config.get("action_space"):
  1458. act_space = policy_config["action_space"]
  1459. else:
  1460. raise ValueError(
  1461. "`action_space` not provided in PolicySpec for "
  1462. f"{pid} and env does not have an action space OR "
  1463. "no spaces received from other workers' env(s) OR no "
  1464. "`action_space` specified in config!")
  1465. multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
  1466. action_space=act_space)
  1467. return multi_agent_dict
  1468. def _validate_env(env: EnvType, env_context: EnvContext = None):
  1469. # Base message for checking the env for vector-index=0
  1470. msg = f"Validating sub-env at vector index={env_context.vector_index} ..."
  1471. allowed_types = [
  1472. gym.Env, ExternalEnv, VectorEnv, BaseEnv, ray.actor.ActorHandle
  1473. ]
  1474. if not any(isinstance(env, tpe) for tpe in allowed_types):
  1475. # Allow this as a special case (assumed gym.Env).
  1476. # TODO: Disallow this early-out. Everything should conform to a few
  1477. # supported classes, i.e. gym.Env/MultiAgentEnv/etc...
  1478. if hasattr(env, "observation_space") and hasattr(env, "action_space"):
  1479. logger.warning(msg + f" (warning; invalid env-type={type(env)})")
  1480. return
  1481. else:
  1482. logger.warning(msg + " (NOT OK)")
  1483. raise EnvError(
  1484. "Returned env should be an instance of gym.Env (incl. "
  1485. "MultiAgentEnv), ExternalEnv, VectorEnv, or BaseEnv. "
  1486. f"The provided env creator function returned {env} "
  1487. f"(type={type(env)}).")
  1488. # Do some test runs with the provided env.
  1489. if isinstance(env, gym.Env) and not isinstance(env, MultiAgentEnv):
  1490. # Make sure the gym.Env has the two space attributes properly set.
  1491. assert hasattr(env, "observation_space") and hasattr(
  1492. env, "action_space")
  1493. # Get a dummy observation by resetting the env.
  1494. dummy_obs = env.reset()
  1495. # Convert lists to np.ndarrays.
  1496. if type(dummy_obs) is list and isinstance(env.observation_space, Box):
  1497. dummy_obs = np.array(dummy_obs)
  1498. # Ignore float32/float64 diffs.
  1499. if isinstance(env.observation_space, Box) and \
  1500. env.observation_space.dtype != dummy_obs.dtype:
  1501. dummy_obs = dummy_obs.astype(env.observation_space.dtype)
  1502. # Check, if observation is ok (part of the observation space). If not,
  1503. # error.
  1504. if not env.observation_space.contains(dummy_obs):
  1505. logger.warning(msg + " (NOT OK)")
  1506. raise EnvError(
  1507. f"Env's `observation_space` {env.observation_space} does not "
  1508. f"contain returned observation after a reset ({dummy_obs})!")
  1509. # Log that everything is ok.
  1510. logger.info(msg + " (ok)")