123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701 |
- """
- TensorFlow policy class used for SAC.
- """
- import gym
- from gym.spaces import Box, Discrete
- from functools import partial
- import logging
- from typing import Dict, List, Optional, Tuple, Type, Union
- import ray
- import ray.experimental.tf_utils
- from ray.rllib.agents.ddpg.ddpg_tf_policy import ComputeTDErrorMixin, \
- TargetNetworkMixin
- from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
- PRIO_WEIGHTS
- from ray.rllib.agents.sac.sac_tf_model import SACTFModel
- from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
- from ray.rllib.evaluation.episode import Episode
- from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
- DiagGaussian, Dirichlet, SquashedGaussian, TFActionDistribution
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.policy.tf_policy_template import build_tf_policy
- from ray.rllib.utils.error import UnsupportedSpaceException
- from ray.rllib.utils.framework import get_variable, try_import_tf
- from ray.rllib.utils.spaces.simplex import Simplex
- from ray.rllib.utils.tf_utils import huber_loss
- from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \
- TensorType, TrainerConfigDict
- tf1, tf, tfv = try_import_tf()
- logger = logging.getLogger(__name__)
- def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> ModelV2:
- """Constructs the necessary ModelV2 for the Policy and returns it.
- Args:
- policy (Policy): The TFPolicy that will use the models.
- obs_space (gym.spaces.Space): The observation space.
- action_space (gym.spaces.Space): The action space.
- config (TrainerConfigDict): The SAC trainer's config dict.
- Returns:
- ModelV2: The ModelV2 to be used by the Policy. Note: An additional
- target model will be created in this function and assigned to
- `policy.target_model`.
- """
- # Force-ignore any additionally provided hidden layer sizes.
- # Everything should be configured using SAC's "Q_model" and "policy_model"
- # settings.
- policy_model_config = MODEL_DEFAULTS.copy()
- policy_model_config.update(config["policy_model"])
- q_model_config = MODEL_DEFAULTS.copy()
- q_model_config.update(config["Q_model"])
- default_model_cls = SACTorchModel if config["framework"] == "torch" \
- else SACTFModel
- model = ModelCatalog.get_model_v2(
- obs_space=obs_space,
- action_space=action_space,
- num_outputs=None,
- model_config=config["model"],
- framework=config["framework"],
- default_model=default_model_cls,
- name="sac_model",
- policy_model_config=policy_model_config,
- q_model_config=q_model_config,
- twin_q=config["twin_q"],
- initial_alpha=config["initial_alpha"],
- target_entropy=config["target_entropy"])
- assert isinstance(model, default_model_cls)
- # Create an exact copy of the model and store it in `policy.target_model`.
- # This will be used for tau-synched Q-target models that run behind the
- # actual Q-networks and are used for target q-value calculations in the
- # loss terms.
- policy.target_model = ModelCatalog.get_model_v2(
- obs_space=obs_space,
- action_space=action_space,
- num_outputs=None,
- model_config=config["model"],
- framework=config["framework"],
- default_model=default_model_cls,
- name="target_sac_model",
- policy_model_config=policy_model_config,
- q_model_config=q_model_config,
- twin_q=config["twin_q"],
- initial_alpha=config["initial_alpha"],
- target_entropy=config["target_entropy"])
- assert isinstance(policy.target_model, default_model_cls)
- return model
- def postprocess_trajectory(
- policy: Policy,
- sample_batch: SampleBatch,
- other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
- episode: Optional[Episode] = None) -> SampleBatch:
- """Postprocesses a trajectory and returns the processed trajectory.
- The trajectory contains only data from one episode and from one agent.
- - If `config.batch_mode=truncate_episodes` (default), sample_batch may
- contain a truncated (at-the-end) episode, in case the
- `config.rollout_fragment_length` was reached by the sampler.
- - If `config.batch_mode=complete_episodes`, sample_batch will contain
- exactly one episode (no matter how long).
- New columns can be added to sample_batch and existing ones may be altered.
- Args:
- policy (Policy): The Policy used to generate the trajectory
- (`sample_batch`)
- sample_batch (SampleBatch): The SampleBatch to postprocess.
- other_agent_batches (Optional[Dict[AgentID, SampleBatch]]): Optional
- dict of AgentIDs mapping to other agents' trajectory data (from the
- same episode). NOTE: The other agents use the same policy.
- episode (Optional[Episode]): Optional multi-agent episode
- object in which the agents operated.
- Returns:
- SampleBatch: The postprocessed, modified SampleBatch (or a new one).
- """
- return postprocess_nstep_and_prio(policy, sample_batch)
- def _get_dist_class(policy: Policy,
- config: TrainerConfigDict,
- action_space: gym.spaces.Space) -> \
- Type[TFActionDistribution]:
- """Helper function to return a dist class based on config and action space.
- Args:
- policy (Policy): The policy for which to return the action
- dist class.
- config (TrainerConfigDict): The Trainer's config dict.
- action_space (gym.spaces.Space): The action space used.
- Returns:
- Type[TFActionDistribution]: A TF distribution class.
- """
- if hasattr(policy, "dist_class") and policy.dist_class is not None:
- return policy.dist_class
- elif config["model"].get("custom_action_dist"):
- action_dist_class, _ = ModelCatalog.get_action_dist(
- action_space, config["model"], framework="tf")
- return action_dist_class
- elif isinstance(action_space, Discrete):
- return Categorical
- elif isinstance(action_space, Simplex):
- return Dirichlet
- else:
- assert isinstance(action_space, Box)
- if config["normalize_actions"]:
- return SquashedGaussian if \
- not config["_use_beta_distribution"] else Beta
- else:
- return DiagGaussian
- def get_distribution_inputs_and_class(
- policy: Policy,
- model: ModelV2,
- obs_batch: TensorType,
- *,
- explore: bool = True,
- **kwargs) \
- -> Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]:
- """The action distribution function to be used the algorithm.
- An action distribution function is used to customize the choice of action
- distribution class and the resulting action distribution inputs (to
- parameterize the distribution object).
- After parameterizing the distribution, a `sample()` call
- will be made on it to generate actions.
- Args:
- policy (Policy): The Policy being queried for actions and calling this
- function.
- model (SACTFModel): The SAC specific Model to use to generate the
- distribution inputs (see sac_tf|torch_model.py). Must support the
- `get_policy_output` method.
- obs_batch (TensorType): The observations to be used as inputs to the
- model.
- explore (bool): Whether to activate exploration or not.
- Returns:
- Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: The
- dist inputs, dist class, and a list of internal state outputs
- (in the RNN case).
- """
- # Get base-model (forward) output (this should be a noop call).
- forward_out, state_out = model(
- SampleBatch(
- obs=obs_batch, _is_training=policy._get_is_training_placeholder()),
- [], None)
- # Use the base output to get the policy outputs from the SAC model's
- # policy components.
- distribution_inputs = model.get_policy_output(forward_out)
- # Get a distribution class to be used with the just calculated dist-inputs.
- action_dist_class = _get_dist_class(policy, policy.config,
- policy.action_space)
- return distribution_inputs, action_dist_class, state_out
- def sac_actor_critic_loss(
- policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
- train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
- """Constructs the loss for the Soft Actor Critic.
- Args:
- policy (Policy): The Policy to calculate the loss for.
- model (ModelV2): The Model to calculate the loss for.
- dist_class (Type[ActionDistribution]: The action distr. class.
- train_batch (SampleBatch): The training data.
- Returns:
- Union[TensorType, List[TensorType]]: A single loss tensor or a list
- of loss tensors.
- """
- # Should be True only for debugging purposes (e.g. test cases)!
- deterministic = policy.config["_deterministic_loss"]
- _is_training = policy._get_is_training_placeholder()
- # Get the base model output from the train batch.
- model_out_t, _ = model(
- SampleBatch(
- obs=train_batch[SampleBatch.CUR_OBS], _is_training=_is_training),
- [], None)
- # Get the base model output from the next observations in the train batch.
- model_out_tp1, _ = model(
- SampleBatch(
- obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training),
- [], None)
- # Get the target model's base outputs from the next observations in the
- # train batch.
- target_model_out_tp1, _ = policy.target_model(
- SampleBatch(
- obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training),
- [], None)
- # Discrete actions case.
- if model.discrete:
- # Get all action probs directly from pi and form their logp.
- log_pis_t = tf.nn.log_softmax(model.get_policy_output(model_out_t), -1)
- policy_t = tf.math.exp(log_pis_t)
- log_pis_tp1 = tf.nn.log_softmax(
- model.get_policy_output(model_out_tp1), -1)
- policy_tp1 = tf.math.exp(log_pis_tp1)
- # Q-values.
- q_t = model.get_q_values(model_out_t)
- # Target Q-values.
- q_tp1 = policy.target_model.get_q_values(target_model_out_tp1)
- if policy.config["twin_q"]:
- twin_q_t = model.get_twin_q_values(model_out_t)
- twin_q_tp1 = policy.target_model.get_twin_q_values(
- target_model_out_tp1)
- q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
- q_tp1 -= model.alpha * log_pis_tp1
- # Actually selected Q-values (from the actions batch).
- one_hot = tf.one_hot(
- train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1])
- q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1)
- if policy.config["twin_q"]:
- twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1)
- # Discrete case: "Best" means weighted by the policy (prob) outputs.
- q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1)
- q_tp1_best_masked = \
- (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \
- q_tp1_best
- # Continuous actions case.
- else:
- # Sample simgle actions from distribution.
- action_dist_class = _get_dist_class(policy, policy.config,
- policy.action_space)
- action_dist_t = action_dist_class(
- model.get_policy_output(model_out_t), policy.model)
- policy_t = action_dist_t.sample() if not deterministic else \
- action_dist_t.deterministic_sample()
- log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1)
- action_dist_tp1 = action_dist_class(
- model.get_policy_output(model_out_tp1), policy.model)
- policy_tp1 = action_dist_tp1.sample() if not deterministic else \
- action_dist_tp1.deterministic_sample()
- log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1)
- # Q-values for the actually selected actions.
- q_t = model.get_q_values(
- model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32))
- if policy.config["twin_q"]:
- twin_q_t = model.get_twin_q_values(
- model_out_t,
- tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32))
- # Q-values for current policy in given current state.
- q_t_det_policy = model.get_q_values(model_out_t, policy_t)
- if policy.config["twin_q"]:
- twin_q_t_det_policy = model.get_twin_q_values(
- model_out_t, policy_t)
- q_t_det_policy = tf.reduce_min(
- (q_t_det_policy, twin_q_t_det_policy), axis=0)
- # target q network evaluation
- q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
- policy_tp1)
- if policy.config["twin_q"]:
- twin_q_tp1 = policy.target_model.get_twin_q_values(
- target_model_out_tp1, policy_tp1)
- # Take min over both twin-NNs.
- q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
- q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
- if policy.config["twin_q"]:
- twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
- q_tp1 -= model.alpha * log_pis_tp1
- q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
- q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES],
- tf.float32)) * q_tp1_best
- # Compute RHS of bellman equation for the Q-loss (critic(s)).
- q_t_selected_target = tf.stop_gradient(
- tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) +
- policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked)
- # Compute the TD-error (potentially clipped).
- base_td_error = tf.math.abs(q_t_selected - q_t_selected_target)
- if policy.config["twin_q"]:
- twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target)
- td_error = 0.5 * (base_td_error + twin_td_error)
- else:
- td_error = base_td_error
- # Calculate one or two critic losses (2 in the twin_q case).
- prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)
- critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))]
- if policy.config["twin_q"]:
- critic_loss.append(
- tf.reduce_mean(prio_weights * huber_loss(twin_td_error)))
- # Alpha- and actor losses.
- # Note: In the papers, alpha is used directly, here we take the log.
- # Discrete case: Multiply the action probs as weights with the original
- # loss terms (no expectations needed).
- if model.discrete:
- alpha_loss = tf.reduce_mean(
- tf.reduce_sum(
- tf.multiply(
- tf.stop_gradient(policy_t), -model.log_alpha *
- tf.stop_gradient(log_pis_t + model.target_entropy)),
- axis=-1))
- actor_loss = tf.reduce_mean(
- tf.reduce_sum(
- tf.multiply(
- # NOTE: No stop_grad around policy output here
- # (compare with q_t_det_policy for continuous case).
- policy_t,
- model.alpha * log_pis_t - tf.stop_gradient(q_t)),
- axis=-1))
- else:
- alpha_loss = -tf.reduce_mean(
- model.log_alpha *
- tf.stop_gradient(log_pis_t + model.target_entropy))
- actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy)
- # Save for stats function.
- policy.policy_t = policy_t
- policy.q_t = q_t
- policy.td_error = td_error
- policy.actor_loss = actor_loss
- policy.critic_loss = critic_loss
- policy.alpha_loss = alpha_loss
- policy.alpha_value = model.alpha
- policy.target_entropy = model.target_entropy
- # In a custom apply op we handle the losses separately, but return them
- # combined in one loss here.
- return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
- def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
- loss: TensorType) -> ModelGradients:
- """Gradients computing function (from loss tensor, using local optimizer).
- Note: For SAC, optimizer and loss are ignored b/c we have 3
- losses and 3 local optimizers (all stored in policy).
- `optimizer` will be used, though, in the tf-eager case b/c it is then a
- fake optimizer (OptimizerWrapper) object with a `tape` property to
- generate a GradientTape object for gradient recording.
- Args:
- policy (Policy): The Policy object that generated the loss tensor and
- that holds the given local optimizer.
- optimizer (LocalOptimizer): The tf (local) optimizer object to
- calculate the gradients with.
- loss (TensorType): The loss tensor for which gradients should be
- calculated.
- Returns:
- ModelGradients: List of the possibly clipped gradients- and variable
- tuples.
- """
- # Eager: Use GradientTape (which is a property of the `optimizer` object
- # (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py).
- if policy.config["framework"] in ["tf2", "tfe"]:
- tape = optimizer.tape
- pol_weights = policy.model.policy_variables()
- actor_grads_and_vars = list(
- zip(tape.gradient(policy.actor_loss, pol_weights), pol_weights))
- q_weights = policy.model.q_variables()
- if policy.config["twin_q"]:
- half_cutoff = len(q_weights) // 2
- grads_1 = tape.gradient(policy.critic_loss[0],
- q_weights[:half_cutoff])
- grads_2 = tape.gradient(policy.critic_loss[1],
- q_weights[half_cutoff:])
- critic_grads_and_vars = \
- list(zip(grads_1, q_weights[:half_cutoff])) + \
- list(zip(grads_2, q_weights[half_cutoff:]))
- else:
- critic_grads_and_vars = list(
- zip(
- tape.gradient(policy.critic_loss[0], q_weights),
- q_weights))
- alpha_vars = [policy.model.log_alpha]
- alpha_grads_and_vars = list(
- zip(tape.gradient(policy.alpha_loss, alpha_vars), alpha_vars))
- # Tf1.x: Use optimizer.compute_gradients()
- else:
- actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
- policy.actor_loss, var_list=policy.model.policy_variables())
- q_weights = policy.model.q_variables()
- if policy.config["twin_q"]:
- half_cutoff = len(q_weights) // 2
- base_q_optimizer, twin_q_optimizer = policy._critic_optimizer
- critic_grads_and_vars = base_q_optimizer.compute_gradients(
- policy.critic_loss[0], var_list=q_weights[:half_cutoff]
- ) + twin_q_optimizer.compute_gradients(
- policy.critic_loss[1], var_list=q_weights[half_cutoff:])
- else:
- critic_grads_and_vars = policy._critic_optimizer[
- 0].compute_gradients(
- policy.critic_loss[0], var_list=q_weights)
- alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients(
- policy.alpha_loss, var_list=[policy.model.log_alpha])
- # Clip if necessary.
- if policy.config["grad_clip"]:
- clip_func = partial(
- tf.clip_by_norm, clip_norm=policy.config["grad_clip"])
- else:
- clip_func = tf.identity
- # Save grads and vars for later use in `build_apply_op`.
- policy._actor_grads_and_vars = [(clip_func(g), v)
- for (g, v) in actor_grads_and_vars
- if g is not None]
- policy._critic_grads_and_vars = [(clip_func(g), v)
- for (g, v) in critic_grads_and_vars
- if g is not None]
- policy._alpha_grads_and_vars = [(clip_func(g), v)
- for (g, v) in alpha_grads_and_vars
- if g is not None]
- grads_and_vars = (
- policy._actor_grads_and_vars + policy._critic_grads_and_vars +
- policy._alpha_grads_and_vars)
- return grads_and_vars
- def apply_gradients(
- policy: Policy, optimizer: LocalOptimizer,
- grads_and_vars: ModelGradients) -> Union["tf.Operation", None]:
- """Gradients applying function (from list of "grad_and_var" tuples).
- Note: For SAC, optimizer and grads_and_vars are ignored b/c we have 3
- losses and optimizers (stored in policy).
- Args:
- policy (Policy): The Policy object whose Model(s) the given gradients
- should be applied to.
- optimizer (LocalOptimizer): The tf (local) optimizer object through
- which to apply the gradients.
- grads_and_vars (ModelGradients): The list of grad_and_var tuples to
- apply via the given optimizer.
- Returns:
- Union[tf.Operation, None]: The tf op to be used to run the apply
- operation. None for eager mode.
- """
- actor_apply_ops = policy._actor_optimizer.apply_gradients(
- policy._actor_grads_and_vars)
- cgrads = policy._critic_grads_and_vars
- half_cutoff = len(cgrads) // 2
- if policy.config["twin_q"]:
- critic_apply_ops = [
- policy._critic_optimizer[0].apply_gradients(cgrads[:half_cutoff]),
- policy._critic_optimizer[1].apply_gradients(cgrads[half_cutoff:])
- ]
- else:
- critic_apply_ops = [
- policy._critic_optimizer[0].apply_gradients(cgrads)
- ]
- # Eager mode -> Just apply and return None.
- if policy.config["framework"] in ["tf2", "tfe"]:
- policy._alpha_optimizer.apply_gradients(policy._alpha_grads_and_vars)
- return
- # Tf static graph -> Return op.
- else:
- alpha_apply_ops = policy._alpha_optimizer.apply_gradients(
- policy._alpha_grads_and_vars,
- global_step=tf1.train.get_or_create_global_step())
- return tf.group([actor_apply_ops, alpha_apply_ops] + critic_apply_ops)
- def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
- """Stats function for SAC. Returns a dict with important loss stats.
- Args:
- policy (Policy): The Policy to generate stats for.
- train_batch (SampleBatch): The SampleBatch (already) used for training.
- Returns:
- Dict[str, TensorType]: The stats dict.
- """
- return {
- "mean_td_error": tf.reduce_mean(policy.td_error),
- "actor_loss": tf.reduce_mean(policy.actor_loss),
- "critic_loss": tf.reduce_mean(policy.critic_loss),
- "alpha_loss": tf.reduce_mean(policy.alpha_loss),
- "alpha_value": tf.reduce_mean(policy.alpha_value),
- "target_entropy": tf.constant(policy.target_entropy),
- "mean_q": tf.reduce_mean(policy.q_t),
- "max_q": tf.reduce_max(policy.q_t),
- "min_q": tf.reduce_min(policy.q_t),
- }
- class ActorCriticOptimizerMixin:
- """Mixin class to generate the necessary optimizers for actor-critic algos.
- - Creates global step for counting the number of update operations.
- - Creates separate optimizers for actor, critic, and alpha.
- """
- def __init__(self, config):
- # Eager mode.
- if config["framework"] in ["tf2", "tfe"]:
- self.global_step = get_variable(0, tf_name="global_step")
- self._actor_optimizer = tf.keras.optimizers.Adam(
- learning_rate=config["optimization"]["actor_learning_rate"])
- self._critic_optimizer = [
- tf.keras.optimizers.Adam(learning_rate=config["optimization"][
- "critic_learning_rate"])
- ]
- if config["twin_q"]:
- self._critic_optimizer.append(
- tf.keras.optimizers.Adam(learning_rate=config[
- "optimization"]["critic_learning_rate"]))
- self._alpha_optimizer = tf.keras.optimizers.Adam(
- learning_rate=config["optimization"]["entropy_learning_rate"])
- # Static graph mode.
- else:
- self.global_step = tf1.train.get_or_create_global_step()
- self._actor_optimizer = tf1.train.AdamOptimizer(
- learning_rate=config["optimization"]["actor_learning_rate"])
- self._critic_optimizer = [
- tf1.train.AdamOptimizer(learning_rate=config["optimization"][
- "critic_learning_rate"])
- ]
- if config["twin_q"]:
- self._critic_optimizer.append(
- tf1.train.AdamOptimizer(learning_rate=config[
- "optimization"]["critic_learning_rate"]))
- self._alpha_optimizer = tf1.train.AdamOptimizer(
- learning_rate=config["optimization"]["entropy_learning_rate"])
- def setup_early_mixins(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> None:
- """Call mixin classes' constructors before Policy's initialization.
- Adds the necessary optimizers to the given Policy.
- Args:
- policy (Policy): The Policy object.
- obs_space (gym.spaces.Space): The Policy's observation space.
- action_space (gym.spaces.Space): The Policy's action space.
- config (TrainerConfigDict): The Policy's config.
- """
- ActorCriticOptimizerMixin.__init__(policy, config)
- def setup_mid_mixins(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> None:
- """Call mixin classes' constructors before Policy's loss initialization.
- Adds the `compute_td_error` method to the given policy.
- Calling `compute_td_error` with batch data will re-calculate the loss
- on that batch AND return the per-batch-item TD-error for prioritized
- replay buffer record weight updating (in case a prioritized replay buffer
- is used).
- Args:
- policy (Policy): The Policy object.
- obs_space (gym.spaces.Space): The Policy's observation space.
- action_space (gym.spaces.Space): The Policy's action space.
- config (TrainerConfigDict): The Policy's config.
- """
- ComputeTDErrorMixin.__init__(policy, sac_actor_critic_loss)
- def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> None:
- """Call mixin classes' constructors after Policy initialization.
- Adds the `update_target` method to the given policy.
- Calling `update_target` updates all target Q-networks' weights from their
- respective "main" Q-metworks, based on tau (smooth, partial updating).
- Args:
- policy (Policy): The Policy object.
- obs_space (gym.spaces.Space): The Policy's observation space.
- action_space (gym.spaces.Space): The Policy's action space.
- config (TrainerConfigDict): The Policy's config.
- """
- TargetNetworkMixin.__init__(policy, config)
- def validate_spaces(policy: Policy, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> None:
- """Validates the observation- and action spaces used for the Policy.
- Args:
- policy (Policy): The policy, whose spaces are being validated.
- observation_space (gym.spaces.Space): The observation space to
- validate.
- action_space (gym.spaces.Space): The action space to validate.
- config (TrainerConfigDict): The Policy's config dict.
- Raises:
- UnsupportedSpaceException: If one of the spaces is not supported.
- """
- # Only support single Box or single Discrete spaces.
- if not isinstance(action_space, (Box, Discrete, Simplex)):
- raise UnsupportedSpaceException(
- "Action space ({}) of {} is not supported for "
- "SAC. Must be [Box|Discrete|Simplex].".format(
- action_space, policy))
- # If Box, make sure it's a 1D vector space.
- elif isinstance(action_space,
- (Box, Simplex)) and len(action_space.shape) > 1:
- raise UnsupportedSpaceException(
- "Action space ({}) of {} has multiple dimensions "
- "{}. ".format(action_space, policy, action_space.shape) +
- "Consider reshaping this into a single dimension, "
- "using a Tuple action space, or the multi-agent API.")
- # Build a child class of `DynamicTFPolicy`, given the custom functions defined
- # above.
- SACTFPolicy = build_tf_policy(
- name="SACTFPolicy",
- get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG,
- make_model=build_sac_model,
- postprocess_fn=postprocess_trajectory,
- action_distribution_fn=get_distribution_inputs_and_class,
- loss_fn=sac_actor_critic_loss,
- stats_fn=stats,
- compute_gradients_fn=compute_and_clip_gradients,
- apply_gradients_fn=apply_gradients,
- extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
- mixins=[
- TargetNetworkMixin, ActorCriticOptimizerMixin, ComputeTDErrorMixin
- ],
- validate_spaces=validate_spaces,
- before_init=setup_early_mixins,
- before_loss_init=setup_mid_mixins,
- after_init=setup_late_mixins,
- )
|