r2d2.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import logging
  2. from typing import Type
  3. from ray.rllib.agents.dqn import DQNTrainer, DEFAULT_CONFIG as \
  4. DQN_DEFAULT_CONFIG
  5. from ray.rllib.agents.dqn.r2d2_tf_policy import R2D2TFPolicy
  6. from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
  7. from ray.rllib.agents.trainer import Trainer
  8. from ray.rllib.policy.policy import Policy
  9. from ray.rllib.utils.annotations import override
  10. from ray.rllib.utils.typing import TrainerConfigDict
  11. logger = logging.getLogger(__name__)
  12. # yapf: disable
  13. # __sphinx_doc_begin__
  14. R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
  15. DQN_DEFAULT_CONFIG, # See keys in impala.py, which are also supported.
  16. {
  17. # Learning rate for adam optimizer.
  18. "lr": 1e-4,
  19. # Discount factor.
  20. "gamma": 0.997,
  21. # Train batch size (in number of single timesteps).
  22. "train_batch_size": 64 * 20,
  23. # Adam epsilon hyper parameter
  24. "adam_epsilon": 1e-3,
  25. # Run in parallel by default.
  26. "num_workers": 2,
  27. # Batch mode must be complete_episodes.
  28. "batch_mode": "complete_episodes",
  29. # If True, assume a zero-initialized state input (no matter where in
  30. # the episode the sequence is located).
  31. # If False, store the initial states along with each SampleBatch, use
  32. # it (as initial state when running through the network for training),
  33. # and update that initial state during training (from the internal
  34. # state outputs of the immediately preceding sequence).
  35. "zero_init_states": True,
  36. # If > 0, use the `burn_in` first steps of each replay-sampled sequence
  37. # (starting either from all 0.0-values if `zero_init_state=True` or
  38. # from the already stored values) to calculate an even more accurate
  39. # initial states for the actual sequence (starting after this burn-in
  40. # window). In the burn-in case, the actual length of the sequence
  41. # used for loss calculation is `n - burn_in` time steps
  42. # (n=LSTM’s/attention net’s max_seq_len).
  43. "burn_in": 0,
  44. # Whether to use the h-function from the paper [1] to scale target
  45. # values in the R2D2-loss function:
  46. # h(x) = sign(x)(􏰅|x| + 1 − 1) + εx
  47. "use_h_function": True,
  48. # The epsilon parameter from the R2D2 loss function (only used
  49. # if `use_h_function`=True.
  50. "h_function_epsilon": 1e-3,
  51. # === Hyperparameters from the paper [1] ===
  52. # Size of the replay buffer (in sequences, not timesteps).
  53. "buffer_size": 100000,
  54. # If True prioritized replay buffer will be used.
  55. "prioritized_replay": False,
  56. # Set automatically: The number of contiguous environment steps to
  57. # replay at once. Will be calculated via
  58. # model->max_seq_len + burn_in.
  59. # Do not set this to any valid value!
  60. "replay_sequence_length": -1,
  61. # Update the target network every `target_network_update_freq` steps.
  62. "target_network_update_freq": 2500,
  63. },
  64. _allow_unknown_configs=True,
  65. )
  66. # __sphinx_doc_end__
  67. # yapf: enable
  68. # Build an R2D2 trainer, which uses the framework specific Policy
  69. # determined in `get_policy_class()` above.
  70. class R2D2Trainer(DQNTrainer):
  71. """Recurrent Experience Replay in Distrib. Reinforcement Learning (R2D2).
  72. Trainer defining the distributed R2D2 algorithm.
  73. See `r2d2_[tf|torch]_policy.py` for the definition of the policies.
  74. [1] Recurrent Experience Replay in Distributed Reinforcement Learning -
  75. S Kapturowski, G Ostrovski, J Quan, R Munos, W Dabney - 2019, DeepMind
  76. Detailed documentation:
  77. https://docs.ray.io/en/master/rllib-algorithms.html#\
  78. recurrent-replay-distributed-dqn-r2d2
  79. """
  80. @classmethod
  81. @override(DQNTrainer)
  82. def get_default_config(cls) -> TrainerConfigDict:
  83. return R2D2_DEFAULT_CONFIG
  84. @override(DQNTrainer)
  85. def get_default_policy_class(self,
  86. config: TrainerConfigDict) -> Type[Policy]:
  87. if config["framework"] == "torch":
  88. return R2D2TorchPolicy
  89. else:
  90. return R2D2TFPolicy
  91. @override(DQNTrainer)
  92. def validate_config(self, config: TrainerConfigDict) -> None:
  93. """Checks and updates the config based on settings.
  94. Rewrites rollout_fragment_length to take into account burn-in and
  95. max_seq_len truncation.
  96. """
  97. # Call super's validation method.
  98. super().validate_config(config)
  99. if config["replay_sequence_length"] != -1:
  100. raise ValueError(
  101. "`replay_sequence_length` is calculated automatically to be "
  102. "model->max_seq_len + burn_in!")
  103. # Add the `burn_in` to the Model's max_seq_len.
  104. # Set the replay sequence length to the max_seq_len of the model.
  105. config["replay_sequence_length"] = \
  106. config["burn_in"] + config["model"]["max_seq_len"]
  107. if config.get("batch_mode") != "complete_episodes":
  108. raise ValueError("`batch_mode` must be 'complete_episodes'!")