bandit_envs_discrete.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import copy
  2. import gymnasium as gym
  3. from gymnasium.spaces import Box, Discrete
  4. import numpy as np
  5. import random
  6. class SimpleContextualBandit(gym.Env):
  7. """Simple env w/ 2 states and 3 actions (arms): 0, 1, and 2.
  8. Episodes last only for one timestep, possible observations are:
  9. [-1.0, 1.0] and [1.0, -1.0], where the first element is the "current context".
  10. The highest reward (+10.0) is received for selecting arm 0 for context=1.0
  11. and arm 2 for context=-1.0. Action 1 always yields 0.0 reward.
  12. """
  13. def __init__(self, config=None):
  14. self.action_space = Discrete(3)
  15. self.observation_space = Box(low=-1.0, high=1.0, shape=(2,))
  16. self.cur_context = None
  17. def reset(self, *, seed=None, options=None):
  18. self.cur_context = random.choice([-1.0, 1.0])
  19. return np.array([self.cur_context, -self.cur_context]), {}
  20. def step(self, action):
  21. rewards_for_context = {
  22. -1.0: [-10, 0, 10],
  23. 1.0: [10, 0, -10],
  24. }
  25. reward = rewards_for_context[self.cur_context][action]
  26. return (
  27. np.array([-self.cur_context, self.cur_context]),
  28. reward,
  29. True,
  30. False,
  31. {"regret": 10 - reward},
  32. )
  33. class LinearDiscreteEnv(gym.Env):
  34. """Samples data from linearly parameterized arms.
  35. The reward for context X and arm i is given by X^T * theta_i, for some
  36. latent set of parameters {theta_i : i = 1, ..., k}.
  37. The thetas are sampled uniformly at random, the contexts are Gaussian,
  38. and Gaussian noise is added to the rewards.
  39. """
  40. DEFAULT_CONFIG_LINEAR = {
  41. "feature_dim": 8,
  42. "num_actions": 4,
  43. "reward_noise_std": 0.01,
  44. }
  45. def __init__(self, config=None):
  46. self.config = copy.copy(self.DEFAULT_CONFIG_LINEAR)
  47. if config is not None and type(config) == dict:
  48. self.config.update(config)
  49. self.feature_dim = self.config["feature_dim"]
  50. self.num_actions = self.config["num_actions"]
  51. self.sigma = self.config["reward_noise_std"]
  52. self.action_space = Discrete(self.num_actions)
  53. self.observation_space = Box(low=-10, high=10, shape=(self.feature_dim,))
  54. self.thetas = np.random.uniform(-1, 1, (self.num_actions, self.feature_dim))
  55. self.thetas /= np.linalg.norm(self.thetas, axis=1, keepdims=True)
  56. self._elapsed_steps = 0
  57. self._current_context = None
  58. def _sample_context(self):
  59. return np.random.normal(scale=1 / 3, size=(self.feature_dim,))
  60. def reset(self, *, seed=None, options=None):
  61. self._current_context = self._sample_context()
  62. return self._current_context, {}
  63. def step(self, action):
  64. assert (
  65. self._elapsed_steps is not None
  66. ), "Cannot call env.step() beforecalling reset()"
  67. assert action < self.num_actions, "Invalid action."
  68. action = int(action)
  69. context = self._current_context
  70. rewards = self.thetas.dot(context)
  71. opt_action = rewards.argmax()
  72. regret = rewards.max() - rewards[action]
  73. # Add Gaussian noise
  74. rewards += np.random.normal(scale=self.sigma, size=rewards.shape)
  75. reward = rewards[action]
  76. self._current_context = self._sample_context()
  77. return (
  78. self._current_context,
  79. reward,
  80. True,
  81. False,
  82. {"regret": regret, "opt_action": opt_action},
  83. )
  84. def render(self, mode="human"):
  85. raise NotImplementedError
  86. class WheelBanditEnv(gym.Env):
  87. """Wheel bandit environment for 2D contexts
  88. (see https://arxiv.org/abs/1802.09127).
  89. """
  90. DEFAULT_CONFIG_WHEEL = {
  91. "delta": 0.5,
  92. "mu_1": 1.2,
  93. "mu_2": 1,
  94. "mu_3": 50,
  95. "std": 0.01,
  96. }
  97. feature_dim = 2
  98. num_actions = 5
  99. def __init__(self, config=None):
  100. self.config = copy.copy(self.DEFAULT_CONFIG_WHEEL)
  101. if config is not None and type(config) == dict:
  102. self.config.update(config)
  103. self.delta = self.config["delta"]
  104. self.mu_1 = self.config["mu_1"]
  105. self.mu_2 = self.config["mu_2"]
  106. self.mu_3 = self.config["mu_3"]
  107. self.std = self.config["std"]
  108. self.action_space = Discrete(self.num_actions)
  109. self.observation_space = Box(low=-1, high=1, shape=(self.feature_dim,))
  110. self.means = [self.mu_1] + 4 * [self.mu_2]
  111. self._elapsed_steps = 0
  112. self._current_context = None
  113. def _sample_context(self):
  114. while True:
  115. state = np.random.uniform(-1, 1, self.feature_dim)
  116. if np.linalg.norm(state) <= 1:
  117. return state
  118. def reset(self, *, seed=None, options=None):
  119. self._current_context = self._sample_context()
  120. return self._current_context, {}
  121. def step(self, action):
  122. assert (
  123. self._elapsed_steps is not None
  124. ), "Cannot call env.step() before calling reset()"
  125. action = int(action)
  126. self._elapsed_steps += 1
  127. rewards = [
  128. np.random.normal(self.means[j], self.std) for j in range(self.num_actions)
  129. ]
  130. context = self._current_context
  131. r_big = np.random.normal(self.mu_3, self.std)
  132. if np.linalg.norm(context) >= self.delta:
  133. if context[0] > 0:
  134. if context[1] > 0:
  135. # First quadrant
  136. rewards[1] = r_big
  137. opt_action = 1
  138. else:
  139. # Fourth quadrant
  140. rewards[4] = r_big
  141. opt_action = 4
  142. else:
  143. if context[1] > 0:
  144. # Second quadrant
  145. rewards[2] = r_big
  146. opt_action = 2
  147. else:
  148. # Third quadrant
  149. rewards[3] = r_big
  150. opt_action = 3
  151. else:
  152. # Smaller region where action 0 is optimal
  153. opt_action = 0
  154. reward = rewards[action]
  155. regret = rewards[opt_action] - reward
  156. self._current_context = self._sample_context()
  157. return (
  158. self._current_context,
  159. reward,
  160. True,
  161. False,
  162. {"regret": regret, "opt_action": opt_action},
  163. )
  164. def render(self, mode="human"):
  165. raise NotImplementedError