123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580 |
- from gym.spaces import Tuple, Discrete, Dict
- import logging
- import numpy as np
- import tree # pip install dm_tree
- import ray
- from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
- from ray.rllib.agents.qmix.model import RNNModel, _get_size
- from ray.rllib.env.multi_agent_env import ENV_STATE
- from ray.rllib.env.wrappers.group_agents_wrapper import GROUP_REWARDS
- from ray.rllib.models.torch.torch_action_dist import TorchCategorical
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.rnn_sequencing import chop_into_sequences
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.models.catalog import ModelCatalog
- from ray.rllib.models.modelv2 import _unpack_obs
- from ray.rllib.utils.framework import try_import_torch
- from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
- from ray.rllib.utils.annotations import override
- # Torch must be installed.
- torch, nn = try_import_torch(error=True)
- logger = logging.getLogger(__name__)
- class QMixLoss(nn.Module):
- def __init__(self,
- model,
- target_model,
- mixer,
- target_mixer,
- n_agents,
- n_actions,
- double_q=True,
- gamma=0.99):
- nn.Module.__init__(self)
- self.model = model
- self.target_model = target_model
- self.mixer = mixer
- self.target_mixer = target_mixer
- self.n_agents = n_agents
- self.n_actions = n_actions
- self.double_q = double_q
- self.gamma = gamma
- def forward(self,
- rewards,
- actions,
- terminated,
- mask,
- obs,
- next_obs,
- action_mask,
- next_action_mask,
- state=None,
- next_state=None):
- """Forward pass of the loss.
- Args:
- rewards: Tensor of shape [B, T, n_agents]
- actions: Tensor of shape [B, T, n_agents]
- terminated: Tensor of shape [B, T, n_agents]
- mask: Tensor of shape [B, T, n_agents]
- obs: Tensor of shape [B, T, n_agents, obs_size]
- next_obs: Tensor of shape [B, T, n_agents, obs_size]
- action_mask: Tensor of shape [B, T, n_agents, n_actions]
- next_action_mask: Tensor of shape [B, T, n_agents, n_actions]
- state: Tensor of shape [B, T, state_dim] (optional)
- next_state: Tensor of shape [B, T, state_dim] (optional)
- """
- # Assert either none or both of state and next_state are given
- if state is None and next_state is None:
- state = obs # default to state being all agents' observations
- next_state = next_obs
- elif (state is None) != (next_state is None):
- raise ValueError("Expected either neither or both of `state` and "
- "`next_state` to be given. Got: "
- "\n`state` = {}\n`next_state` = {}".format(
- state, next_state))
- # Calculate estimated Q-Values
- mac_out = _unroll_mac(self.model, obs)
- # Pick the Q-Values for the actions taken -> [B * n_agents, T]
- chosen_action_qvals = torch.gather(
- mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3)
- # Calculate the Q-Values necessary for the target
- target_mac_out = _unroll_mac(self.target_model, next_obs)
- # Mask out unavailable actions for the t+1 step
- ignore_action_tp1 = (next_action_mask == 0) & (mask == 1).unsqueeze(-1)
- target_mac_out[ignore_action_tp1] = -np.inf
- # Max over target Q-Values
- if self.double_q:
- # Double Q learning computes the target Q values by selecting the
- # t+1 timestep action according to the "policy" neural network and
- # then estimating the Q-value of that action with the "target"
- # neural network
- # Compute the t+1 Q-values to be used in action selection
- # using next_obs
- mac_out_tp1 = _unroll_mac(self.model, next_obs)
- # mask out unallowed actions
- mac_out_tp1[ignore_action_tp1] = -np.inf
- # obtain best actions at t+1 according to policy NN
- cur_max_actions = mac_out_tp1.argmax(dim=3, keepdim=True)
- # use the target network to estimate the Q-values of policy
- # network's selected actions
- target_max_qvals = torch.gather(target_mac_out, 3,
- cur_max_actions).squeeze(3)
- else:
- target_max_qvals = target_mac_out.max(dim=3)[0]
- assert target_max_qvals.min().item() != -np.inf, \
- "target_max_qvals contains a masked action; \
- there may be a state with no valid actions."
- # Mix
- if self.mixer is not None:
- chosen_action_qvals = self.mixer(chosen_action_qvals, state)
- target_max_qvals = self.target_mixer(target_max_qvals, next_state)
- # Calculate 1-step Q-Learning targets
- targets = rewards + self.gamma * (1 - terminated) * target_max_qvals
- # Td-error
- td_error = (chosen_action_qvals - targets.detach())
- mask = mask.expand_as(td_error)
- # 0-out the targets that came from padded data
- masked_td_error = td_error * mask
- # Normal L2 loss, take mean over actual data
- loss = (masked_td_error**2).sum() / mask.sum()
- return loss, mask, masked_td_error, chosen_action_qvals, targets
- # TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
- class QMixTorchPolicy(Policy):
- """QMix impl. Assumes homogeneous agents for now.
- You must use MultiAgentEnv.with_agent_groups() to group agents
- together for QMix. This creates the proper Tuple obs/action spaces and
- populates the '_group_rewards' info field.
- Action masking: to specify an action mask for individual agents, use a
- dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
- The mask space must be `Box(0, 1, (n_actions,))`.
- """
- def __init__(self, obs_space, action_space, config):
- _validate(obs_space, action_space)
- config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
- self.framework = "torch"
- super().__init__(obs_space, action_space, config)
- self.n_agents = len(obs_space.original_space.spaces)
- config["model"]["n_agents"] = self.n_agents
- self.n_actions = action_space.spaces[0].n
- self.h_size = config["model"]["lstm_cell_size"]
- self.has_env_global_state = False
- self.has_action_mask = False
- self.device = (torch.device("cuda")
- if torch.cuda.is_available() else torch.device("cpu"))
- agent_obs_space = obs_space.original_space.spaces[0]
- if isinstance(agent_obs_space, Dict):
- space_keys = set(agent_obs_space.spaces.keys())
- if "obs" not in space_keys:
- raise ValueError(
- "Dict obs space must have subspace labeled `obs`")
- self.obs_size = _get_size(agent_obs_space.spaces["obs"])
- if "action_mask" in space_keys:
- mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
- if mask_shape != (self.n_actions, ):
- raise ValueError(
- "Action mask shape must be {}, got {}".format(
- (self.n_actions, ), mask_shape))
- self.has_action_mask = True
- if ENV_STATE in space_keys:
- self.env_global_state_shape = _get_size(
- agent_obs_space.spaces[ENV_STATE])
- self.has_env_global_state = True
- else:
- self.env_global_state_shape = (self.obs_size, self.n_agents)
- # The real agent obs space is nested inside the dict
- config["model"]["full_obs_space"] = agent_obs_space
- agent_obs_space = agent_obs_space.spaces["obs"]
- else:
- self.obs_size = _get_size(agent_obs_space)
- self.env_global_state_shape = (self.obs_size, self.n_agents)
- self.model = ModelCatalog.get_model_v2(
- agent_obs_space,
- action_space.spaces[0],
- self.n_actions,
- config["model"],
- framework="torch",
- name="model",
- default_model=RNNModel).to(self.device)
- self.target_model = ModelCatalog.get_model_v2(
- agent_obs_space,
- action_space.spaces[0],
- self.n_actions,
- config["model"],
- framework="torch",
- name="target_model",
- default_model=RNNModel).to(self.device)
- self.exploration = self._create_exploration()
- # Setup the mixer network.
- if config["mixer"] is None:
- self.mixer = None
- self.target_mixer = None
- elif config["mixer"] == "qmix":
- self.mixer = QMixer(self.n_agents, self.env_global_state_shape,
- config["mixing_embed_dim"]).to(self.device)
- self.target_mixer = QMixer(
- self.n_agents, self.env_global_state_shape,
- config["mixing_embed_dim"]).to(self.device)
- elif config["mixer"] == "vdn":
- self.mixer = VDNMixer().to(self.device)
- self.target_mixer = VDNMixer().to(self.device)
- else:
- raise ValueError("Unknown mixer type {}".format(config["mixer"]))
- self.cur_epsilon = 1.0
- self.update_target() # initial sync
- # Setup optimizer
- self.params = list(self.model.parameters())
- if self.mixer:
- self.params += list(self.mixer.parameters())
- self.loss = QMixLoss(self.model, self.target_model, self.mixer,
- self.target_mixer, self.n_agents, self.n_actions,
- self.config["double_q"], self.config["gamma"])
- from torch.optim import RMSprop
- self.optimiser = RMSprop(
- params=self.params,
- lr=config["lr"],
- alpha=config["optim_alpha"],
- eps=config["optim_eps"])
- @override(Policy)
- def compute_actions(self,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None,
- info_batch=None,
- episodes=None,
- explore=None,
- timestep=None,
- **kwargs):
- explore = explore if explore is not None else self.config["explore"]
- obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
- # We need to ensure we do not use the env global state
- # to compute actions
- # Compute actions
- with torch.no_grad():
- q_values, hiddens = _mac(
- self.model,
- torch.as_tensor(
- obs_batch, dtype=torch.float, device=self.device), [
- torch.as_tensor(
- np.array(s), dtype=torch.float, device=self.device)
- for s in state_batches
- ])
- avail = torch.as_tensor(
- action_mask, dtype=torch.float, device=self.device)
- masked_q_values = q_values.clone()
- masked_q_values[avail == 0.0] = -float("inf")
- masked_q_values_folded = torch.reshape(
- masked_q_values, [-1] + list(masked_q_values.shape)[2:])
- actions, _ = self.exploration.get_exploration_action(
- action_distribution=TorchCategorical(masked_q_values_folded),
- timestep=timestep,
- explore=explore)
- actions = torch.reshape(
- actions,
- list(masked_q_values.shape)[:-1]).cpu().numpy()
- hiddens = [s.cpu().numpy() for s in hiddens]
- return tuple(actions.transpose([1, 0])), hiddens, {}
- @override(Policy)
- def compute_log_likelihoods(self,
- actions,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None):
- obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
- return np.zeros(obs_batch.size()[0])
- @override(Policy)
- def learn_on_batch(self, samples):
- obs_batch, action_mask, env_global_state = self._unpack_observation(
- samples[SampleBatch.CUR_OBS])
- (next_obs_batch, next_action_mask,
- next_env_global_state) = self._unpack_observation(
- samples[SampleBatch.NEXT_OBS])
- group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])
- input_list = [
- group_rewards, action_mask, next_action_mask,
- samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
- obs_batch, next_obs_batch
- ]
- if self.has_env_global_state:
- input_list.extend([env_global_state, next_env_global_state])
- output_list, _, seq_lens = \
- chop_into_sequences(
- episode_ids=samples[SampleBatch.EPS_ID],
- unroll_ids=samples[SampleBatch.UNROLL_ID],
- agent_indices=samples[SampleBatch.AGENT_INDEX],
- feature_columns=input_list,
- state_columns=[], # RNN states not used here
- max_seq_len=self.config["model"]["max_seq_len"],
- dynamic_max=True,
- )
- # These will be padded to shape [B * T, ...]
- if self.has_env_global_state:
- (rew, action_mask, next_action_mask, act, dones, obs, next_obs,
- env_global_state, next_env_global_state) = output_list
- else:
- (rew, action_mask, next_action_mask, act, dones, obs,
- next_obs) = output_list
- B, T = len(seq_lens), max(seq_lens)
- def to_batches(arr, dtype):
- new_shape = [B, T] + list(arr.shape[1:])
- return torch.as_tensor(
- np.reshape(arr, new_shape), dtype=dtype, device=self.device)
- rewards = to_batches(rew, torch.float)
- actions = to_batches(act, torch.long)
- obs = to_batches(obs, torch.float).reshape(
- [B, T, self.n_agents, self.obs_size])
- action_mask = to_batches(action_mask, torch.float)
- next_obs = to_batches(next_obs, torch.float).reshape(
- [B, T, self.n_agents, self.obs_size])
- next_action_mask = to_batches(next_action_mask, torch.float)
- if self.has_env_global_state:
- env_global_state = to_batches(env_global_state, torch.float)
- next_env_global_state = to_batches(next_env_global_state,
- torch.float)
- # TODO(ekl) this treats group termination as individual termination
- terminated = to_batches(dones, torch.float).unsqueeze(2).expand(
- B, T, self.n_agents)
- # Create mask for where index is < unpadded sequence length
- filled = np.reshape(
- np.tile(np.arange(T, dtype=np.float32), B),
- [B, T]) < np.expand_dims(seq_lens, 1)
- mask = torch.as_tensor(
- filled, dtype=torch.float, device=self.device).unsqueeze(2).expand(
- B, T, self.n_agents)
- # Compute loss
- loss_out, mask, masked_td_error, chosen_action_qvals, targets = (
- self.loss(rewards, actions, terminated, mask, obs, next_obs,
- action_mask, next_action_mask, env_global_state,
- next_env_global_state))
- # Optimise
- self.optimiser.zero_grad()
- loss_out.backward()
- grad_norm = torch.nn.utils.clip_grad_norm_(
- self.params, self.config["grad_norm_clipping"])
- self.optimiser.step()
- mask_elems = mask.sum().item()
- stats = {
- "loss": loss_out.item(),
- "grad_norm": grad_norm
- if isinstance(grad_norm, float) else grad_norm.item(),
- "td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
- "q_taken_mean": (chosen_action_qvals * mask).sum().item() /
- mask_elems,
- "target_mean": (targets * mask).sum().item() / mask_elems,
- }
- return {LEARNER_STATS_KEY: stats}
- @override(Policy)
- def get_initial_state(self): # initial RNN state
- return [
- s.expand([self.n_agents, -1]).cpu().numpy()
- for s in self.model.get_initial_state()
- ]
- @override(Policy)
- def get_weights(self):
- return {
- "model": self._cpu_dict(self.model.state_dict()),
- "target_model": self._cpu_dict(self.target_model.state_dict()),
- "mixer": self._cpu_dict(self.mixer.state_dict())
- if self.mixer else None,
- "target_mixer": self._cpu_dict(self.target_mixer.state_dict())
- if self.mixer else None,
- }
- @override(Policy)
- def set_weights(self, weights):
- self.model.load_state_dict(self._device_dict(weights["model"]))
- self.target_model.load_state_dict(
- self._device_dict(weights["target_model"]))
- if weights["mixer"] is not None:
- self.mixer.load_state_dict(self._device_dict(weights["mixer"]))
- self.target_mixer.load_state_dict(
- self._device_dict(weights["target_mixer"]))
- @override(Policy)
- def get_state(self):
- state = self.get_weights()
- state["cur_epsilon"] = self.cur_epsilon
- return state
- @override(Policy)
- def set_state(self, state):
- self.set_weights(state)
- self.set_epsilon(state["cur_epsilon"])
- def update_target(self):
- self.target_model.load_state_dict(self.model.state_dict())
- if self.mixer is not None:
- self.target_mixer.load_state_dict(self.mixer.state_dict())
- logger.debug("Updated target networks")
- def set_epsilon(self, epsilon):
- self.cur_epsilon = epsilon
- def _get_group_rewards(self, info_batch):
- group_rewards = np.array([
- info.get(GROUP_REWARDS, [0.0] * self.n_agents)
- for info in info_batch
- ])
- return group_rewards
- def _device_dict(self, state_dict):
- return {
- k: torch.as_tensor(v, device=self.device)
- for k, v in state_dict.items()
- }
- @staticmethod
- def _cpu_dict(state_dict):
- return {k: v.cpu().detach().numpy() for k, v in state_dict.items()}
- def _unpack_observation(self, obs_batch):
- """Unpacks the observation, action mask, and state (if present)
- from agent grouping.
- Returns:
- obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size]
- mask (np.ndarray): action mask, if any
- state (np.ndarray or None): state tensor of shape [B, state_size]
- or None if it is not in the batch
- """
- unpacked = _unpack_obs(
- np.array(obs_batch, dtype=np.float32),
- self.observation_space.original_space,
- tensorlib=np)
- if isinstance(unpacked[0], dict):
- assert "obs" in unpacked[0]
- unpacked_obs = [
- np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked
- ]
- else:
- unpacked_obs = unpacked
- obs = np.concatenate(
- unpacked_obs,
- axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
- if self.has_action_mask:
- action_mask = np.concatenate(
- [o["action_mask"] for o in unpacked], axis=1).reshape(
- [len(obs_batch), self.n_agents, self.n_actions])
- else:
- action_mask = np.ones(
- [len(obs_batch), self.n_agents, self.n_actions],
- dtype=np.float32)
- if self.has_env_global_state:
- state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1)
- else:
- state = None
- return obs, action_mask, state
- def _validate(obs_space, action_space):
- if not hasattr(obs_space, "original_space") or \
- not isinstance(obs_space.original_space, Tuple):
- raise ValueError("Obs space must be a Tuple, got {}. Use ".format(
- obs_space) + "MultiAgentEnv.with_agent_groups() to group related "
- "agents for QMix.")
- if not isinstance(action_space, Tuple):
- raise ValueError(
- "Action space must be a Tuple, got {}. ".format(action_space) +
- "Use MultiAgentEnv.with_agent_groups() to group related "
- "agents for QMix.")
- if not isinstance(action_space.spaces[0], Discrete):
- raise ValueError(
- "QMix requires a discrete action space, got {}".format(
- action_space.spaces[0]))
- if len({str(x) for x in obs_space.original_space.spaces}) > 1:
- raise ValueError(
- "Implementation limitation: observations of grouped agents "
- "must be homogeneous, got {}".format(
- obs_space.original_space.spaces))
- if len({str(x) for x in action_space.spaces}) > 1:
- raise ValueError(
- "Implementation limitation: action space of grouped agents "
- "must be homogeneous, got {}".format(action_space.spaces))
- def _mac(model, obs, h):
- """Forward pass of the multi-agent controller.
- Args:
- model: TorchModelV2 class
- obs: Tensor of shape [B, n_agents, obs_size]
- h: List of tensors of shape [B, n_agents, h_size]
- Returns:
- q_vals: Tensor of shape [B, n_agents, n_actions]
- h: Tensor of shape [B, n_agents, h_size]
- """
- B, n_agents = obs.size(0), obs.size(1)
- if not isinstance(obs, dict):
- obs = {"obs": obs}
- obs_agents_as_batches = {k: _drop_agent_dim(v) for k, v in obs.items()}
- h_flat = [s.reshape([B * n_agents, -1]) for s in h]
- q_flat, h_flat = model(obs_agents_as_batches, h_flat, None)
- return q_flat.reshape(
- [B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat]
- def _unroll_mac(model, obs_tensor):
- """Computes the estimated Q values for an entire trajectory batch"""
- B = obs_tensor.size(0)
- T = obs_tensor.size(1)
- n_agents = obs_tensor.size(2)
- mac_out = []
- h = [s.expand([B, n_agents, -1]) for s in model.get_initial_state()]
- for t in range(T):
- q, h = _mac(model, obs_tensor[:, t], h)
- mac_out.append(q)
- mac_out = torch.stack(mac_out, dim=1) # Concat over time
- return mac_out
- def _drop_agent_dim(T):
- shape = list(T.shape)
- B, n_agents = shape[0], shape[1]
- return T.reshape([B * n_agents] + shape[2:])
- def _add_agent_dim(T, n_agents):
- shape = list(T.shape)
- B = shape[0] // n_agents
- assert shape[0] % n_agents == 0
- return T.reshape([B, n_agents] + shape[1:])
|