123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421 |
- """PyTorch policy class used for SlateQ"""
- from typing import Dict, List, Sequence, Tuple
- import gym
- import numpy as np
- import ray
- from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
- from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
- TorchDistributionWrapper)
- from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.policy_template import build_policy_class
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.framework import try_import_torch
- from ray.rllib.utils.typing import (ModelConfigDict, TensorType,
- TrainerConfigDict)
- torch, nn = try_import_torch()
- F = None
- if nn:
- F = nn.functional
- class QValueModel(nn.Module):
- """The Q-value model for SlateQ"""
- def __init__(self, embedding_size: int, q_hiddens: Sequence[int]):
- super().__init__()
- # construct hidden layers
- layers = []
- ins = 2 * embedding_size
- for n in q_hiddens:
- layers.append(nn.Linear(ins, n))
- layers.append(nn.LeakyReLU())
- ins = n
- layers.append(nn.Linear(ins, 1))
- self.layers = nn.Sequential(*layers)
- def forward(self, user: TensorType, doc: TensorType) -> TensorType:
- """Evaluate the user-doc Q model
- Args:
- user (TensorType): User embedding of shape (batch_size,
- embedding_size).
- doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
- embedding_size).
- Returns:
- score (TensorType): q_values of shape (batch_size, num_docs + 1).
- """
- batch_size, num_docs, embedding_size = doc.shape
- doc_flat = doc.view((batch_size * num_docs, embedding_size))
- user_repeated = user.repeat(num_docs, 1)
- x = torch.cat([user_repeated, doc_flat], dim=1)
- x = self.layers(x)
- # Similar to Google's SlateQ implementation in RecSim, we force the
- # Q-values to zeros if there are no clicks.
- x_no_click = torch.zeros((batch_size, 1), device=x.device)
- return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)
- class UserChoiceModel(nn.Module):
- r"""The user choice model for SlateQ
- This class implements a multinomial logit model for predicting user clicks.
- Under this model, the click probability of a document is proportional to:
- .. math::
- \exp(\text{beta} * \text{doc_user_affinity} + \text{score_no_click})
- """
- def __init__(self):
- super().__init__()
- self.beta = nn.Parameter(torch.tensor(0., dtype=torch.float))
- self.score_no_click = nn.Parameter(torch.tensor(0., dtype=torch.float))
- def forward(self, user: TensorType, doc: TensorType) -> TensorType:
- """Evaluate the user choice model
- This function outputs user click scores for candidate documents. The
- exponentials of these scores are proportional user click probabilities.
- Here we return the scores unnormalized because because only some of the
- documents will be selected and shown to the user.
- Args:
- user (TensorType): User embeddings of shape (batch_size,
- embedding_size).
- doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
- embedding_size).
- Returns:
- score (TensorType): logits of shape (batch_size, num_docs + 1),
- where the last dimension represents no_click.
- """
- batch_size = user.shape[0]
- s = torch.einsum("be,bde->bd", user, doc)
- s = s * self.beta
- s = torch.cat([s, self.score_no_click.expand((batch_size, 1))], dim=1)
- return s
- class SlateQModel(TorchModelV2, nn.Module):
- """The SlateQ model class
- It includes both the user choice model and the Q-value model.
- """
- def __init__(
- self,
- obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- model_config: ModelConfigDict,
- name: str,
- *,
- embedding_size: int,
- q_hiddens: Sequence[int],
- ):
- nn.Module.__init__(self)
- TorchModelV2.__init__(
- self,
- obs_space,
- action_space,
- # This required parameter (num_outputs) seems redundant: it has no
- # real imact, and can be set arbitrarily. TODO: fix this.
- num_outputs=0,
- model_config=model_config,
- name=name)
- self.choice_model = UserChoiceModel()
- self.q_model = QValueModel(embedding_size, q_hiddens)
- self.slate_size = len(action_space.nvec)
- def choose_slate(self, user: TensorType,
- doc: TensorType) -> Tuple[TensorType, TensorType]:
- """Build a slate by selecting from candidate documents
- Args:
- user (TensorType): User embeddings of shape (batch_size,
- embedding_size).
- doc (TensorType): Doc embeddings of shape (batch_size,
- num_docs, embedding_size).
- Returns:
- slate_selected (TensorType): Indices of documents selected for
- the slate, with shape (batch_size, slate_size).
- best_slate_q_value (TensorType): The Q-value of the selected slate,
- with shape (batch_size).
- """
- # Step 1: compute item scores (proportional to click probabilities)
- # raw_scores.shape=[batch_size, num_docs+1]
- raw_scores = self.choice_model(user, doc)
- # max_raw_scores.shape=[batch_size, 1]
- max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
- # deduct scores by max_scores to avoid value explosion
- scores = torch.exp(raw_scores - max_raw_scores)
- scores_doc = scores[:, :-1] # shape=[batch_size, num_docs]
- scores_no_click = scores[:, [-1]] # shape=[batch_size, 1]
- # Step 2: calculate the item-wise Q values
- # q_values.shape=[batch_size, num_docs+1]
- q_values = self.q_model(user, doc)
- q_values_doc = q_values[:, :-1] # shape=[batch_size, num_docs]
- q_values_no_click = q_values[:, [-1]] # shape=[batch_size, 1]
- # Step 3: construct all possible slates
- _, num_docs, _ = doc.shape
- indices = torch.arange(num_docs, dtype=torch.long, device=doc.device)
- # slates.shape = [num_slates, slate_size]
- slates = torch.combinations(indices, r=self.slate_size)
- num_slates, _ = slates.shape
- # Step 4: calculate slate Q values
- batch_size, _ = q_values_doc.shape
- # slate_decomp_q_values.shape: [batch_size, num_slates, slate_size]
- slate_decomp_q_values = torch.gather(
- # input.shape: [batch_size, num_slates, num_docs]
- input=q_values_doc.unsqueeze(1).expand(-1, num_slates, -1),
- dim=2,
- # index.shape: [batch_size, num_slates, slate_size]
- index=slates.unsqueeze(0).expand(batch_size, -1, -1))
- # slate_scores.shape: [batch_size, num_slates, slate_size]
- slate_scores = torch.gather(
- # input.shape: [batch_size, num_slates, num_docs]
- input=scores_doc.unsqueeze(1).expand(-1, num_slates, -1),
- dim=2,
- # index.shape: [batch_size, num_slates, slate_size]
- index=slates.unsqueeze(0).expand(batch_size, -1, -1))
- # slate_q_values.shape: [batch_size, num_slates]
- slate_q_values = ((slate_decomp_q_values * slate_scores).sum(dim=2) +
- (q_values_no_click * scores_no_click)) / (
- slate_scores.sum(dim=2) + scores_no_click)
- # Step 5: find the slate that maximizes q value
- best_slate_q_value, max_idx = torch.max(slate_q_values, dim=1)
- # slates_selected.shape: [batch_size, slate_size]
- slates_selected = slates[max_idx]
- return slates_selected, best_slate_q_value
- def forward(self, input_dict: Dict[str, TensorType],
- state: List[TensorType],
- seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
- # user.shape: [batch_size, embedding_size]
- user = input_dict[SampleBatch.OBS]["user"]
- # doc.shape: [batch_size, num_docs, embedding_size]
- doc = torch.cat([
- val.unsqueeze(1)
- for val in input_dict[SampleBatch.OBS]["doc"].values()
- ], 1)
- slates_selected, _ = self.choose_slate(user, doc)
- state_out = []
- return slates_selected, state_out
- def build_slateq_model_and_distribution(
- policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
- """Build models for SlateQ
- Args:
- policy (Policy): The policy, which will use the model for optimization.
- obs_space (gym.spaces.Space): The policy's observation space.
- action_space (gym.spaces.Space): The policy's action space.
- config (TrainerConfigDict):
- Returns:
- (q_model, TorchCategorical)
- """
- model = SlateQModel(
- obs_space,
- action_space,
- model_config=config["model"],
- name="slateq_model",
- embedding_size=config["recsim_embedding_size"],
- q_hiddens=config["hiddens"],
- )
- return model, TorchCategorical
- def build_slateq_losses(policy: Policy, model: SlateQModel, _,
- train_batch: SampleBatch) -> TensorType:
- """Constructs the losses for SlateQPolicy.
- Args:
- policy (Policy): The Policy to calculate the loss for.
- model (ModelV2): The Model to calculate the loss for.
- train_batch (SampleBatch): The training data.
- Returns:
- TensorType: A single loss tensor.
- """
- obs = restore_original_dimensions(
- train_batch[SampleBatch.OBS],
- policy.observation_space,
- tensorlib=torch)
- # user.shape: [batch_size, embedding_size]
- user = obs["user"]
- # doc.shape: [batch_size, num_docs, embedding_size]
- doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
- # action.shape: [batch_size, slate_size]
- actions = train_batch[SampleBatch.ACTIONS]
- next_obs = restore_original_dimensions(
- train_batch[SampleBatch.NEXT_OBS],
- policy.observation_space,
- tensorlib=torch)
- # Step 1: Build user choice model loss
- _, _, embedding_size = doc.shape
- # selected_doc.shape: [batch_size, slate_size, embedding_size]
- selected_doc = torch.gather(
- # input.shape: [batch_size, num_docs, embedding_size]
- input=doc,
- dim=1,
- # index.shape: [batch_size, slate_size, embedding_size]
- index=actions.unsqueeze(2).expand(-1, -1, embedding_size))
- scores = model.choice_model(user, selected_doc)
- choice_loss_fn = nn.CrossEntropyLoss()
- # clicks.shape: [batch_size, slate_size]
- clicks = torch.stack(
- [resp["click"][:, 1] for resp in next_obs["response"]], dim=1)
- no_clicks = 1 - torch.sum(clicks, 1, keepdim=True)
- # clicks.shape: [batch_size, slate_size+1]
- targets = torch.cat([clicks, no_clicks], dim=1)
- choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1))
- # print(model.choice_model.a.item(), model.choice_model.b.item())
- # Step 2: Build qvalue loss
- # Fields in available in train_batch: ['t', 'eps_id', 'agent_index',
- # 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions',
- # 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights',
- # 'batch_indexes']
- learning_strategy = policy.config["slateq_strategy"]
- if learning_strategy == "SARSA":
- # next_doc.shape: [batch_size, num_docs, embedding_size]
- next_doc = torch.cat(
- [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
- next_actions = train_batch["next_actions"]
- _, _, embedding_size = next_doc.shape
- # selected_doc.shape: [batch_size, slate_size, embedding_size]
- next_selected_doc = torch.gather(
- # input.shape: [batch_size, num_docs, embedding_size]
- input=next_doc,
- dim=1,
- # index.shape: [batch_size, slate_size, embedding_size]
- index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size))
- next_user = next_obs["user"]
- dones = train_batch["dones"]
- with torch.no_grad():
- # q_values.shape: [batch_size, slate_size+1]
- q_values = model.q_model(next_user, next_selected_doc)
- # raw_scores.shape: [batch_size, slate_size+1]
- raw_scores = model.choice_model(next_user, next_selected_doc)
- max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
- scores = torch.exp(raw_scores - max_raw_scores)
- # next_q_values.shape: [batch_size]
- next_q_values = torch.sum(
- q_values * scores, dim=1) / torch.sum(
- scores, dim=1)
- next_q_values[dones] = 0.0
- elif learning_strategy == "MYOP":
- next_q_values = 0.
- elif learning_strategy == "QL":
- # next_doc.shape: [batch_size, num_docs, embedding_size]
- next_doc = torch.cat(
- [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
- next_user = next_obs["user"]
- dones = train_batch["dones"]
- with torch.no_grad():
- _, next_q_values = model.choose_slate(next_user, next_doc)
- next_q_values[dones] = 0.0
- else:
- raise ValueError(learning_strategy)
- # target_q_values.shape: [batch_size]
- target_q_values = next_q_values + train_batch["rewards"]
- q_values = model.q_model(user,
- selected_doc) # shape: [batch_size, slate_size+1]
- # raw_scores.shape: [batch_size, slate_size+1]
- raw_scores = model.choice_model(user, selected_doc)
- max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
- scores = torch.exp(raw_scores - max_raw_scores)
- q_values = torch.sum(
- q_values * scores, dim=1) / torch.sum(
- scores, dim=1) # shape=[batch_size]
- q_value_loss = nn.MSELoss()(q_values, target_q_values)
- return [choice_loss, q_value_loss]
- def build_slateq_optimizers(policy: Policy, config: TrainerConfigDict
- ) -> List["torch.optim.Optimizer"]:
- optimizer_choice = torch.optim.Adam(
- policy.model.choice_model.parameters(), lr=config["lr_choice_model"])
- optimizer_q_value = torch.optim.Adam(
- policy.model.q_model.parameters(),
- lr=config["lr_q_model"],
- eps=config["adam_epsilon"])
- return [optimizer_choice, optimizer_q_value]
- def action_sampler_fn(policy: Policy, model: SlateQModel, input_dict, state,
- explore, timestep):
- """Determine which action to take"""
- # First, we transform the observation into its unflattened form
- obs = restore_original_dimensions(
- input_dict[SampleBatch.CUR_OBS],
- policy.observation_space,
- tensorlib=torch)
- # user.shape: [batch_size(=1), embedding_size]
- user = obs["user"]
- # doc.shape: [batch_size(=1), num_docs, embedding_size]
- doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
- selected_slates, _ = model.choose_slate(user, doc)
- action = selected_slates
- logp = None
- state_out = []
- return action, logp, state_out
- def postprocess_fn_add_next_actions_for_sarsa(policy: Policy,
- batch: SampleBatch,
- other_agent=None,
- episode=None) -> SampleBatch:
- """Add next_actions to SampleBatch for SARSA training"""
- if policy.config["slateq_strategy"] == "SARSA":
- if not batch["dones"][-1]:
- raise RuntimeError(
- "Expected a complete episode in each sample batch. "
- f"But this batch is not: {batch}.")
- batch["next_actions"] = np.roll(batch["actions"], -1, axis=0)
- return batch
- SlateQTorchPolicy = build_policy_class(
- name="SlateQTorchPolicy",
- framework="torch",
- get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,
- # build model, loss functions, and optimizers
- make_model_and_action_dist=build_slateq_model_and_distribution,
- optimizer_fn=build_slateq_optimizers,
- loss_fn=build_slateq_losses,
- # define how to act
- action_sampler_fn=action_sampler_fn,
- # post processing batch sampled data
- postprocess_fn=postprocess_fn_add_next_actions_for_sarsa,
- )
|