123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- import gymnasium as gym
- import numpy as np
- import tree # pip install dm_tree
- import random
- from typing import Union, Optional
- from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution
- from ray.rllib.models.action_dist import ActionDistribution
- from ray.rllib.utils.annotations import override, PublicAPI
- from ray.rllib.utils.exploration.exploration import Exploration, TensorType
- from ray.rllib.utils.framework import try_import_tf, try_import_torch, get_variable
- from ray.rllib.utils.from_config import from_config
- from ray.rllib.utils.numpy import convert_to_numpy
- from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule
- from ray.rllib.utils.torch_utils import FLOAT_MIN
- tf1, tf, tfv = try_import_tf()
- torch, _ = try_import_torch()
- @PublicAPI
- class EpsilonGreedy(Exploration):
- """Epsilon-greedy Exploration class that produces exploration actions.
- When given a Model's output and a current epsilon value (based on some
- Schedule), it produces a random action (if rand(1) < eps) or
- uses the model-computed one (if rand(1) >= eps).
- """
- def __init__(
- self,
- action_space: gym.spaces.Space,
- *,
- framework: str,
- initial_epsilon: float = 1.0,
- final_epsilon: float = 0.05,
- warmup_timesteps: int = 0,
- epsilon_timesteps: int = int(1e5),
- epsilon_schedule: Optional[Schedule] = None,
- **kwargs,
- ):
- """Create an EpsilonGreedy exploration class.
- Args:
- action_space: The action space the exploration should occur in.
- framework: The framework specifier.
- initial_epsilon: The initial epsilon value to use.
- final_epsilon: The final epsilon value to use.
- warmup_timesteps: The timesteps over which to not change epsilon in the
- beginning.
- epsilon_timesteps: The timesteps (additional to `warmup_timesteps`)
- after which epsilon should always be `final_epsilon`.
- E.g.: warmup_timesteps=20k epsilon_timesteps=50k -> After 70k timesteps,
- epsilon will reach its final value.
- epsilon_schedule: An optional Schedule object
- to use (instead of constructing one from the given parameters).
- """
- assert framework is not None
- super().__init__(action_space=action_space, framework=framework, **kwargs)
- self.epsilon_schedule = from_config(
- Schedule, epsilon_schedule, framework=framework
- ) or PiecewiseSchedule(
- endpoints=[
- (0, initial_epsilon),
- (warmup_timesteps, initial_epsilon),
- (warmup_timesteps + epsilon_timesteps, final_epsilon),
- ],
- outside_value=final_epsilon,
- framework=self.framework,
- )
- # The current timestep value (tf-var or python int).
- self.last_timestep = get_variable(
- np.array(0, np.int64),
- framework=framework,
- tf_name="timestep",
- dtype=np.int64,
- )
- # Build the tf-info-op.
- if self.framework == "tf":
- self._tf_state_op = self.get_state()
- @override(Exploration)
- def get_exploration_action(
- self,
- *,
- action_distribution: ActionDistribution,
- timestep: Union[int, TensorType],
- explore: Optional[Union[bool, TensorType]] = True,
- ):
- if self.framework in ["tf2", "tf"]:
- return self._get_tf_exploration_action_op(
- action_distribution, explore, timestep
- )
- else:
- return self._get_torch_exploration_action(
- action_distribution, explore, timestep
- )
- def _get_tf_exploration_action_op(
- self,
- action_distribution: ActionDistribution,
- explore: Union[bool, TensorType],
- timestep: Union[int, TensorType],
- ) -> "tf.Tensor":
- """TF method to produce the tf op for an epsilon exploration action.
- Args:
- action_distribution: The instantiated ActionDistribution object
- to work with when creating exploration actions.
- Returns:
- The tf exploration-action op.
- """
- # TODO: Support MultiActionDistr for tf.
- q_values = action_distribution.inputs
- epsilon = self.epsilon_schedule(
- timestep if timestep is not None else self.last_timestep
- )
- # Get the exploit action as the one with the highest logit value.
- exploit_action = tf.argmax(q_values, axis=1)
- batch_size = tf.shape(q_values)[0]
- # Mask out actions with q-value=-inf so that we don't even consider
- # them for exploration.
- random_valid_action_logits = tf.where(
- tf.equal(q_values, tf.float32.min),
- tf.ones_like(q_values) * tf.float32.min,
- tf.ones_like(q_values),
- )
- random_actions = tf.squeeze(
- tf.random.categorical(random_valid_action_logits, 1), axis=1
- )
- chose_random = (
- tf.random.uniform(
- tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32
- )
- < epsilon
- )
- action = tf.cond(
- pred=tf.constant(explore, dtype=tf.bool)
- if isinstance(explore, bool)
- else explore,
- true_fn=(lambda: tf.where(chose_random, random_actions, exploit_action)),
- false_fn=lambda: exploit_action,
- )
- if self.framework == "tf2" and not self.policy_config["eager_tracing"]:
- self.last_timestep = timestep
- return action, tf.zeros_like(action, dtype=tf.float32)
- else:
- assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64))
- with tf1.control_dependencies([assign_op]):
- return action, tf.zeros_like(action, dtype=tf.float32)
- def _get_torch_exploration_action(
- self,
- action_distribution: ActionDistribution,
- explore: bool,
- timestep: Union[int, TensorType],
- ) -> "torch.Tensor":
- """Torch method to produce an epsilon exploration action.
- Args:
- action_distribution: The instantiated
- ActionDistribution object to work with when creating
- exploration actions.
- Returns:
- The exploration-action.
- """
- q_values = action_distribution.inputs
- self.last_timestep = timestep
- exploit_action = action_distribution.deterministic_sample()
- batch_size = q_values.size()[0]
- action_logp = torch.zeros(batch_size, dtype=torch.float)
- # Explore.
- if explore:
- # Get the current epsilon.
- epsilon = self.epsilon_schedule(self.last_timestep)
- if isinstance(action_distribution, TorchMultiActionDistribution):
- exploit_action = tree.flatten(exploit_action)
- for i in range(batch_size):
- if random.random() < epsilon:
- # TODO: (bcahlit) Mask out actions
- random_action = tree.flatten(self.action_space.sample())
- for j in range(len(exploit_action)):
- exploit_action[j][i] = torch.tensor(random_action[j])
- exploit_action = tree.unflatten_as(
- action_distribution.action_space_struct, exploit_action
- )
- return exploit_action, action_logp
- else:
- # Mask out actions, whose Q-values are -inf, so that we don't
- # even consider them for exploration.
- random_valid_action_logits = torch.where(
- q_values <= FLOAT_MIN,
- torch.ones_like(q_values) * 0.0,
- torch.ones_like(q_values),
- )
- # A random action.
- random_actions = torch.squeeze(
- torch.multinomial(random_valid_action_logits, 1), axis=1
- )
- # Pick either random or greedy.
- action = torch.where(
- torch.empty((batch_size,)).uniform_().to(self.device) < epsilon,
- random_actions,
- exploit_action,
- )
- return action, action_logp
- # Return the deterministic "sample" (argmax) over the logits.
- else:
- return exploit_action, action_logp
- @override(Exploration)
- def get_state(self, sess: Optional["tf.Session"] = None):
- if sess:
- return sess.run(self._tf_state_op)
- eps = self.epsilon_schedule(self.last_timestep)
- return {
- "cur_epsilon": convert_to_numpy(eps) if self.framework != "tf" else eps,
- "last_timestep": convert_to_numpy(self.last_timestep)
- if self.framework != "tf"
- else self.last_timestep,
- }
- @override(Exploration)
- def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None:
- if self.framework == "tf":
- self.last_timestep.load(state["last_timestep"], session=sess)
- elif isinstance(self.last_timestep, int):
- self.last_timestep = state["last_timestep"]
- else:
- self.last_timestep.assign(state["last_timestep"])
|