slateq.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. """
  2. SlateQ (Reinforcement Learning for Recommendation)
  3. ==================================================
  4. This file defines the trainer class for the SlateQ algorithm from the
  5. `"Reinforcement Learning for Slate-based Recommender Systems: A Tractable
  6. Decomposition and Practical Methodology" <https://arxiv.org/abs/1905.12767>`_
  7. paper.
  8. See `slateq_torch_policy.py` for the definition of the policy. Currently, only
  9. PyTorch is supported. The algorithm is written and tested for Google's RecSim
  10. environment (https://github.com/google-research/recsim).
  11. """
  12. import logging
  13. from typing import List, Type
  14. from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy
  15. from ray.rllib.agents.trainer import Trainer, with_common_config
  16. from ray.rllib.evaluation.worker_set import WorkerSet
  17. from ray.rllib.examples.policy.random_policy import RandomPolicy
  18. from ray.rllib.execution.concurrency_ops import Concurrently
  19. from ray.rllib.execution.metric_ops import StandardMetricsReporting
  20. from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
  21. from ray.rllib.execution.rollout_ops import ParallelRollouts
  22. from ray.rllib.execution.train_ops import TrainOneStep
  23. from ray.rllib.policy.policy import Policy
  24. from ray.rllib.utils.annotations import override
  25. from ray.rllib.utils.deprecation import DEPRECATED_VALUE
  26. from ray.rllib.utils.typing import TrainerConfigDict
  27. from ray.util.iter import LocalIterator
  28. logger = logging.getLogger(__name__)
  29. # Defines all SlateQ strategies implemented.
  30. ALL_SLATEQ_STRATEGIES = [
  31. # RANDOM: Randomly select documents for slates.
  32. "RANDOM",
  33. # MYOP: Select documents that maximize user click probabilities. This is
  34. # a myopic strategy and ignores long term rewards. This is equivalent to
  35. # setting a zero discount rate for future rewards.
  36. "MYOP",
  37. # SARSA: Use the SlateQ SARSA learning algorithm.
  38. "SARSA",
  39. # QL: Use the SlateQ Q-learning algorithm.
  40. "QL",
  41. ]
  42. # yapf: disable
  43. # __sphinx_doc_begin__
  44. DEFAULT_CONFIG = with_common_config({
  45. # === Model ===
  46. # Dense-layer setup for each the advantage branch and the value branch
  47. # in a dueling architecture.
  48. "hiddens": [256, 64, 16],
  49. # set batchmode
  50. "batch_mode": "complete_episodes",
  51. # === Deep Learning Framework Settings ===
  52. # Currently, only PyTorch is supported
  53. "framework": "torch",
  54. # === Exploration Settings ===
  55. "exploration_config": {
  56. # The Exploration class to use.
  57. "type": "EpsilonGreedy",
  58. # Config for the Exploration class' constructor:
  59. "initial_epsilon": 1.0,
  60. "final_epsilon": 0.02,
  61. "epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
  62. },
  63. # Switch to greedy actions in evaluation workers.
  64. "evaluation_config": {
  65. "explore": False,
  66. },
  67. # Minimum env steps to optimize for per train call. This value does
  68. # not affect learning, only the length of iterations.
  69. "timesteps_per_iteration": 1000,
  70. # === Replay buffer ===
  71. # Size of the replay buffer. Note that if async_updates is set, then
  72. # each worker will have a replay buffer of this size.
  73. "buffer_size": DEPRECATED_VALUE,
  74. "replay_buffer_config": {
  75. "type": "MultiAgentReplayBuffer",
  76. "capacity": 50000,
  77. },
  78. # The number of contiguous environment steps to replay at once. This may
  79. # be set to greater than 1 to support recurrent models.
  80. "replay_sequence_length": 1,
  81. # Whether to LZ4 compress observations
  82. "compress_observations": False,
  83. # If set, this will fix the ratio of replayed from a buffer and learned on
  84. # timesteps to sampled from an environment and stored in the replay buffer
  85. # timesteps. Otherwise, the replay will proceed at the native ratio
  86. # determined by (train_batch_size / rollout_fragment_length).
  87. "training_intensity": None,
  88. # === Optimization ===
  89. # Learning rate for adam optimizer for the user choice model
  90. "lr_choice_model": 1e-2,
  91. # Learning rate for adam optimizer for the q model
  92. "lr_q_model": 1e-2,
  93. # Adam epsilon hyper parameter
  94. "adam_epsilon": 1e-8,
  95. # If not None, clip gradients during optimization at this value
  96. "grad_clip": 40,
  97. # How many steps of the model to sample before learning starts.
  98. "learning_starts": 1000,
  99. # Update the replay buffer with this many samples at once. Note that
  100. # this setting applies per-worker if num_workers > 1.
  101. "rollout_fragment_length": 1000,
  102. # Size of a batch sampled from replay buffer for training. Note that
  103. # if async_updates is set, then each worker returns gradients for a
  104. # batch of this size.
  105. "train_batch_size": 32,
  106. # === Parallelism ===
  107. # Number of workers for collecting samples with. This only makes sense
  108. # to increase if your environment is particularly slow to sample, or if
  109. # you"re using the Async or Ape-X optimizers.
  110. "num_workers": 0,
  111. # Whether to compute priorities on workers.
  112. "worker_side_prioritization": False,
  113. # Prevent reporting frequency from going lower than this time span.
  114. "min_time_s_per_reporting": 1,
  115. # === SlateQ specific options ===
  116. # Learning method used by the slateq policy. Choose from: RANDOM,
  117. # MYOP (myopic), SARSA, QL (Q-Learning),
  118. "slateq_strategy": "QL",
  119. # user/doc embedding size for the recsim environment
  120. "recsim_embedding_size": 20,
  121. })
  122. # __sphinx_doc_end__
  123. # yapf: enable
  124. def calculate_round_robin_weights(config: TrainerConfigDict) -> List[float]:
  125. """Calculate the round robin weights for the rollout and train steps"""
  126. if not config["training_intensity"]:
  127. return [1, 1]
  128. # e.g., 32 / 4 -> native ratio of 8.0
  129. native_ratio = (
  130. config["train_batch_size"] / config["rollout_fragment_length"])
  131. # Training intensity is specified in terms of
  132. # (steps_replayed / steps_sampled), so adjust for the native ratio.
  133. weights = [1, config["training_intensity"] / native_ratio]
  134. return weights
  135. class SlateQTrainer(Trainer):
  136. @classmethod
  137. @override(Trainer)
  138. def get_default_config(cls) -> TrainerConfigDict:
  139. return DEFAULT_CONFIG
  140. @override(Trainer)
  141. def validate_config(self, config: TrainerConfigDict) -> None:
  142. # Call super's validation method.
  143. super().validate_config(config)
  144. if config["num_gpus"] > 1:
  145. raise ValueError("`num_gpus` > 1 not yet supported for SlateQ!")
  146. if config["framework"] != "torch":
  147. raise ValueError("SlateQ only runs on PyTorch")
  148. if config["slateq_strategy"] not in ALL_SLATEQ_STRATEGIES:
  149. raise ValueError("Unknown slateq_strategy: "
  150. f"{config['slateq_strategy']}.")
  151. if config["slateq_strategy"] == "SARSA":
  152. if config["batch_mode"] != "complete_episodes":
  153. raise ValueError("For SARSA strategy, batch_mode must be "
  154. "'complete_episodes'")
  155. @override(Trainer)
  156. def get_default_policy_class(self, config: TrainerConfigDict) -> \
  157. Type[Policy]:
  158. if config["slateq_strategy"] == "RANDOM":
  159. return RandomPolicy
  160. else:
  161. return SlateQTorchPolicy
  162. @staticmethod
  163. @override(Trainer)
  164. def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
  165. **kwargs) -> LocalIterator[dict]:
  166. assert "local_replay_buffer" in kwargs, (
  167. "SlateQ execution plan requires a local replay buffer.")
  168. rollouts = ParallelRollouts(workers, mode="bulk_sync")
  169. # We execute the following steps concurrently:
  170. # (1) Generate rollouts and store them in our local replay buffer.
  171. # Calling next() on store_op drives this.
  172. store_op = rollouts.for_each(
  173. StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"]))
  174. # (2) Read and train on experiences from the replay buffer. Every batch
  175. # returned from the LocalReplay() iterator is passed to TrainOneStep to
  176. # take a SGD step.
  177. replay_op = Replay(local_buffer=kwargs["local_replay_buffer"]) \
  178. .for_each(TrainOneStep(workers))
  179. if config["slateq_strategy"] != "RANDOM":
  180. # Alternate deterministically between (1) and (2). Only return the
  181. # output of (2) since training metrics are not available until (2)
  182. # runs.
  183. train_op = Concurrently(
  184. [store_op, replay_op],
  185. mode="round_robin",
  186. output_indexes=[1],
  187. round_robin_weights=calculate_round_robin_weights(config))
  188. else:
  189. # No training is needed for the RANDOM strategy.
  190. train_op = rollouts
  191. return StandardMetricsReporting(train_op, workers, config)