soft_q.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from gym.spaces import Discrete, Space
  2. from typing import Union, Optional
  3. from ray.rllib.models.action_dist import ActionDistribution
  4. from ray.rllib.models.tf.tf_action_dist import Categorical
  5. from ray.rllib.models.torch.torch_action_dist import TorchCategorical
  6. from ray.rllib.utils.annotations import override
  7. from ray.rllib.utils.exploration.stochastic_sampling import StochasticSampling
  8. from ray.rllib.utils.framework import TensorType
  9. class SoftQ(StochasticSampling):
  10. """Special case of StochasticSampling w/ Categorical and temperature param.
  11. Returns a stochastic sample from a Categorical parameterized by the model
  12. output divided by the temperature. Returns the argmax iff explore=False.
  13. """
  14. def __init__(self,
  15. action_space: Space,
  16. *,
  17. framework: Optional[str],
  18. temperature: float = 1.0,
  19. **kwargs):
  20. """Initializes a SoftQ Exploration object.
  21. Args:
  22. action_space (Space): The gym action space used by the environment.
  23. temperature (float): The temperature to divide model outputs by
  24. before creating the Categorical distribution to sample from.
  25. framework (str): One of None, "tf", "torch".
  26. """
  27. assert isinstance(action_space, Discrete)
  28. super().__init__(action_space, framework=framework, **kwargs)
  29. self.temperature = temperature
  30. @override(StochasticSampling)
  31. def get_exploration_action(self,
  32. action_distribution: ActionDistribution,
  33. timestep: Union[int, TensorType],
  34. explore: bool = True):
  35. cls = type(action_distribution)
  36. assert cls in [Categorical, TorchCategorical]
  37. # Re-create the action distribution with the correct temperature
  38. # applied.
  39. dist = cls(
  40. action_distribution.inputs,
  41. self.model,
  42. temperature=self.temperature)
  43. # Delegate to super method.
  44. return super().get_exploration_action(
  45. action_distribution=dist, timestep=timestep, explore=explore)