123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- """
- SlateQ (Reinforcement Learning for Recommendation)
- ==================================================
- This file defines the trainer class for the SlateQ algorithm from the
- `"Reinforcement Learning for Slate-based Recommender Systems: A Tractable
- Decomposition and Practical Methodology" <https://arxiv.org/abs/1905.12767>`_
- paper.
- See `slateq_torch_policy.py` for the definition of the policy. Currently, only
- PyTorch is supported. The algorithm is written and tested for Google's RecSim
- environment (https://github.com/google-research/recsim).
- """
- import logging
- from typing import List, Type
- from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy
- from ray.rllib.agents.trainer import with_common_config
- from ray.rllib.agents.trainer_template import build_trainer
- from ray.rllib.evaluation.worker_set import WorkerSet
- from ray.rllib.examples.policy.random_policy import RandomPolicy
- from ray.rllib.execution.concurrency_ops import Concurrently
- from ray.rllib.execution.metric_ops import StandardMetricsReporting
- from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
- from ray.rllib.execution.rollout_ops import ParallelRollouts
- from ray.rllib.execution.train_ops import TrainOneStep
- from ray.rllib.policy.policy import Policy
- from ray.rllib.utils.deprecation import DEPRECATED_VALUE
- from ray.rllib.utils.typing import TrainerConfigDict
- from ray.util.iter import LocalIterator
- logger = logging.getLogger(__name__)
- # Defines all SlateQ strategies implemented.
- ALL_SLATEQ_STRATEGIES = [
- # RANDOM: Randomly select documents for slates.
- "RANDOM",
- # MYOP: Select documents that maximize user click probabilities. This is
- # a myopic strategy and ignores long term rewards. This is equivalent to
- # setting a zero discount rate for future rewards.
- "MYOP",
- # SARSA: Use the SlateQ SARSA learning algorithm.
- "SARSA",
- # QL: Use the SlateQ Q-learning algorithm.
- "QL",
- ]
- # yapf: disable
- # __sphinx_doc_begin__
- DEFAULT_CONFIG = with_common_config({
- # === Model ===
- # Dense-layer setup for each the advantage branch and the value branch
- # in a dueling architecture.
- "hiddens": [256, 64, 16],
- # set batchmode
- "batch_mode": "complete_episodes",
- # === Deep Learning Framework Settings ===
- # Currently, only PyTorch is supported
- "framework": "torch",
- # === Exploration Settings ===
- "exploration_config": {
- # The Exploration class to use.
- "type": "EpsilonGreedy",
- # Config for the Exploration class' constructor:
- "initial_epsilon": 1.0,
- "final_epsilon": 0.02,
- "epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
- },
- # Switch to greedy actions in evaluation workers.
- "evaluation_config": {
- "explore": False,
- },
- # Minimum env steps to optimize for per train call. This value does
- # not affect learning, only the length of iterations.
- "timesteps_per_iteration": 1000,
- # === Replay buffer ===
- # Size of the replay buffer. Note that if async_updates is set, then
- # each worker will have a replay buffer of this size.
- "buffer_size": DEPRECATED_VALUE,
- "replay_buffer_config": {
- "type": "LocalReplayBuffer",
- "capacity": 50000,
- },
- # The number of contiguous environment steps to replay at once. This may
- # be set to greater than 1 to support recurrent models.
- "replay_sequence_length": 1,
- # Whether to LZ4 compress observations
- "compress_observations": False,
- # If set, this will fix the ratio of replayed from a buffer and learned on
- # timesteps to sampled from an environment and stored in the replay buffer
- # timesteps. Otherwise, the replay will proceed at the native ratio
- # determined by (train_batch_size / rollout_fragment_length).
- "training_intensity": None,
- # === Optimization ===
- # Learning rate for adam optimizer for the user choice model
- "lr_choice_model": 1e-2,
- # Learning rate for adam optimizer for the q model
- "lr_q_model": 1e-2,
- # Adam epsilon hyper parameter
- "adam_epsilon": 1e-8,
- # If not None, clip gradients during optimization at this value
- "grad_clip": 40,
- # How many steps of the model to sample before learning starts.
- "learning_starts": 1000,
- # Update the replay buffer with this many samples at once. Note that
- # this setting applies per-worker if num_workers > 1.
- "rollout_fragment_length": 1000,
- # Size of a batch sampled from replay buffer for training. Note that
- # if async_updates is set, then each worker returns gradients for a
- # batch of this size.
- "train_batch_size": 32,
- # === Parallelism ===
- # Number of workers for collecting samples with. This only makes sense
- # to increase if your environment is particularly slow to sample, or if
- # you"re using the Async or Ape-X optimizers.
- "num_workers": 0,
- # Whether to compute priorities on workers.
- "worker_side_prioritization": False,
- # Prevent iterations from going lower than this time span
- "min_iter_time_s": 1,
- # === SlateQ specific options ===
- # Learning method used by the slateq policy. Choose from: RANDOM,
- # MYOP (myopic), SARSA, QL (Q-Learning),
- "slateq_strategy": "QL",
- # user/doc embedding size for the recsim environment
- "recsim_embedding_size": 20,
- })
- # __sphinx_doc_end__
- # yapf: enable
- def validate_config(config: TrainerConfigDict) -> None:
- """Checks the config based on settings"""
- if config["num_gpus"] > 1:
- raise ValueError("`num_gpus` > 1 not yet supported for SlateQ!")
- if config["framework"] != "torch":
- raise ValueError("SlateQ only runs on PyTorch")
- if config["slateq_strategy"] not in ALL_SLATEQ_STRATEGIES:
- raise ValueError("Unknown slateq_strategy: "
- f"{config['slateq_strategy']}.")
- if config["slateq_strategy"] == "SARSA":
- if config["batch_mode"] != "complete_episodes":
- raise ValueError(
- "For SARSA strategy, batch_mode must be 'complete_episodes'")
- def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
- **kwargs) -> LocalIterator[dict]:
- """Execution plan of the SlateQ algorithm. Defines the distributed dataflow.
- Args:
- workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
- of the Trainer.
- config (TrainerConfigDict): The trainer's configuration dict.
- Returns:
- LocalIterator[dict]: A local iterator over training metrics.
- """
- assert "local_replay_buffer" in kwargs, (
- "SlateQ execution plan requires a local replay buffer.")
- rollouts = ParallelRollouts(workers, mode="bulk_sync")
- # We execute the following steps concurrently:
- # (1) Generate rollouts and store them in our local replay buffer. Calling
- # next() on store_op drives this.
- store_op = rollouts.for_each(
- StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"]))
- # (2) Read and train on experiences from the replay buffer. Every batch
- # returned from the LocalReplay() iterator is passed to TrainOneStep to
- # take a SGD step.
- replay_op = Replay(local_buffer=kwargs["local_replay_buffer"]) \
- .for_each(TrainOneStep(workers))
- if config["slateq_strategy"] != "RANDOM":
- # Alternate deterministically between (1) and (2). Only return the
- # output of (2) since training metrics are not available until (2)
- # runs.
- train_op = Concurrently(
- [store_op, replay_op],
- mode="round_robin",
- output_indexes=[1],
- round_robin_weights=calculate_round_robin_weights(config))
- else:
- # No training is needed for the RANDOM strategy.
- train_op = rollouts
- return StandardMetricsReporting(train_op, workers, config)
- def calculate_round_robin_weights(config: TrainerConfigDict) -> List[float]:
- """Calculate the round robin weights for the rollout and train steps"""
- if not config["training_intensity"]:
- return [1, 1]
- # e.g., 32 / 4 -> native ratio of 8.0
- native_ratio = (
- config["train_batch_size"] / config["rollout_fragment_length"])
- # Training intensity is specified in terms of
- # (steps_replayed / steps_sampled), so adjust for the native ratio.
- weights = [1, config["training_intensity"] / native_ratio]
- return weights
- def get_policy_class(config: TrainerConfigDict) -> Type[Policy]:
- """Policy class picker function.
- Args:
- config (TrainerConfigDict): The trainer's configuration dict.
- Returns:
- Type[Policy]: The Policy class to use with SlateQTrainer.
- """
- if config["slateq_strategy"] == "RANDOM":
- return RandomPolicy
- else:
- return SlateQTorchPolicy
- SlateQTrainer = build_trainer(
- name="SlateQ",
- get_policy_class=get_policy_class,
- default_config=DEFAULT_CONFIG,
- validate_config=validate_config,
- execution_plan=execution_plan)
|