random.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from gymnasium.spaces import Discrete, Box, MultiDiscrete, Space
  2. import numpy as np
  3. import tree # pip install dm_tree
  4. from typing import Union, Optional
  5. from ray.rllib.models.action_dist import ActionDistribution
  6. from ray.rllib.models.modelv2 import ModelV2
  7. from ray.rllib.utils.annotations import override, PublicAPI
  8. from ray.rllib.utils.exploration.exploration import Exploration
  9. from ray.rllib.utils import force_tuple
  10. from ray.rllib.utils.framework import try_import_tf, try_import_torch, TensorType
  11. from ray.rllib.utils.spaces.simplex import Simplex
  12. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  13. from ray.rllib.utils.tf_utils import zero_logps_from_actions
  14. tf1, tf, tfv = try_import_tf()
  15. torch, _ = try_import_torch()
  16. @PublicAPI
  17. class Random(Exploration):
  18. """A random action selector (deterministic/greedy for explore=False).
  19. If explore=True, returns actions randomly from `self.action_space` (via
  20. Space.sample()).
  21. If explore=False, returns the greedy/max-likelihood action.
  22. """
  23. def __init__(
  24. self, action_space: Space, *, model: ModelV2, framework: Optional[str], **kwargs
  25. ):
  26. """Initialize a Random Exploration object.
  27. Args:
  28. action_space: The gym action space used by the environment.
  29. framework: One of None, "tf", "torch".
  30. """
  31. super().__init__(
  32. action_space=action_space, model=model, framework=framework, **kwargs
  33. )
  34. self.action_space_struct = get_base_struct_from_space(self.action_space)
  35. @override(Exploration)
  36. def get_exploration_action(
  37. self,
  38. *,
  39. action_distribution: ActionDistribution,
  40. timestep: Union[int, TensorType],
  41. explore: bool = True
  42. ):
  43. # Instantiate the distribution object.
  44. if self.framework in ["tf2", "tf"]:
  45. return self.get_tf_exploration_action_op(action_distribution, explore)
  46. else:
  47. return self.get_torch_exploration_action(action_distribution, explore)
  48. def get_tf_exploration_action_op(
  49. self,
  50. action_dist: ActionDistribution,
  51. explore: Optional[Union[bool, TensorType]],
  52. ):
  53. def true_fn():
  54. batch_size = 1
  55. req = force_tuple(
  56. action_dist.required_model_output_shape(
  57. self.action_space, getattr(self.model, "model_config", None)
  58. )
  59. )
  60. # Add a batch dimension?
  61. if len(action_dist.inputs.shape) == len(req) + 1:
  62. batch_size = tf.shape(action_dist.inputs)[0]
  63. # Function to produce random samples from primitive space
  64. # components: (Multi)Discrete or Box.
  65. def random_component(component):
  66. # Have at least an additional shape of (1,), even if the
  67. # component is Box(-1.0, 1.0, shape=()).
  68. shape = component.shape or (1,)
  69. if isinstance(component, Discrete):
  70. return tf.random.uniform(
  71. shape=(batch_size,) + component.shape,
  72. maxval=component.n,
  73. dtype=component.dtype,
  74. )
  75. elif isinstance(component, MultiDiscrete):
  76. return tf.concat(
  77. [
  78. tf.random.uniform(
  79. shape=(batch_size, 1), maxval=n, dtype=component.dtype
  80. )
  81. for n in component.nvec
  82. ],
  83. axis=1,
  84. )
  85. elif isinstance(component, Box):
  86. if component.bounded_above.all() and component.bounded_below.all():
  87. if component.dtype.name.startswith("int"):
  88. return tf.random.uniform(
  89. shape=(batch_size,) + shape,
  90. minval=component.low.flat[0],
  91. maxval=component.high.flat[0],
  92. dtype=component.dtype,
  93. )
  94. else:
  95. return tf.random.uniform(
  96. shape=(batch_size,) + shape,
  97. minval=component.low,
  98. maxval=component.high,
  99. dtype=component.dtype,
  100. )
  101. else:
  102. return tf.random.normal(
  103. shape=(batch_size,) + shape, dtype=component.dtype
  104. )
  105. else:
  106. assert isinstance(component, Simplex), (
  107. "Unsupported distribution component '{}' for random "
  108. "sampling!".format(component)
  109. )
  110. return tf.nn.softmax(
  111. tf.random.uniform(
  112. shape=(batch_size,) + shape,
  113. minval=0.0,
  114. maxval=1.0,
  115. dtype=component.dtype,
  116. )
  117. )
  118. actions = tree.map_structure(random_component, self.action_space_struct)
  119. return actions
  120. def false_fn():
  121. return action_dist.deterministic_sample()
  122. action = tf.cond(
  123. pred=tf.constant(explore, dtype=tf.bool)
  124. if isinstance(explore, bool)
  125. else explore,
  126. true_fn=true_fn,
  127. false_fn=false_fn,
  128. )
  129. logp = zero_logps_from_actions(action)
  130. return action, logp
  131. def get_torch_exploration_action(
  132. self, action_dist: ActionDistribution, explore: bool
  133. ):
  134. if explore:
  135. req = force_tuple(
  136. action_dist.required_model_output_shape(
  137. self.action_space, getattr(self.model, "model_config", None)
  138. )
  139. )
  140. # Add a batch dimension?
  141. if len(action_dist.inputs.shape) == len(req) + 1:
  142. batch_size = action_dist.inputs.shape[0]
  143. a = np.stack([self.action_space.sample() for _ in range(batch_size)])
  144. else:
  145. a = self.action_space.sample()
  146. # Convert action to torch tensor.
  147. action = torch.from_numpy(a).to(self.device)
  148. else:
  149. action = action_dist.deterministic_sample()
  150. logp = torch.zeros((action.size()[0],), dtype=torch.float32, device=self.device)
  151. return action, logp