123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367 |
- """
- PyTorch policy class used for CQL.
- """
- import numpy as np
- import gym
- import logging
- from typing import Dict, List, Tuple, Type, Union
- import ray
- import ray.experimental.tf_utils
- from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \
- validate_spaces
- from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \
- build_sac_model_and_action_dist, optimizer_fn, ComputeTDErrorMixin, \
- TargetNetworkMixin, setup_late_mixins, action_distribution_fn
- from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
- from ray.rllib.policy.policy_template import build_policy_class
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.framework import try_import_torch
- from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
- from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
- TrainerConfigDict
- from ray.rllib.utils.torch_utils import apply_grad_clipping, \
- convert_to_torch_tensor, concat_multi_gpu_td_errors
- torch, nn = try_import_torch()
- F = nn.functional
- logger = logging.getLogger(__name__)
- MEAN_MIN = -9.0
- MEAN_MAX = 9.0
- # Returns policy tiled actions and log probabilities for CQL Loss
- def policy_actions_repeat(model, action_dist, obs, num_repeat=1):
- obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(
- obs.shape[0] * num_repeat, obs.shape[1])
- logits = model.get_policy_output(obs_temp)
- policy_dist = action_dist(logits, model)
- actions, logp_ = policy_dist.sample_logp()
- logp = logp_.unsqueeze(-1)
- return actions, logp.view(obs.shape[0], num_repeat, 1)
- def q_values_repeat(model, obs, actions, twin=False):
- action_shape = actions.shape[0]
- obs_shape = obs.shape[0]
- num_repeat = int(action_shape / obs_shape)
- obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(
- obs.shape[0] * num_repeat, obs.shape[1])
- if not twin:
- preds_ = model.get_q_values(obs_temp, actions)
- else:
- preds_ = model.get_twin_q_values(obs_temp, actions)
- preds = preds_.view(obs.shape[0], num_repeat, 1)
- return preds
- def cql_loss(policy: Policy, model: ModelV2,
- dist_class: Type[TorchDistributionWrapper],
- train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
- logger.info(f"Current iteration = {policy.cur_iter}")
- policy.cur_iter += 1
- # Look up the target model (tower) using the model tower.
- target_model = policy.target_models[model]
- # For best performance, turn deterministic off
- deterministic = policy.config["_deterministic_loss"]
- assert not deterministic
- twin_q = policy.config["twin_q"]
- discount = policy.config["gamma"]
- action_low = model.action_space.low[0]
- action_high = model.action_space.high[0]
- # CQL Parameters
- bc_iters = policy.config["bc_iters"]
- cql_temp = policy.config["temperature"]
- num_actions = policy.config["num_actions"]
- min_q_weight = policy.config["min_q_weight"]
- use_lagrange = policy.config["lagrangian"]
- target_action_gap = policy.config["lagrangian_thresh"]
- obs = train_batch[SampleBatch.CUR_OBS]
- actions = train_batch[SampleBatch.ACTIONS]
- rewards = train_batch[SampleBatch.REWARDS].float()
- next_obs = train_batch[SampleBatch.NEXT_OBS]
- terminals = train_batch[SampleBatch.DONES]
- model_out_t, _ = model({
- "obs": obs,
- "is_training": True,
- }, [], None)
- model_out_tp1, _ = model({
- "obs": next_obs,
- "is_training": True,
- }, [], None)
- target_model_out_tp1, _ = target_model({
- "obs": next_obs,
- "is_training": True,
- }, [], None)
- action_dist_class = _get_dist_class(policy, policy.config,
- policy.action_space)
- action_dist_t = action_dist_class(
- model.get_policy_output(model_out_t), policy.model)
- policy_t, log_pis_t = action_dist_t.sample_logp()
- log_pis_t = torch.unsqueeze(log_pis_t, -1)
- # Unlike original SAC, Alpha and Actor Loss are computed first.
- # Alpha Loss
- alpha_loss = -(model.log_alpha *
- (log_pis_t + model.target_entropy).detach()).mean()
- if obs.shape[0] == policy.config["train_batch_size"]:
- policy.alpha_optim.zero_grad()
- alpha_loss.backward()
- policy.alpha_optim.step()
- # Policy Loss (Either Behavior Clone Loss or SAC Loss)
- alpha = torch.exp(model.log_alpha)
- if policy.cur_iter >= bc_iters:
- min_q = model.get_q_values(model_out_t, policy_t)
- if twin_q:
- twin_q_ = model.get_twin_q_values(model_out_t, policy_t)
- min_q = torch.min(min_q, twin_q_)
- actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
- else:
- bc_logp = action_dist_t.logp(actions)
- actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean()
- # actor_loss = -bc_logp.mean()
- if obs.shape[0] == policy.config["train_batch_size"]:
- policy.actor_optim.zero_grad()
- actor_loss.backward(retain_graph=True)
- policy.actor_optim.step()
- # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
- # SAC Loss:
- # Q-values for the batched actions.
- action_dist_tp1 = action_dist_class(
- model.get_policy_output(model_out_tp1), policy.model)
- policy_tp1, _ = action_dist_tp1.sample_logp()
- q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
- q_t_selected = torch.squeeze(q_t, dim=-1)
- if twin_q:
- twin_q_t = model.get_twin_q_values(model_out_t,
- train_batch[SampleBatch.ACTIONS])
- twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
- # Target q network evaluation.
- q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1)
- if twin_q:
- twin_q_tp1 = target_model.get_twin_q_values(target_model_out_tp1,
- policy_tp1)
- # Take min over both twin-NNs.
- q_tp1 = torch.min(q_tp1, twin_q_tp1)
- q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
- q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best
- # compute RHS of bellman equation
- q_t_target = (
- rewards +
- (discount**policy.config["n_step"]) * q_tp1_best_masked).detach()
- # Compute the TD-error (potentially clipped), for priority replay buffer
- base_td_error = torch.abs(q_t_selected - q_t_target)
- if twin_q:
- twin_td_error = torch.abs(twin_q_t_selected - q_t_target)
- td_error = 0.5 * (base_td_error + twin_td_error)
- else:
- td_error = base_td_error
- critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target)
- if twin_q:
- critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target)
- # CQL Loss (We are using Entropy version of CQL (the best version))
- rand_actions = convert_to_torch_tensor(
- torch.FloatTensor(actions.shape[0] * num_actions,
- actions.shape[-1]).uniform_(action_low, action_high),
- policy.device)
- curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class,
- model_out_t, num_actions)
- next_actions, next_logp = policy_actions_repeat(model, action_dist_class,
- model_out_tp1, num_actions)
- q1_rand = q_values_repeat(model, model_out_t, rand_actions)
- q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
- q1_next_actions = q_values_repeat(model, model_out_t, next_actions)
- if twin_q:
- q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
- q2_curr_actions = q_values_repeat(
- model, model_out_t, curr_actions, twin=True)
- q2_next_actions = q_values_repeat(
- model, model_out_t, next_actions, twin=True)
- random_density = np.log(0.5**curr_actions.shape[-1])
- cat_q1 = torch.cat([
- q1_rand - random_density, q1_next_actions - next_logp.detach(),
- q1_curr_actions - curr_logp.detach()
- ], 1)
- if twin_q:
- cat_q2 = torch.cat([
- q2_rand - random_density, q2_next_actions - next_logp.detach(),
- q2_curr_actions - curr_logp.detach()
- ], 1)
- min_qf1_loss_ = torch.logsumexp(
- cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
- min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight)
- if twin_q:
- min_qf2_loss_ = torch.logsumexp(
- cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
- min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight)
- if use_lagrange:
- alpha_prime = torch.clamp(
- model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0]
- min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
- if twin_q:
- min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
- alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
- else:
- alpha_prime_loss = -min_qf1_loss
- cql_loss = [min_qf1_loss]
- if twin_q:
- cql_loss.append(min_qf2_loss)
- critic_loss = [critic_loss_1 + min_qf1_loss]
- if twin_q:
- critic_loss.append(critic_loss_2 + min_qf2_loss)
- if obs.shape[0] == policy.config["train_batch_size"]:
- policy.critic_optims[0].zero_grad()
- critic_loss[0].backward(retain_graph=True)
- policy.critic_optims[0].step()
- if twin_q:
- policy.critic_optims[1].zero_grad()
- critic_loss[1].backward(retain_graph=False)
- policy.critic_optims[1].step()
- # Store values for stats function in model (tower), such that for
- # multi-GPU, we do not override them during the parallel loss phase.
- # SAC stats.
- model.tower_stats["q_t"] = q_t_selected
- model.tower_stats["policy_t"] = policy_t
- model.tower_stats["log_pis_t"] = log_pis_t
- model.tower_stats["actor_loss"] = actor_loss
- model.tower_stats["critic_loss"] = critic_loss
- model.tower_stats["alpha_loss"] = alpha_loss
- model.tower_stats["log_alpha_value"] = model.log_alpha
- model.tower_stats["alpha_value"] = alpha
- model.tower_stats["target_entropy"] = model.target_entropy
- # CQL stats.
- model.tower_stats["cql_loss"] = cql_loss
- # TD-error tensor in final stats
- # will be concatenated and retrieved for each individual batch item.
- model.tower_stats["td_error"] = td_error
- if use_lagrange:
- model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0]
- model.tower_stats["alpha_prime_value"] = alpha_prime
- model.tower_stats["alpha_prime_loss"] = alpha_prime_loss
- if obs.shape[0] == policy.config["train_batch_size"]:
- policy.alpha_prime_optim.zero_grad()
- alpha_prime_loss.backward()
- policy.alpha_prime_optim.step()
- # Return all loss terms corresponding to our optimizers.
- return tuple([actor_loss] + critic_loss + [alpha_loss] +
- ([alpha_prime_loss] if use_lagrange else []))
- def cql_stats(policy: Policy,
- train_batch: SampleBatch) -> Dict[str, TensorType]:
- # Get SAC loss stats.
- stats_dict = stats(policy, train_batch)
- # Add CQL loss stats to the dict.
- stats_dict["cql_loss"] = torch.mean(
- torch.stack(*policy.get_tower_stats("cql_loss")))
- if policy.config["lagrangian"]:
- stats_dict["log_alpha_prime_value"] = torch.mean(
- torch.stack(policy.get_tower_stats("log_alpha_prime_value")))
- stats_dict["alpha_prime_value"] = torch.mean(
- torch.stack(policy.get_tower_stats("alpha_prime_value")))
- stats_dict["alpha_prime_loss"] = torch.mean(
- torch.stack(policy.get_tower_stats("alpha_prime_loss")))
- return stats_dict
- def cql_optimizer_fn(policy: Policy, config: TrainerConfigDict) -> \
- Tuple[LocalOptimizer]:
- policy.cur_iter = 0
- opt_list = optimizer_fn(policy, config)
- if config["lagrangian"]:
- log_alpha_prime = nn.Parameter(
- torch.zeros(1, requires_grad=True).float())
- policy.model.register_parameter("log_alpha_prime", log_alpha_prime)
- policy.alpha_prime_optim = torch.optim.Adam(
- params=[policy.model.log_alpha_prime],
- lr=config["optimization"]["critic_learning_rate"],
- eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default
- )
- return tuple([policy.actor_optim] + policy.critic_optims +
- [policy.alpha_optim] + [policy.alpha_prime_optim])
- return opt_list
- def cql_setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> None:
- setup_late_mixins(policy, obs_space, action_space, config)
- if config["lagrangian"]:
- policy.model.log_alpha_prime = policy.model.log_alpha_prime.to(
- policy.device)
- def compute_gradients_fn(policy, postprocessed_batch):
- batches = [policy._lazy_tensor_dict(postprocessed_batch)]
- model = policy.model
- policy._loss(policy, model, policy.dist_class, batches[0])
- stats = {
- LEARNER_STATS_KEY: policy._convert_to_non_torch_type(
- cql_stats(policy, batches[0]))
- }
- return [None, stats]
- def apply_gradients_fn(policy, gradients):
- return
- # Build a child class of `TorchPolicy`, given the custom functions defined
- # above.
- CQLTorchPolicy = build_policy_class(
- name="CQLTorchPolicy",
- framework="torch",
- loss_fn=cql_loss,
- get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
- stats_fn=cql_stats,
- postprocess_fn=postprocess_trajectory,
- extra_grad_process_fn=apply_grad_clipping,
- optimizer_fn=cql_optimizer_fn,
- validate_spaces=validate_spaces,
- before_loss_init=cql_setup_late_mixins,
- make_model_and_action_dist=build_sac_model_and_action_dist,
- extra_learn_fetches_fn=concat_multi_gpu_td_errors,
- mixins=[TargetNetworkMixin, ComputeTDErrorMixin],
- action_distribution_fn=action_distribution_fn,
- compute_gradients_fn=compute_gradients_fn,
- apply_gradients_fn=apply_gradients_fn,
- )
|