slateq.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. """
  2. SlateQ (Reinforcement Learning for Recommendation)
  3. ==================================================
  4. This file defines the trainer class for the SlateQ algorithm from the
  5. `"Reinforcement Learning for Slate-based Recommender Systems: A Tractable
  6. Decomposition and Practical Methodology" <https://arxiv.org/abs/1905.12767>`_
  7. paper.
  8. See `slateq_torch_policy.py` for the definition of the policy. Currently, only
  9. PyTorch is supported. The algorithm is written and tested for Google's RecSim
  10. environment (https://github.com/google-research/recsim).
  11. """
  12. import logging
  13. from typing import List, Type
  14. from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy
  15. from ray.rllib.agents.trainer import with_common_config
  16. from ray.rllib.agents.trainer_template import build_trainer
  17. from ray.rllib.evaluation.worker_set import WorkerSet
  18. from ray.rllib.examples.policy.random_policy import RandomPolicy
  19. from ray.rllib.execution.concurrency_ops import Concurrently
  20. from ray.rllib.execution.metric_ops import StandardMetricsReporting
  21. from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
  22. from ray.rllib.execution.rollout_ops import ParallelRollouts
  23. from ray.rllib.execution.train_ops import TrainOneStep
  24. from ray.rllib.policy.policy import Policy
  25. from ray.rllib.utils.deprecation import DEPRECATED_VALUE
  26. from ray.rllib.utils.typing import TrainerConfigDict
  27. from ray.util.iter import LocalIterator
  28. logger = logging.getLogger(__name__)
  29. # Defines all SlateQ strategies implemented.
  30. ALL_SLATEQ_STRATEGIES = [
  31. # RANDOM: Randomly select documents for slates.
  32. "RANDOM",
  33. # MYOP: Select documents that maximize user click probabilities. This is
  34. # a myopic strategy and ignores long term rewards. This is equivalent to
  35. # setting a zero discount rate for future rewards.
  36. "MYOP",
  37. # SARSA: Use the SlateQ SARSA learning algorithm.
  38. "SARSA",
  39. # QL: Use the SlateQ Q-learning algorithm.
  40. "QL",
  41. ]
  42. # yapf: disable
  43. # __sphinx_doc_begin__
  44. DEFAULT_CONFIG = with_common_config({
  45. # === Model ===
  46. # Dense-layer setup for each the advantage branch and the value branch
  47. # in a dueling architecture.
  48. "hiddens": [256, 64, 16],
  49. # set batchmode
  50. "batch_mode": "complete_episodes",
  51. # === Deep Learning Framework Settings ===
  52. # Currently, only PyTorch is supported
  53. "framework": "torch",
  54. # === Exploration Settings ===
  55. "exploration_config": {
  56. # The Exploration class to use.
  57. "type": "EpsilonGreedy",
  58. # Config for the Exploration class' constructor:
  59. "initial_epsilon": 1.0,
  60. "final_epsilon": 0.02,
  61. "epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
  62. },
  63. # Switch to greedy actions in evaluation workers.
  64. "evaluation_config": {
  65. "explore": False,
  66. },
  67. # Minimum env steps to optimize for per train call. This value does
  68. # not affect learning, only the length of iterations.
  69. "timesteps_per_iteration": 1000,
  70. # === Replay buffer ===
  71. # Size of the replay buffer. Note that if async_updates is set, then
  72. # each worker will have a replay buffer of this size.
  73. "buffer_size": DEPRECATED_VALUE,
  74. "replay_buffer_config": {
  75. "type": "LocalReplayBuffer",
  76. "capacity": 50000,
  77. },
  78. # The number of contiguous environment steps to replay at once. This may
  79. # be set to greater than 1 to support recurrent models.
  80. "replay_sequence_length": 1,
  81. # Whether to LZ4 compress observations
  82. "compress_observations": False,
  83. # If set, this will fix the ratio of replayed from a buffer and learned on
  84. # timesteps to sampled from an environment and stored in the replay buffer
  85. # timesteps. Otherwise, the replay will proceed at the native ratio
  86. # determined by (train_batch_size / rollout_fragment_length).
  87. "training_intensity": None,
  88. # === Optimization ===
  89. # Learning rate for adam optimizer for the user choice model
  90. "lr_choice_model": 1e-2,
  91. # Learning rate for adam optimizer for the q model
  92. "lr_q_model": 1e-2,
  93. # Adam epsilon hyper parameter
  94. "adam_epsilon": 1e-8,
  95. # If not None, clip gradients during optimization at this value
  96. "grad_clip": 40,
  97. # How many steps of the model to sample before learning starts.
  98. "learning_starts": 1000,
  99. # Update the replay buffer with this many samples at once. Note that
  100. # this setting applies per-worker if num_workers > 1.
  101. "rollout_fragment_length": 1000,
  102. # Size of a batch sampled from replay buffer for training. Note that
  103. # if async_updates is set, then each worker returns gradients for a
  104. # batch of this size.
  105. "train_batch_size": 32,
  106. # === Parallelism ===
  107. # Number of workers for collecting samples with. This only makes sense
  108. # to increase if your environment is particularly slow to sample, or if
  109. # you"re using the Async or Ape-X optimizers.
  110. "num_workers": 0,
  111. # Whether to compute priorities on workers.
  112. "worker_side_prioritization": False,
  113. # Prevent iterations from going lower than this time span
  114. "min_iter_time_s": 1,
  115. # === SlateQ specific options ===
  116. # Learning method used by the slateq policy. Choose from: RANDOM,
  117. # MYOP (myopic), SARSA, QL (Q-Learning),
  118. "slateq_strategy": "QL",
  119. # user/doc embedding size for the recsim environment
  120. "recsim_embedding_size": 20,
  121. })
  122. # __sphinx_doc_end__
  123. # yapf: enable
  124. def validate_config(config: TrainerConfigDict) -> None:
  125. """Checks the config based on settings"""
  126. if config["num_gpus"] > 1:
  127. raise ValueError("`num_gpus` > 1 not yet supported for SlateQ!")
  128. if config["framework"] != "torch":
  129. raise ValueError("SlateQ only runs on PyTorch")
  130. if config["slateq_strategy"] not in ALL_SLATEQ_STRATEGIES:
  131. raise ValueError("Unknown slateq_strategy: "
  132. f"{config['slateq_strategy']}.")
  133. if config["slateq_strategy"] == "SARSA":
  134. if config["batch_mode"] != "complete_episodes":
  135. raise ValueError(
  136. "For SARSA strategy, batch_mode must be 'complete_episodes'")
  137. def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
  138. **kwargs) -> LocalIterator[dict]:
  139. """Execution plan of the SlateQ algorithm. Defines the distributed dataflow.
  140. Args:
  141. workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
  142. of the Trainer.
  143. config (TrainerConfigDict): The trainer's configuration dict.
  144. Returns:
  145. LocalIterator[dict]: A local iterator over training metrics.
  146. """
  147. assert "local_replay_buffer" in kwargs, (
  148. "SlateQ execution plan requires a local replay buffer.")
  149. rollouts = ParallelRollouts(workers, mode="bulk_sync")
  150. # We execute the following steps concurrently:
  151. # (1) Generate rollouts and store them in our local replay buffer. Calling
  152. # next() on store_op drives this.
  153. store_op = rollouts.for_each(
  154. StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"]))
  155. # (2) Read and train on experiences from the replay buffer. Every batch
  156. # returned from the LocalReplay() iterator is passed to TrainOneStep to
  157. # take a SGD step.
  158. replay_op = Replay(local_buffer=kwargs["local_replay_buffer"]) \
  159. .for_each(TrainOneStep(workers))
  160. if config["slateq_strategy"] != "RANDOM":
  161. # Alternate deterministically between (1) and (2). Only return the
  162. # output of (2) since training metrics are not available until (2)
  163. # runs.
  164. train_op = Concurrently(
  165. [store_op, replay_op],
  166. mode="round_robin",
  167. output_indexes=[1],
  168. round_robin_weights=calculate_round_robin_weights(config))
  169. else:
  170. # No training is needed for the RANDOM strategy.
  171. train_op = rollouts
  172. return StandardMetricsReporting(train_op, workers, config)
  173. def calculate_round_robin_weights(config: TrainerConfigDict) -> List[float]:
  174. """Calculate the round robin weights for the rollout and train steps"""
  175. if not config["training_intensity"]:
  176. return [1, 1]
  177. # e.g., 32 / 4 -> native ratio of 8.0
  178. native_ratio = (
  179. config["train_batch_size"] / config["rollout_fragment_length"])
  180. # Training intensity is specified in terms of
  181. # (steps_replayed / steps_sampled), so adjust for the native ratio.
  182. weights = [1, config["training_intensity"] / native_ratio]
  183. return weights
  184. def get_policy_class(config: TrainerConfigDict) -> Type[Policy]:
  185. """Policy class picker function.
  186. Args:
  187. config (TrainerConfigDict): The trainer's configuration dict.
  188. Returns:
  189. Type[Policy]: The Policy class to use with SlateQTrainer.
  190. """
  191. if config["slateq_strategy"] == "RANDOM":
  192. return RandomPolicy
  193. else:
  194. return SlateQTorchPolicy
  195. SlateQTrainer = build_trainer(
  196. name="SlateQ",
  197. get_policy_class=get_policy_class,
  198. default_config=DEFAULT_CONFIG,
  199. validate_config=validate_config,
  200. execution_plan=execution_plan)