slateq_torch_policy.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. """PyTorch policy class used for SlateQ"""
  2. from typing import Dict, List, Sequence, Tuple
  3. import gym
  4. import numpy as np
  5. import ray
  6. from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
  7. from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
  8. TorchDistributionWrapper)
  9. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  10. from ray.rllib.policy.policy import Policy
  11. from ray.rllib.policy.policy_template import build_policy_class
  12. from ray.rllib.policy.sample_batch import SampleBatch
  13. from ray.rllib.utils.framework import try_import_torch
  14. from ray.rllib.utils.typing import (ModelConfigDict, TensorType,
  15. TrainerConfigDict)
  16. torch, nn = try_import_torch()
  17. F = None
  18. if nn:
  19. F = nn.functional
  20. class QValueModel(nn.Module):
  21. """The Q-value model for SlateQ"""
  22. def __init__(self, embedding_size: int, q_hiddens: Sequence[int]):
  23. super().__init__()
  24. # construct hidden layers
  25. layers = []
  26. ins = 2 * embedding_size
  27. for n in q_hiddens:
  28. layers.append(nn.Linear(ins, n))
  29. layers.append(nn.LeakyReLU())
  30. ins = n
  31. layers.append(nn.Linear(ins, 1))
  32. self.layers = nn.Sequential(*layers)
  33. def forward(self, user: TensorType, doc: TensorType) -> TensorType:
  34. """Evaluate the user-doc Q model
  35. Args:
  36. user (TensorType): User embedding of shape (batch_size,
  37. embedding_size).
  38. doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
  39. embedding_size).
  40. Returns:
  41. score (TensorType): q_values of shape (batch_size, num_docs + 1).
  42. """
  43. batch_size, num_docs, embedding_size = doc.shape
  44. doc_flat = doc.view((batch_size * num_docs, embedding_size))
  45. user_repeated = user.repeat(num_docs, 1)
  46. x = torch.cat([user_repeated, doc_flat], dim=1)
  47. x = self.layers(x)
  48. # Similar to Google's SlateQ implementation in RecSim, we force the
  49. # Q-values to zeros if there are no clicks.
  50. x_no_click = torch.zeros((batch_size, 1), device=x.device)
  51. return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)
  52. class UserChoiceModel(nn.Module):
  53. r"""The user choice model for SlateQ
  54. This class implements a multinomial logit model for predicting user clicks.
  55. Under this model, the click probability of a document is proportional to:
  56. .. math::
  57. \exp(\text{beta} * \text{doc_user_affinity} + \text{score_no_click})
  58. """
  59. def __init__(self):
  60. super().__init__()
  61. self.beta = nn.Parameter(torch.tensor(0., dtype=torch.float))
  62. self.score_no_click = nn.Parameter(torch.tensor(0., dtype=torch.float))
  63. def forward(self, user: TensorType, doc: TensorType) -> TensorType:
  64. """Evaluate the user choice model
  65. This function outputs user click scores for candidate documents. The
  66. exponentials of these scores are proportional user click probabilities.
  67. Here we return the scores unnormalized because because only some of the
  68. documents will be selected and shown to the user.
  69. Args:
  70. user (TensorType): User embeddings of shape (batch_size,
  71. embedding_size).
  72. doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
  73. embedding_size).
  74. Returns:
  75. score (TensorType): logits of shape (batch_size, num_docs + 1),
  76. where the last dimension represents no_click.
  77. """
  78. batch_size = user.shape[0]
  79. s = torch.einsum("be,bde->bd", user, doc)
  80. s = s * self.beta
  81. s = torch.cat([s, self.score_no_click.expand((batch_size, 1))], dim=1)
  82. return s
  83. class SlateQModel(TorchModelV2, nn.Module):
  84. """The SlateQ model class
  85. It includes both the user choice model and the Q-value model.
  86. """
  87. def __init__(
  88. self,
  89. obs_space: gym.spaces.Space,
  90. action_space: gym.spaces.Space,
  91. model_config: ModelConfigDict,
  92. name: str,
  93. *,
  94. embedding_size: int,
  95. q_hiddens: Sequence[int],
  96. ):
  97. nn.Module.__init__(self)
  98. TorchModelV2.__init__(
  99. self,
  100. obs_space,
  101. action_space,
  102. # This required parameter (num_outputs) seems redundant: it has no
  103. # real imact, and can be set arbitrarily. TODO: fix this.
  104. num_outputs=0,
  105. model_config=model_config,
  106. name=name)
  107. self.choice_model = UserChoiceModel()
  108. self.q_model = QValueModel(embedding_size, q_hiddens)
  109. self.slate_size = len(action_space.nvec)
  110. def choose_slate(self, user: TensorType,
  111. doc: TensorType) -> Tuple[TensorType, TensorType]:
  112. """Build a slate by selecting from candidate documents
  113. Args:
  114. user (TensorType): User embeddings of shape (batch_size,
  115. embedding_size).
  116. doc (TensorType): Doc embeddings of shape (batch_size,
  117. num_docs, embedding_size).
  118. Returns:
  119. slate_selected (TensorType): Indices of documents selected for
  120. the slate, with shape (batch_size, slate_size).
  121. best_slate_q_value (TensorType): The Q-value of the selected slate,
  122. with shape (batch_size).
  123. """
  124. # Step 1: compute item scores (proportional to click probabilities)
  125. # raw_scores.shape=[batch_size, num_docs+1]
  126. raw_scores = self.choice_model(user, doc)
  127. # max_raw_scores.shape=[batch_size, 1]
  128. max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
  129. # deduct scores by max_scores to avoid value explosion
  130. scores = torch.exp(raw_scores - max_raw_scores)
  131. scores_doc = scores[:, :-1] # shape=[batch_size, num_docs]
  132. scores_no_click = scores[:, [-1]] # shape=[batch_size, 1]
  133. # Step 2: calculate the item-wise Q values
  134. # q_values.shape=[batch_size, num_docs+1]
  135. q_values = self.q_model(user, doc)
  136. q_values_doc = q_values[:, :-1] # shape=[batch_size, num_docs]
  137. q_values_no_click = q_values[:, [-1]] # shape=[batch_size, 1]
  138. # Step 3: construct all possible slates
  139. _, num_docs, _ = doc.shape
  140. indices = torch.arange(num_docs, dtype=torch.long, device=doc.device)
  141. # slates.shape = [num_slates, slate_size]
  142. slates = torch.combinations(indices, r=self.slate_size)
  143. num_slates, _ = slates.shape
  144. # Step 4: calculate slate Q values
  145. batch_size, _ = q_values_doc.shape
  146. # slate_decomp_q_values.shape: [batch_size, num_slates, slate_size]
  147. slate_decomp_q_values = torch.gather(
  148. # input.shape: [batch_size, num_slates, num_docs]
  149. input=q_values_doc.unsqueeze(1).expand(-1, num_slates, -1),
  150. dim=2,
  151. # index.shape: [batch_size, num_slates, slate_size]
  152. index=slates.unsqueeze(0).expand(batch_size, -1, -1))
  153. # slate_scores.shape: [batch_size, num_slates, slate_size]
  154. slate_scores = torch.gather(
  155. # input.shape: [batch_size, num_slates, num_docs]
  156. input=scores_doc.unsqueeze(1).expand(-1, num_slates, -1),
  157. dim=2,
  158. # index.shape: [batch_size, num_slates, slate_size]
  159. index=slates.unsqueeze(0).expand(batch_size, -1, -1))
  160. # slate_q_values.shape: [batch_size, num_slates]
  161. slate_q_values = ((slate_decomp_q_values * slate_scores).sum(dim=2) +
  162. (q_values_no_click * scores_no_click)) / (
  163. slate_scores.sum(dim=2) + scores_no_click)
  164. # Step 5: find the slate that maximizes q value
  165. best_slate_q_value, max_idx = torch.max(slate_q_values, dim=1)
  166. # slates_selected.shape: [batch_size, slate_size]
  167. slates_selected = slates[max_idx]
  168. return slates_selected, best_slate_q_value
  169. def forward(self, input_dict: Dict[str, TensorType],
  170. state: List[TensorType],
  171. seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
  172. # user.shape: [batch_size, embedding_size]
  173. user = input_dict[SampleBatch.OBS]["user"]
  174. # doc.shape: [batch_size, num_docs, embedding_size]
  175. doc = torch.cat([
  176. val.unsqueeze(1)
  177. for val in input_dict[SampleBatch.OBS]["doc"].values()
  178. ], 1)
  179. slates_selected, _ = self.choose_slate(user, doc)
  180. state_out = []
  181. return slates_selected, state_out
  182. def build_slateq_model_and_distribution(
  183. policy: Policy, obs_space: gym.spaces.Space,
  184. action_space: gym.spaces.Space,
  185. config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
  186. """Build models for SlateQ
  187. Args:
  188. policy (Policy): The policy, which will use the model for optimization.
  189. obs_space (gym.spaces.Space): The policy's observation space.
  190. action_space (gym.spaces.Space): The policy's action space.
  191. config (TrainerConfigDict):
  192. Returns:
  193. (q_model, TorchCategorical)
  194. """
  195. model = SlateQModel(
  196. obs_space,
  197. action_space,
  198. model_config=config["model"],
  199. name="slateq_model",
  200. embedding_size=config["recsim_embedding_size"],
  201. q_hiddens=config["hiddens"],
  202. )
  203. return model, TorchCategorical
  204. def build_slateq_losses(policy: Policy, model: SlateQModel, _,
  205. train_batch: SampleBatch) -> TensorType:
  206. """Constructs the losses for SlateQPolicy.
  207. Args:
  208. policy (Policy): The Policy to calculate the loss for.
  209. model (ModelV2): The Model to calculate the loss for.
  210. train_batch (SampleBatch): The training data.
  211. Returns:
  212. TensorType: A single loss tensor.
  213. """
  214. obs = restore_original_dimensions(
  215. train_batch[SampleBatch.OBS],
  216. policy.observation_space,
  217. tensorlib=torch)
  218. # user.shape: [batch_size, embedding_size]
  219. user = obs["user"]
  220. # doc.shape: [batch_size, num_docs, embedding_size]
  221. doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
  222. # action.shape: [batch_size, slate_size]
  223. actions = train_batch[SampleBatch.ACTIONS]
  224. next_obs = restore_original_dimensions(
  225. train_batch[SampleBatch.NEXT_OBS],
  226. policy.observation_space,
  227. tensorlib=torch)
  228. # Step 1: Build user choice model loss
  229. _, _, embedding_size = doc.shape
  230. # selected_doc.shape: [batch_size, slate_size, embedding_size]
  231. selected_doc = torch.gather(
  232. # input.shape: [batch_size, num_docs, embedding_size]
  233. input=doc,
  234. dim=1,
  235. # index.shape: [batch_size, slate_size, embedding_size]
  236. index=actions.unsqueeze(2).expand(-1, -1, embedding_size))
  237. scores = model.choice_model(user, selected_doc)
  238. choice_loss_fn = nn.CrossEntropyLoss()
  239. # clicks.shape: [batch_size, slate_size]
  240. clicks = torch.stack(
  241. [resp["click"][:, 1] for resp in next_obs["response"]], dim=1)
  242. no_clicks = 1 - torch.sum(clicks, 1, keepdim=True)
  243. # clicks.shape: [batch_size, slate_size+1]
  244. targets = torch.cat([clicks, no_clicks], dim=1)
  245. choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1))
  246. # print(model.choice_model.a.item(), model.choice_model.b.item())
  247. # Step 2: Build qvalue loss
  248. # Fields in available in train_batch: ['t', 'eps_id', 'agent_index',
  249. # 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions',
  250. # 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights',
  251. # 'batch_indexes']
  252. learning_strategy = policy.config["slateq_strategy"]
  253. if learning_strategy == "SARSA":
  254. # next_doc.shape: [batch_size, num_docs, embedding_size]
  255. next_doc = torch.cat(
  256. [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
  257. next_actions = train_batch["next_actions"]
  258. _, _, embedding_size = next_doc.shape
  259. # selected_doc.shape: [batch_size, slate_size, embedding_size]
  260. next_selected_doc = torch.gather(
  261. # input.shape: [batch_size, num_docs, embedding_size]
  262. input=next_doc,
  263. dim=1,
  264. # index.shape: [batch_size, slate_size, embedding_size]
  265. index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size))
  266. next_user = next_obs["user"]
  267. dones = train_batch["dones"]
  268. with torch.no_grad():
  269. # q_values.shape: [batch_size, slate_size+1]
  270. q_values = model.q_model(next_user, next_selected_doc)
  271. # raw_scores.shape: [batch_size, slate_size+1]
  272. raw_scores = model.choice_model(next_user, next_selected_doc)
  273. max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
  274. scores = torch.exp(raw_scores - max_raw_scores)
  275. # next_q_values.shape: [batch_size]
  276. next_q_values = torch.sum(
  277. q_values * scores, dim=1) / torch.sum(
  278. scores, dim=1)
  279. next_q_values[dones] = 0.0
  280. elif learning_strategy == "MYOP":
  281. next_q_values = 0.
  282. elif learning_strategy == "QL":
  283. # next_doc.shape: [batch_size, num_docs, embedding_size]
  284. next_doc = torch.cat(
  285. [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
  286. next_user = next_obs["user"]
  287. dones = train_batch["dones"]
  288. with torch.no_grad():
  289. _, next_q_values = model.choose_slate(next_user, next_doc)
  290. next_q_values[dones] = 0.0
  291. else:
  292. raise ValueError(learning_strategy)
  293. # target_q_values.shape: [batch_size]
  294. target_q_values = next_q_values + train_batch["rewards"]
  295. q_values = model.q_model(user,
  296. selected_doc) # shape: [batch_size, slate_size+1]
  297. # raw_scores.shape: [batch_size, slate_size+1]
  298. raw_scores = model.choice_model(user, selected_doc)
  299. max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
  300. scores = torch.exp(raw_scores - max_raw_scores)
  301. q_values = torch.sum(
  302. q_values * scores, dim=1) / torch.sum(
  303. scores, dim=1) # shape=[batch_size]
  304. q_value_loss = nn.MSELoss()(q_values, target_q_values)
  305. return [choice_loss, q_value_loss]
  306. def build_slateq_optimizers(policy: Policy, config: TrainerConfigDict
  307. ) -> List["torch.optim.Optimizer"]:
  308. optimizer_choice = torch.optim.Adam(
  309. policy.model.choice_model.parameters(), lr=config["lr_choice_model"])
  310. optimizer_q_value = torch.optim.Adam(
  311. policy.model.q_model.parameters(),
  312. lr=config["lr_q_model"],
  313. eps=config["adam_epsilon"])
  314. return [optimizer_choice, optimizer_q_value]
  315. def action_sampler_fn(policy: Policy, model: SlateQModel, input_dict, state,
  316. explore, timestep):
  317. """Determine which action to take"""
  318. # First, we transform the observation into its unflattened form
  319. obs = restore_original_dimensions(
  320. input_dict[SampleBatch.CUR_OBS],
  321. policy.observation_space,
  322. tensorlib=torch)
  323. # user.shape: [batch_size(=1), embedding_size]
  324. user = obs["user"]
  325. # doc.shape: [batch_size(=1), num_docs, embedding_size]
  326. doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
  327. selected_slates, _ = model.choose_slate(user, doc)
  328. action = selected_slates
  329. logp = None
  330. state_out = []
  331. return action, logp, state_out
  332. def postprocess_fn_add_next_actions_for_sarsa(policy: Policy,
  333. batch: SampleBatch,
  334. other_agent=None,
  335. episode=None) -> SampleBatch:
  336. """Add next_actions to SampleBatch for SARSA training"""
  337. if policy.config["slateq_strategy"] == "SARSA":
  338. if not batch["dones"][-1]:
  339. raise RuntimeError(
  340. "Expected a complete episode in each sample batch. "
  341. f"But this batch is not: {batch}.")
  342. batch["next_actions"] = np.roll(batch["actions"], -1, axis=0)
  343. return batch
  344. SlateQTorchPolicy = build_policy_class(
  345. name="SlateQTorchPolicy",
  346. framework="torch",
  347. get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,
  348. # build model, loss functions, and optimizers
  349. make_model_and_action_dist=build_slateq_model_and_distribution,
  350. optimizer_fn=build_slateq_optimizers,
  351. loss_fn=build_slateq_losses,
  352. # define how to act
  353. action_sampler_fn=action_sampler_fn,
  354. # post processing batch sampled data
  355. postprocess_fn=postprocess_fn_add_next_actions_for_sarsa,
  356. )