impala.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. import logging
  2. from typing import Optional, Type
  3. import ray
  4. from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy
  5. from ray.rllib.agents.trainer import Trainer, with_common_config
  6. from ray.rllib.execution.learner_thread import LearnerThread
  7. from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
  8. from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
  9. from ray.rllib.execution.common import (STEPS_TRAINED_COUNTER,
  10. STEPS_TRAINED_THIS_ITER_COUNTER,
  11. _get_global_vars, _get_shared_metrics)
  12. from ray.rllib.execution.replay_ops import MixInReplay
  13. from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
  14. from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
  15. from ray.rllib.execution.metric_ops import StandardMetricsReporting
  16. from ray.rllib.policy.policy import Policy
  17. from ray.rllib.utils.annotations import override
  18. from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
  19. from ray.rllib.utils.typing import PartialTrainerConfigDict, TrainerConfigDict
  20. from ray.tune.utils.placement_groups import PlacementGroupFactory
  21. logger = logging.getLogger(__name__)
  22. # yapf: disable
  23. # __sphinx_doc_begin__
  24. DEFAULT_CONFIG = with_common_config({
  25. # V-trace params (see vtrace_tf/torch.py).
  26. "vtrace": True,
  27. "vtrace_clip_rho_threshold": 1.0,
  28. "vtrace_clip_pg_rho_threshold": 1.0,
  29. # If True, drop the last timestep for the vtrace calculations, such that
  30. # all data goes into the calculations as [B x T-1] (+ the bootstrap value).
  31. # This is the default and legacy RLlib behavior, however, could potentially
  32. # have a destabilizing effect on learning, especially in sparse reward
  33. # or reward-at-goal environments.
  34. # False for not dropping the last timestep.
  35. "vtrace_drop_last_ts": True,
  36. # System params.
  37. #
  38. # == Overview of data flow in IMPALA ==
  39. # 1. Policy evaluation in parallel across `num_workers` actors produces
  40. # batches of size `rollout_fragment_length * num_envs_per_worker`.
  41. # 2. If enabled, the replay buffer stores and produces batches of size
  42. # `rollout_fragment_length * num_envs_per_worker`.
  43. # 3. If enabled, the minibatch ring buffer stores and replays batches of
  44. # size `train_batch_size` up to `num_sgd_iter` times per batch.
  45. # 4. The learner thread executes data parallel SGD across `num_gpus` GPUs
  46. # on batches of size `train_batch_size`.
  47. #
  48. "rollout_fragment_length": 50,
  49. "train_batch_size": 500,
  50. "min_time_s_per_reporting": 10,
  51. "num_workers": 2,
  52. # Number of GPUs the learner should use.
  53. "num_gpus": 1,
  54. # For each stack of multi-GPU towers, how many slots should we reserve for
  55. # parallel data loading? Set this to >1 to load data into GPUs in
  56. # parallel. This will increase GPU memory usage proportionally with the
  57. # number of stacks.
  58. # Example:
  59. # 2 GPUs and `num_multi_gpu_tower_stacks=3`:
  60. # - One tower stack consists of 2 GPUs, each with a copy of the
  61. # model/graph.
  62. # - Each of the stacks will create 3 slots for batch data on each of its
  63. # GPUs, increasing memory requirements on each GPU by 3x.
  64. # - This enables us to preload data into these stacks while another stack
  65. # is performing gradient calculations.
  66. "num_multi_gpu_tower_stacks": 1,
  67. # How many train batches should be retained for minibatching. This conf
  68. # only has an effect if `num_sgd_iter > 1`.
  69. "minibatch_buffer_size": 1,
  70. # Number of passes to make over each train batch.
  71. "num_sgd_iter": 1,
  72. # Set >0 to enable experience replay. Saved samples will be replayed with
  73. # a p:1 proportion to new data samples.
  74. "replay_proportion": 0.0,
  75. # Number of sample batches to store for replay. The number of transitions
  76. # saved total will be (replay_buffer_num_slots * rollout_fragment_length).
  77. "replay_buffer_num_slots": 0,
  78. # Max queue size for train batches feeding into the learner.
  79. "learner_queue_size": 16,
  80. # Wait for train batches to be available in minibatch buffer queue
  81. # this many seconds. This may need to be increased e.g. when training
  82. # with a slow environment.
  83. "learner_queue_timeout": 300,
  84. # Level of queuing for sampling.
  85. "max_sample_requests_in_flight_per_worker": 2,
  86. # Max number of workers to broadcast one set of weights to.
  87. "broadcast_interval": 1,
  88. # Use n (`num_aggregation_workers`) extra Actors for multi-level
  89. # aggregation of the data produced by the m RolloutWorkers
  90. # (`num_workers`). Note that n should be much smaller than m.
  91. # This can make sense if ingesting >2GB/s of samples, or if
  92. # the data requires decompression.
  93. "num_aggregation_workers": 0,
  94. # Learning params.
  95. "grad_clip": 40.0,
  96. # Either "adam" or "rmsprop".
  97. "opt_type": "adam",
  98. "lr": 0.0005,
  99. "lr_schedule": None,
  100. # `opt_type=rmsprop` settings.
  101. "decay": 0.99,
  102. "momentum": 0.0,
  103. "epsilon": 0.1,
  104. # Balancing the three losses.
  105. "vf_loss_coeff": 0.5,
  106. "entropy_coeff": 0.01,
  107. "entropy_coeff_schedule": None,
  108. # Set this to true to have two separate optimizers optimize the policy-
  109. # and value networks.
  110. "_separate_vf_optimizer": False,
  111. # If _separate_vf_optimizer is True, define separate learning rate
  112. # for the value network.
  113. "_lr_vf": 0.0005,
  114. # Callback for APPO to use to update KL, target network periodically.
  115. # The input to the callback is the learner fetches dict.
  116. "after_train_step": None,
  117. # DEPRECATED:
  118. "num_data_loader_buffers": DEPRECATED_VALUE,
  119. })
  120. # __sphinx_doc_end__
  121. # yapf: enable
  122. def make_learner_thread(local_worker, config):
  123. if not config["simple_optimizer"]:
  124. logger.info(
  125. "Enabling multi-GPU mode, {} GPUs, {} parallel tower-stacks".
  126. format(config["num_gpus"], config["num_multi_gpu_tower_stacks"]))
  127. num_stacks = config["num_multi_gpu_tower_stacks"]
  128. buffer_size = config["minibatch_buffer_size"]
  129. if num_stacks < buffer_size:
  130. logger.warning(
  131. "In multi-GPU mode you should have at least as many "
  132. "multi-GPU tower stacks (to load data into on one device) as "
  133. "you have stack-index slots in the buffer! You have "
  134. f"configured {num_stacks} stacks and a buffer of size "
  135. f"{buffer_size}. Setting "
  136. f"`minibatch_buffer_size={num_stacks}`.")
  137. config["minibatch_buffer_size"] = num_stacks
  138. learner_thread = MultiGPULearnerThread(
  139. local_worker,
  140. num_gpus=config["num_gpus"],
  141. lr=config["lr"],
  142. train_batch_size=config["train_batch_size"],
  143. num_multi_gpu_tower_stacks=config["num_multi_gpu_tower_stacks"],
  144. num_sgd_iter=config["num_sgd_iter"],
  145. learner_queue_size=config["learner_queue_size"],
  146. learner_queue_timeout=config["learner_queue_timeout"])
  147. else:
  148. learner_thread = LearnerThread(
  149. local_worker,
  150. minibatch_buffer_size=config["minibatch_buffer_size"],
  151. num_sgd_iter=config["num_sgd_iter"],
  152. learner_queue_size=config["learner_queue_size"],
  153. learner_queue_timeout=config["learner_queue_timeout"])
  154. return learner_thread
  155. def gather_experiences_directly(workers, config):
  156. rollouts = ParallelRollouts(
  157. workers,
  158. mode="async",
  159. num_async=config["max_sample_requests_in_flight_per_worker"])
  160. # Augment with replay and concat to desired train batch size.
  161. train_batches = rollouts \
  162. .for_each(lambda batch: batch.decompress_if_needed()) \
  163. .for_each(MixInReplay(
  164. num_slots=config["replay_buffer_num_slots"],
  165. replay_proportion=config["replay_proportion"])) \
  166. .flatten() \
  167. .combine(
  168. ConcatBatches(
  169. min_batch_size=config["train_batch_size"],
  170. count_steps_by=config["multiagent"]["count_steps_by"],
  171. ))
  172. return train_batches
  173. # Update worker weights as they finish generating experiences.
  174. class BroadcastUpdateLearnerWeights:
  175. def __init__(self, learner_thread, workers, broadcast_interval):
  176. self.learner_thread = learner_thread
  177. self.steps_since_broadcast = 0
  178. self.broadcast_interval = broadcast_interval
  179. self.workers = workers
  180. self.weights = workers.local_worker().get_weights()
  181. def __call__(self, item):
  182. actor, batch = item
  183. self.steps_since_broadcast += 1
  184. if (self.steps_since_broadcast >= self.broadcast_interval
  185. and self.learner_thread.weights_updated):
  186. self.weights = ray.put(self.workers.local_worker().get_weights())
  187. self.steps_since_broadcast = 0
  188. self.learner_thread.weights_updated = False
  189. # Update metrics.
  190. metrics = _get_shared_metrics()
  191. metrics.counters["num_weight_broadcasts"] += 1
  192. actor.set_weights.remote(self.weights, _get_global_vars())
  193. # Also update global vars of the local worker.
  194. self.workers.local_worker().set_global_vars(_get_global_vars())
  195. class ImpalaTrainer(Trainer):
  196. @classmethod
  197. @override(Trainer)
  198. def get_default_config(cls) -> TrainerConfigDict:
  199. return DEFAULT_CONFIG
  200. @override(Trainer)
  201. def get_default_policy_class(self, config: PartialTrainerConfigDict) -> \
  202. Optional[Type[Policy]]:
  203. if config["framework"] == "torch":
  204. if config["vtrace"]:
  205. from ray.rllib.agents.impala.vtrace_torch_policy import \
  206. VTraceTorchPolicy
  207. return VTraceTorchPolicy
  208. else:
  209. from ray.rllib.agents.a3c.a3c_torch_policy import \
  210. A3CTorchPolicy
  211. return A3CTorchPolicy
  212. else:
  213. if config["vtrace"]:
  214. return VTraceTFPolicy
  215. else:
  216. from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
  217. return A3CTFPolicy
  218. @override(Trainer)
  219. def validate_config(self, config):
  220. # Call the super class' validation method first.
  221. super().validate_config(config)
  222. # Check the IMPALA specific config.
  223. if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
  224. deprecation_warning(
  225. "num_data_loader_buffers",
  226. "num_multi_gpu_tower_stacks",
  227. error=False)
  228. config["num_multi_gpu_tower_stacks"] = \
  229. config["num_data_loader_buffers"]
  230. if config["entropy_coeff"] < 0.0:
  231. raise ValueError("`entropy_coeff` must be >= 0.0!")
  232. # Check whether worker to aggregation-worker ratio makes sense.
  233. if config["num_aggregation_workers"] > config["num_workers"]:
  234. raise ValueError(
  235. "`num_aggregation_workers` must be smaller than or equal "
  236. "`num_workers`! Aggregation makes no sense otherwise.")
  237. elif config["num_aggregation_workers"] > \
  238. config["num_workers"] / 2:
  239. logger.warning(
  240. "`num_aggregation_workers` should be significantly smaller "
  241. "than `num_workers`! Try setting it to 0.5*`num_workers` or "
  242. "less.")
  243. # If two separate optimizers/loss terms used for tf, must also set
  244. # `_tf_policy_handles_more_than_one_loss` to True.
  245. if config["_separate_vf_optimizer"] is True:
  246. # Only supported to tf so far.
  247. # TODO(sven): Need to change APPO|IMPALATorchPolicies (and the
  248. # models to return separate sets of weights in order to create
  249. # the different torch optimizers).
  250. if config["framework"] not in ["tf", "tf2", "tfe"]:
  251. raise ValueError(
  252. "`_separate_vf_optimizer` only supported to tf so far!")
  253. if config["_tf_policy_handles_more_than_one_loss"] is False:
  254. logger.warning(
  255. "`_tf_policy_handles_more_than_one_loss` must be set to "
  256. "True, for TFPolicy to support more than one loss "
  257. "term/optimizer! Auto-setting it to True.")
  258. config["_tf_policy_handles_more_than_one_loss"] = True
  259. @staticmethod
  260. @override(Trainer)
  261. def execution_plan(workers, config, **kwargs):
  262. assert len(kwargs) == 0, (
  263. "IMPALA execution_plan does NOT take any additional parameters")
  264. if config["num_aggregation_workers"] > 0:
  265. train_batches = gather_experiences_tree_aggregation(
  266. workers, config)
  267. else:
  268. train_batches = gather_experiences_directly(workers, config)
  269. # Start the learner thread.
  270. learner_thread = make_learner_thread(workers.local_worker(), config)
  271. learner_thread.start()
  272. # This sub-flow sends experiences to the learner.
  273. enqueue_op = train_batches \
  274. .for_each(Enqueue(learner_thread.inqueue))
  275. # Only need to update workers if there are remote workers.
  276. if workers.remote_workers():
  277. enqueue_op = enqueue_op.zip_with_source_actor() \
  278. .for_each(BroadcastUpdateLearnerWeights(
  279. learner_thread, workers,
  280. broadcast_interval=config["broadcast_interval"]))
  281. def record_steps_trained(item):
  282. count, fetches = item
  283. metrics = _get_shared_metrics()
  284. # Manually update the steps trained counter since the learner
  285. # thread is executing outside the pipeline.
  286. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
  287. metrics.counters[STEPS_TRAINED_COUNTER] += count
  288. return item
  289. # This sub-flow updates the steps trained counter based on learner
  290. # output.
  291. dequeue_op = Dequeue(
  292. learner_thread.outqueue, check=learner_thread.is_alive) \
  293. .for_each(record_steps_trained)
  294. merged_op = Concurrently(
  295. [enqueue_op, dequeue_op], mode="async", output_indexes=[1])
  296. # Callback for APPO to use to update KL, target network periodically.
  297. # The input to the callback is the learner fetches dict.
  298. if config["after_train_step"]:
  299. merged_op = merged_op.for_each(lambda t: t[1]).for_each(
  300. config["after_train_step"](workers, config))
  301. return StandardMetricsReporting(merged_op, workers, config) \
  302. .for_each(learner_thread.add_learner_metrics)
  303. @classmethod
  304. @override(Trainer)
  305. def default_resource_request(cls, config):
  306. cf = dict(cls.get_default_config(), **config)
  307. eval_config = cf["evaluation_config"]
  308. # Return PlacementGroupFactory containing all needed resources
  309. # (already properly defined as device bundles).
  310. return PlacementGroupFactory(
  311. bundles=[{
  312. # Driver + Aggregation Workers:
  313. # Force to be on same node to maximize data bandwidth
  314. # between aggregation workers and the learner (driver).
  315. # Aggregation workers tree-aggregate experiences collected
  316. # from RolloutWorkers (n rollout workers map to m
  317. # aggregation workers, where m < n) and always use 1 CPU
  318. # each.
  319. "CPU": cf["num_cpus_for_driver"] +
  320. cf["num_aggregation_workers"],
  321. "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
  322. }] + [
  323. {
  324. # RolloutWorkers.
  325. "CPU": cf["num_cpus_per_worker"],
  326. "GPU": cf["num_gpus_per_worker"],
  327. } for _ in range(cf["num_workers"])
  328. ] + ([
  329. {
  330. # Evaluation (remote) workers.
  331. # Note: The local eval worker is located on the driver
  332. # CPU or not even created iff >0 eval workers.
  333. "CPU": eval_config.get("num_cpus_per_worker",
  334. cf["num_cpus_per_worker"]),
  335. "GPU": eval_config.get("num_gpus_per_worker",
  336. cf["num_gpus_per_worker"]),
  337. } for _ in range(cf["evaluation_num_workers"])
  338. ] if cf["evaluation_interval"] else []),
  339. strategy=config.get("placement_strategy", "PACK"))