slate_epsilon_greedy.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from typing import Union
  2. from ray.rllib.models.action_dist import ActionDistribution
  3. from ray.rllib.utils.annotations import override, PublicAPI
  4. from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy
  5. from ray.rllib.utils.exploration.exploration import TensorType
  6. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  7. tf1, tf, tfv = try_import_tf()
  8. torch, _ = try_import_torch()
  9. @PublicAPI
  10. class SlateEpsilonGreedy(EpsilonGreedy):
  11. @override(EpsilonGreedy)
  12. def _get_tf_exploration_action_op(
  13. self,
  14. action_distribution: ActionDistribution,
  15. explore: Union[bool, TensorType],
  16. timestep: Union[int, TensorType],
  17. ) -> "tf.Tensor":
  18. per_slate_q_values = action_distribution.inputs
  19. all_slates = action_distribution.all_slates
  20. exploit_action = action_distribution.deterministic_sample()
  21. batch_size, num_slates = (
  22. tf.shape(per_slate_q_values)[0],
  23. tf.shape(per_slate_q_values)[1],
  24. )
  25. action_logp = tf.zeros(batch_size, dtype=tf.float32)
  26. # Get the current epsilon.
  27. epsilon = self.epsilon_schedule(
  28. timestep if timestep is not None else self.last_timestep
  29. )
  30. # A random action.
  31. random_indices = tf.random.uniform(
  32. (batch_size,),
  33. minval=0,
  34. maxval=num_slates,
  35. dtype=tf.dtypes.int32,
  36. )
  37. random_actions = tf.gather(all_slates, random_indices)
  38. choose_random = (
  39. tf.random.uniform(
  40. tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32
  41. )
  42. < epsilon
  43. )
  44. # Pick either random or greedy.
  45. action = tf.cond(
  46. pred=tf.constant(explore, dtype=tf.bool)
  47. if isinstance(explore, bool)
  48. else explore,
  49. true_fn=(lambda: tf.where(choose_random, random_actions, exploit_action)),
  50. false_fn=lambda: exploit_action,
  51. )
  52. if self.framework == "tf2" and not self.policy_config["eager_tracing"]:
  53. self.last_timestep = timestep
  54. return action, action_logp
  55. else:
  56. assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64))
  57. with tf1.control_dependencies([assign_op]):
  58. return action, action_logp
  59. @override(EpsilonGreedy)
  60. def _get_torch_exploration_action(
  61. self,
  62. action_distribution: ActionDistribution,
  63. explore: bool,
  64. timestep: Union[int, TensorType],
  65. ) -> "torch.Tensor":
  66. per_slate_q_values = action_distribution.inputs
  67. all_slates = self.model.slates
  68. device = all_slates.device
  69. exploit_indices = action_distribution.deterministic_sample()
  70. exploit_indices = exploit_indices.to(device)
  71. exploit_action = all_slates[exploit_indices]
  72. batch_size = per_slate_q_values.size()[0]
  73. action_logp = torch.zeros(batch_size, dtype=torch.float)
  74. self.last_timestep = timestep
  75. # Explore.
  76. if explore:
  77. # Get the current epsilon.
  78. epsilon = self.epsilon_schedule(self.last_timestep)
  79. # A random action.
  80. random_indices = torch.randint(
  81. 0,
  82. per_slate_q_values.shape[1],
  83. (per_slate_q_values.shape[0],),
  84. device=device,
  85. )
  86. random_actions = all_slates[random_indices]
  87. # Pick either random or greedy.
  88. action = torch.where(
  89. torch.empty((batch_size,)).uniform_() < epsilon,
  90. random_actions,
  91. exploit_action,
  92. )
  93. return action, action_logp
  94. # Return the deterministic "sample" (argmax) over the logits.
  95. else:
  96. return exploit_action, action_logp