123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- """PyTorch policy class used for Simple Q-Learning"""
- import logging
- from typing import Dict, Tuple
- import gym
- import ray
- from ray.rllib.agents.dqn.simple_q_tf_policy import (
- build_q_models, compute_q_values, get_distribution_inputs_and_class)
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
- TorchDistributionWrapper
- from ray.rllib.policy import Policy
- from ray.rllib.policy.policy_template import build_policy_class
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.policy.torch_policy import TorchPolicy
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.framework import try_import_torch
- from ray.rllib.utils.torch_utils import concat_multi_gpu_td_errors, huber_loss
- from ray.rllib.utils.typing import TensorType, TrainerConfigDict
- torch, nn = try_import_torch()
- F = None
- if nn:
- F = nn.functional
- logger = logging.getLogger(__name__)
- class TargetNetworkMixin:
- """Assign the `update_target` method to the SimpleQTorchPolicy
- The function is called every `target_network_update_freq` steps by the
- master learner.
- """
- def __init__(self):
- # Hard initial update from Q-net(s) to target Q-net(s).
- self.update_target()
- def update_target(self):
- # Update_target_fn will be called periodically to copy Q network to
- # target Q networks.
- state_dict = self.model.state_dict()
- for target in self.target_models.values():
- target.load_state_dict(state_dict)
- @override(TorchPolicy)
- def set_weights(self, weights):
- # Makes sure that whenever we restore weights for this policy's
- # model, we sync the target network (from the main model)
- # at the same time.
- TorchPolicy.set_weights(self, weights)
- self.update_target()
- def build_q_model_and_distribution(
- policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
- return build_q_models(policy, obs_space, action_space, config), \
- TorchCategorical
- def build_q_losses(policy: Policy, model, dist_class,
- train_batch: SampleBatch) -> TensorType:
- """Constructs the loss for SimpleQTorchPolicy.
- Args:
- policy (Policy): The Policy to calculate the loss for.
- model (ModelV2): The Model to calculate the loss for.
- dist_class (Type[ActionDistribution]): The action distribution class.
- train_batch (SampleBatch): The training data.
- Returns:
- TensorType: A single loss tensor.
- """
- target_model = policy.target_models[model]
- # q network evaluation
- q_t = compute_q_values(
- policy,
- model,
- train_batch[SampleBatch.CUR_OBS],
- explore=False,
- is_training=True)
- # target q network evalution
- q_tp1 = compute_q_values(
- policy,
- target_model,
- train_batch[SampleBatch.NEXT_OBS],
- explore=False,
- is_training=True)
- # q scores for actions which we know were selected in the given state.
- one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
- policy.action_space.n)
- q_t_selected = torch.sum(q_t * one_hot_selection, 1)
- # compute estimate of best possible value starting from state at t + 1
- dones = train_batch[SampleBatch.DONES].float()
- q_tp1_best_one_hot_selection = F.one_hot(
- torch.argmax(q_tp1, 1), policy.action_space.n)
- q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
- q_tp1_best_masked = (1.0 - dones) * q_tp1_best
- # compute RHS of bellman equation
- q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
- policy.config["gamma"] * q_tp1_best_masked)
- # Compute the error (Square/Huber).
- td_error = q_t_selected - q_t_selected_target.detach()
- loss = torch.mean(huber_loss(td_error))
- # Store values for stats function in model (tower), such that for
- # multi-GPU, we do not override them during the parallel loss phase.
- model.tower_stats["loss"] = loss
- # TD-error tensor in final stats
- # will be concatenated and retrieved for each individual batch item.
- model.tower_stats["td_error"] = td_error
- return loss
- def stats_fn(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]:
- return {"loss": torch.mean(torch.stack(policy.get_tower_stats("loss")))}
- def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
- action_dist) -> Dict[str, TensorType]:
- """Adds q-values to the action out dict."""
- return {"q_values": policy.q_values}
- def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict) -> None:
- """Call all mixin classes' constructors before SimpleQTorchPolicy
- initialization.
- Args:
- policy (Policy): The Policy object.
- obs_space (gym.spaces.Space): The Policy's observation space.
- action_space (gym.spaces.Space): The Policy's action space.
- config (TrainerConfigDict): The Policy's config.
- """
- TargetNetworkMixin.__init__(policy)
- SimpleQTorchPolicy = build_policy_class(
- name="SimpleQPolicy",
- framework="torch",
- loss_fn=build_q_losses,
- get_default_config=lambda: ray.rllib.agents.dqn.simple_q.DEFAULT_CONFIG,
- stats_fn=stats_fn,
- extra_action_out_fn=extra_action_out_fn,
- after_init=setup_late_mixins,
- make_model_and_action_dist=build_q_model_and_distribution,
- mixins=[TargetNetworkMixin],
- action_distribution_fn=get_distribution_inputs_and_class,
- extra_learn_fetches_fn=concat_multi_gpu_td_errors,
- )
|