r2d2.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import logging
  2. from typing import Optional, Type
  3. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  4. from ray.rllib.algorithms.dqn import DQN, DQNConfig
  5. from ray.rllib.algorithms.r2d2.r2d2_tf_policy import R2D2TFPolicy
  6. from ray.rllib.algorithms.r2d2.r2d2_torch_policy import R2D2TorchPolicy
  7. from ray.rllib.policy.policy import Policy
  8. from ray.rllib.utils.annotations import override
  9. from ray.rllib.utils.deprecation import (
  10. DEPRECATED_VALUE,
  11. Deprecated,
  12. ALGO_DEPRECATION_WARNING,
  13. )
  14. logger = logging.getLogger(__name__)
  15. class R2D2Config(DQNConfig):
  16. r"""Defines a configuration class from which a R2D2 Algorithm can be built.
  17. Example:
  18. >>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
  19. >>> config = R2D2Config()
  20. >>> print(config.h_function_epsilon) # doctest: +SKIP
  21. >>> replay_config = config.replay_buffer_config.update(
  22. >>> {
  23. >>> "capacity": 1000000,
  24. >>> "replay_burn_in": 20,
  25. >>> }
  26. >>> )
  27. >>> config.training(replay_buffer_config=replay_config)\ # doctest: +SKIP
  28. >>> .resources(num_gpus=1)\
  29. >>> .rollouts(num_rollout_workers=30)\
  30. >>> .environment("CartPole-v1")
  31. >>> algo = R2D2(config=config) # doctest: +SKIP
  32. >>> algo.train() # doctest: +SKIP
  33. Example:
  34. >>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
  35. >>> from ray import air
  36. >>> from ray import tune
  37. >>> config = R2D2Config()
  38. >>> config.training(train_batch_size=tune.grid_search([256, 64])
  39. >>> config.environment(env="CartPole-v1")
  40. >>> tune.Tuner( # doctest: +SKIP
  41. ... "R2D2",
  42. ... run_config=air.RunConfig(stop={"episode_reward_mean":200}),
  43. ... param_space=config.to_dict()
  44. ... ).fit()
  45. Example:
  46. >>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
  47. >>> config = R2D2Config()
  48. >>> print(config.exploration_config) # doctest: +SKIP
  49. >>> explore_config = config.exploration_config.update(
  50. >>> {
  51. >>> "initial_epsilon": 1.0,
  52. >>> "final_epsilon": 0.1,
  53. >>> "epsilone_timesteps": 200000,
  54. >>> }
  55. >>> )
  56. >>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
  57. >>> .exploration(exploration_config=explore_config)
  58. Example:
  59. >>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
  60. >>> config = R2D2Config()
  61. >>> print(config.exploration_config) # doctest: +SKIP
  62. >>> explore_config = config.exploration_config.update(
  63. >>> {
  64. >>> "type": "SoftQ",
  65. >>> "temperature": [1.0],
  66. >>> }
  67. >>> )
  68. >>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
  69. >>> .exploration(exploration_config=explore_config)
  70. """
  71. def __init__(self, algo_class=None):
  72. """Initializes a ApexConfig instance."""
  73. super().__init__(algo_class=algo_class or R2D2)
  74. # fmt: off
  75. # __sphinx_doc_begin__
  76. # R2D2-specific settings:
  77. self.zero_init_states = True
  78. self.use_h_function = True
  79. self.h_function_epsilon = 1e-3
  80. # R2D2 settings overriding DQN ones:
  81. # .training()
  82. self.adam_epsilon = 1e-3
  83. self.lr = 1e-4
  84. self.gamma = 0.997
  85. self.train_batch_size = 1000
  86. self.target_network_update_freq = 1000
  87. self.training_intensity = 150
  88. # R2D2 is using a buffer that stores sequences.
  89. self.replay_buffer_config = {
  90. "type": "MultiAgentReplayBuffer",
  91. # Specify prioritized replay by supplying a buffer type that supports
  92. # prioritization, for example: MultiAgentPrioritizedReplayBuffer.
  93. "prioritized_replay": DEPRECATED_VALUE,
  94. # Size of the replay buffer (in sequences, not timesteps).
  95. "capacity": 100000,
  96. # This algorithm learns on sequences. We therefore require the replay buffer
  97. # to slice sampled batches into sequences before replay. How sequences
  98. # are sliced depends on the parameters `replay_sequence_length`,
  99. # `replay_burn_in`, and `replay_zero_init_states`.
  100. "storage_unit": "sequences",
  101. # Set automatically: The number
  102. # of contiguous environment steps to
  103. # replay at once. Will be calculated via
  104. # model->max_seq_len + burn_in.
  105. # Do not set this to any valid value!
  106. "replay_sequence_length": -1,
  107. # If > 0, use the `replay_burn_in` first steps of each replay-sampled
  108. # sequence (starting either from all 0.0-values if `zero_init_state=True` or
  109. # from the already stored values) to calculate an even more accurate
  110. # initial states for the actual sequence (starting after this burn-in
  111. # window). In the burn-in case, the actual length of the sequence
  112. # used for loss calculation is `n - replay_burn_in` time steps
  113. # (n=LSTM’s/attention net’s max_seq_len).
  114. "replay_burn_in": 0,
  115. }
  116. # .rollouts()
  117. self.num_rollout_workers = 2
  118. self.batch_mode = "complete_episodes"
  119. # fmt: on
  120. # __sphinx_doc_end__
  121. self.burn_in = DEPRECATED_VALUE
  122. def training(
  123. self,
  124. *,
  125. zero_init_states: Optional[bool] = NotProvided,
  126. use_h_function: Optional[bool] = NotProvided,
  127. h_function_epsilon: Optional[float] = NotProvided,
  128. **kwargs,
  129. ) -> "R2D2Config":
  130. """Sets the training related configuration.
  131. Args:
  132. zero_init_states: If True, assume a zero-initialized state input (no
  133. matter where in the episode the sequence is located).
  134. If False, store the initial states along with each SampleBatch, use
  135. it (as initial state when running through the network for training),
  136. and update that initial state during training (from the internal
  137. state outputs of the immediately preceding sequence).
  138. use_h_function: Whether to use the h-function from the paper [1] to scale
  139. target values in the R2D2-loss function:
  140. h(x) = sign(x)(􏰅|x| + 1 − 1) + εx
  141. h_function_epsilon: The epsilon parameter from the R2D2 loss function (only
  142. used if `use_h_function`=True.
  143. Returns:
  144. This updated AlgorithmConfig object.
  145. """
  146. # Pass kwargs onto super's `training()` method.
  147. super().training(**kwargs)
  148. if zero_init_states is not NotProvided:
  149. self.zero_init_states = zero_init_states
  150. if use_h_function is not NotProvided:
  151. self.use_h_function = use_h_function
  152. if h_function_epsilon is not NotProvided:
  153. self.h_function_epsilon = h_function_epsilon
  154. return self
  155. @override(DQNConfig)
  156. def validate(self) -> None:
  157. # Call super's validation method.
  158. super().validate()
  159. if (
  160. not self.in_evaluation
  161. and self.replay_buffer_config.get("replay_sequence_length", -1) != -1
  162. ):
  163. raise ValueError(
  164. "`replay_sequence_length` is calculated automatically to be "
  165. "model->max_seq_len + burn_in!"
  166. )
  167. # Add the `burn_in` to the Model's max_seq_len.
  168. # Set the replay sequence length to the max_seq_len of the model.
  169. self.replay_buffer_config["replay_sequence_length"] = (
  170. self.replay_buffer_config["replay_burn_in"] + self.model["max_seq_len"]
  171. )
  172. if self.batch_mode != "complete_episodes":
  173. raise ValueError("`batch_mode` must be 'complete_episodes'!")
  174. @Deprecated(
  175. old="rllib/algorithms/r2d2/",
  176. new="rllib_contrib/r2d2/",
  177. help=ALGO_DEPRECATION_WARNING,
  178. error=False,
  179. )
  180. class R2D2(DQN):
  181. """Recurrent Experience Replay in Distrib. Reinforcement Learning (R2D2).
  182. Algorithm defining the distributed R2D2 algorithm.
  183. See `r2d2_[tf|torch]_policy.py` for the definition of the policies.
  184. [1] Recurrent Experience Replay in Distributed Reinforcement Learning -
  185. S Kapturowski, G Ostrovski, J Quan, R Munos, W Dabney - 2019, DeepMind
  186. Detailed documentation:
  187. https://docs.ray.io/en/master/rllib-algorithms.html#\
  188. recurrent-replay-distributed-dqn-r2d2
  189. """
  190. @classmethod
  191. @override(DQN)
  192. def get_default_config(cls) -> AlgorithmConfig:
  193. return R2D2Config()
  194. @classmethod
  195. @override(DQN)
  196. def get_default_policy_class(
  197. cls, config: AlgorithmConfig
  198. ) -> Optional[Type[Policy]]:
  199. if config["framework"] == "torch":
  200. return R2D2TorchPolicy
  201. else:
  202. return R2D2TFPolicy