123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452 |
- """
- TensorFlow policy class used for APPO.
- Adapted from VTraceTFPolicy to use the PPO surrogate loss.
- Keep in sync with changes to VTraceTFPolicy.
- """
- import numpy as np
- import logging
- import gym
- from typing import Dict, List, Optional, Type, Union
- from ray.rllib.agents.impala import vtrace_tf as vtrace
- from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
- clip_gradients, choose_optimizer
- from ray.rllib.evaluation.episode import Episode
- from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
- Postprocessing
- from ray.rllib.models.tf.tf_action_dist import Categorical
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.policy.tf_policy_template import build_tf_policy
- from ray.rllib.policy.tf_policy import EntropyCoeffSchedule, \
- LearningRateSchedule, TFPolicy
- from ray.rllib.agents.ppo.ppo_tf_policy import KLCoeffMixin, ValueNetworkMixin
- from ray.rllib.models.catalog import ModelCatalog
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
- from ray.rllib.utils.typing import AgentID, TensorType, TrainerConfigDict
- tf1, tf, tfv = try_import_tf()
- POLICY_SCOPE = "func"
- TARGET_POLICY_SCOPE = "target_func"
- logger = logging.getLogger(__name__)
- def make_appo_model(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> ModelV2:
- """Builds model and target model for APPO.
- Args:
- policy (Policy): The Policy, which will use the model for optimization.
- obs_space (gym.spaces.Space): The policy's observation space.
- action_space (gym.spaces.Space): The policy's action space.
- config (TrainerConfigDict):
- Returns:
- ModelV2: The Model for the Policy to use.
- Note: The target model will not be returned, just assigned to
- `policy.target_model`.
- """
- # Get the num_outputs for the following model construction calls.
- _, logit_dim = ModelCatalog.get_action_dist(action_space, config["model"])
- # Construct the (main) model.
- policy.model = ModelCatalog.get_model_v2(
- obs_space,
- action_space,
- logit_dim,
- config["model"],
- name=POLICY_SCOPE,
- framework="torch" if config["framework"] == "torch" else "tf")
- policy.model_variables = policy.model.variables()
- # Construct the target model.
- policy.target_model = ModelCatalog.get_model_v2(
- obs_space,
- action_space,
- logit_dim,
- config["model"],
- name=TARGET_POLICY_SCOPE,
- framework="torch" if config["framework"] == "torch" else "tf")
- policy.target_model_variables = policy.target_model.variables()
- # Return only the model (not the target model).
- return policy.model
- def appo_surrogate_loss(
- policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
- train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
- """Constructs the loss for APPO.
- With IS modifications and V-trace for Advantage Estimation.
- Args:
- policy (Policy): The Policy to calculate the loss for.
- model (ModelV2): The Model to calculate the loss for.
- dist_class (Type[ActionDistribution]): The action distr. class.
- train_batch (SampleBatch): The training data.
- Returns:
- Union[TensorType, List[TensorType]]: A single loss tensor or a list
- of loss tensors.
- """
- model_out, _ = model(train_batch)
- action_dist = dist_class(model_out, model)
- if isinstance(policy.action_space, gym.spaces.Discrete):
- is_multidiscrete = False
- output_hidden_shape = [policy.action_space.n]
- elif isinstance(policy.action_space,
- gym.spaces.multi_discrete.MultiDiscrete):
- is_multidiscrete = True
- output_hidden_shape = policy.action_space.nvec.astype(np.int32)
- else:
- is_multidiscrete = False
- output_hidden_shape = 1
- # TODO: (sven) deprecate this when trajectory view API gets activated.
- def make_time_major(*args, **kw):
- return _make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
- *args, **kw)
- actions = train_batch[SampleBatch.ACTIONS]
- dones = train_batch[SampleBatch.DONES]
- rewards = train_batch[SampleBatch.REWARDS]
- behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
- target_model_out, _ = policy.target_model(train_batch)
- prev_action_dist = dist_class(behaviour_logits, policy.model)
- values = policy.model.value_function()
- values_time_major = make_time_major(values)
- policy.model_vars = policy.model.variables()
- policy.target_model_vars = policy.target_model.variables()
- if policy.is_recurrent():
- max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
- mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
- mask = tf.reshape(mask, [-1])
- mask = make_time_major(mask, drop_last=policy.config["vtrace"])
- def reduce_mean_valid(t):
- return tf.reduce_mean(tf.boolean_mask(t, mask))
- else:
- reduce_mean_valid = tf.reduce_mean
- if policy.config["vtrace"]:
- drop_last = policy.config["vtrace_drop_last_ts"]
- logger.debug("Using V-Trace surrogate loss (vtrace=True; "
- f"drop_last={drop_last})")
- # Prepare actions for loss.
- loss_actions = actions if is_multidiscrete else tf.expand_dims(
- actions, axis=1)
- old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
- old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
- # Prepare KL for Loss
- mean_kl = make_time_major(
- old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last)
- unpacked_behaviour_logits = tf.split(
- behaviour_logits, output_hidden_shape, axis=1)
- unpacked_old_policy_behaviour_logits = tf.split(
- old_policy_behaviour_logits, output_hidden_shape, axis=1)
- # Compute vtrace on the CPU for better perf.
- with tf.device("/cpu:0"):
- vtrace_returns = vtrace.multi_from_logits(
- behaviour_policy_logits=make_time_major(
- unpacked_behaviour_logits, drop_last=drop_last),
- target_policy_logits=make_time_major(
- unpacked_old_policy_behaviour_logits, drop_last=drop_last),
- actions=tf.unstack(
- make_time_major(loss_actions, drop_last=drop_last),
- axis=2),
- discounts=tf.cast(
- ~make_time_major(
- tf.cast(dones, tf.bool), drop_last=drop_last),
- tf.float32) * policy.config["gamma"],
- rewards=make_time_major(rewards, drop_last=drop_last),
- values=values_time_major[:-1]
- if drop_last else values_time_major,
- bootstrap_value=values_time_major[-1],
- dist_class=Categorical if is_multidiscrete else dist_class,
- model=model,
- clip_rho_threshold=tf.cast(
- policy.config["vtrace_clip_rho_threshold"], tf.float32),
- clip_pg_rho_threshold=tf.cast(
- policy.config["vtrace_clip_pg_rho_threshold"], tf.float32),
- )
- actions_logp = make_time_major(
- action_dist.logp(actions), drop_last=drop_last)
- prev_actions_logp = make_time_major(
- prev_action_dist.logp(actions), drop_last=drop_last)
- old_policy_actions_logp = make_time_major(
- old_policy_action_dist.logp(actions), drop_last=drop_last)
- is_ratio = tf.clip_by_value(
- tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
- logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp)
- policy._is_ratio = is_ratio
- advantages = vtrace_returns.pg_advantages
- surrogate_loss = tf.minimum(
- advantages * logp_ratio,
- advantages *
- tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"],
- 1 + policy.config["clip_param"]))
- action_kl = tf.reduce_mean(mean_kl, axis=0) \
- if is_multidiscrete else mean_kl
- mean_kl_loss = reduce_mean_valid(action_kl)
- mean_policy_loss = -reduce_mean_valid(surrogate_loss)
- # The value function loss.
- if drop_last:
- delta = values_time_major[:-1] - vtrace_returns.vs
- else:
- delta = values_time_major - vtrace_returns.vs
- value_targets = vtrace_returns.vs
- mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
- # The entropy loss.
- actions_entropy = make_time_major(
- action_dist.multi_entropy(), drop_last=True)
- mean_entropy = reduce_mean_valid(actions_entropy)
- else:
- logger.debug("Using PPO surrogate loss (vtrace=False)")
- # Prepare KL for Loss
- mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))
- logp_ratio = tf.math.exp(
- make_time_major(action_dist.logp(actions)) -
- make_time_major(prev_action_dist.logp(actions)))
- advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES])
- surrogate_loss = tf.minimum(
- advantages * logp_ratio,
- advantages *
- tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"],
- 1 + policy.config["clip_param"]))
- action_kl = tf.reduce_mean(mean_kl, axis=0) \
- if is_multidiscrete else mean_kl
- mean_kl_loss = reduce_mean_valid(action_kl)
- mean_policy_loss = -reduce_mean_valid(surrogate_loss)
- # The value function loss.
- value_targets = make_time_major(
- train_batch[Postprocessing.VALUE_TARGETS])
- delta = values_time_major - value_targets
- mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
- # The entropy loss.
- mean_entropy = reduce_mean_valid(
- make_time_major(action_dist.multi_entropy()))
- # The summed weighted loss.
- total_loss = mean_policy_loss - \
- mean_entropy * policy.entropy_coeff
- # Optional KL loss.
- if policy.config["use_kl_loss"]:
- total_loss += policy.kl_coeff * mean_kl_loss
- # Optional vf loss (or in a separate term due to separate
- # optimizers/networks).
- loss_wo_vf = total_loss
- if not policy.config["_separate_vf_optimizer"]:
- total_loss += mean_vf_loss * policy.config["vf_loss_coeff"]
- # Store stats in policy for stats_fn.
- policy._total_loss = total_loss
- policy._loss_wo_vf = loss_wo_vf
- policy._mean_policy_loss = mean_policy_loss
- # Backward compatibility: Deprecate policy._mean_kl.
- policy._mean_kl_loss = policy._mean_kl = mean_kl_loss
- policy._mean_vf_loss = mean_vf_loss
- policy._mean_entropy = mean_entropy
- policy._value_targets = value_targets
- # Return one total loss or two losses: vf vs rest (policy + kl).
- if policy.config["_separate_vf_optimizer"]:
- return loss_wo_vf, mean_vf_loss
- else:
- return total_loss
- def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
- """Stats function for APPO. Returns a dict with important loss stats.
- Args:
- policy (Policy): The Policy to generate stats for.
- train_batch (SampleBatch): The SampleBatch (already) used for training.
- Returns:
- Dict[str, TensorType]: The stats dict.
- """
- values_batched = _make_time_major(
- policy,
- train_batch.get(SampleBatch.SEQ_LENS),
- policy.model.value_function(),
- drop_last=policy.config["vtrace"]
- and policy.config["vtrace_drop_last_ts"])
- stats_dict = {
- "cur_lr": tf.cast(policy.cur_lr, tf.float64),
- "total_loss": policy._total_loss,
- "policy_loss": policy._mean_policy_loss,
- "entropy": policy._mean_entropy,
- "var_gnorm": tf.linalg.global_norm(policy.model.trainable_variables()),
- "vf_loss": policy._mean_vf_loss,
- "vf_explained_var": explained_variance(
- tf.reshape(policy._value_targets, [-1]),
- tf.reshape(values_batched, [-1])),
- "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64),
- }
- if policy.config["vtrace"]:
- is_stat_mean, is_stat_var = tf.nn.moments(policy._is_ratio, [0, 1])
- stats_dict["mean_IS"] = is_stat_mean
- stats_dict["var_IS"] = is_stat_var
- if policy.config["use_kl_loss"]:
- stats_dict["kl"] = policy._mean_kl_loss
- stats_dict["KL_Coeff"] = policy.kl_coeff
- return stats_dict
- def postprocess_trajectory(
- policy: Policy,
- sample_batch: SampleBatch,
- other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
- episode: Optional[Episode] = None) -> SampleBatch:
- """Postprocesses a trajectory and returns the processed trajectory.
- The trajectory contains only data from one episode and from one agent.
- - If `config.batch_mode=truncate_episodes` (default), sample_batch may
- contain a truncated (at-the-end) episode, in case the
- `config.rollout_fragment_length` was reached by the sampler.
- - If `config.batch_mode=complete_episodes`, sample_batch will contain
- exactly one episode (no matter how long).
- New columns can be added to sample_batch and existing ones may be altered.
- Args:
- policy (Policy): The Policy used to generate the trajectory
- (`sample_batch`)
- sample_batch (SampleBatch): The SampleBatch to postprocess.
- other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
- dict of AgentIDs mapping to other agents' trajectory data (from the
- same episode). NOTE: The other agents use the same policy.
- episode (Optional[Episode]): Optional multi-agent episode
- object in which the agents operated.
- Returns:
- SampleBatch: The postprocessed, modified SampleBatch (or a new one).
- """
- if not policy.config["vtrace"]:
- sample_batch = compute_gae_for_sample_batch(
- policy, sample_batch, other_agent_batches, episode)
- return sample_batch
- def add_values(policy):
- out = {}
- if not policy.config["vtrace"]:
- out[SampleBatch.VF_PREDS] = policy.model.value_function()
- return out
- class TargetNetworkMixin:
- """Target NN is updated by master learner via the `update_target` method.
- Updates happen every `trainer.update_target_frequency` steps. All worker
- batches are importance sampled wrt the target network to ensure a more
- stable pi_old in PPO.
- """
- def __init__(self, obs_space, action_space, config):
- @make_tf_callable(self.get_session())
- def do_update():
- assign_ops = []
- assert len(self.model_vars) == len(self.target_model_vars)
- for var, var_target in zip(self.model_vars,
- self.target_model_vars):
- assign_ops.append(var_target.assign(var))
- return tf.group(*assign_ops)
- self.update_target = do_update
- @override(TFPolicy)
- def variables(self):
- return self.model_vars + self.target_model_vars
- def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> None:
- """Call all mixin classes' constructors before APPOPolicy initialization.
- Args:
- policy (Policy): The Policy object.
- obs_space (gym.spaces.Space): The Policy's observation space.
- action_space (gym.spaces.Space): The Policy's action space.
- config (TrainerConfigDict): The Policy's config.
- """
- LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
- KLCoeffMixin.__init__(policy, config)
- ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
- EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
- config["entropy_coeff_schedule"])
- def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> None:
- """Call all mixin classes' constructors after APPOPolicy initialization.
- Args:
- policy (Policy): The Policy object.
- obs_space (gym.spaces.Space): The Policy's observation space.
- action_space (gym.spaces.Space): The Policy's action space.
- config (TrainerConfigDict): The Policy's config.
- """
- TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
- # Build a child class of `DynamicTFPolicy`, given the custom functions defined
- # above.
- AsyncPPOTFPolicy = build_tf_policy(
- name="AsyncPPOTFPolicy",
- make_model=make_appo_model,
- loss_fn=appo_surrogate_loss,
- stats_fn=stats,
- postprocess_fn=postprocess_trajectory,
- optimizer_fn=choose_optimizer,
- compute_gradients_fn=clip_gradients,
- extra_action_out_fn=add_values,
- before_loss_init=setup_mixins,
- after_init=setup_late_mixins,
- mixins=[
- LearningRateSchedule,
- KLCoeffMixin,
- TargetNetworkMixin,
- ValueNetworkMixin,
- EntropyCoeffSchedule,
- ],
- get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])
|