epsilon_greedy.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import gymnasium as gym
  2. import numpy as np
  3. import tree # pip install dm_tree
  4. import random
  5. from typing import Union, Optional
  6. from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution
  7. from ray.rllib.models.action_dist import ActionDistribution
  8. from ray.rllib.utils.annotations import override, PublicAPI
  9. from ray.rllib.utils.exploration.exploration import Exploration, TensorType
  10. from ray.rllib.utils.framework import try_import_tf, try_import_torch, get_variable
  11. from ray.rllib.utils.from_config import from_config
  12. from ray.rllib.utils.numpy import convert_to_numpy
  13. from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule
  14. from ray.rllib.utils.torch_utils import FLOAT_MIN
  15. tf1, tf, tfv = try_import_tf()
  16. torch, _ = try_import_torch()
  17. @PublicAPI
  18. class EpsilonGreedy(Exploration):
  19. """Epsilon-greedy Exploration class that produces exploration actions.
  20. When given a Model's output and a current epsilon value (based on some
  21. Schedule), it produces a random action (if rand(1) < eps) or
  22. uses the model-computed one (if rand(1) >= eps).
  23. """
  24. def __init__(
  25. self,
  26. action_space: gym.spaces.Space,
  27. *,
  28. framework: str,
  29. initial_epsilon: float = 1.0,
  30. final_epsilon: float = 0.05,
  31. warmup_timesteps: int = 0,
  32. epsilon_timesteps: int = int(1e5),
  33. epsilon_schedule: Optional[Schedule] = None,
  34. **kwargs,
  35. ):
  36. """Create an EpsilonGreedy exploration class.
  37. Args:
  38. action_space: The action space the exploration should occur in.
  39. framework: The framework specifier.
  40. initial_epsilon: The initial epsilon value to use.
  41. final_epsilon: The final epsilon value to use.
  42. warmup_timesteps: The timesteps over which to not change epsilon in the
  43. beginning.
  44. epsilon_timesteps: The timesteps (additional to `warmup_timesteps`)
  45. after which epsilon should always be `final_epsilon`.
  46. E.g.: warmup_timesteps=20k epsilon_timesteps=50k -> After 70k timesteps,
  47. epsilon will reach its final value.
  48. epsilon_schedule: An optional Schedule object
  49. to use (instead of constructing one from the given parameters).
  50. """
  51. assert framework is not None
  52. super().__init__(action_space=action_space, framework=framework, **kwargs)
  53. self.epsilon_schedule = from_config(
  54. Schedule, epsilon_schedule, framework=framework
  55. ) or PiecewiseSchedule(
  56. endpoints=[
  57. (0, initial_epsilon),
  58. (warmup_timesteps, initial_epsilon),
  59. (warmup_timesteps + epsilon_timesteps, final_epsilon),
  60. ],
  61. outside_value=final_epsilon,
  62. framework=self.framework,
  63. )
  64. # The current timestep value (tf-var or python int).
  65. self.last_timestep = get_variable(
  66. np.array(0, np.int64),
  67. framework=framework,
  68. tf_name="timestep",
  69. dtype=np.int64,
  70. )
  71. # Build the tf-info-op.
  72. if self.framework == "tf":
  73. self._tf_state_op = self.get_state()
  74. @override(Exploration)
  75. def get_exploration_action(
  76. self,
  77. *,
  78. action_distribution: ActionDistribution,
  79. timestep: Union[int, TensorType],
  80. explore: Optional[Union[bool, TensorType]] = True,
  81. ):
  82. if self.framework in ["tf2", "tf"]:
  83. return self._get_tf_exploration_action_op(
  84. action_distribution, explore, timestep
  85. )
  86. else:
  87. return self._get_torch_exploration_action(
  88. action_distribution, explore, timestep
  89. )
  90. def _get_tf_exploration_action_op(
  91. self,
  92. action_distribution: ActionDistribution,
  93. explore: Union[bool, TensorType],
  94. timestep: Union[int, TensorType],
  95. ) -> "tf.Tensor":
  96. """TF method to produce the tf op for an epsilon exploration action.
  97. Args:
  98. action_distribution: The instantiated ActionDistribution object
  99. to work with when creating exploration actions.
  100. Returns:
  101. The tf exploration-action op.
  102. """
  103. # TODO: Support MultiActionDistr for tf.
  104. q_values = action_distribution.inputs
  105. epsilon = self.epsilon_schedule(
  106. timestep if timestep is not None else self.last_timestep
  107. )
  108. # Get the exploit action as the one with the highest logit value.
  109. exploit_action = tf.argmax(q_values, axis=1)
  110. batch_size = tf.shape(q_values)[0]
  111. # Mask out actions with q-value=-inf so that we don't even consider
  112. # them for exploration.
  113. random_valid_action_logits = tf.where(
  114. tf.equal(q_values, tf.float32.min),
  115. tf.ones_like(q_values) * tf.float32.min,
  116. tf.ones_like(q_values),
  117. )
  118. random_actions = tf.squeeze(
  119. tf.random.categorical(random_valid_action_logits, 1), axis=1
  120. )
  121. chose_random = (
  122. tf.random.uniform(
  123. tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32
  124. )
  125. < epsilon
  126. )
  127. action = tf.cond(
  128. pred=tf.constant(explore, dtype=tf.bool)
  129. if isinstance(explore, bool)
  130. else explore,
  131. true_fn=(lambda: tf.where(chose_random, random_actions, exploit_action)),
  132. false_fn=lambda: exploit_action,
  133. )
  134. if self.framework == "tf2" and not self.policy_config["eager_tracing"]:
  135. self.last_timestep = timestep
  136. return action, tf.zeros_like(action, dtype=tf.float32)
  137. else:
  138. assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64))
  139. with tf1.control_dependencies([assign_op]):
  140. return action, tf.zeros_like(action, dtype=tf.float32)
  141. def _get_torch_exploration_action(
  142. self,
  143. action_distribution: ActionDistribution,
  144. explore: bool,
  145. timestep: Union[int, TensorType],
  146. ) -> "torch.Tensor":
  147. """Torch method to produce an epsilon exploration action.
  148. Args:
  149. action_distribution: The instantiated
  150. ActionDistribution object to work with when creating
  151. exploration actions.
  152. Returns:
  153. The exploration-action.
  154. """
  155. q_values = action_distribution.inputs
  156. self.last_timestep = timestep
  157. exploit_action = action_distribution.deterministic_sample()
  158. batch_size = q_values.size()[0]
  159. action_logp = torch.zeros(batch_size, dtype=torch.float)
  160. # Explore.
  161. if explore:
  162. # Get the current epsilon.
  163. epsilon = self.epsilon_schedule(self.last_timestep)
  164. if isinstance(action_distribution, TorchMultiActionDistribution):
  165. exploit_action = tree.flatten(exploit_action)
  166. for i in range(batch_size):
  167. if random.random() < epsilon:
  168. # TODO: (bcahlit) Mask out actions
  169. random_action = tree.flatten(self.action_space.sample())
  170. for j in range(len(exploit_action)):
  171. exploit_action[j][i] = torch.tensor(random_action[j])
  172. exploit_action = tree.unflatten_as(
  173. action_distribution.action_space_struct, exploit_action
  174. )
  175. return exploit_action, action_logp
  176. else:
  177. # Mask out actions, whose Q-values are -inf, so that we don't
  178. # even consider them for exploration.
  179. random_valid_action_logits = torch.where(
  180. q_values <= FLOAT_MIN,
  181. torch.ones_like(q_values) * 0.0,
  182. torch.ones_like(q_values),
  183. )
  184. # A random action.
  185. random_actions = torch.squeeze(
  186. torch.multinomial(random_valid_action_logits, 1), axis=1
  187. )
  188. # Pick either random or greedy.
  189. action = torch.where(
  190. torch.empty((batch_size,)).uniform_().to(self.device) < epsilon,
  191. random_actions,
  192. exploit_action,
  193. )
  194. return action, action_logp
  195. # Return the deterministic "sample" (argmax) over the logits.
  196. else:
  197. return exploit_action, action_logp
  198. @override(Exploration)
  199. def get_state(self, sess: Optional["tf.Session"] = None):
  200. if sess:
  201. return sess.run(self._tf_state_op)
  202. eps = self.epsilon_schedule(self.last_timestep)
  203. return {
  204. "cur_epsilon": convert_to_numpy(eps) if self.framework != "tf" else eps,
  205. "last_timestep": convert_to_numpy(self.last_timestep)
  206. if self.framework != "tf"
  207. else self.last_timestep,
  208. }
  209. @override(Exploration)
  210. def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None:
  211. if self.framework == "tf":
  212. self.last_timestep.load(state["last_timestep"], session=sess)
  213. elif isinstance(self.last_timestep, int):
  214. self.last_timestep = state["last_timestep"]
  215. else:
  216. self.last_timestep.assign(state["last_timestep"])