123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431 |
- """TensorFlow policy class used for DQN"""
- from typing import Dict
- import gym
- import numpy as np
- import ray
- from ray.rllib.agents.dqn.distributional_q_tf_model import \
- DistributionalQTFModel
- from ray.rllib.agents.dqn.simple_q_tf_policy import TargetNetworkMixin
- from ray.rllib.evaluation.postprocessing import adjust_nstep
- from ray.rllib.models import ModelCatalog
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.models.tf.tf_action_dist import Categorical
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.policy.tf_policy import LearningRateSchedule
- from ray.rllib.policy.tf_policy_template import build_tf_policy
- from ray.rllib.utils.error import UnsupportedSpaceException
- from ray.rllib.utils.exploration import ParameterNoise
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.numpy import convert_to_numpy
- from ray.rllib.utils.tf_utils import (
- huber_loss, make_tf_callable, minimize_and_clip, reduce_mean_ignore_inf)
- from ray.rllib.utils.typing import (ModelGradients, TensorType,
- TrainerConfigDict)
- tf1, tf, tfv = try_import_tf()
- Q_SCOPE = "q_func"
- Q_TARGET_SCOPE = "target_q_func"
- # Importance sampling weights for prioritized replay
- PRIO_WEIGHTS = "weights"
- class QLoss:
- def __init__(self,
- q_t_selected: TensorType,
- q_logits_t_selected: TensorType,
- q_tp1_best: TensorType,
- q_dist_tp1_best: TensorType,
- importance_weights: TensorType,
- rewards: TensorType,
- done_mask: TensorType,
- gamma: float = 0.99,
- n_step: int = 1,
- num_atoms: int = 1,
- v_min: float = -10.0,
- v_max: float = 10.0):
- if num_atoms > 1:
- # Distributional Q-learning which corresponds to an entropy loss
- z = tf.range(num_atoms, dtype=tf.float32)
- z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
- # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
- r_tau = tf.expand_dims(
- rewards, -1) + gamma**n_step * tf.expand_dims(
- 1.0 - done_mask, -1) * tf.expand_dims(z, 0)
- r_tau = tf.clip_by_value(r_tau, v_min, v_max)
- b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
- lb = tf.floor(b)
- ub = tf.math.ceil(b)
- # indispensable judgement which is missed in most implementations
- # when b happens to be an integer, lb == ub, so pr_j(s', a*) will
- # be discarded because (ub-b) == (b-lb) == 0
- floor_equal_ceil = tf.cast(tf.less(ub - lb, 0.5), tf.float32)
- l_project = tf.one_hot(
- tf.cast(lb, dtype=tf.int32),
- num_atoms) # (batch_size, num_atoms, num_atoms)
- u_project = tf.one_hot(
- tf.cast(ub, dtype=tf.int32),
- num_atoms) # (batch_size, num_atoms, num_atoms)
- ml_delta = q_dist_tp1_best * (ub - b + floor_equal_ceil)
- mu_delta = q_dist_tp1_best * (b - lb)
- ml_delta = tf.reduce_sum(
- l_project * tf.expand_dims(ml_delta, -1), axis=1)
- mu_delta = tf.reduce_sum(
- u_project * tf.expand_dims(mu_delta, -1), axis=1)
- m = ml_delta + mu_delta
- # Rainbow paper claims that using this cross entropy loss for
- # priority is robust and insensitive to `prioritized_replay_alpha`
- self.td_error = tf.nn.softmax_cross_entropy_with_logits(
- labels=m, logits=q_logits_t_selected)
- self.loss = tf.reduce_mean(
- self.td_error * tf.cast(importance_weights, tf.float32))
- self.stats = {
- # TODO: better Q stats for dist dqn
- "mean_td_error": tf.reduce_mean(self.td_error),
- }
- else:
- q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
- # compute RHS of bellman equation
- q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
- # compute the error (potentially clipped)
- self.td_error = (
- q_t_selected - tf.stop_gradient(q_t_selected_target))
- self.loss = tf.reduce_mean(
- tf.cast(importance_weights, tf.float32) * huber_loss(
- self.td_error))
- self.stats = {
- "mean_q": tf.reduce_mean(q_t_selected),
- "min_q": tf.reduce_min(q_t_selected),
- "max_q": tf.reduce_max(q_t_selected),
- "mean_td_error": tf.reduce_mean(self.td_error),
- }
- class ComputeTDErrorMixin:
- """Assign the `compute_td_error` method to the DQNTFPolicy
- This allows us to prioritize on the worker side.
- """
- def __init__(self):
- @make_tf_callable(self.get_session(), dynamic_shape=True)
- def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
- importance_weights):
- # Do forward pass on loss to update td error attribute
- build_q_losses(
- self, self.model, None, {
- SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
- SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
- SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
- SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
- SampleBatch.DONES: tf.convert_to_tensor(done_mask),
- PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
- })
- return self.q_loss.td_error
- self.compute_td_error = compute_td_error
- def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> ModelV2:
- """Build q_model and target_model for DQN
- Args:
- policy (Policy): The Policy, which will use the model for optimization.
- obs_space (gym.spaces.Space): The policy's observation space.
- action_space (gym.spaces.Space): The policy's action space.
- config (TrainerConfigDict):
- Returns:
- ModelV2: The Model for the Policy to use.
- Note: The target q model will not be returned, just assigned to
- `policy.target_model`.
- """
- if not isinstance(action_space, gym.spaces.Discrete):
- raise UnsupportedSpaceException(
- "Action space {} is not supported for DQN.".format(action_space))
- if config["hiddens"]:
- # try to infer the last layer size, otherwise fall back to 256
- num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
- config["model"]["no_final_linear"] = True
- else:
- num_outputs = action_space.n
- q_model = ModelCatalog.get_model_v2(
- obs_space=obs_space,
- action_space=action_space,
- num_outputs=num_outputs,
- model_config=config["model"],
- framework="tf",
- model_interface=DistributionalQTFModel,
- name=Q_SCOPE,
- num_atoms=config["num_atoms"],
- dueling=config["dueling"],
- q_hiddens=config["hiddens"],
- use_noisy=config["noisy"],
- v_min=config["v_min"],
- v_max=config["v_max"],
- sigma0=config["sigma0"],
- # TODO(sven): Move option to add LayerNorm after each Dense
- # generically into ModelCatalog.
- add_layer_norm=isinstance(
- getattr(policy, "exploration", None), ParameterNoise)
- or config["exploration_config"]["type"] == "ParameterNoise")
- policy.target_model = ModelCatalog.get_model_v2(
- obs_space=obs_space,
- action_space=action_space,
- num_outputs=num_outputs,
- model_config=config["model"],
- framework="tf",
- model_interface=DistributionalQTFModel,
- name=Q_TARGET_SCOPE,
- num_atoms=config["num_atoms"],
- dueling=config["dueling"],
- q_hiddens=config["hiddens"],
- use_noisy=config["noisy"],
- v_min=config["v_min"],
- v_max=config["v_max"],
- sigma0=config["sigma0"],
- # TODO(sven): Move option to add LayerNorm after each Dense
- # generically into ModelCatalog.
- add_layer_norm=isinstance(
- getattr(policy, "exploration", None), ParameterNoise)
- or config["exploration_config"]["type"] == "ParameterNoise")
- return q_model
- def get_distribution_inputs_and_class(policy: Policy,
- model: ModelV2,
- input_dict: SampleBatch,
- *,
- explore=True,
- **kwargs):
- q_vals = compute_q_values(
- policy, model, input_dict, state_batches=None, explore=explore)
- q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
- policy.q_values = q_vals
- return policy.q_values, Categorical, [] # state-out
- def build_q_losses(policy: Policy, model, _,
- train_batch: SampleBatch) -> TensorType:
- """Constructs the loss for DQNTFPolicy.
- Args:
- policy (Policy): The Policy to calculate the loss for.
- model (ModelV2): The Model to calculate the loss for.
- train_batch (SampleBatch): The training data.
- Returns:
- TensorType: A single loss tensor.
- """
- config = policy.config
- # q network evaluation
- q_t, q_logits_t, q_dist_t, _ = compute_q_values(
- policy,
- model,
- SampleBatch({
- "obs": train_batch[SampleBatch.CUR_OBS]
- }),
- state_batches=None,
- explore=False)
- # target q network evalution
- q_tp1, q_logits_tp1, q_dist_tp1, _ = compute_q_values(
- policy,
- policy.target_model,
- SampleBatch({
- "obs": train_batch[SampleBatch.NEXT_OBS]
- }),
- state_batches=None,
- explore=False)
- if not hasattr(policy, "target_q_func_vars"):
- policy.target_q_func_vars = policy.target_model.variables()
- # q scores for actions which we know were selected in the given state.
- one_hot_selection = tf.one_hot(
- tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32),
- policy.action_space.n)
- q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
- q_logits_t_selected = tf.reduce_sum(
- q_logits_t * tf.expand_dims(one_hot_selection, -1), 1)
- # compute estimate of best possible value starting from state at t + 1
- if config["double_q"]:
- q_tp1_using_online_net, q_logits_tp1_using_online_net, \
- q_dist_tp1_using_online_net, _ = compute_q_values(
- policy, model,
- SampleBatch({"obs": train_batch[SampleBatch.NEXT_OBS]}),
- state_batches=None,
- explore=False)
- q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
- q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net,
- policy.action_space.n)
- q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
- q_dist_tp1_best = tf.reduce_sum(
- q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
- else:
- q_tp1_best_one_hot_selection = tf.one_hot(
- tf.argmax(q_tp1, 1), policy.action_space.n)
- q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
- q_dist_tp1_best = tf.reduce_sum(
- q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
- policy.q_loss = QLoss(
- q_t_selected, q_logits_t_selected, q_tp1_best, q_dist_tp1_best,
- train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS],
- tf.cast(train_batch[SampleBatch.DONES],
- tf.float32), config["gamma"], config["n_step"],
- config["num_atoms"], config["v_min"], config["v_max"])
- return policy.q_loss.loss
- def adam_optimizer(policy: Policy, config: TrainerConfigDict
- ) -> "tf.keras.optimizers.Optimizer":
- if policy.config["framework"] in ["tf2", "tfe"]:
- return tf.keras.optimizers.Adam(
- learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
- else:
- return tf1.train.AdamOptimizer(
- learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
- def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer",
- loss: TensorType) -> ModelGradients:
- if not hasattr(policy, "q_func_vars"):
- policy.q_func_vars = policy.model.variables()
- return minimize_and_clip(
- optimizer,
- loss,
- var_list=policy.q_func_vars,
- clip_val=policy.config["grad_clip"])
- def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
- return dict({
- "cur_lr": tf.cast(policy.cur_lr, tf.float64),
- }, **policy.q_loss.stats)
- def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None:
- LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
- ComputeTDErrorMixin.__init__(policy)
- def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> None:
- TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
- def compute_q_values(policy: Policy,
- model: ModelV2,
- input_batch: SampleBatch,
- state_batches=None,
- seq_lens=None,
- explore=None,
- is_training: bool = False):
- config = policy.config
- model_out, state = model(input_batch, state_batches or [], seq_lens)
- if config["num_atoms"] > 1:
- (action_scores, z, support_logits_per_action, logits,
- dist) = model.get_q_value_distributions(model_out)
- else:
- (action_scores, logits,
- dist) = model.get_q_value_distributions(model_out)
- if config["dueling"]:
- state_score = model.get_state_value(model_out)
- if config["num_atoms"] > 1:
- support_logits_per_action_mean = tf.reduce_mean(
- support_logits_per_action, 1)
- support_logits_per_action_centered = (
- support_logits_per_action - tf.expand_dims(
- support_logits_per_action_mean, 1))
- support_logits_per_action = tf.expand_dims(
- state_score, 1) + support_logits_per_action_centered
- support_prob_per_action = tf.nn.softmax(
- logits=support_logits_per_action)
- value = tf.reduce_sum(
- input_tensor=z * support_prob_per_action, axis=-1)
- logits = support_logits_per_action
- dist = support_prob_per_action
- else:
- action_scores_mean = reduce_mean_ignore_inf(action_scores, 1)
- action_scores_centered = action_scores - tf.expand_dims(
- action_scores_mean, 1)
- value = state_score + action_scores_centered
- else:
- value = action_scores
- return value, logits, dist, state
- def postprocess_nstep_and_prio(policy: Policy,
- batch: SampleBatch,
- other_agent=None,
- episode=None) -> SampleBatch:
- # N-step Q adjustments.
- if policy.config["n_step"] > 1:
- adjust_nstep(policy.config["n_step"], policy.config["gamma"], batch)
- # Create dummy prio-weights (1.0) in case we don't have any in
- # the batch.
- if PRIO_WEIGHTS not in batch:
- batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
- # Prioritize on the worker side.
- if batch.count > 0 and policy.config["worker_side_prioritization"]:
- td_errors = policy.compute_td_error(
- batch[SampleBatch.OBS], batch[SampleBatch.ACTIONS],
- batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
- batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
- new_priorities = (np.abs(convert_to_numpy(td_errors)) +
- policy.config["prioritized_replay_eps"])
- batch[PRIO_WEIGHTS] = new_priorities
- return batch
- DQNTFPolicy = build_tf_policy(
- name="DQNTFPolicy",
- get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
- make_model=build_q_model,
- action_distribution_fn=get_distribution_inputs_and_class,
- loss_fn=build_q_losses,
- stats_fn=build_q_stats,
- postprocess_fn=postprocess_nstep_and_prio,
- optimizer_fn=adam_optimizer,
- compute_gradients_fn=clip_gradients,
- extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
- extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
- before_loss_init=setup_mid_mixins,
- after_init=setup_late_mixins,
- mixins=[
- TargetNetworkMixin,
- ComputeTDErrorMixin,
- LearningRateSchedule,
- ])
|