qmix.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from typing import Type
  2. from ray.rllib.agents.trainer import with_common_config
  3. from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
  4. from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
  5. from ray.rllib.evaluation.worker_set import WorkerSet
  6. from ray.rllib.execution.concurrency_ops import Concurrently
  7. from ray.rllib.execution.metric_ops import StandardMetricsReporting
  8. from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
  9. StoreToReplayBuffer
  10. from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
  11. from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
  12. from ray.rllib.policy.policy import Policy
  13. from ray.rllib.utils.annotations import override
  14. from ray.rllib.utils.typing import TrainerConfigDict
  15. from ray.util.iter import LocalIterator
  16. # yapf: disable
  17. # __sphinx_doc_begin__
  18. DEFAULT_CONFIG = with_common_config({
  19. # === QMix ===
  20. # Mixing network. Either "qmix", "vdn", or None
  21. "mixer": "qmix",
  22. # Size of the mixing network embedding
  23. "mixing_embed_dim": 32,
  24. # Whether to use Double_Q learning
  25. "double_q": True,
  26. # Optimize over complete episodes by default.
  27. "batch_mode": "complete_episodes",
  28. # === Exploration Settings ===
  29. "exploration_config": {
  30. # The Exploration class to use.
  31. "type": "EpsilonGreedy",
  32. # Config for the Exploration class' constructor:
  33. "initial_epsilon": 1.0,
  34. "final_epsilon": 0.01,
  35. # Timesteps over which to anneal epsilon.
  36. "epsilon_timesteps": 40000,
  37. # For soft_q, use:
  38. # "exploration_config" = {
  39. # "type": "SoftQ"
  40. # "temperature": [float, e.g. 1.0]
  41. # }
  42. },
  43. # === Evaluation ===
  44. # Evaluate with epsilon=0 every `evaluation_interval` training iterations.
  45. # The evaluation stats will be reported under the "evaluation" metric key.
  46. # Note that evaluation is currently not parallelized, and that for Ape-X
  47. # metrics are already only reported for the lowest epsilon workers.
  48. "evaluation_interval": None,
  49. # Number of episodes to run per evaluation period.
  50. "evaluation_duration": 10,
  51. # Switch to greedy actions in evaluation workers.
  52. "evaluation_config": {
  53. "explore": False,
  54. },
  55. # Number of env steps to optimize for before returning
  56. "timesteps_per_iteration": 1000,
  57. # Update the target network every `target_network_update_freq` steps.
  58. "target_network_update_freq": 500,
  59. # === Replay buffer ===
  60. # Size of the replay buffer in batches (not timesteps!).
  61. "buffer_size": 1000,
  62. # === Optimization ===
  63. # Learning rate for RMSProp optimizer
  64. "lr": 0.0005,
  65. # RMSProp alpha
  66. "optim_alpha": 0.99,
  67. # RMSProp epsilon
  68. "optim_eps": 0.00001,
  69. # If not None, clip gradients during optimization at this value
  70. "grad_norm_clipping": 10,
  71. # How many steps of the model to sample before learning starts.
  72. "learning_starts": 1000,
  73. # Update the replay buffer with this many samples at once. Note that
  74. # this setting applies per-worker if num_workers > 1.
  75. "rollout_fragment_length": 4,
  76. # Size of a batched sampled from replay buffer for training. Note that
  77. # if async_updates is set, then each worker returns gradients for a
  78. # batch of this size.
  79. "train_batch_size": 32,
  80. # === Parallelism ===
  81. # Number of workers for collecting samples with. This only makes sense
  82. # to increase if your environment is particularly slow to sample, or if
  83. # you"re using the Async or Ape-X optimizers.
  84. "num_workers": 0,
  85. # Whether to compute priorities on workers.
  86. "worker_side_prioritization": False,
  87. # Prevent reporting frequency from going lower than this time span.
  88. "min_time_s_per_reporting": 1,
  89. # === Model ===
  90. "model": {
  91. "lstm_cell_size": 64,
  92. "max_seq_len": 999999,
  93. },
  94. # Only torch supported so far.
  95. "framework": "torch",
  96. })
  97. # __sphinx_doc_end__
  98. # yapf: enable
  99. class QMixTrainer(SimpleQTrainer):
  100. @classmethod
  101. @override(SimpleQTrainer)
  102. def get_default_config(cls) -> TrainerConfigDict:
  103. return DEFAULT_CONFIG
  104. @override(SimpleQTrainer)
  105. def validate_config(self, config: TrainerConfigDict) -> None:
  106. # Call super's validation method.
  107. super().validate_config(config)
  108. if config["framework"] != "torch":
  109. raise ValueError(
  110. "Only `framework=torch` supported so far for QMixTrainer!")
  111. @override(SimpleQTrainer)
  112. def get_default_policy_class(self,
  113. config: TrainerConfigDict) -> Type[Policy]:
  114. return QMixTorchPolicy
  115. @staticmethod
  116. @override(SimpleQTrainer)
  117. def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
  118. **kwargs) -> LocalIterator[dict]:
  119. assert len(kwargs) == 0, (
  120. "QMIX execution_plan does NOT take any additional parameters")
  121. rollouts = ParallelRollouts(workers, mode="bulk_sync")
  122. replay_buffer = SimpleReplayBuffer(config["buffer_size"])
  123. store_op = rollouts \
  124. .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
  125. train_op = Replay(local_buffer=replay_buffer) \
  126. .combine(
  127. ConcatBatches(
  128. min_batch_size=config["train_batch_size"],
  129. count_steps_by=config["multiagent"]["count_steps_by"]
  130. )) \
  131. .for_each(TrainOneStep(workers)) \
  132. .for_each(UpdateTargetNetwork(
  133. workers, config["target_network_update_freq"]))
  134. merged_op = Concurrently(
  135. [store_op, train_op], mode="round_robin", output_indexes=[1])
  136. return StandardMetricsReporting(merged_op, workers, config)