bandit_envs_recommender_system.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. """Examples for recommender system simulating envs ready to be used by RLlib Algorithms.
  2. This env follows RecSim obs and action APIs.
  3. """
  4. import gymnasium as gym
  5. import numpy as np
  6. from typing import Optional
  7. from ray.rllib.utils.numpy import softmax
  8. class ParametricRecSys(gym.Env):
  9. """A recommendation environment which generates items with visible features
  10. randomly (parametric actions).
  11. The environment can be configured to be multi-user, i.e. different models
  12. will be learned independently for each user, by setting num_users_in_db
  13. parameter.
  14. To enable slate recommendation, the `slate_size` config parameter can be
  15. set as > 1.
  16. """
  17. def __init__(
  18. self,
  19. embedding_size: int = 20,
  20. num_docs_to_select_from: int = 10,
  21. slate_size: int = 1,
  22. num_docs_in_db: Optional[int] = None,
  23. num_users_in_db: Optional[int] = None,
  24. user_time_budget: float = 60.0,
  25. ):
  26. """Initializes a ParametricRecSys instance.
  27. Args:
  28. embedding_size: Embedding size for both users and docs.
  29. Each value in the user/doc embeddings can have values between
  30. -1.0 and 1.0.
  31. num_docs_to_select_from: The number of documents to present to the
  32. agent each timestep. The agent will then have to pick a slate
  33. out of these.
  34. slate_size: The size of the slate to recommend to the user at each
  35. timestep.
  36. num_docs_in_db: The total number of documents in the DB. Set this
  37. to None, in case you would like to resample docs from an
  38. infinite pool.
  39. num_users_in_db: The total number of users in the DB. Set this to
  40. None, in case you would like to resample users from an infinite
  41. pool.
  42. user_time_budget: The total time budget a user has throughout an
  43. episode. Once this time budget is used up (through engagements
  44. with clicked/selected documents), the episode ends.
  45. """
  46. self.embedding_size = embedding_size
  47. self.num_docs_to_select_from = num_docs_to_select_from
  48. self.slate_size = slate_size
  49. self.num_docs_in_db = num_docs_in_db
  50. self.docs_db = None
  51. self.num_users_in_db = num_users_in_db
  52. self.users_db = None
  53. self.current_user = None
  54. self.user_time_budget = user_time_budget
  55. self.current_user_budget = user_time_budget
  56. self.observation_space = gym.spaces.Dict(
  57. {
  58. # The D docs our agent sees at each timestep.
  59. # It has to select a k-slate out of these.
  60. "doc": gym.spaces.Dict(
  61. {
  62. str(i): gym.spaces.Box(
  63. -1.0, 1.0, shape=(self.embedding_size,), dtype=np.float32
  64. )
  65. for i in range(self.num_docs_to_select_from)
  66. }
  67. ),
  68. # The user engaging in this timestep/episode.
  69. "user": gym.spaces.Box(
  70. -1.0, 1.0, shape=(self.embedding_size,), dtype=np.float32
  71. ),
  72. # For each item in the previous slate, was it clicked?
  73. # If yes, how long was it being engaged with (e.g. watched)?
  74. "response": gym.spaces.Tuple(
  75. [
  76. gym.spaces.Dict(
  77. {
  78. # Clicked or not?
  79. "click": gym.spaces.Discrete(2),
  80. # Engagement time (how many minutes watched?).
  81. "engagement": gym.spaces.Box(
  82. 0.0, 100.0, shape=(), dtype=np.float32
  83. ),
  84. }
  85. )
  86. for _ in range(self.slate_size)
  87. ]
  88. ),
  89. }
  90. )
  91. # Our action space is
  92. self.action_space = gym.spaces.MultiDiscrete(
  93. [self.num_docs_to_select_from for _ in range(self.slate_size)]
  94. )
  95. def _get_embedding(self):
  96. return np.random.uniform(-1, 1, size=(self.embedding_size,)).astype(np.float32)
  97. def reset(self, *, seed=None, options=None):
  98. # Reset the current user's time budget.
  99. self.current_user_budget = self.user_time_budget
  100. # Sample a user for the next episode/session.
  101. # Pick from a only-once-sampled user DB.
  102. if self.num_users_in_db is not None:
  103. if self.users_db is None:
  104. self.users_db = [
  105. self._get_embedding() for _ in range(self.num_users_in_db)
  106. ]
  107. self.current_user = self.users_db[np.random.choice(self.num_users_in_db)]
  108. # Pick from an infinite pool of users.
  109. else:
  110. self.current_user = self._get_embedding()
  111. return self._get_obs(), {}
  112. def step(self, action):
  113. # Action is the suggested slate (indices of the docs in the
  114. # suggested ones).
  115. # We calculate scores as the dot product between document features and user
  116. # features. The softmax ensures regret<1 further down.
  117. scores = softmax(
  118. [np.dot(self.current_user, doc) for doc in self.currently_suggested_docs]
  119. )
  120. best_reward = np.max(scores)
  121. # User choice model: User picks a doc stochastically,
  122. # where probs are dot products between user- and doc feature
  123. # (categories) vectors (rewards).
  124. # There is also a no-click doc whose weight is 0.0.
  125. user_doc_overlaps = np.array([scores[a] for a in action] + [0.0])
  126. # We have to softmax again so that probabilities add up to 1
  127. probabilities = softmax(user_doc_overlaps)
  128. which_clicked = np.random.choice(
  129. np.arange(self.slate_size + 1), p=probabilities
  130. )
  131. reward = 0.0
  132. if which_clicked < self.slate_size:
  133. # Reward is 1.0 - regret if clicked. 0.0 if not clicked.
  134. regret = best_reward - user_doc_overlaps[which_clicked]
  135. # The reward also represents the user engagement that we define to be
  136. # withing the range [0...100].
  137. reward = (1 - regret) * 100
  138. # If anything clicked, deduct from the current user's time budget.
  139. self.current_user_budget -= 1.0
  140. done = truncated = self.current_user_budget <= 0.0
  141. # Compile response.
  142. response = tuple(
  143. {
  144. "click": int(idx == which_clicked),
  145. "engagement": reward if idx == which_clicked else 0.0,
  146. }
  147. for idx in range(len(user_doc_overlaps) - 1)
  148. )
  149. return self._get_obs(response=response), reward, done, truncated, {}
  150. def _get_obs(self, response=None):
  151. # Sample D docs from infinity or our pre-existing docs.
  152. # Pick from a only-once-sampled docs DB.
  153. if self.num_docs_in_db is not None:
  154. if self.docs_db is None:
  155. self.docs_db = [
  156. self._get_embedding() for _ in range(self.num_docs_in_db)
  157. ]
  158. self.currently_suggested_docs = [
  159. self.docs_db[doc_idx].astype(np.float32)
  160. for doc_idx in np.random.choice(
  161. self.num_docs_in_db,
  162. size=(self.num_docs_to_select_from,),
  163. replace=False,
  164. )
  165. ]
  166. # Pick from an infinite pool of docs.
  167. else:
  168. self.currently_suggested_docs = [
  169. self._get_embedding() for _ in range(self.num_docs_to_select_from)
  170. ]
  171. doc = {str(i): d for i, d in enumerate(self.currently_suggested_docs)}
  172. if not response:
  173. response = self.observation_space["response"].sample()
  174. return {
  175. "user": self.current_user.astype(np.float32),
  176. "doc": doc,
  177. "response": response,
  178. }
  179. if __name__ == "__main__":
  180. """Test RecommSys env with random actions for baseline performance."""
  181. env = ParametricRecSys(
  182. num_docs_in_db=100,
  183. num_users_in_db=1,
  184. )
  185. obs, info = env.reset()
  186. num_episodes = 0
  187. episode_rewards = []
  188. episode_reward = 0.0
  189. while num_episodes < 100:
  190. action = env.action_space.sample()
  191. obs, reward, done, truncated, _ = env.step(action)
  192. episode_reward += reward
  193. if done:
  194. print(f"episode reward = {episode_reward}")
  195. env.reset()
  196. num_episodes += 1
  197. episode_rewards.append(episode_reward)
  198. episode_reward = 0.0
  199. print(f"Avg reward={np.mean(episode_rewards)}")