123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- from gymnasium.spaces import Box, Discrete, Space
- import numpy as np
- from typing import List, Optional, Union
- from ray.rllib.models.action_dist import ActionDistribution
- from ray.rllib.models.catalog import ModelCatalog
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.annotations import override, PublicAPI
- from ray.rllib.utils.exploration.exploration import Exploration
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.from_config import from_config
- from ray.rllib.utils.tf_utils import get_placeholder
- from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType
- tf1, tf, tfv = try_import_tf()
- class _MovingMeanStd:
- """Track moving mean, std and count."""
- def __init__(self, epsilon: float = 1e-4, shape: Optional[List[int]] = None):
- """Initialize object.
- Args:
- epsilon: Initial count.
- shape: Shape of the trackables mean and std.
- """
- if not shape:
- shape = []
- self.mean = np.zeros(shape, dtype=np.float32)
- self.var = np.ones(shape, dtype=np.float32)
- self.count = epsilon
- def __call__(self, inputs: np.ndarray) -> np.ndarray:
- """Normalize input batch using moving mean and std.
- Args:
- inputs: Input batch to normalize.
- Returns:
- Logarithmic scaled normalized output.
- """
- batch_mean = np.mean(inputs, axis=0)
- batch_var = np.var(inputs, axis=0)
- batch_count = inputs.shape[0]
- self.update_params(batch_mean, batch_var, batch_count)
- return np.log(inputs / self.std + 1)
- def update_params(
- self, batch_mean: float, batch_var: float, batch_count: float
- ) -> None:
- """Update moving mean, std and count.
- Args:
- batch_mean: Input batch mean.
- batch_var: Input batch variance.
- batch_count: Number of cases in the batch.
- """
- delta = batch_mean - self.mean
- tot_count = self.count + batch_count
- # This moving mean calculation is from reference implementation.
- self.mean = self.mean + delta + batch_count / tot_count
- m_a = self.var * self.count
- m_b = batch_var * batch_count
- M2 = m_a + m_b + np.power(delta, 2) * self.count * batch_count / tot_count
- self.var = M2 / tot_count
- self.count = tot_count
- @property
- def std(self) -> float:
- """Get moving standard deviation.
- Returns:
- Returns moving standard deviation.
- """
- return np.sqrt(self.var)
- @PublicAPI
- def update_beta(beta_schedule: str, beta: float, rho: float, step: int) -> float:
- """Update beta based on schedule and training step.
- Args:
- beta_schedule: Schedule for beta update.
- beta: Initial beta.
- rho: Schedule decay parameter.
- step: Current training iteration.
- Returns:
- Updated beta as per input schedule.
- """
- if beta_schedule == "linear_decay":
- return beta * ((1.0 - rho) ** step)
- return beta
- @PublicAPI
- def compute_states_entropy(
- obs_embeds: np.ndarray, embed_dim: int, k_nn: int
- ) -> np.ndarray:
- """Compute states entropy using K nearest neighbour method.
- Args:
- obs_embeds: Observation latent representation using
- encoder model.
- embed_dim: Embedding vector dimension.
- k_nn: Number of nearest neighbour for K-NN estimation.
- Returns:
- Computed states entropy.
- """
- obs_embeds_ = np.reshape(obs_embeds, [-1, embed_dim])
- dist = np.linalg.norm(obs_embeds_[:, None, :] - obs_embeds_[None, :, :], axis=-1)
- return dist.argsort(axis=-1)[:, :k_nn][:, -1].astype(np.float32)
- @PublicAPI
- class RE3(Exploration):
- """Random Encoder for Efficient Exploration.
- Implementation of:
- [1] State entropy maximization with random encoders for efficient
- exploration. Seo, Chen, Shin, Lee, Abbeel, & Lee, (2021).
- arXiv preprint arXiv:2102.09430.
- Estimates state entropy using a particle-based k-nearest neighbors (k-NN)
- estimator in the latent space. The state's latent representation is
- calculated using an encoder with randomly initialized parameters.
- The entropy of a state is considered as intrinsic reward and added to the
- environment's extrinsic reward for policy optimization.
- Entropy is calculated per batch, it does not take the distribution of
- the entire replay buffer into consideration.
- """
- def __init__(
- self,
- action_space: Space,
- *,
- framework: str,
- model: ModelV2,
- embeds_dim: int = 128,
- encoder_net_config: Optional[ModelConfigDict] = None,
- beta: float = 0.2,
- beta_schedule: str = "constant",
- rho: float = 0.1,
- k_nn: int = 50,
- random_timesteps: int = 10000,
- sub_exploration: Optional[FromConfigSpec] = None,
- **kwargs
- ):
- """Initialize RE3.
- Args:
- action_space: The action space in which to explore.
- framework: Supports "tf", this implementation does not
- support torch.
- model: The policy's model.
- embeds_dim: The dimensionality of the observation embedding
- vectors in latent space.
- encoder_net_config: Optional model
- configuration for the encoder network, producing embedding
- vectors from observations. This can be used to configure
- fcnet- or conv_net setups to properly process any
- observation space.
- beta: Hyperparameter to choose between exploration and
- exploitation.
- beta_schedule: Schedule to use for beta decay, one of
- "constant" or "linear_decay".
- rho: Beta decay factor, used for on-policy algorithm.
- k_nn: Number of neighbours to set for K-NN entropy
- estimation.
- random_timesteps: The number of timesteps to act completely
- randomly (see [1]).
- sub_exploration: The config dict for the underlying Exploration
- to use (e.g. epsilon-greedy for DQN). If None, uses the
- FromSpecDict provided in the Policy's default config.
- Raises:
- ValueError: If the input framework is Torch.
- """
- # TODO(gjoliver): Add supports for Pytorch.
- if framework == "torch":
- raise ValueError("This RE3 implementation does not support Torch.")
- super().__init__(action_space, model=model, framework=framework, **kwargs)
- self.beta = beta
- self.rho = rho
- self.k_nn = k_nn
- self.embeds_dim = embeds_dim
- if encoder_net_config is None:
- encoder_net_config = self.policy_config["model"].copy()
- self.encoder_net_config = encoder_net_config
- # Auto-detection of underlying exploration functionality.
- if sub_exploration is None:
- # For discrete action spaces, use an underlying EpsilonGreedy with
- # a special schedule.
- if isinstance(self.action_space, Discrete):
- sub_exploration = {
- "type": "EpsilonGreedy",
- "epsilon_schedule": {
- "type": "PiecewiseSchedule",
- # Step function (see [2]).
- "endpoints": [
- (0, 1.0),
- (random_timesteps + 1, 1.0),
- (random_timesteps + 2, 0.01),
- ],
- "outside_value": 0.01,
- },
- }
- elif isinstance(self.action_space, Box):
- sub_exploration = {
- "type": "OrnsteinUhlenbeckNoise",
- "random_timesteps": random_timesteps,
- }
- else:
- raise NotImplementedError
- self.sub_exploration = sub_exploration
- # Creates ModelV2 embedding module / layers.
- self._encoder_net = ModelCatalog.get_model_v2(
- self.model.obs_space,
- self.action_space,
- self.embeds_dim,
- model_config=self.encoder_net_config,
- framework=self.framework,
- name="encoder_net",
- )
- if self.framework == "tf":
- self._obs_ph = get_placeholder(
- space=self.model.obs_space, name="_encoder_obs"
- )
- self._obs_embeds = tf.stop_gradient(
- self._encoder_net({SampleBatch.OBS: self._obs_ph})[0]
- )
- # This is only used to select the correct action
- self.exploration_submodule = from_config(
- cls=Exploration,
- config=self.sub_exploration,
- action_space=self.action_space,
- framework=self.framework,
- policy_config=self.policy_config,
- model=self.model,
- num_workers=self.num_workers,
- worker_index=self.worker_index,
- )
- @override(Exploration)
- def get_exploration_action(
- self,
- *,
- action_distribution: ActionDistribution,
- timestep: Union[int, TensorType],
- explore: bool = True
- ):
- # Simply delegate to sub-Exploration module.
- return self.exploration_submodule.get_exploration_action(
- action_distribution=action_distribution, timestep=timestep, explore=explore
- )
- @override(Exploration)
- def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
- """Calculate states' latent representations/embeddings.
- Embeddings are added to the SampleBatch object such that it doesn't
- need to be calculated during each training step.
- """
- if self.framework != "torch":
- sample_batch = self._postprocess_tf(policy, sample_batch, tf_sess)
- else:
- raise ValueError("Not implemented for Torch.")
- return sample_batch
- def _postprocess_tf(self, policy, sample_batch, tf_sess):
- """Calculate states' embeddings and add it to SampleBatch."""
- if self.framework == "tf":
- obs_embeds = tf_sess.run(
- self._obs_embeds,
- feed_dict={self._obs_ph: sample_batch[SampleBatch.OBS]},
- )
- else:
- obs_embeds = tf.stop_gradient(
- self._encoder_net({SampleBatch.OBS: sample_batch[SampleBatch.OBS]})[0]
- ).numpy()
- sample_batch[SampleBatch.OBS_EMBEDS] = obs_embeds
- return sample_batch
|