soft_q.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from gymnasium.spaces import Discrete, MultiDiscrete, Space
  2. from typing import Union, Optional
  3. from ray.rllib.models.action_dist import ActionDistribution
  4. from ray.rllib.models.tf.tf_action_dist import Categorical
  5. from ray.rllib.models.torch.torch_action_dist import TorchCategorical
  6. from ray.rllib.utils.annotations import override, PublicAPI
  7. from ray.rllib.utils.exploration.stochastic_sampling import StochasticSampling
  8. from ray.rllib.utils.framework import TensorType
  9. @PublicAPI
  10. class SoftQ(StochasticSampling):
  11. """Special case of StochasticSampling w/ Categorical and temperature param.
  12. Returns a stochastic sample from a Categorical parameterized by the model
  13. output divided by the temperature. Returns the argmax iff explore=False.
  14. """
  15. def __init__(
  16. self,
  17. action_space: Space,
  18. *,
  19. framework: Optional[str],
  20. temperature: float = 1.0,
  21. **kwargs
  22. ):
  23. """Initializes a SoftQ Exploration object.
  24. Args:
  25. action_space: The gym action space used by the environment.
  26. temperature: The temperature to divide model outputs by
  27. before creating the Categorical distribution to sample from.
  28. framework: One of None, "tf", "torch".
  29. """
  30. assert isinstance(action_space, (Discrete, MultiDiscrete))
  31. super().__init__(action_space, framework=framework, **kwargs)
  32. self.temperature = temperature
  33. @override(StochasticSampling)
  34. def get_exploration_action(
  35. self,
  36. action_distribution: ActionDistribution,
  37. timestep: Union[int, TensorType],
  38. explore: bool = True,
  39. ):
  40. cls = type(action_distribution)
  41. assert issubclass(cls, (Categorical, TorchCategorical))
  42. # Re-create the action distribution with the correct temperature
  43. # applied.
  44. dist = cls(action_distribution.inputs, self.model, temperature=self.temperature)
  45. # Delegate to super method.
  46. return super().get_exploration_action(
  47. action_distribution=dist, timestep=timestep, explore=explore
  48. )