upper_confidence_bound.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from typing import Union
  2. from ray.rllib.models.action_dist import ActionDistribution
  3. from ray.rllib.utils.annotations import override, PublicAPI
  4. from ray.rllib.utils.exploration.exploration import Exploration
  5. from ray.rllib.utils.framework import (
  6. TensorType,
  7. try_import_tf,
  8. )
  9. tf1, tf, tfv = try_import_tf()
  10. @PublicAPI
  11. class UpperConfidenceBound(Exploration):
  12. @override(Exploration)
  13. def get_exploration_action(
  14. self,
  15. action_distribution: ActionDistribution,
  16. timestep: Union[int, TensorType],
  17. explore: bool = True,
  18. ):
  19. if self.framework == "torch":
  20. return self._get_torch_exploration_action(action_distribution, explore)
  21. elif self.framework == "tf2":
  22. return self._get_tf_exploration_action(action_distribution, explore)
  23. else:
  24. raise NotImplementedError
  25. def _get_torch_exploration_action(self, action_dist, explore):
  26. if explore:
  27. return action_dist.inputs.argmax(dim=-1), None
  28. else:
  29. scores = self.model.value_function()
  30. return scores.argmax(dim=-1), None
  31. def _get_tf_exploration_action(self, action_dist, explore):
  32. action = tf.argmax(
  33. tf.cond(
  34. explore, lambda: action_dist.inputs, lambda: self.model.value_function()
  35. ),
  36. axis=-1,
  37. )
  38. return action, None