a2c.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import math
  2. from ray.util.iter import LocalIterator
  3. from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, \
  4. A3CTrainer
  5. from ray.rllib.agents.trainer import Trainer
  6. from ray.rllib.execution.metric_ops import StandardMetricsReporting
  7. from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
  8. from ray.rllib.execution.train_ops import ComputeGradients, AverageGradients, \
  9. ApplyGradients, MultiGPUTrainOneStep, TrainOneStep
  10. from ray.rllib.utils import merge_dicts
  11. from ray.rllib.utils.annotations import override
  12. from ray.rllib.utils.typing import TrainerConfigDict
  13. from ray.rllib.evaluation.worker_set import WorkerSet
  14. A2C_DEFAULT_CONFIG = merge_dicts(
  15. A3C_CONFIG,
  16. {
  17. "rollout_fragment_length": 20,
  18. "min_time_s_per_reporting": 10,
  19. "sample_async": False,
  20. # A2C supports microbatching, in which we accumulate gradients over
  21. # batch of this size until the train batch size is reached. This allows
  22. # training with batch sizes much larger than can fit in GPU memory.
  23. # To enable, set this to a value less than the train batch size.
  24. "microbatch_size": None,
  25. },
  26. )
  27. class A2CTrainer(A3CTrainer):
  28. @classmethod
  29. @override(A3CTrainer)
  30. def get_default_config(cls) -> TrainerConfigDict:
  31. return A2C_DEFAULT_CONFIG
  32. @staticmethod
  33. @override(Trainer)
  34. def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
  35. **kwargs) -> LocalIterator[dict]:
  36. assert len(kwargs) == 0, (
  37. "A2C execution_plan does NOT take any additional parameters")
  38. rollouts = ParallelRollouts(workers, mode="bulk_sync")
  39. if config["microbatch_size"]:
  40. num_microbatches = math.ceil(
  41. config["train_batch_size"] / config["microbatch_size"])
  42. # In microbatch mode, we want to compute gradients on experience
  43. # microbatches, average a number of these microbatches, and then
  44. # apply the averaged gradient in one SGD step. This conserves GPU
  45. # memory, allowing for extremely large experience batches to be
  46. # used.
  47. train_op = (
  48. rollouts.combine(
  49. ConcatBatches(
  50. min_batch_size=config["microbatch_size"],
  51. count_steps_by=config["multiagent"]["count_steps_by"]))
  52. .for_each(ComputeGradients(workers)) # (grads, info)
  53. .batch(num_microbatches) # List[(grads, info)]
  54. .for_each(AverageGradients()) # (avg_grads, info)
  55. .for_each(ApplyGradients(workers)))
  56. else:
  57. # In normal mode, we execute one SGD step per each train batch.
  58. if config["simple_optimizer"]:
  59. train_step_op = TrainOneStep(workers)
  60. else:
  61. train_step_op = MultiGPUTrainOneStep(
  62. workers=workers,
  63. sgd_minibatch_size=config["train_batch_size"],
  64. num_sgd_iter=1,
  65. num_gpus=config["num_gpus"],
  66. _fake_gpus=config["_fake_gpus"])
  67. train_op = rollouts.combine(
  68. ConcatBatches(
  69. min_batch_size=config["train_batch_size"],
  70. count_steps_by=config["multiagent"][
  71. "count_steps_by"])).for_each(train_step_op)
  72. return StandardMetricsReporting(train_op, workers, config)