impala.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259
  1. import copy
  2. import dataclasses
  3. from functools import partial
  4. import logging
  5. import platform
  6. import queue
  7. import random
  8. from typing import Callable, List, Optional, Set, Tuple, Type, Union
  9. import numpy as np
  10. import tree # pip install dm_tree
  11. import ray
  12. from ray import ObjectRef
  13. from ray.rllib import SampleBatch
  14. from ray.rllib.algorithms.algorithm import Algorithm
  15. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  16. from ray.rllib.algorithms.impala.impala_learner import (
  17. ImpalaLearnerHyperparameters,
  18. _reduce_impala_results,
  19. )
  20. from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
  21. from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
  22. from ray.rllib.evaluation.worker_set import handle_remote_call_result_errors
  23. from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
  24. from ray.rllib.execution.learner_thread import LearnerThread
  25. from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
  26. from ray.rllib.policy.policy import Policy
  27. from ray.rllib.policy.sample_batch import concat_samples
  28. from ray.rllib.utils.actor_manager import (
  29. FaultAwareApply,
  30. FaultTolerantActorManager,
  31. RemoteCallResults,
  32. )
  33. from ray.rllib.utils.actors import create_colocated_actors
  34. from ray.rllib.utils.annotations import override
  35. from ray.rllib.utils.metrics import ALL_MODULES
  36. from ray.rllib.utils.deprecation import (
  37. DEPRECATED_VALUE,
  38. deprecation_warning,
  39. )
  40. from ray.rllib.utils.metrics import (
  41. NUM_AGENT_STEPS_SAMPLED,
  42. NUM_AGENT_STEPS_TRAINED,
  43. NUM_ENV_STEPS_SAMPLED,
  44. NUM_ENV_STEPS_TRAINED,
  45. NUM_SYNCH_WORKER_WEIGHTS,
  46. NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS,
  47. SYNCH_WORKER_WEIGHTS_TIMER,
  48. SAMPLE_TIMER,
  49. )
  50. from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
  51. from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
  52. from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
  53. from ray.rllib.utils.schedules.scheduler import Scheduler
  54. from ray.rllib.utils.typing import (
  55. PartialAlgorithmConfigDict,
  56. PolicyID,
  57. ResultDict,
  58. SampleBatchType,
  59. )
  60. from ray.tune.execution.placement_groups import PlacementGroupFactory
  61. logger = logging.getLogger(__name__)
  62. class ImpalaConfig(AlgorithmConfig):
  63. """Defines a configuration class from which an Impala can be built.
  64. Example:
  65. >>> from ray.rllib.algorithms.impala import ImpalaConfig
  66. >>> config = ImpalaConfig()
  67. >>> config = config.training(lr=0.0003, train_batch_size=512) # doctest: +SKIP
  68. >>> config = config.resources(num_gpus=4) # doctest: +SKIP
  69. >>> config = config.rollouts(num_rollout_workers=64) # doctest: +SKIP
  70. >>> print(config.to_dict()) # doctest: +SKIP
  71. >>> # Build a Algorithm object from the config and run 1 training iteration.
  72. >>> algo = config.build(env="CartPole-v1") # doctest: +SKIP
  73. >>> algo.train() # doctest: +SKIP
  74. Example:
  75. >>> from ray.rllib.algorithms.impala import ImpalaConfig
  76. >>> from ray import air
  77. >>> from ray import tune
  78. >>> config = ImpalaConfig()
  79. >>> # Print out some default values.
  80. >>> print(config.vtrace) # doctest: +SKIP
  81. >>> # Update the config object.
  82. >>> config = config.training( # doctest: +SKIP
  83. ... lr=tune.grid_search([0.0001, 0.0003]), grad_clip=20.0
  84. ... )
  85. >>> # Set the config object's env.
  86. >>> config = config.environment(env="CartPole-v1") # doctest: +SKIP
  87. >>> # Use to_dict() to get the old-style python config dict
  88. >>> # when running with tune.
  89. >>> tune.Tuner( # doctest: +SKIP
  90. ... "IMPALA",
  91. ... run_config=air.RunConfig(stop={"episode_reward_mean": 200}),
  92. ... param_space=config.to_dict(),
  93. ... ).fit()
  94. """
  95. def __init__(self, algo_class=None):
  96. """Initializes a ImpalaConfig instance."""
  97. super().__init__(algo_class=algo_class or Impala)
  98. # fmt: off
  99. # __sphinx_doc_begin__
  100. # IMPALA specific settings:
  101. self.vtrace = True
  102. self.vtrace_clip_rho_threshold = 1.0
  103. self.vtrace_clip_pg_rho_threshold = 1.0
  104. self.num_multi_gpu_tower_stacks = 1
  105. self.minibatch_buffer_size = 1
  106. self.num_sgd_iter = 1
  107. self.replay_proportion = 0.0
  108. self.replay_buffer_num_slots = 0
  109. self.learner_queue_size = 16
  110. self.learner_queue_timeout = 300
  111. self.max_requests_in_flight_per_aggregator_worker = 2
  112. self.timeout_s_sampler_manager = 0.0
  113. self.timeout_s_aggregator_manager = 0.0
  114. self.broadcast_interval = 1
  115. self.num_aggregation_workers = 0
  116. self.grad_clip = 40.0
  117. # Note: Only when using _enable_learner_api=True can the clipping mode be
  118. # configured by the user. On the old API stack, RLlib will always clip by
  119. # global_norm, no matter the value of `grad_clip_by`.
  120. self.grad_clip_by = "global_norm"
  121. self.opt_type = "adam"
  122. self.lr_schedule = None
  123. self.decay = 0.99
  124. self.momentum = 0.0
  125. self.epsilon = 0.1
  126. self.vf_loss_coeff = 0.5
  127. self.entropy_coeff = 0.01
  128. self.entropy_coeff_schedule = None
  129. self._separate_vf_optimizer = False
  130. self._lr_vf = 0.0005
  131. self.after_train_step = None
  132. # Override some of AlgorithmConfig's default values with IMPALA-specific values.
  133. self.rollout_fragment_length = 50
  134. self.train_batch_size = 500
  135. self._minibatch_size = "auto"
  136. self.num_rollout_workers = 2
  137. self.num_gpus = 1
  138. self.lr = 0.0005
  139. self.min_time_s_per_iteration = 10
  140. self._tf_policy_handles_more_than_one_loss = True
  141. self.exploration_config = {
  142. # The Exploration class to use. In the simplest case, this is the name
  143. # (str) of any class present in the `rllib.utils.exploration` package.
  144. # You can also provide the python class directly or the full location
  145. # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
  146. # EpsilonGreedy").
  147. "type": "StochasticSampling",
  148. # Add constructor kwargs here (if any).
  149. }
  150. # __sphinx_doc_end__
  151. # fmt: on
  152. # Deprecated value.
  153. self.num_data_loader_buffers = DEPRECATED_VALUE
  154. self.vtrace_drop_last_ts = DEPRECATED_VALUE
  155. @override(AlgorithmConfig)
  156. def training(
  157. self,
  158. *,
  159. vtrace: Optional[bool] = NotProvided,
  160. vtrace_clip_rho_threshold: Optional[float] = NotProvided,
  161. vtrace_clip_pg_rho_threshold: Optional[float] = NotProvided,
  162. gamma: Optional[float] = NotProvided,
  163. num_multi_gpu_tower_stacks: Optional[int] = NotProvided,
  164. minibatch_buffer_size: Optional[int] = NotProvided,
  165. minibatch_size: Optional[Union[int, str]] = NotProvided,
  166. num_sgd_iter: Optional[int] = NotProvided,
  167. replay_proportion: Optional[float] = NotProvided,
  168. replay_buffer_num_slots: Optional[int] = NotProvided,
  169. learner_queue_size: Optional[int] = NotProvided,
  170. learner_queue_timeout: Optional[float] = NotProvided,
  171. max_requests_in_flight_per_aggregator_worker: Optional[int] = NotProvided,
  172. timeout_s_sampler_manager: Optional[float] = NotProvided,
  173. timeout_s_aggregator_manager: Optional[float] = NotProvided,
  174. broadcast_interval: Optional[int] = NotProvided,
  175. num_aggregation_workers: Optional[int] = NotProvided,
  176. grad_clip: Optional[float] = NotProvided,
  177. opt_type: Optional[str] = NotProvided,
  178. lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
  179. decay: Optional[float] = NotProvided,
  180. momentum: Optional[float] = NotProvided,
  181. epsilon: Optional[float] = NotProvided,
  182. vf_loss_coeff: Optional[float] = NotProvided,
  183. entropy_coeff: Optional[float] = NotProvided,
  184. entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
  185. _separate_vf_optimizer: Optional[bool] = NotProvided,
  186. _lr_vf: Optional[float] = NotProvided,
  187. after_train_step: Optional[Callable[[dict], None]] = NotProvided,
  188. # deprecated.
  189. vtrace_drop_last_ts=None,
  190. **kwargs,
  191. ) -> "ImpalaConfig":
  192. """Sets the training related configuration.
  193. Args:
  194. vtrace: V-trace params (see vtrace_tf/torch.py).
  195. vtrace_clip_rho_threshold:
  196. vtrace_clip_pg_rho_threshold:
  197. gamma: Float specifying the discount factor of the Markov Decision process.
  198. num_multi_gpu_tower_stacks: For each stack of multi-GPU towers, how many
  199. slots should we reserve for parallel data loading? Set this to >1 to
  200. load data into GPUs in parallel. This will increase GPU memory usage
  201. proportionally with the number of stacks.
  202. Example:
  203. 2 GPUs and `num_multi_gpu_tower_stacks=3`:
  204. - One tower stack consists of 2 GPUs, each with a copy of the
  205. model/graph.
  206. - Each of the stacks will create 3 slots for batch data on each of its
  207. GPUs, increasing memory requirements on each GPU by 3x.
  208. - This enables us to preload data into these stacks while another stack
  209. is performing gradient calculations.
  210. minibatch_buffer_size: How many train batches should be retained for
  211. minibatching. This conf only has an effect if `num_sgd_iter > 1`.
  212. minibatch_size: The size of minibatches that are trained over during
  213. each SGD iteration. If "auto", will use the same value as
  214. `train_batch_size`.
  215. Note that this setting only has an effect if `_enable_learner_api=True`
  216. and it must be a multiple of `rollout_fragment_length` or
  217. `sequence_length` and smaller than or equal to `train_batch_size`.
  218. num_sgd_iter: Number of passes to make over each train batch.
  219. replay_proportion: Set >0 to enable experience replay. Saved samples will
  220. be replayed with a p:1 proportion to new data samples.
  221. replay_buffer_num_slots: Number of sample batches to store for replay.
  222. The number of transitions saved total will be
  223. (replay_buffer_num_slots * rollout_fragment_length).
  224. learner_queue_size: Max queue size for train batches feeding into the
  225. learner.
  226. learner_queue_timeout: Wait for train batches to be available in minibatch
  227. buffer queue this many seconds. This may need to be increased e.g. when
  228. training with a slow environment.
  229. max_requests_in_flight_per_aggregator_worker: Level of queuing for replay
  230. aggregator operations (if using aggregator workers).
  231. timeout_s_sampler_manager: The timeout for waiting for sampling results
  232. for workers -- typically if this is too low, the manager won't be able
  233. to retrieve ready sampling results.
  234. timeout_s_aggregator_manager: The timeout for waiting for replay worker
  235. results -- typically if this is too low, the manager won't be able to
  236. retrieve ready replay requests.
  237. broadcast_interval: Number of training step calls before weights are
  238. broadcasted to rollout workers that are sampled during any iteration.
  239. num_aggregation_workers: Use n (`num_aggregation_workers`) extra Actors for
  240. multi-level aggregation of the data produced by the m RolloutWorkers
  241. (`num_workers`). Note that n should be much smaller than m.
  242. This can make sense if ingesting >2GB/s of samples, or if
  243. the data requires decompression.
  244. grad_clip: If specified, clip the global norm of gradients by this amount.
  245. opt_type: Either "adam" or "rmsprop".
  246. lr_schedule: Learning rate schedule. In the format of
  247. [[timestep, lr-value], [timestep, lr-value], ...]
  248. Intermediary timesteps will be assigned to interpolated learning rate
  249. values. A schedule should normally start from timestep 0.
  250. decay: Decay setting for the RMSProp optimizer, in case `opt_type=rmsprop`.
  251. momentum: Momentum setting for the RMSProp optimizer, in case
  252. `opt_type=rmsprop`.
  253. epsilon: Epsilon setting for the RMSProp optimizer, in case
  254. `opt_type=rmsprop`.
  255. vf_loss_coeff: Coefficient for the value function term in the loss function.
  256. entropy_coeff: Coefficient for the entropy regularizer term in the loss
  257. function.
  258. entropy_coeff_schedule: Decay schedule for the entropy regularizer.
  259. _separate_vf_optimizer: Set this to true to have two separate optimizers
  260. optimize the policy-and value networks.
  261. _lr_vf: If _separate_vf_optimizer is True, define separate learning rate
  262. for the value network.
  263. after_train_step: Callback for APPO to use to update KL, target network
  264. periodically. The input to the callback is the learner fetches dict.
  265. Returns:
  266. This updated AlgorithmConfig object.
  267. """
  268. if vtrace_drop_last_ts is not None:
  269. deprecation_warning(
  270. old="vtrace_drop_last_ts",
  271. help="The v-trace operations in RLlib have been enhanced and we are "
  272. "now using proper value bootstrapping at the end of each "
  273. "trajectory, such that no timesteps in our loss functions have to "
  274. "be dropped anymore.",
  275. error=True,
  276. )
  277. # Pass kwargs onto super's `training()` method.
  278. super().training(**kwargs)
  279. if vtrace is not NotProvided:
  280. self.vtrace = vtrace
  281. if vtrace_clip_rho_threshold is not NotProvided:
  282. self.vtrace_clip_rho_threshold = vtrace_clip_rho_threshold
  283. if vtrace_clip_pg_rho_threshold is not NotProvided:
  284. self.vtrace_clip_pg_rho_threshold = vtrace_clip_pg_rho_threshold
  285. if num_multi_gpu_tower_stacks is not NotProvided:
  286. self.num_multi_gpu_tower_stacks = num_multi_gpu_tower_stacks
  287. if minibatch_buffer_size is not NotProvided:
  288. self.minibatch_buffer_size = minibatch_buffer_size
  289. if num_sgd_iter is not NotProvided:
  290. self.num_sgd_iter = num_sgd_iter
  291. if replay_proportion is not NotProvided:
  292. self.replay_proportion = replay_proportion
  293. if replay_buffer_num_slots is not NotProvided:
  294. self.replay_buffer_num_slots = replay_buffer_num_slots
  295. if learner_queue_size is not NotProvided:
  296. self.learner_queue_size = learner_queue_size
  297. if learner_queue_timeout is not NotProvided:
  298. self.learner_queue_timeout = learner_queue_timeout
  299. if broadcast_interval is not NotProvided:
  300. self.broadcast_interval = broadcast_interval
  301. if num_aggregation_workers is not NotProvided:
  302. self.num_aggregation_workers = num_aggregation_workers
  303. if max_requests_in_flight_per_aggregator_worker is not NotProvided:
  304. self.max_requests_in_flight_per_aggregator_worker = (
  305. max_requests_in_flight_per_aggregator_worker
  306. )
  307. if timeout_s_sampler_manager is not NotProvided:
  308. self.timeout_s_sampler_manager = timeout_s_sampler_manager
  309. if timeout_s_aggregator_manager is not NotProvided:
  310. self.timeout_s_aggregator_manager = timeout_s_aggregator_manager
  311. if grad_clip is not NotProvided:
  312. self.grad_clip = grad_clip
  313. if opt_type is not NotProvided:
  314. self.opt_type = opt_type
  315. if lr_schedule is not NotProvided:
  316. self.lr_schedule = lr_schedule
  317. if decay is not NotProvided:
  318. self.decay = decay
  319. if momentum is not NotProvided:
  320. self.momentum = momentum
  321. if epsilon is not NotProvided:
  322. self.epsilon = epsilon
  323. if vf_loss_coeff is not NotProvided:
  324. self.vf_loss_coeff = vf_loss_coeff
  325. if entropy_coeff is not NotProvided:
  326. self.entropy_coeff = entropy_coeff
  327. if entropy_coeff_schedule is not NotProvided:
  328. self.entropy_coeff_schedule = entropy_coeff_schedule
  329. if _separate_vf_optimizer is not NotProvided:
  330. self._separate_vf_optimizer = _separate_vf_optimizer
  331. if _lr_vf is not NotProvided:
  332. self._lr_vf = _lr_vf
  333. if after_train_step is not NotProvided:
  334. self.after_train_step = after_train_step
  335. if gamma is not NotProvided:
  336. self.gamma = gamma
  337. if minibatch_size is not NotProvided:
  338. self._minibatch_size = minibatch_size
  339. return self
  340. @override(AlgorithmConfig)
  341. def validate(self) -> None:
  342. # Call the super class' validation method first.
  343. super().validate()
  344. if self.num_data_loader_buffers != DEPRECATED_VALUE:
  345. deprecation_warning(
  346. "num_data_loader_buffers", "num_multi_gpu_tower_stacks", error=True
  347. )
  348. # Entropy coeff schedule checking.
  349. if self._enable_learner_api:
  350. if self.entropy_coeff_schedule is not None:
  351. raise ValueError(
  352. "`entropy_coeff_schedule` is deprecated and must be None! Use the "
  353. "`entropy_coeff` setting to setup a schedule."
  354. )
  355. Scheduler.validate(
  356. fixed_value_or_schedule=self.entropy_coeff,
  357. setting_name="entropy_coeff",
  358. description="entropy coefficient",
  359. )
  360. elif isinstance(self.entropy_coeff, float) and self.entropy_coeff < 0.0:
  361. raise ValueError("`entropy_coeff` must be >= 0.0")
  362. # Check whether worker to aggregation-worker ratio makes sense.
  363. if self.num_aggregation_workers > self.num_rollout_workers:
  364. raise ValueError(
  365. "`num_aggregation_workers` must be smaller than or equal "
  366. "`num_rollout_workers`! Aggregation makes no sense otherwise."
  367. )
  368. elif self.num_aggregation_workers > self.num_rollout_workers / 2:
  369. logger.warning(
  370. "`num_aggregation_workers` should be significantly smaller "
  371. "than `num_workers`! Try setting it to 0.5*`num_workers` or "
  372. "less."
  373. )
  374. # If two separate optimizers/loss terms used for tf, must also set
  375. # `_tf_policy_handles_more_than_one_loss` to True.
  376. if self._separate_vf_optimizer is True:
  377. # Only supported in tf on the old API stack.
  378. if self.framework_str not in ["tf", "tf2"]:
  379. raise ValueError(
  380. "`_separate_vf_optimizer` only supported to tf so far!"
  381. )
  382. if self._tf_policy_handles_more_than_one_loss is False:
  383. raise ValueError(
  384. "`_tf_policy_handles_more_than_one_loss` must be set to "
  385. "True, for TFPolicy to support more than one loss "
  386. "term/optimizer! Try setting config.training("
  387. "_tf_policy_handles_more_than_one_loss=True)."
  388. )
  389. # Learner API specific checks.
  390. if self._enable_learner_api:
  391. if not (
  392. (self.minibatch_size % self.rollout_fragment_length == 0)
  393. and self.minibatch_size <= self.train_batch_size
  394. ):
  395. raise ValueError(
  396. f"`minibatch_size` ({self._minibatch_size}) must either be 'auto' "
  397. "or a multiple of `rollout_fragment_length` "
  398. f"({self.rollout_fragment_length}) while at the same time smaller "
  399. f"than or equal to `train_batch_size` ({self.train_batch_size})!"
  400. )
  401. @override(AlgorithmConfig)
  402. def get_learner_hyperparameters(self) -> ImpalaLearnerHyperparameters:
  403. base_hps = super().get_learner_hyperparameters()
  404. learner_hps = ImpalaLearnerHyperparameters(
  405. rollout_frag_or_episode_len=self.get_rollout_fragment_length(),
  406. discount_factor=self.gamma,
  407. entropy_coeff=self.entropy_coeff,
  408. vf_loss_coeff=self.vf_loss_coeff,
  409. vtrace_clip_rho_threshold=self.vtrace_clip_rho_threshold,
  410. vtrace_clip_pg_rho_threshold=self.vtrace_clip_pg_rho_threshold,
  411. **dataclasses.asdict(base_hps),
  412. )
  413. # TODO: We currently do not use the `recurrent_seq_len` property anyways.
  414. # We should re-think the handling of RNN/SEQ_LENs/etc.. once we start
  415. # supporting them in RLModules and then revisit this check here.
  416. # Also, such a check should be moved into `IMPALAConfig.validate()`.
  417. assert (learner_hps.rollout_frag_or_episode_len is None) != (
  418. learner_hps.recurrent_seq_len is None
  419. ), (
  420. "One of `rollout_frag_or_episode_len` or `recurrent_seq_len` must be not "
  421. "None in ImpalaLearnerHyperparameters!"
  422. )
  423. return learner_hps
  424. # TODO (sven): Make these get_... methods all read-only @properties instead.
  425. def get_replay_ratio(self) -> float:
  426. """Returns replay ratio (between 0.0 and 1.0) based off self.replay_proportion.
  427. Formula: ratio = 1 / proportion
  428. """
  429. return (1 / self.replay_proportion) if self.replay_proportion > 0 else 0.0
  430. @property
  431. def minibatch_size(self):
  432. # If 'auto', use the train_batch_size (meaning each SGD iter is a single pass
  433. # through the entire train batch). Otherwise, use user provided setting.
  434. return (
  435. self.train_batch_size
  436. if self._minibatch_size == "auto"
  437. else self._minibatch_size
  438. )
  439. @override(AlgorithmConfig)
  440. def get_default_learner_class(self):
  441. if self.framework_str == "torch":
  442. from ray.rllib.algorithms.impala.torch.impala_torch_learner import (
  443. ImpalaTorchLearner,
  444. )
  445. return ImpalaTorchLearner
  446. elif self.framework_str == "tf2":
  447. from ray.rllib.algorithms.impala.tf.impala_tf_learner import ImpalaTfLearner
  448. return ImpalaTfLearner
  449. else:
  450. raise ValueError(
  451. f"The framework {self.framework_str} is not supported. "
  452. "Use either 'torch' or 'tf2'."
  453. )
  454. @override(AlgorithmConfig)
  455. def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec:
  456. if self.framework_str == "tf2":
  457. from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
  458. return SingleAgentRLModuleSpec(
  459. module_class=PPOTfRLModule, catalog_class=PPOCatalog
  460. )
  461. elif self.framework_str == "torch":
  462. from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
  463. PPOTorchRLModule,
  464. )
  465. return SingleAgentRLModuleSpec(
  466. module_class=PPOTorchRLModule, catalog_class=PPOCatalog
  467. )
  468. else:
  469. raise ValueError(
  470. f"The framework {self.framework_str} is not supported. "
  471. "Use either 'torch' or 'tf2'."
  472. )
  473. def make_learner_thread(local_worker, config):
  474. if not config["simple_optimizer"]:
  475. logger.info(
  476. "Enabling multi-GPU mode, {} GPUs, {} parallel tower-stacks".format(
  477. config["num_gpus"], config["num_multi_gpu_tower_stacks"]
  478. )
  479. )
  480. num_stacks = config["num_multi_gpu_tower_stacks"]
  481. buffer_size = config["minibatch_buffer_size"]
  482. if num_stacks < buffer_size:
  483. logger.warning(
  484. "In multi-GPU mode you should have at least as many "
  485. "multi-GPU tower stacks (to load data into on one device) as "
  486. "you have stack-index slots in the buffer! You have "
  487. f"configured {num_stacks} stacks and a buffer of size "
  488. f"{buffer_size}. Setting "
  489. f"`minibatch_buffer_size={num_stacks}`."
  490. )
  491. config["minibatch_buffer_size"] = num_stacks
  492. learner_thread = MultiGPULearnerThread(
  493. local_worker,
  494. num_gpus=config["num_gpus"],
  495. lr=config["lr"],
  496. train_batch_size=config["train_batch_size"],
  497. num_multi_gpu_tower_stacks=config["num_multi_gpu_tower_stacks"],
  498. num_sgd_iter=config["num_sgd_iter"],
  499. learner_queue_size=config["learner_queue_size"],
  500. learner_queue_timeout=config["learner_queue_timeout"],
  501. )
  502. else:
  503. learner_thread = LearnerThread(
  504. local_worker,
  505. minibatch_buffer_size=config["minibatch_buffer_size"],
  506. num_sgd_iter=config["num_sgd_iter"],
  507. learner_queue_size=config["learner_queue_size"],
  508. learner_queue_timeout=config["learner_queue_timeout"],
  509. )
  510. return learner_thread
  511. class Impala(Algorithm):
  512. """Importance weighted actor/learner architecture (IMPALA) Algorithm
  513. == Overview of data flow in IMPALA ==
  514. 1. Policy evaluation in parallel across `num_workers` actors produces
  515. batches of size `rollout_fragment_length * num_envs_per_worker`.
  516. 2. If enabled, the replay buffer stores and produces batches of size
  517. `rollout_fragment_length * num_envs_per_worker`.
  518. 3. If enabled, the minibatch ring buffer stores and replays batches of
  519. size `train_batch_size` up to `num_sgd_iter` times per batch.
  520. 4. The learner thread executes data parallel SGD across `num_gpus` GPUs
  521. on batches of size `train_batch_size`.
  522. """
  523. @classmethod
  524. @override(Algorithm)
  525. def get_default_config(cls) -> AlgorithmConfig:
  526. return ImpalaConfig()
  527. @classmethod
  528. @override(Algorithm)
  529. def get_default_policy_class(
  530. cls, config: AlgorithmConfig
  531. ) -> Optional[Type[Policy]]:
  532. if not config["vtrace"]:
  533. raise ValueError("IMPALA with the learner API does not support non-VTrace ")
  534. if config["framework"] == "torch":
  535. if config["vtrace"]:
  536. from ray.rllib.algorithms.impala.impala_torch_policy import (
  537. ImpalaTorchPolicy,
  538. )
  539. return ImpalaTorchPolicy
  540. else:
  541. from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy
  542. return A3CTorchPolicy
  543. elif config["framework"] == "tf":
  544. if config["vtrace"]:
  545. from ray.rllib.algorithms.impala.impala_tf_policy import (
  546. ImpalaTF1Policy,
  547. )
  548. return ImpalaTF1Policy
  549. else:
  550. from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy
  551. return A3CTFPolicy
  552. else:
  553. if config["vtrace"]:
  554. from ray.rllib.algorithms.impala.impala_tf_policy import (
  555. ImpalaTF2Policy,
  556. )
  557. return ImpalaTF2Policy
  558. else:
  559. from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy
  560. return A3CTFPolicy
  561. @override(Algorithm)
  562. def setup(self, config: AlgorithmConfig):
  563. super().setup(config)
  564. # Queue of batches to be sent to the Learner.
  565. self.batches_to_place_on_learner = []
  566. # Create extra aggregation workers and assign each rollout worker to
  567. # one of them.
  568. self.batch_being_built = []
  569. if self.config.num_aggregation_workers > 0:
  570. # This spawns `num_aggregation_workers` actors that aggregate
  571. # experiences coming from RolloutWorkers in parallel. We force
  572. # colocation on the same node (localhost) to maximize data bandwidth
  573. # between them and the learner.
  574. localhost = platform.node()
  575. assert localhost != "", (
  576. "ERROR: Cannot determine local node name! "
  577. "`platform.node()` returned empty string."
  578. )
  579. all_co_located = create_colocated_actors(
  580. actor_specs=[
  581. # (class, args, kwargs={}, count=1)
  582. (
  583. AggregatorWorker,
  584. [
  585. self.config,
  586. ],
  587. {},
  588. self.config.num_aggregation_workers,
  589. )
  590. ],
  591. node=localhost,
  592. )
  593. aggregator_workers = [
  594. actor for actor_groups in all_co_located for actor in actor_groups
  595. ]
  596. self._aggregator_actor_manager = FaultTolerantActorManager(
  597. aggregator_workers,
  598. max_remote_requests_in_flight_per_actor=(
  599. self.config.max_requests_in_flight_per_aggregator_worker
  600. ),
  601. )
  602. self._timeout_s_aggregator_manager = (
  603. self.config.timeout_s_aggregator_manager
  604. )
  605. else:
  606. # Create our local mixin buffer if the num of aggregation workers is 0.
  607. self.local_mixin_buffer = MixInMultiAgentReplayBuffer(
  608. capacity=(
  609. self.config.replay_buffer_num_slots
  610. if self.config.replay_buffer_num_slots > 0
  611. else 1
  612. ),
  613. replay_ratio=self.config.get_replay_ratio(),
  614. replay_mode=ReplayMode.LOCKSTEP,
  615. )
  616. self._aggregator_actor_manager = None
  617. # This variable is used to keep track of the statistics from the most recent
  618. # update of the learner group
  619. self._results = {}
  620. self._timeout_s_sampler_manager = self.config.timeout_s_sampler_manager
  621. if not self.config._enable_learner_api:
  622. # Create and start the learner thread.
  623. self._learner_thread = make_learner_thread(
  624. self.workers.local_worker(), self.config
  625. )
  626. self._learner_thread.start()
  627. @override(Algorithm)
  628. def training_step(self) -> ResultDict:
  629. # First, check, whether our learner thread is still healthy.
  630. if not self.config._enable_learner_api and not self._learner_thread.is_alive():
  631. raise RuntimeError("The learner thread died while training!")
  632. use_tree_aggregation = (
  633. self._aggregator_actor_manager
  634. and self._aggregator_actor_manager.num_healthy_actors() > 0
  635. )
  636. # Get sampled SampleBatches from our workers (by ray references if we use
  637. # tree-aggregation).
  638. unprocessed_sample_batches = self.get_samples_from_workers(
  639. return_object_refs=use_tree_aggregation,
  640. )
  641. # Tag workers that actually produced ready sample batches this iteration.
  642. # Those workers will have to get updated at the end of the iteration.
  643. workers_that_need_updates = {
  644. worker_id for worker_id, _ in unprocessed_sample_batches
  645. }
  646. # Send the collected batches (still object refs) to our aggregation workers.
  647. if use_tree_aggregation:
  648. batches = self.process_experiences_tree_aggregation(
  649. unprocessed_sample_batches
  650. )
  651. # Resolve collected batches here on local process (using the mixin buffer).
  652. else:
  653. batches = self.process_experiences_directly(unprocessed_sample_batches)
  654. # Increase sampling counters now that we have the actual SampleBatches on
  655. # the local process (and can measure their sizes).
  656. for batch in batches:
  657. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.count
  658. self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
  659. # Concatenate single batches into batches of size `train_batch_size`.
  660. self.concatenate_batches_and_pre_queue(batches)
  661. # Using the Learner API. Call `update()` on our LearnerGroup object with
  662. # all collected batches.
  663. if self.config._enable_learner_api:
  664. train_results = self.learn_on_processed_samples()
  665. additional_results = self.learner_group.additional_update(
  666. module_ids_to_update=set(train_results.keys()) - {ALL_MODULES},
  667. timestep=self._counters[
  668. NUM_ENV_STEPS_TRAINED
  669. if self.config.count_steps_by == "env_steps"
  670. else NUM_AGENT_STEPS_TRAINED
  671. ],
  672. # TODO (sven): Feels hacked, but solves the problem of algos inheriting
  673. # from IMPALA (like APPO). In the old stack, we didn't have this
  674. # problem b/c IMPALA didn't need to call any additional update methods
  675. # as the entropy- and lr-schedules were handled by
  676. # `Policy.on_global_var_update()`.
  677. **self._get_additional_update_kwargs(train_results),
  678. )
  679. for key, res in additional_results.items():
  680. if key in train_results:
  681. train_results[key].update(res)
  682. else:
  683. # Move train batches (of size `train_batch_size`) onto learner queue.
  684. self.place_processed_samples_on_learner_thread_queue()
  685. # Extract most recent train results from learner thread.
  686. train_results = self.process_trained_results()
  687. # Sync worker weights (only those policies that were actually updated).
  688. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
  689. if self.config._enable_learner_api:
  690. if train_results:
  691. pids = list(set(train_results.keys()) - {ALL_MODULES})
  692. self.update_workers_from_learner_group(
  693. workers_that_need_updates=workers_that_need_updates,
  694. policy_ids=pids,
  695. )
  696. else:
  697. pids = list(train_results.keys())
  698. self.update_workers_if_necessary(
  699. workers_that_need_updates=workers_that_need_updates,
  700. policy_ids=pids,
  701. )
  702. # With a training step done, try to bring any aggregators back to life
  703. # if necessary.
  704. # Aggregation workers are stateless, so we do not need to restore any
  705. # state here.
  706. if self._aggregator_actor_manager:
  707. self._aggregator_actor_manager.probe_unhealthy_actors(
  708. timeout_seconds=self.config.worker_health_probe_timeout_s,
  709. mark_healthy=True,
  710. )
  711. if self.config._enable_learner_api:
  712. if train_results:
  713. # Store the most recent result and return it if no new result is
  714. # available. This keeps backwards compatibility with the old
  715. # training stack / results reporting stack. This is necessary
  716. # any time we develop an asynchronous algorithm.
  717. self._results = train_results
  718. return self._results
  719. else:
  720. return train_results
  721. @classmethod
  722. @override(Algorithm)
  723. def default_resource_request(
  724. cls,
  725. config: Union[AlgorithmConfig, PartialAlgorithmConfigDict],
  726. ):
  727. if isinstance(config, AlgorithmConfig):
  728. cf: ImpalaConfig = config
  729. else:
  730. cf: ImpalaConfig = cls.get_default_config().update_from_dict(config)
  731. eval_config = cf.get_evaluation_config_object()
  732. bundles = (
  733. [
  734. {
  735. # Driver + Aggregation Workers:
  736. # Force to be on same node to maximize data bandwidth
  737. # between aggregation workers and the learner (driver).
  738. # Aggregation workers tree-aggregate experiences collected
  739. # from RolloutWorkers (n rollout workers map to m
  740. # aggregation workers, where m < n) and always use 1 CPU
  741. # each.
  742. "CPU": cf.num_cpus_for_local_worker + cf.num_aggregation_workers,
  743. "GPU": 0 if cf._fake_gpus else cf.num_gpus,
  744. }
  745. ]
  746. + [
  747. {
  748. # RolloutWorkers.
  749. "CPU": cf.num_cpus_per_worker,
  750. "GPU": cf.num_gpus_per_worker,
  751. **cf.custom_resources_per_worker,
  752. }
  753. for _ in range(cf.num_rollout_workers)
  754. ]
  755. + (
  756. [
  757. {
  758. # Evaluation (remote) workers.
  759. # Note: The local eval worker is located on the driver
  760. # CPU or not even created iff >0 eval workers.
  761. "CPU": eval_config.num_cpus_per_worker,
  762. "GPU": eval_config.num_gpus_per_worker,
  763. **eval_config.custom_resources_per_worker,
  764. }
  765. for _ in range(cf.evaluation_num_workers)
  766. ]
  767. if cf.evaluation_interval
  768. else []
  769. )
  770. )
  771. # TODO(avnishn): Remove this once we have a way to extend placement group
  772. # factories.
  773. if cf._enable_learner_api:
  774. # Resources for the Algorithm.
  775. learner_bundles = cls._get_learner_bundles(cf)
  776. bundles += learner_bundles
  777. # Return PlacementGroupFactory containing all needed resources
  778. # (already properly defined as device bundles).
  779. return PlacementGroupFactory(
  780. bundles=bundles,
  781. strategy=cf.placement_strategy,
  782. )
  783. def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]):
  784. """Concatenate batches that are being returned from rollout workers
  785. Args:
  786. batches: batches of experiences from rollout workers
  787. """
  788. def aggregate_into_larger_batch():
  789. if (
  790. sum(b.count for b in self.batch_being_built)
  791. >= self.config.train_batch_size
  792. ):
  793. batch_to_add = concat_samples(self.batch_being_built)
  794. self.batches_to_place_on_learner.append(batch_to_add)
  795. self.batch_being_built = []
  796. for batch in batches:
  797. self.batch_being_built.append(batch)
  798. aggregate_into_larger_batch()
  799. def get_samples_from_workers(
  800. self,
  801. return_object_refs: Optional[bool] = False,
  802. ) -> List[Tuple[int, Union[ObjectRef, SampleBatchType]]]:
  803. """Get samples from rollout workers for training.
  804. Args:
  805. return_object_refs: If True, return ObjectRefs instead of the samples
  806. directly. This is useful when using aggregator workers so that data
  807. collected on rollout workers is directly de referenced on the aggregator
  808. workers instead of first in the driver and then on the aggregator
  809. workers.
  810. Returns:
  811. a list of tuples of (worker_index, sample batch or ObjectRef to a sample
  812. batch)
  813. """
  814. with self._timers[SAMPLE_TIMER]:
  815. # Sample from healthy remote workers by default. If there is no healthy
  816. # worker (either because they have all died, or because there was none to
  817. # begin) check if the local_worker exists. If the local worker has an
  818. # env_instance (either because there are no remote workers or
  819. # self.config.create_env_on_local_worker == True), then sample from the
  820. # local worker. Otherwise just return an empty list.
  821. if self.workers.num_healthy_remote_workers() > 0:
  822. # Perform asynchronous sampling on all (remote) rollout workers.
  823. self.workers.foreach_worker_async(
  824. lambda worker: worker.sample(),
  825. healthy_only=True,
  826. )
  827. sample_batches: List[
  828. Tuple[int, ObjectRef]
  829. ] = self.workers.fetch_ready_async_reqs(
  830. timeout_seconds=self._timeout_s_sampler_manager,
  831. return_obj_refs=return_object_refs,
  832. )
  833. elif (
  834. self.workers.local_worker()
  835. and self.workers.local_worker().async_env is not None
  836. ):
  837. # Sampling from the local worker
  838. sample_batch = self.workers.local_worker().sample()
  839. if return_object_refs:
  840. sample_batch = ray.put(sample_batch)
  841. sample_batches = [(0, sample_batch)]
  842. else:
  843. # Not much we can do. Return empty list and wait.
  844. return []
  845. return sample_batches
  846. def learn_on_processed_samples(self) -> ResultDict:
  847. """Update the learner group with the latest batch of processed samples.
  848. Returns:
  849. Aggregated results from the learner group after an update is completed.
  850. """
  851. # There are batches on the queue -> Send them all to the learner group.
  852. if self.batches_to_place_on_learner:
  853. batches = self.batches_to_place_on_learner[:]
  854. self.batches_to_place_on_learner.clear()
  855. # If there are no learner workers and learning is directly on the driver
  856. # Then we can't do async updates, so we need to block.
  857. blocking = self.config.num_learner_workers == 0
  858. results = []
  859. for batch in batches:
  860. if blocking:
  861. result = self.learner_group.update(
  862. batch,
  863. reduce_fn=_reduce_impala_results,
  864. num_iters=self.config.num_sgd_iter,
  865. minibatch_size=self.config.minibatch_size,
  866. )
  867. results = [result]
  868. else:
  869. results = self.learner_group.async_update(
  870. batch,
  871. reduce_fn=_reduce_impala_results,
  872. num_iters=self.config.num_sgd_iter,
  873. minibatch_size=self.config.minibatch_size,
  874. )
  875. for r in results:
  876. self._counters[NUM_ENV_STEPS_TRAINED] += r[ALL_MODULES].pop(
  877. NUM_ENV_STEPS_TRAINED
  878. )
  879. self._counters[NUM_AGENT_STEPS_TRAINED] += r[ALL_MODULES].pop(
  880. NUM_AGENT_STEPS_TRAINED
  881. )
  882. self._counters.update(self.learner_group.get_in_queue_stats())
  883. # If there are results, reduce-mean over each individual value and return.
  884. if results:
  885. return tree.map_structure(lambda *x: np.mean(x), *results)
  886. # Nothing on the queue -> Don't send requests to learner group
  887. # or no results ready (from previous `self.learner_group.update()` calls) for
  888. # reducing.
  889. return {}
  890. def place_processed_samples_on_learner_thread_queue(self) -> None:
  891. """Place processed samples on the learner queue for training.
  892. NOTE: This method is called if self.config._enable_learner_api is False.
  893. """
  894. while self.batches_to_place_on_learner:
  895. batch = self.batches_to_place_on_learner[0]
  896. try:
  897. # Setting block = True prevents the learner thread,
  898. # the main thread, and the gpu loader threads from
  899. # thrashing when there are more samples than the
  900. # learner can reasonable process.
  901. # see https://github.com/ray-project/ray/pull/26581#issuecomment-1187877674 # noqa
  902. self._learner_thread.inqueue.put(batch, block=True)
  903. self.batches_to_place_on_learner.pop(0)
  904. self._counters["num_samples_added_to_queue"] += (
  905. batch.agent_steps()
  906. if self.config.count_steps_by == "agent_steps"
  907. else batch.count
  908. )
  909. except queue.Full:
  910. self._counters["num_times_learner_queue_full"] += 1
  911. def process_trained_results(self) -> ResultDict:
  912. """Process training results that are outputed by the learner thread.
  913. NOTE: This method is called if self.config._enable_learner_api is False.
  914. Returns:
  915. Aggregated results from the learner thread after an update is completed.
  916. """
  917. # Get learner outputs/stats from output queue.
  918. num_env_steps_trained = 0
  919. num_agent_steps_trained = 0
  920. learner_infos = []
  921. # Loop through output queue and update our counts.
  922. for _ in range(self._learner_thread.outqueue.qsize()):
  923. (
  924. env_steps,
  925. agent_steps,
  926. learner_results,
  927. ) = self._learner_thread.outqueue.get(timeout=0.001)
  928. num_env_steps_trained += env_steps
  929. num_agent_steps_trained += agent_steps
  930. if learner_results:
  931. learner_infos.append(learner_results)
  932. # Nothing new happened since last time, use the same learner stats.
  933. if not learner_infos:
  934. final_learner_info = copy.deepcopy(self._learner_thread.learner_info)
  935. # Accumulate learner stats using the `LearnerInfoBuilder` utility.
  936. else:
  937. builder = LearnerInfoBuilder()
  938. for info in learner_infos:
  939. builder.add_learn_on_batch_results_multi_agent(info)
  940. final_learner_info = builder.finalize()
  941. # Update the steps trained counters.
  942. self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained
  943. self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained
  944. return final_learner_info
  945. def process_experiences_directly(
  946. self,
  947. worker_to_sample_batches: List[Tuple[int, SampleBatch]],
  948. ) -> List[SampleBatchType]:
  949. """Process sample batches directly on the driver, for training.
  950. Args:
  951. worker_to_sample_batches: List of (worker_id, sample_batch) tuples.
  952. Returns:
  953. Batches that have been processed by the mixin buffer.
  954. """
  955. batches = [b for _, b in worker_to_sample_batches]
  956. processed_batches = []
  957. for batch in batches:
  958. assert not isinstance(
  959. batch, ObjectRef
  960. ), "process_experiences_directly can not handle ObjectRefs. "
  961. batch = batch.decompress_if_needed()
  962. self.local_mixin_buffer.add(batch)
  963. batch = self.local_mixin_buffer.replay(_ALL_POLICIES)
  964. if batch:
  965. processed_batches.append(batch)
  966. return processed_batches
  967. def process_experiences_tree_aggregation(
  968. self,
  969. worker_to_sample_batches_refs: List[Tuple[int, ObjectRef]],
  970. ) -> List[SampleBatchType]:
  971. """Process sample batches using tree aggregation workers.
  972. Args:
  973. worker_to_sample_batches_refs: List of (worker_id, sample_batch_ref)
  974. NOTE: This will provide speedup when sample batches have been compressed,
  975. and the decompression can happen on the aggregation workers in parallel to
  976. the training.
  977. Returns:
  978. Batches that have been processed by the mixin buffers on the aggregation
  979. workers.
  980. """
  981. def _process_episodes(actor, batch):
  982. return actor.process_episodes(ray.get(batch))
  983. for _, batch in worker_to_sample_batches_refs:
  984. assert isinstance(batch, ObjectRef), (
  985. "For efficiency, process_experiences_tree_aggregation should "
  986. f"be given ObjectRefs instead of {type(batch)}."
  987. )
  988. # Randomly pick an aggregation worker to process this batch.
  989. aggregator_id = random.choice(
  990. self._aggregator_actor_manager.healthy_actor_ids()
  991. )
  992. calls_placed = self._aggregator_actor_manager.foreach_actor_async(
  993. partial(_process_episodes, batch=batch),
  994. remote_actor_ids=[aggregator_id],
  995. )
  996. if calls_placed <= 0:
  997. self._counters["num_times_no_aggregation_worker_available"] += 1
  998. waiting_processed_sample_batches: RemoteCallResults = (
  999. self._aggregator_actor_manager.fetch_ready_async_reqs(
  1000. timeout_seconds=self._timeout_s_aggregator_manager,
  1001. )
  1002. )
  1003. handle_remote_call_result_errors(
  1004. waiting_processed_sample_batches,
  1005. self.config.ignore_worker_failures,
  1006. )
  1007. return [b.get() for b in waiting_processed_sample_batches.ignore_errors()]
  1008. def update_workers_from_learner_group(
  1009. self,
  1010. workers_that_need_updates: Set[int],
  1011. policy_ids: Optional[List[PolicyID]] = None,
  1012. ):
  1013. """Updates all RolloutWorkers that require updating.
  1014. Updates only if NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS has been
  1015. reached and the worker has sent samples in this iteration. Also only updates
  1016. those policies, whose IDs are given via `policies` (if None, update all
  1017. policies).
  1018. Args:
  1019. workers_that_need_updates: Set of worker IDs that need to be updated.
  1020. policy_ids: Optional list of Policy IDs to update. If None, will update all
  1021. policies on the to-be-updated workers.
  1022. """
  1023. # Only need to update workers if there are remote workers.
  1024. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] += 1
  1025. if (
  1026. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS]
  1027. >= self.config.broadcast_interval
  1028. and workers_that_need_updates
  1029. ):
  1030. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] = 0
  1031. self._counters[NUM_SYNCH_WORKER_WEIGHTS] += 1
  1032. weights = self.learner_group.get_weights(policy_ids)
  1033. if self.config.num_rollout_workers == 0:
  1034. worker = self.workers.local_worker()
  1035. worker.set_weights(weights)
  1036. else:
  1037. weights_ref = ray.put(weights)
  1038. self.workers.foreach_worker(
  1039. func=lambda w: w.set_weights(ray.get(weights_ref)),
  1040. local_worker=False,
  1041. remote_worker_ids=list(workers_that_need_updates),
  1042. timeout_seconds=0, # Don't wait for the workers to finish.
  1043. )
  1044. # If we have a local worker that we sample from in addition to
  1045. # our remote workers, we need to update its weights as well.
  1046. if self.config.create_env_on_local_worker:
  1047. self.workers.local_worker().set_weights(weights)
  1048. def update_workers_if_necessary(
  1049. self,
  1050. workers_that_need_updates: Set[int],
  1051. policy_ids: Optional[List[PolicyID]] = None,
  1052. ) -> None:
  1053. """Updates all RolloutWorkers that require updating.
  1054. Updates only if NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS has been
  1055. reached and the worker has sent samples in this iteration. Also only updates
  1056. those policies, whose IDs are given via `policies` (if None, update all
  1057. policies).
  1058. Args:
  1059. workers_that_need_updates: Set of worker IDs that need to be updated.
  1060. policy_ids: Optional list of Policy IDs to update. If None, will update all
  1061. policies on the to-be-updated workers.
  1062. """
  1063. local_worker = self.workers.local_worker()
  1064. # Update global vars of the local worker.
  1065. if self.config.policy_states_are_swappable:
  1066. local_worker.lock()
  1067. global_vars = {
  1068. "timestep": self._counters[NUM_AGENT_STEPS_TRAINED],
  1069. "num_grad_updates_per_policy": {
  1070. pid: local_worker.policy_map[pid].num_grad_updates
  1071. for pid in policy_ids or []
  1072. },
  1073. }
  1074. local_worker.set_global_vars(global_vars, policy_ids=policy_ids)
  1075. if self.config.policy_states_are_swappable:
  1076. local_worker.unlock()
  1077. # Only need to update workers if there are remote workers.
  1078. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] += 1
  1079. if (
  1080. self.workers.num_remote_workers() > 0
  1081. and self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS]
  1082. >= self.config.broadcast_interval
  1083. and workers_that_need_updates
  1084. ):
  1085. if self.config.policy_states_are_swappable:
  1086. local_worker.lock()
  1087. weights = local_worker.get_weights(policy_ids)
  1088. if self.config.policy_states_are_swappable:
  1089. local_worker.unlock()
  1090. weights = ray.put(weights)
  1091. self._learner_thread.policy_ids_updated.clear()
  1092. self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] = 0
  1093. self._counters[NUM_SYNCH_WORKER_WEIGHTS] += 1
  1094. self.workers.foreach_worker(
  1095. func=lambda w: w.set_weights(ray.get(weights), global_vars),
  1096. local_worker=False,
  1097. remote_worker_ids=list(workers_that_need_updates),
  1098. timeout_seconds=0, # Don't wait for the workers to finish.
  1099. )
  1100. def _get_additional_update_kwargs(self, train_results: dict) -> dict:
  1101. """Returns the kwargs to `LearnerGroup.additional_update()`.
  1102. Should be overridden by subclasses to specify wanted/needed kwargs for
  1103. their own implementation of `Learner.additional_update_for_module()`.
  1104. """
  1105. return {}
  1106. @override(Algorithm)
  1107. def _compile_iteration_results(self, *args, **kwargs):
  1108. result = super()._compile_iteration_results(*args, **kwargs)
  1109. if not self.config._enable_learner_api:
  1110. result = self._learner_thread.add_learner_metrics(
  1111. result, overwrite_learner_info=False
  1112. )
  1113. return result
  1114. @ray.remote(num_cpus=0, max_restarts=-1)
  1115. class AggregatorWorker(FaultAwareApply):
  1116. """A worker for doing tree aggregation of collected episodes"""
  1117. def __init__(self, config: AlgorithmConfig):
  1118. self.config = config
  1119. self._mixin_buffer = MixInMultiAgentReplayBuffer(
  1120. capacity=(
  1121. self.config.replay_buffer_num_slots
  1122. if self.config.replay_buffer_num_slots > 0
  1123. else 1
  1124. ),
  1125. replay_ratio=self.config.get_replay_ratio(),
  1126. replay_mode=ReplayMode.LOCKSTEP,
  1127. )
  1128. def process_episodes(self, batch: SampleBatchType) -> SampleBatchType:
  1129. batch = batch.decompress_if_needed()
  1130. self._mixin_buffer.add(batch)
  1131. processed_batches = self._mixin_buffer.replay(_ALL_POLICIES)
  1132. return processed_batches
  1133. def get_host(self) -> str:
  1134. return platform.node()