simple_q.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. """
  2. Simple Q-Learning
  3. =================
  4. This module provides a basic implementation of the DQN algorithm without any
  5. optimizations.
  6. This file defines the distributed Trainer class for the Simple Q algorithm.
  7. See `simple_q_[tf|torch]_policy.py` for the definition of the policy loss.
  8. """
  9. import logging
  10. from typing import Optional, Type
  11. from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
  12. from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
  13. from ray.rllib.agents.trainer import Trainer, with_common_config
  14. from ray.rllib.execution.concurrency_ops import Concurrently
  15. from ray.rllib.execution.metric_ops import StandardMetricsReporting
  16. from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
  17. from ray.rllib.execution.rollout_ops import ParallelRollouts
  18. from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
  19. UpdateTargetNetwork
  20. from ray.rllib.policy.policy import Policy
  21. from ray.rllib.utils.annotations import override
  22. from ray.rllib.utils.deprecation import DEPRECATED_VALUE
  23. from ray.rllib.utils.typing import TrainerConfigDict
  24. logger = logging.getLogger(__name__)
  25. # yapf: disable
  26. # __sphinx_doc_begin__
  27. DEFAULT_CONFIG = with_common_config({
  28. # === Exploration Settings ===
  29. "exploration_config": {
  30. # The Exploration class to use.
  31. "type": "EpsilonGreedy",
  32. # Config for the Exploration class' constructor:
  33. "initial_epsilon": 1.0,
  34. "final_epsilon": 0.02,
  35. "epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
  36. # For soft_q, use:
  37. # "exploration_config" = {
  38. # "type": "SoftQ"
  39. # "temperature": [float, e.g. 1.0]
  40. # }
  41. },
  42. # Switch to greedy actions in evaluation workers.
  43. "evaluation_config": {
  44. "explore": False,
  45. },
  46. # Minimum env steps to optimize for per train call. This value does
  47. # not affect learning, only the length of iterations.
  48. "timesteps_per_iteration": 1000,
  49. # Update the target network every `target_network_update_freq` steps.
  50. "target_network_update_freq": 500,
  51. # === Replay buffer ===
  52. # Size of the replay buffer. Note that if async_updates is set, then
  53. # each worker will have a replay buffer of this size.
  54. "buffer_size": DEPRECATED_VALUE,
  55. "replay_buffer_config": {
  56. "type": "MultiAgentReplayBuffer",
  57. "capacity": 50000,
  58. },
  59. # Set this to True, if you want the contents of your buffer(s) to be
  60. # stored in any saved checkpoints as well.
  61. # Warnings will be created if:
  62. # - This is True AND restoring from a checkpoint that contains no buffer
  63. # data.
  64. # - This is False AND restoring from a checkpoint that does contain
  65. # buffer data.
  66. "store_buffer_in_checkpoints": False,
  67. # The number of contiguous environment steps to replay at once. This may
  68. # be set to greater than 1 to support recurrent models.
  69. "replay_sequence_length": 1,
  70. # === Optimization ===
  71. # Learning rate for adam optimizer
  72. "lr": 5e-4,
  73. # Learning rate schedule
  74. # In the format of [[timestep, value], [timestep, value], ...]
  75. # A schedule should normally start from timestep 0.
  76. "lr_schedule": None,
  77. # Adam epsilon hyper parameter
  78. "adam_epsilon": 1e-8,
  79. # If not None, clip gradients during optimization at this value
  80. "grad_clip": 40,
  81. # How many steps of the model to sample before learning starts.
  82. "learning_starts": 1000,
  83. # Update the replay buffer with this many samples at once. Note that
  84. # this setting applies per-worker if num_workers > 1.
  85. "rollout_fragment_length": 4,
  86. # Size of a batch sampled from replay buffer for training. Note that
  87. # if async_updates is set, then each worker returns gradients for a
  88. # batch of this size.
  89. "train_batch_size": 32,
  90. # === Parallelism ===
  91. # Number of workers for collecting samples with. This only makes sense
  92. # to increase if your environment is particularly slow to sample, or if
  93. # you"re using the Async or Ape-X optimizers.
  94. "num_workers": 0,
  95. # Prevent reporting frequency from going lower than this time span.
  96. "min_time_s_per_reporting": 1,
  97. })
  98. # __sphinx_doc_end__
  99. # yapf: enable
  100. class SimpleQTrainer(Trainer):
  101. @classmethod
  102. @override(Trainer)
  103. def get_default_config(cls) -> TrainerConfigDict:
  104. return DEFAULT_CONFIG
  105. @override(Trainer)
  106. def validate_config(self, config: TrainerConfigDict) -> None:
  107. """Checks and updates the config based on settings.
  108. """
  109. # Call super's validation method.
  110. super().validate_config(config)
  111. if config["exploration_config"]["type"] == "ParameterNoise":
  112. if config["batch_mode"] != "complete_episodes":
  113. logger.warning(
  114. "ParameterNoise Exploration requires `batch_mode` to be "
  115. "'complete_episodes'. Setting batch_mode="
  116. "complete_episodes.")
  117. config["batch_mode"] = "complete_episodes"
  118. if config.get("noisy", False):
  119. raise ValueError(
  120. "ParameterNoise Exploration and `noisy` network cannot be"
  121. " used at the same time!")
  122. if config.get("prioritized_replay"):
  123. if config["multiagent"]["replay_mode"] == "lockstep":
  124. raise ValueError("Prioritized replay is not supported when "
  125. "replay_mode=lockstep.")
  126. elif config.get("replay_sequence_length", 0) > 1:
  127. raise ValueError("Prioritized replay is not supported when "
  128. "replay_sequence_length > 1.")
  129. else:
  130. if config.get("worker_side_prioritization"):
  131. raise ValueError(
  132. "Worker side prioritization is not supported when "
  133. "prioritized_replay=False.")
  134. # Multi-agent mode and multi-GPU optimizer.
  135. if config["multiagent"]["policies"] and \
  136. not config["simple_optimizer"]:
  137. logger.info(
  138. "In multi-agent mode, policies will be optimized sequentially"
  139. " by the multi-GPU optimizer. Consider setting "
  140. "`simple_optimizer=True` if this doesn't work for you.")
  141. @override(Trainer)
  142. def get_default_policy_class(
  143. self, config: TrainerConfigDict) -> Optional[Type[Policy]]:
  144. if config["framework"] == "torch":
  145. return SimpleQTorchPolicy
  146. else:
  147. return SimpleQTFPolicy
  148. @staticmethod
  149. @override(Trainer)
  150. def execution_plan(workers, config, **kwargs):
  151. assert "local_replay_buffer" in kwargs, (
  152. "GenericOffPolicy execution plan requires a local replay buffer.")
  153. local_replay_buffer = kwargs["local_replay_buffer"]
  154. rollouts = ParallelRollouts(workers, mode="bulk_sync")
  155. # (1) Generate rollouts and store them in our local replay buffer.
  156. store_op = rollouts.for_each(
  157. StoreToReplayBuffer(local_buffer=local_replay_buffer))
  158. if config["simple_optimizer"]:
  159. train_step_op = TrainOneStep(workers)
  160. else:
  161. train_step_op = MultiGPUTrainOneStep(
  162. workers=workers,
  163. sgd_minibatch_size=config["train_batch_size"],
  164. num_sgd_iter=1,
  165. num_gpus=config["num_gpus"],
  166. _fake_gpus=config["_fake_gpus"])
  167. # (2) Read and train on experiences from the replay buffer.
  168. replay_op = Replay(local_buffer=local_replay_buffer) \
  169. .for_each(train_step_op) \
  170. .for_each(UpdateTargetNetwork(
  171. workers, config["target_network_update_freq"]))
  172. # Alternate deterministically between (1) and (2).
  173. train_op = Concurrently(
  174. [store_op, replay_op], mode="round_robin", output_indexes=[1])
  175. return StandardMetricsReporting(train_op, workers, config)