ornstein_uhlenbeck_noise.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import numpy as np
  2. from typing import Optional, Union
  3. from ray.rllib.models.action_dist import ActionDistribution
  4. from ray.rllib.utils.annotations import override
  5. from ray.rllib.utils.exploration.gaussian_noise import GaussianNoise
  6. from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
  7. get_variable, TensorType
  8. from ray.rllib.utils.numpy import convert_to_numpy
  9. from ray.rllib.utils.schedules import Schedule
  10. from ray.rllib.utils.tf_utils import zero_logps_from_actions
  11. tf1, tf, tfv = try_import_tf()
  12. torch, _ = try_import_torch()
  13. class OrnsteinUhlenbeckNoise(GaussianNoise):
  14. """An exploration that adds Ornstein-Uhlenbeck noise to continuous actions.
  15. If explore=True, returns sampled actions plus a noise term X,
  16. which changes according to this formula:
  17. Xt+1 = -theta*Xt + sigma*N[0,stddev], where theta, sigma and stddev are
  18. constants. Also, some completely random period is possible at the
  19. beginning.
  20. If explore=False, returns the deterministic action.
  21. """
  22. def __init__(self,
  23. action_space,
  24. *,
  25. framework: str,
  26. ou_theta: float = 0.15,
  27. ou_sigma: float = 0.2,
  28. ou_base_scale: float = 0.1,
  29. random_timesteps: int = 1000,
  30. initial_scale: float = 1.0,
  31. final_scale: float = 0.02,
  32. scale_timesteps: int = 10000,
  33. scale_schedule: Optional[Schedule] = None,
  34. **kwargs):
  35. """Initializes an Ornstein-Uhlenbeck Exploration object.
  36. Args:
  37. action_space (Space): The gym action space used by the environment.
  38. ou_theta (float): The theta parameter of the Ornstein-Uhlenbeck
  39. process.
  40. ou_sigma (float): The sigma parameter of the Ornstein-Uhlenbeck
  41. process.
  42. ou_base_scale (float): A fixed scaling factor, by which all OU-
  43. noise is multiplied. NOTE: This is on top of the parent
  44. GaussianNoise's scaling.
  45. random_timesteps (int): The number of timesteps for which to act
  46. completely randomly. Only after this number of timesteps, the
  47. `self.scale` annealing process will start (see below).
  48. initial_scale (float): The initial scaling weight to multiply
  49. the noise with.
  50. final_scale (float): The final scaling weight to multiply
  51. the noise with.
  52. scale_timesteps (int): The timesteps over which to linearly anneal
  53. the scaling factor (after(!) having used random actions for
  54. `random_timesteps` steps.
  55. scale_schedule (Optional[Schedule]): An optional Schedule object
  56. to use (instead of constructing one from the given parameters).
  57. framework (Optional[str]): One of None, "tf", "torch".
  58. """
  59. # The current OU-state value (gets updated each time, an eploration
  60. # action is computed).
  61. self.ou_state = get_variable(
  62. np.array(action_space.low.size * [.0], dtype=np.float32),
  63. framework=framework,
  64. tf_name="ou_state",
  65. torch_tensor=True,
  66. device=None)
  67. super().__init__(
  68. action_space,
  69. framework=framework,
  70. random_timesteps=random_timesteps,
  71. initial_scale=initial_scale,
  72. final_scale=final_scale,
  73. scale_timesteps=scale_timesteps,
  74. scale_schedule=scale_schedule,
  75. stddev=1.0, # Force `self.stddev` to 1.0.
  76. **kwargs)
  77. self.ou_theta = ou_theta
  78. self.ou_sigma = ou_sigma
  79. self.ou_base_scale = ou_base_scale
  80. # Now that we know the device, move ou_state there, in case of PyTorch.
  81. if self.framework == "torch" and self.device is not None:
  82. self.ou_state = self.ou_state.to(self.device)
  83. @override(GaussianNoise)
  84. def _get_tf_exploration_action_op(self, action_dist: ActionDistribution,
  85. explore: Union[bool, TensorType],
  86. timestep: Union[int, TensorType]):
  87. ts = timestep if timestep is not None else self.last_timestep
  88. scale = self.scale_schedule(ts)
  89. # The deterministic actions (if explore=False).
  90. deterministic_actions = action_dist.deterministic_sample()
  91. # Apply base-scaled and time-annealed scaled OU-noise to
  92. # deterministic actions.
  93. gaussian_sample = tf.random.normal(
  94. shape=[self.action_space.low.size], stddev=self.stddev)
  95. ou_new = self.ou_theta * -self.ou_state + \
  96. self.ou_sigma * gaussian_sample
  97. if self.framework in ["tf2", "tfe"]:
  98. self.ou_state.assign_add(ou_new)
  99. ou_state_new = self.ou_state
  100. else:
  101. ou_state_new = tf1.assign_add(self.ou_state, ou_new)
  102. high_m_low = self.action_space.high - self.action_space.low
  103. high_m_low = tf.where(
  104. tf.math.is_inf(high_m_low), tf.ones_like(high_m_low), high_m_low)
  105. noise = scale * self.ou_base_scale * ou_state_new * high_m_low
  106. stochastic_actions = tf.clip_by_value(
  107. deterministic_actions + noise,
  108. self.action_space.low * tf.ones_like(deterministic_actions),
  109. self.action_space.high * tf.ones_like(deterministic_actions))
  110. # Stochastic actions could either be: random OR action + noise.
  111. random_actions, _ = \
  112. self.random_exploration.get_tf_exploration_action_op(
  113. action_dist, explore)
  114. exploration_actions = tf.cond(
  115. pred=tf.convert_to_tensor(ts < self.random_timesteps),
  116. true_fn=lambda: random_actions,
  117. false_fn=lambda: stochastic_actions)
  118. # Chose by `explore` (main exploration switch).
  119. action = tf.cond(
  120. pred=tf.constant(explore, dtype=tf.bool)
  121. if isinstance(explore, bool) else explore,
  122. true_fn=lambda: exploration_actions,
  123. false_fn=lambda: deterministic_actions)
  124. # Logp=always zero.
  125. logp = zero_logps_from_actions(deterministic_actions)
  126. # Increment `last_timestep` by 1 (or set to `timestep`).
  127. if self.framework in ["tf2", "tfe"]:
  128. if timestep is None:
  129. self.last_timestep.assign_add(1)
  130. else:
  131. self.last_timestep.assign(tf.cast(timestep, tf.int64))
  132. else:
  133. assign_op = (tf1.assign_add(self.last_timestep, 1)
  134. if timestep is None else tf1.assign(
  135. self.last_timestep, timestep))
  136. with tf1.control_dependencies([assign_op, ou_state_new]):
  137. action = tf.identity(action)
  138. logp = tf.identity(logp)
  139. return action, logp
  140. @override(GaussianNoise)
  141. def _get_torch_exploration_action(self, action_dist: ActionDistribution,
  142. explore: bool,
  143. timestep: Union[int, TensorType]):
  144. # Set last timestep or (if not given) increase by one.
  145. self.last_timestep = timestep if timestep is not None else \
  146. self.last_timestep + 1
  147. # Apply exploration.
  148. if explore:
  149. # Random exploration phase.
  150. if self.last_timestep < self.random_timesteps:
  151. action, _ = \
  152. self.random_exploration.get_torch_exploration_action(
  153. action_dist, explore=True)
  154. # Apply base-scaled and time-annealed scaled OU-noise to
  155. # deterministic actions.
  156. else:
  157. det_actions = action_dist.deterministic_sample()
  158. scale = self.scale_schedule(self.last_timestep)
  159. gaussian_sample = scale * torch.normal(
  160. mean=torch.zeros(self.ou_state.size()), std=1.0) \
  161. .to(self.device)
  162. ou_new = self.ou_theta * -self.ou_state + \
  163. self.ou_sigma * gaussian_sample
  164. self.ou_state += ou_new
  165. high_m_low = torch.from_numpy(
  166. self.action_space.high - self.action_space.low). \
  167. to(self.device)
  168. high_m_low = torch.where(
  169. torch.isinf(high_m_low),
  170. torch.ones_like(high_m_low).to(self.device), high_m_low)
  171. noise = scale * self.ou_base_scale * self.ou_state * high_m_low
  172. action = torch.min(
  173. torch.max(
  174. det_actions + noise,
  175. torch.tensor(
  176. self.action_space.low,
  177. dtype=torch.float32,
  178. device=self.device)),
  179. torch.tensor(
  180. self.action_space.high,
  181. dtype=torch.float32,
  182. device=self.device))
  183. # No exploration -> Return deterministic actions.
  184. else:
  185. action = action_dist.deterministic_sample()
  186. # Logp=always zero.
  187. logp = torch.zeros(
  188. (action.size()[0], ), dtype=torch.float32, device=self.device)
  189. return action, logp
  190. @override(GaussianNoise)
  191. def get_state(self, sess: Optional["tf.Session"] = None):
  192. """Returns the current scale value.
  193. Returns:
  194. Union[float,tf.Tensor[float]]: The current scale value.
  195. """
  196. if sess:
  197. return sess.run(
  198. dict(self._tf_state_op, **{
  199. "ou_state": self.ou_state,
  200. }))
  201. state = super().get_state()
  202. return dict(
  203. state, **{
  204. "ou_state": convert_to_numpy(self.ou_state)
  205. if self.framework != "tf" else self.ou_state,
  206. })
  207. @override(GaussianNoise)
  208. def set_state(self, state: dict,
  209. sess: Optional["tf.Session"] = None) -> None:
  210. if self.framework == "tf":
  211. self.ou_state.load(state["ou_state"], session=sess)
  212. elif isinstance(self.ou_state, np.ndarray) or \
  213. (torch and torch.is_tensor(self.ou_state)):
  214. self.ou_state = state["ou_state"]
  215. else:
  216. self.ou_state.assign(state["ou_state"])
  217. super().set_state(state, sess=sess)