ddpg.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import logging
  2. from typing import Type
  3. from ray.rllib.agents.trainer import with_common_config
  4. from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
  5. from ray.rllib.agents.ddpg.ddpg_tf_policy import DDPGTFPolicy
  6. from ray.rllib.policy.policy import Policy
  7. from ray.rllib.utils.annotations import override
  8. from ray.rllib.utils.deprecation import DEPRECATED_VALUE
  9. from ray.rllib.utils.typing import TrainerConfigDict
  10. logger = logging.getLogger(__name__)
  11. # yapf: disable
  12. # __sphinx_doc_begin__
  13. DEFAULT_CONFIG = with_common_config({
  14. # === Twin Delayed DDPG (TD3) and Soft Actor-Critic (SAC) tricks ===
  15. # TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
  16. # In addition to settings below, you can use "exploration_noise_type" and
  17. # "exploration_gauss_act_noise" to get IID Gaussian exploration noise
  18. # instead of OU exploration noise.
  19. # twin Q-net
  20. "twin_q": False,
  21. # delayed policy update
  22. "policy_delay": 1,
  23. # target policy smoothing
  24. # (this also replaces OU exploration noise with IID Gaussian exploration
  25. # noise, for now)
  26. "smooth_target_policy": False,
  27. # gaussian stddev of target action noise for smoothing
  28. "target_noise": 0.2,
  29. # target noise limit (bound)
  30. "target_noise_clip": 0.5,
  31. # === Evaluation ===
  32. # Evaluate with epsilon=0 every `evaluation_interval` training iterations.
  33. # The evaluation stats will be reported under the "evaluation" metric key.
  34. # Note that evaluation is currently not parallelized, and that for Ape-X
  35. # metrics are already only reported for the lowest epsilon workers.
  36. "evaluation_interval": None,
  37. # Number of episodes to run per evaluation period.
  38. "evaluation_duration": 10,
  39. # === Model ===
  40. # Apply a state preprocessor with spec given by the "model" config option
  41. # (like other RL algorithms). This is mostly useful if you have a weird
  42. # observation shape, like an image. Disabled by default.
  43. "use_state_preprocessor": False,
  44. # Postprocess the policy network model output with these hidden layers. If
  45. # use_state_preprocessor is False, then these will be the *only* hidden
  46. # layers in the network.
  47. "actor_hiddens": [400, 300],
  48. # Hidden layers activation of the postprocessing stage of the policy
  49. # network
  50. "actor_hidden_activation": "relu",
  51. # Postprocess the critic network model output with these hidden layers;
  52. # again, if use_state_preprocessor is True, then the state will be
  53. # preprocessed by the model specified with the "model" config option first.
  54. "critic_hiddens": [400, 300],
  55. # Hidden layers activation of the postprocessing state of the critic.
  56. "critic_hidden_activation": "relu",
  57. # N-step Q learning
  58. "n_step": 1,
  59. # === Exploration ===
  60. "exploration_config": {
  61. # DDPG uses OrnsteinUhlenbeck (stateful) noise to be added to NN-output
  62. # actions (after a possible pure random phase of n timesteps).
  63. "type": "OrnsteinUhlenbeckNoise",
  64. # For how many timesteps should we return completely random actions,
  65. # before we start adding (scaled) noise?
  66. "random_timesteps": 1000,
  67. # The OU-base scaling factor to always apply to action-added noise.
  68. "ou_base_scale": 0.1,
  69. # The OU theta param.
  70. "ou_theta": 0.15,
  71. # The OU sigma param.
  72. "ou_sigma": 0.2,
  73. # The initial noise scaling factor.
  74. "initial_scale": 1.0,
  75. # The final noise scaling factor.
  76. "final_scale": 0.02,
  77. # Timesteps over which to anneal scale (from initial to final values).
  78. "scale_timesteps": 10000,
  79. },
  80. # Number of env steps to optimize for before returning
  81. "timesteps_per_iteration": 1000,
  82. # Extra configuration that disables exploration.
  83. "evaluation_config": {
  84. "explore": False
  85. },
  86. # === Replay buffer ===
  87. # Size of the replay buffer. Note that if async_updates is set, then
  88. # each worker will have a replay buffer of this size.
  89. "buffer_size": DEPRECATED_VALUE,
  90. "replay_buffer_config": {
  91. "type": "MultiAgentReplayBuffer",
  92. "capacity": 50000,
  93. },
  94. # Set this to True, if you want the contents of your buffer(s) to be
  95. # stored in any saved checkpoints as well.
  96. # Warnings will be created if:
  97. # - This is True AND restoring from a checkpoint that contains no buffer
  98. # data.
  99. # - This is False AND restoring from a checkpoint that does contain
  100. # buffer data.
  101. "store_buffer_in_checkpoints": False,
  102. # If True prioritized replay buffer will be used.
  103. "prioritized_replay": True,
  104. # Alpha parameter for prioritized replay buffer.
  105. "prioritized_replay_alpha": 0.6,
  106. # Beta parameter for sampling from prioritized replay buffer.
  107. "prioritized_replay_beta": 0.4,
  108. # Time steps over which the beta parameter is annealed.
  109. "prioritized_replay_beta_annealing_timesteps": 20000,
  110. # Final value of beta
  111. "final_prioritized_replay_beta": 0.4,
  112. # Epsilon to add to the TD errors when updating priorities.
  113. "prioritized_replay_eps": 1e-6,
  114. # Whether to LZ4 compress observations
  115. "compress_observations": False,
  116. # The intensity with which to update the model (vs collecting samples from
  117. # the env). If None, uses the "natural" value of:
  118. # `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
  119. # `num_envs_per_worker`).
  120. # If provided, will make sure that the ratio between ts inserted into and
  121. # sampled from the buffer matches the given value.
  122. # Example:
  123. # training_intensity=1000.0
  124. # train_batch_size=250 rollout_fragment_length=1
  125. # num_workers=1 (or 0) num_envs_per_worker=1
  126. # -> natural value = 250 / 1 = 250.0
  127. # -> will make sure that replay+train op will be executed 4x as
  128. # often as rollout+insert op (4 * 250 = 1000).
  129. # See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
  130. "training_intensity": None,
  131. # === Optimization ===
  132. # Learning rate for the critic (Q-function) optimizer.
  133. "critic_lr": 1e-3,
  134. # Learning rate for the actor (policy) optimizer.
  135. "actor_lr": 1e-3,
  136. # Update the target network every `target_network_update_freq` steps.
  137. "target_network_update_freq": 0,
  138. # Update the target by \tau * policy + (1-\tau) * target_policy
  139. "tau": 0.002,
  140. # If True, use huber loss instead of squared loss for critic network
  141. # Conventionally, no need to clip gradients if using a huber loss
  142. "use_huber": False,
  143. # Threshold of a huber loss
  144. "huber_threshold": 1.0,
  145. # Weights for L2 regularization
  146. "l2_reg": 1e-6,
  147. # If not None, clip gradients during optimization at this value
  148. "grad_clip": None,
  149. # How many steps of the model to sample before learning starts.
  150. "learning_starts": 1500,
  151. # Update the replay buffer with this many samples at once. Note that this
  152. # setting applies per-worker if num_workers > 1.
  153. "rollout_fragment_length": 1,
  154. # Size of a batched sampled from replay buffer for training. Note that
  155. # if async_updates is set, then each worker returns gradients for a
  156. # batch of this size.
  157. "train_batch_size": 256,
  158. # === Parallelism ===
  159. # Number of workers for collecting samples with. This only makes sense
  160. # to increase if your environment is particularly slow to sample, or if
  161. # you're using the Async or Ape-X optimizers.
  162. "num_workers": 0,
  163. # Whether to compute priorities on workers.
  164. "worker_side_prioritization": False,
  165. # Prevent reporting frequency from going lower than this time span.
  166. "min_time_s_per_reporting": 1,
  167. })
  168. # __sphinx_doc_end__
  169. # yapf: enable
  170. class DDPGTrainer(SimpleQTrainer):
  171. @classmethod
  172. @override(SimpleQTrainer)
  173. def get_default_config(cls) -> TrainerConfigDict:
  174. return DEFAULT_CONFIG
  175. @override(SimpleQTrainer)
  176. def get_default_policy_class(self,
  177. config: TrainerConfigDict) -> Type[Policy]:
  178. if config["framework"] == "torch":
  179. from ray.rllib.agents.ddpg.ddpg_torch_policy import DDPGTorchPolicy
  180. return DDPGTorchPolicy
  181. else:
  182. return DDPGTFPolicy
  183. @override(SimpleQTrainer)
  184. def validate_config(self, config: TrainerConfigDict) -> None:
  185. # Call super's validation method.
  186. super().validate_config(config)
  187. if config["model"]["custom_model"]:
  188. logger.warning(
  189. "Setting use_state_preprocessor=True since a custom model "
  190. "was specified.")
  191. config["use_state_preprocessor"] = True
  192. if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
  193. raise ValueError("`grad_clip` value must be > 0.0!")
  194. if config["exploration_config"]["type"] == "ParameterNoise":
  195. if config["batch_mode"] != "complete_episodes":
  196. logger.warning(
  197. "ParameterNoise Exploration requires `batch_mode` to be "
  198. "'complete_episodes'. Setting "
  199. "batch_mode=complete_episodes.")
  200. config["batch_mode"] = "complete_episodes"
  201. if config.get("prioritized_replay"):
  202. if config["multiagent"]["replay_mode"] == "lockstep":
  203. raise ValueError("Prioritized replay is not supported when "
  204. "replay_mode=lockstep.")
  205. else:
  206. if config.get("worker_side_prioritization"):
  207. raise ValueError(
  208. "Worker side prioritization is not supported when "
  209. "prioritized_replay=False.")