random_encoder.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. from gymnasium.spaces import Box, Discrete, Space
  2. import numpy as np
  3. from typing import List, Optional, Union
  4. from ray.rllib.models.action_dist import ActionDistribution
  5. from ray.rllib.models.catalog import ModelCatalog
  6. from ray.rllib.models.modelv2 import ModelV2
  7. from ray.rllib.policy.sample_batch import SampleBatch
  8. from ray.rllib.utils.annotations import override, PublicAPI
  9. from ray.rllib.utils.exploration.exploration import Exploration
  10. from ray.rllib.utils.framework import try_import_tf
  11. from ray.rllib.utils.from_config import from_config
  12. from ray.rllib.utils.tf_utils import get_placeholder
  13. from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType
  14. tf1, tf, tfv = try_import_tf()
  15. class _MovingMeanStd:
  16. """Track moving mean, std and count."""
  17. def __init__(self, epsilon: float = 1e-4, shape: Optional[List[int]] = None):
  18. """Initialize object.
  19. Args:
  20. epsilon: Initial count.
  21. shape: Shape of the trackables mean and std.
  22. """
  23. if not shape:
  24. shape = []
  25. self.mean = np.zeros(shape, dtype=np.float32)
  26. self.var = np.ones(shape, dtype=np.float32)
  27. self.count = epsilon
  28. def __call__(self, inputs: np.ndarray) -> np.ndarray:
  29. """Normalize input batch using moving mean and std.
  30. Args:
  31. inputs: Input batch to normalize.
  32. Returns:
  33. Logarithmic scaled normalized output.
  34. """
  35. batch_mean = np.mean(inputs, axis=0)
  36. batch_var = np.var(inputs, axis=0)
  37. batch_count = inputs.shape[0]
  38. self.update_params(batch_mean, batch_var, batch_count)
  39. return np.log(inputs / self.std + 1)
  40. def update_params(
  41. self, batch_mean: float, batch_var: float, batch_count: float
  42. ) -> None:
  43. """Update moving mean, std and count.
  44. Args:
  45. batch_mean: Input batch mean.
  46. batch_var: Input batch variance.
  47. batch_count: Number of cases in the batch.
  48. """
  49. delta = batch_mean - self.mean
  50. tot_count = self.count + batch_count
  51. # This moving mean calculation is from reference implementation.
  52. self.mean = self.mean + delta + batch_count / tot_count
  53. m_a = self.var * self.count
  54. m_b = batch_var * batch_count
  55. M2 = m_a + m_b + np.power(delta, 2) * self.count * batch_count / tot_count
  56. self.var = M2 / tot_count
  57. self.count = tot_count
  58. @property
  59. def std(self) -> float:
  60. """Get moving standard deviation.
  61. Returns:
  62. Returns moving standard deviation.
  63. """
  64. return np.sqrt(self.var)
  65. @PublicAPI
  66. def update_beta(beta_schedule: str, beta: float, rho: float, step: int) -> float:
  67. """Update beta based on schedule and training step.
  68. Args:
  69. beta_schedule: Schedule for beta update.
  70. beta: Initial beta.
  71. rho: Schedule decay parameter.
  72. step: Current training iteration.
  73. Returns:
  74. Updated beta as per input schedule.
  75. """
  76. if beta_schedule == "linear_decay":
  77. return beta * ((1.0 - rho) ** step)
  78. return beta
  79. @PublicAPI
  80. def compute_states_entropy(
  81. obs_embeds: np.ndarray, embed_dim: int, k_nn: int
  82. ) -> np.ndarray:
  83. """Compute states entropy using K nearest neighbour method.
  84. Args:
  85. obs_embeds: Observation latent representation using
  86. encoder model.
  87. embed_dim: Embedding vector dimension.
  88. k_nn: Number of nearest neighbour for K-NN estimation.
  89. Returns:
  90. Computed states entropy.
  91. """
  92. obs_embeds_ = np.reshape(obs_embeds, [-1, embed_dim])
  93. dist = np.linalg.norm(obs_embeds_[:, None, :] - obs_embeds_[None, :, :], axis=-1)
  94. return dist.argsort(axis=-1)[:, :k_nn][:, -1].astype(np.float32)
  95. @PublicAPI
  96. class RE3(Exploration):
  97. """Random Encoder for Efficient Exploration.
  98. Implementation of:
  99. [1] State entropy maximization with random encoders for efficient
  100. exploration. Seo, Chen, Shin, Lee, Abbeel, & Lee, (2021).
  101. arXiv preprint arXiv:2102.09430.
  102. Estimates state entropy using a particle-based k-nearest neighbors (k-NN)
  103. estimator in the latent space. The state's latent representation is
  104. calculated using an encoder with randomly initialized parameters.
  105. The entropy of a state is considered as intrinsic reward and added to the
  106. environment's extrinsic reward for policy optimization.
  107. Entropy is calculated per batch, it does not take the distribution of
  108. the entire replay buffer into consideration.
  109. """
  110. def __init__(
  111. self,
  112. action_space: Space,
  113. *,
  114. framework: str,
  115. model: ModelV2,
  116. embeds_dim: int = 128,
  117. encoder_net_config: Optional[ModelConfigDict] = None,
  118. beta: float = 0.2,
  119. beta_schedule: str = "constant",
  120. rho: float = 0.1,
  121. k_nn: int = 50,
  122. random_timesteps: int = 10000,
  123. sub_exploration: Optional[FromConfigSpec] = None,
  124. **kwargs
  125. ):
  126. """Initialize RE3.
  127. Args:
  128. action_space: The action space in which to explore.
  129. framework: Supports "tf", this implementation does not
  130. support torch.
  131. model: The policy's model.
  132. embeds_dim: The dimensionality of the observation embedding
  133. vectors in latent space.
  134. encoder_net_config: Optional model
  135. configuration for the encoder network, producing embedding
  136. vectors from observations. This can be used to configure
  137. fcnet- or conv_net setups to properly process any
  138. observation space.
  139. beta: Hyperparameter to choose between exploration and
  140. exploitation.
  141. beta_schedule: Schedule to use for beta decay, one of
  142. "constant" or "linear_decay".
  143. rho: Beta decay factor, used for on-policy algorithm.
  144. k_nn: Number of neighbours to set for K-NN entropy
  145. estimation.
  146. random_timesteps: The number of timesteps to act completely
  147. randomly (see [1]).
  148. sub_exploration: The config dict for the underlying Exploration
  149. to use (e.g. epsilon-greedy for DQN). If None, uses the
  150. FromSpecDict provided in the Policy's default config.
  151. Raises:
  152. ValueError: If the input framework is Torch.
  153. """
  154. # TODO(gjoliver): Add supports for Pytorch.
  155. if framework == "torch":
  156. raise ValueError("This RE3 implementation does not support Torch.")
  157. super().__init__(action_space, model=model, framework=framework, **kwargs)
  158. self.beta = beta
  159. self.rho = rho
  160. self.k_nn = k_nn
  161. self.embeds_dim = embeds_dim
  162. if encoder_net_config is None:
  163. encoder_net_config = self.policy_config["model"].copy()
  164. self.encoder_net_config = encoder_net_config
  165. # Auto-detection of underlying exploration functionality.
  166. if sub_exploration is None:
  167. # For discrete action spaces, use an underlying EpsilonGreedy with
  168. # a special schedule.
  169. if isinstance(self.action_space, Discrete):
  170. sub_exploration = {
  171. "type": "EpsilonGreedy",
  172. "epsilon_schedule": {
  173. "type": "PiecewiseSchedule",
  174. # Step function (see [2]).
  175. "endpoints": [
  176. (0, 1.0),
  177. (random_timesteps + 1, 1.0),
  178. (random_timesteps + 2, 0.01),
  179. ],
  180. "outside_value": 0.01,
  181. },
  182. }
  183. elif isinstance(self.action_space, Box):
  184. sub_exploration = {
  185. "type": "OrnsteinUhlenbeckNoise",
  186. "random_timesteps": random_timesteps,
  187. }
  188. else:
  189. raise NotImplementedError
  190. self.sub_exploration = sub_exploration
  191. # Creates ModelV2 embedding module / layers.
  192. self._encoder_net = ModelCatalog.get_model_v2(
  193. self.model.obs_space,
  194. self.action_space,
  195. self.embeds_dim,
  196. model_config=self.encoder_net_config,
  197. framework=self.framework,
  198. name="encoder_net",
  199. )
  200. if self.framework == "tf":
  201. self._obs_ph = get_placeholder(
  202. space=self.model.obs_space, name="_encoder_obs"
  203. )
  204. self._obs_embeds = tf.stop_gradient(
  205. self._encoder_net({SampleBatch.OBS: self._obs_ph})[0]
  206. )
  207. # This is only used to select the correct action
  208. self.exploration_submodule = from_config(
  209. cls=Exploration,
  210. config=self.sub_exploration,
  211. action_space=self.action_space,
  212. framework=self.framework,
  213. policy_config=self.policy_config,
  214. model=self.model,
  215. num_workers=self.num_workers,
  216. worker_index=self.worker_index,
  217. )
  218. @override(Exploration)
  219. def get_exploration_action(
  220. self,
  221. *,
  222. action_distribution: ActionDistribution,
  223. timestep: Union[int, TensorType],
  224. explore: bool = True
  225. ):
  226. # Simply delegate to sub-Exploration module.
  227. return self.exploration_submodule.get_exploration_action(
  228. action_distribution=action_distribution, timestep=timestep, explore=explore
  229. )
  230. @override(Exploration)
  231. def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
  232. """Calculate states' latent representations/embeddings.
  233. Embeddings are added to the SampleBatch object such that it doesn't
  234. need to be calculated during each training step.
  235. """
  236. if self.framework != "torch":
  237. sample_batch = self._postprocess_tf(policy, sample_batch, tf_sess)
  238. else:
  239. raise ValueError("Not implemented for Torch.")
  240. return sample_batch
  241. def _postprocess_tf(self, policy, sample_batch, tf_sess):
  242. """Calculate states' embeddings and add it to SampleBatch."""
  243. if self.framework == "tf":
  244. obs_embeds = tf_sess.run(
  245. self._obs_embeds,
  246. feed_dict={self._obs_ph: sample_batch[SampleBatch.OBS]},
  247. )
  248. else:
  249. obs_embeds = tf.stop_gradient(
  250. self._encoder_net({SampleBatch.OBS: sample_batch[SampleBatch.OBS]})[0]
  251. ).numpy()
  252. sample_batch[SampleBatch.OBS_EMBEDS] = obs_embeds
  253. return sample_batch