qmix_policy.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. from gym.spaces import Tuple, Discrete, Dict
  2. import logging
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. import ray
  6. from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
  7. from ray.rllib.agents.qmix.model import RNNModel, _get_size
  8. from ray.rllib.env.multi_agent_env import ENV_STATE
  9. from ray.rllib.env.wrappers.group_agents_wrapper import GROUP_REWARDS
  10. from ray.rllib.models.torch.torch_action_dist import TorchCategorical
  11. from ray.rllib.policy.policy import Policy
  12. from ray.rllib.policy.rnn_sequencing import chop_into_sequences
  13. from ray.rllib.policy.sample_batch import SampleBatch
  14. from ray.rllib.models.catalog import ModelCatalog
  15. from ray.rllib.models.modelv2 import _unpack_obs
  16. from ray.rllib.utils.framework import try_import_torch
  17. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  18. from ray.rllib.utils.annotations import override
  19. # Torch must be installed.
  20. torch, nn = try_import_torch(error=True)
  21. logger = logging.getLogger(__name__)
  22. class QMixLoss(nn.Module):
  23. def __init__(self,
  24. model,
  25. target_model,
  26. mixer,
  27. target_mixer,
  28. n_agents,
  29. n_actions,
  30. double_q=True,
  31. gamma=0.99):
  32. nn.Module.__init__(self)
  33. self.model = model
  34. self.target_model = target_model
  35. self.mixer = mixer
  36. self.target_mixer = target_mixer
  37. self.n_agents = n_agents
  38. self.n_actions = n_actions
  39. self.double_q = double_q
  40. self.gamma = gamma
  41. def forward(self,
  42. rewards,
  43. actions,
  44. terminated,
  45. mask,
  46. obs,
  47. next_obs,
  48. action_mask,
  49. next_action_mask,
  50. state=None,
  51. next_state=None):
  52. """Forward pass of the loss.
  53. Args:
  54. rewards: Tensor of shape [B, T, n_agents]
  55. actions: Tensor of shape [B, T, n_agents]
  56. terminated: Tensor of shape [B, T, n_agents]
  57. mask: Tensor of shape [B, T, n_agents]
  58. obs: Tensor of shape [B, T, n_agents, obs_size]
  59. next_obs: Tensor of shape [B, T, n_agents, obs_size]
  60. action_mask: Tensor of shape [B, T, n_agents, n_actions]
  61. next_action_mask: Tensor of shape [B, T, n_agents, n_actions]
  62. state: Tensor of shape [B, T, state_dim] (optional)
  63. next_state: Tensor of shape [B, T, state_dim] (optional)
  64. """
  65. # Assert either none or both of state and next_state are given
  66. if state is None and next_state is None:
  67. state = obs # default to state being all agents' observations
  68. next_state = next_obs
  69. elif (state is None) != (next_state is None):
  70. raise ValueError("Expected either neither or both of `state` and "
  71. "`next_state` to be given. Got: "
  72. "\n`state` = {}\n`next_state` = {}".format(
  73. state, next_state))
  74. # Calculate estimated Q-Values
  75. mac_out = _unroll_mac(self.model, obs)
  76. # Pick the Q-Values for the actions taken -> [B * n_agents, T]
  77. chosen_action_qvals = torch.gather(
  78. mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3)
  79. # Calculate the Q-Values necessary for the target
  80. target_mac_out = _unroll_mac(self.target_model, next_obs)
  81. # Mask out unavailable actions for the t+1 step
  82. ignore_action_tp1 = (next_action_mask == 0) & (mask == 1).unsqueeze(-1)
  83. target_mac_out[ignore_action_tp1] = -np.inf
  84. # Max over target Q-Values
  85. if self.double_q:
  86. # Double Q learning computes the target Q values by selecting the
  87. # t+1 timestep action according to the "policy" neural network and
  88. # then estimating the Q-value of that action with the "target"
  89. # neural network
  90. # Compute the t+1 Q-values to be used in action selection
  91. # using next_obs
  92. mac_out_tp1 = _unroll_mac(self.model, next_obs)
  93. # mask out unallowed actions
  94. mac_out_tp1[ignore_action_tp1] = -np.inf
  95. # obtain best actions at t+1 according to policy NN
  96. cur_max_actions = mac_out_tp1.argmax(dim=3, keepdim=True)
  97. # use the target network to estimate the Q-values of policy
  98. # network's selected actions
  99. target_max_qvals = torch.gather(target_mac_out, 3,
  100. cur_max_actions).squeeze(3)
  101. else:
  102. target_max_qvals = target_mac_out.max(dim=3)[0]
  103. assert target_max_qvals.min().item() != -np.inf, \
  104. "target_max_qvals contains a masked action; \
  105. there may be a state with no valid actions."
  106. # Mix
  107. if self.mixer is not None:
  108. chosen_action_qvals = self.mixer(chosen_action_qvals, state)
  109. target_max_qvals = self.target_mixer(target_max_qvals, next_state)
  110. # Calculate 1-step Q-Learning targets
  111. targets = rewards + self.gamma * (1 - terminated) * target_max_qvals
  112. # Td-error
  113. td_error = (chosen_action_qvals - targets.detach())
  114. mask = mask.expand_as(td_error)
  115. # 0-out the targets that came from padded data
  116. masked_td_error = td_error * mask
  117. # Normal L2 loss, take mean over actual data
  118. loss = (masked_td_error**2).sum() / mask.sum()
  119. return loss, mask, masked_td_error, chosen_action_qvals, targets
  120. # TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
  121. class QMixTorchPolicy(Policy):
  122. """QMix impl. Assumes homogeneous agents for now.
  123. You must use MultiAgentEnv.with_agent_groups() to group agents
  124. together for QMix. This creates the proper Tuple obs/action spaces and
  125. populates the '_group_rewards' info field.
  126. Action masking: to specify an action mask for individual agents, use a
  127. dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
  128. The mask space must be `Box(0, 1, (n_actions,))`.
  129. """
  130. def __init__(self, obs_space, action_space, config):
  131. _validate(obs_space, action_space)
  132. config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
  133. self.framework = "torch"
  134. super().__init__(obs_space, action_space, config)
  135. self.n_agents = len(obs_space.original_space.spaces)
  136. config["model"]["n_agents"] = self.n_agents
  137. self.n_actions = action_space.spaces[0].n
  138. self.h_size = config["model"]["lstm_cell_size"]
  139. self.has_env_global_state = False
  140. self.has_action_mask = False
  141. self.device = (torch.device("cuda")
  142. if torch.cuda.is_available() else torch.device("cpu"))
  143. agent_obs_space = obs_space.original_space.spaces[0]
  144. if isinstance(agent_obs_space, Dict):
  145. space_keys = set(agent_obs_space.spaces.keys())
  146. if "obs" not in space_keys:
  147. raise ValueError(
  148. "Dict obs space must have subspace labeled `obs`")
  149. self.obs_size = _get_size(agent_obs_space.spaces["obs"])
  150. if "action_mask" in space_keys:
  151. mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
  152. if mask_shape != (self.n_actions, ):
  153. raise ValueError(
  154. "Action mask shape must be {}, got {}".format(
  155. (self.n_actions, ), mask_shape))
  156. self.has_action_mask = True
  157. if ENV_STATE in space_keys:
  158. self.env_global_state_shape = _get_size(
  159. agent_obs_space.spaces[ENV_STATE])
  160. self.has_env_global_state = True
  161. else:
  162. self.env_global_state_shape = (self.obs_size, self.n_agents)
  163. # The real agent obs space is nested inside the dict
  164. config["model"]["full_obs_space"] = agent_obs_space
  165. agent_obs_space = agent_obs_space.spaces["obs"]
  166. else:
  167. self.obs_size = _get_size(agent_obs_space)
  168. self.env_global_state_shape = (self.obs_size, self.n_agents)
  169. self.model = ModelCatalog.get_model_v2(
  170. agent_obs_space,
  171. action_space.spaces[0],
  172. self.n_actions,
  173. config["model"],
  174. framework="torch",
  175. name="model",
  176. default_model=RNNModel).to(self.device)
  177. self.target_model = ModelCatalog.get_model_v2(
  178. agent_obs_space,
  179. action_space.spaces[0],
  180. self.n_actions,
  181. config["model"],
  182. framework="torch",
  183. name="target_model",
  184. default_model=RNNModel).to(self.device)
  185. self.exploration = self._create_exploration()
  186. # Setup the mixer network.
  187. if config["mixer"] is None:
  188. self.mixer = None
  189. self.target_mixer = None
  190. elif config["mixer"] == "qmix":
  191. self.mixer = QMixer(self.n_agents, self.env_global_state_shape,
  192. config["mixing_embed_dim"]).to(self.device)
  193. self.target_mixer = QMixer(
  194. self.n_agents, self.env_global_state_shape,
  195. config["mixing_embed_dim"]).to(self.device)
  196. elif config["mixer"] == "vdn":
  197. self.mixer = VDNMixer().to(self.device)
  198. self.target_mixer = VDNMixer().to(self.device)
  199. else:
  200. raise ValueError("Unknown mixer type {}".format(config["mixer"]))
  201. self.cur_epsilon = 1.0
  202. self.update_target() # initial sync
  203. # Setup optimizer
  204. self.params = list(self.model.parameters())
  205. if self.mixer:
  206. self.params += list(self.mixer.parameters())
  207. self.loss = QMixLoss(self.model, self.target_model, self.mixer,
  208. self.target_mixer, self.n_agents, self.n_actions,
  209. self.config["double_q"], self.config["gamma"])
  210. from torch.optim import RMSprop
  211. self.optimiser = RMSprop(
  212. params=self.params,
  213. lr=config["lr"],
  214. alpha=config["optim_alpha"],
  215. eps=config["optim_eps"])
  216. @override(Policy)
  217. def compute_actions(self,
  218. obs_batch,
  219. state_batches=None,
  220. prev_action_batch=None,
  221. prev_reward_batch=None,
  222. info_batch=None,
  223. episodes=None,
  224. explore=None,
  225. timestep=None,
  226. **kwargs):
  227. explore = explore if explore is not None else self.config["explore"]
  228. obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
  229. # We need to ensure we do not use the env global state
  230. # to compute actions
  231. # Compute actions
  232. with torch.no_grad():
  233. q_values, hiddens = _mac(
  234. self.model,
  235. torch.as_tensor(
  236. obs_batch, dtype=torch.float, device=self.device), [
  237. torch.as_tensor(
  238. np.array(s), dtype=torch.float, device=self.device)
  239. for s in state_batches
  240. ])
  241. avail = torch.as_tensor(
  242. action_mask, dtype=torch.float, device=self.device)
  243. masked_q_values = q_values.clone()
  244. masked_q_values[avail == 0.0] = -float("inf")
  245. masked_q_values_folded = torch.reshape(
  246. masked_q_values, [-1] + list(masked_q_values.shape)[2:])
  247. actions, _ = self.exploration.get_exploration_action(
  248. action_distribution=TorchCategorical(masked_q_values_folded),
  249. timestep=timestep,
  250. explore=explore)
  251. actions = torch.reshape(
  252. actions,
  253. list(masked_q_values.shape)[:-1]).cpu().numpy()
  254. hiddens = [s.cpu().numpy() for s in hiddens]
  255. return tuple(actions.transpose([1, 0])), hiddens, {}
  256. @override(Policy)
  257. def compute_log_likelihoods(self,
  258. actions,
  259. obs_batch,
  260. state_batches=None,
  261. prev_action_batch=None,
  262. prev_reward_batch=None):
  263. obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
  264. return np.zeros(obs_batch.size()[0])
  265. @override(Policy)
  266. def learn_on_batch(self, samples):
  267. obs_batch, action_mask, env_global_state = self._unpack_observation(
  268. samples[SampleBatch.CUR_OBS])
  269. (next_obs_batch, next_action_mask,
  270. next_env_global_state) = self._unpack_observation(
  271. samples[SampleBatch.NEXT_OBS])
  272. group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])
  273. input_list = [
  274. group_rewards, action_mask, next_action_mask,
  275. samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
  276. obs_batch, next_obs_batch
  277. ]
  278. if self.has_env_global_state:
  279. input_list.extend([env_global_state, next_env_global_state])
  280. output_list, _, seq_lens = \
  281. chop_into_sequences(
  282. episode_ids=samples[SampleBatch.EPS_ID],
  283. unroll_ids=samples[SampleBatch.UNROLL_ID],
  284. agent_indices=samples[SampleBatch.AGENT_INDEX],
  285. feature_columns=input_list,
  286. state_columns=[], # RNN states not used here
  287. max_seq_len=self.config["model"]["max_seq_len"],
  288. dynamic_max=True,
  289. )
  290. # These will be padded to shape [B * T, ...]
  291. if self.has_env_global_state:
  292. (rew, action_mask, next_action_mask, act, dones, obs, next_obs,
  293. env_global_state, next_env_global_state) = output_list
  294. else:
  295. (rew, action_mask, next_action_mask, act, dones, obs,
  296. next_obs) = output_list
  297. B, T = len(seq_lens), max(seq_lens)
  298. def to_batches(arr, dtype):
  299. new_shape = [B, T] + list(arr.shape[1:])
  300. return torch.as_tensor(
  301. np.reshape(arr, new_shape), dtype=dtype, device=self.device)
  302. rewards = to_batches(rew, torch.float)
  303. actions = to_batches(act, torch.long)
  304. obs = to_batches(obs, torch.float).reshape(
  305. [B, T, self.n_agents, self.obs_size])
  306. action_mask = to_batches(action_mask, torch.float)
  307. next_obs = to_batches(next_obs, torch.float).reshape(
  308. [B, T, self.n_agents, self.obs_size])
  309. next_action_mask = to_batches(next_action_mask, torch.float)
  310. if self.has_env_global_state:
  311. env_global_state = to_batches(env_global_state, torch.float)
  312. next_env_global_state = to_batches(next_env_global_state,
  313. torch.float)
  314. # TODO(ekl) this treats group termination as individual termination
  315. terminated = to_batches(dones, torch.float).unsqueeze(2).expand(
  316. B, T, self.n_agents)
  317. # Create mask for where index is < unpadded sequence length
  318. filled = np.reshape(
  319. np.tile(np.arange(T, dtype=np.float32), B),
  320. [B, T]) < np.expand_dims(seq_lens, 1)
  321. mask = torch.as_tensor(
  322. filled, dtype=torch.float, device=self.device).unsqueeze(2).expand(
  323. B, T, self.n_agents)
  324. # Compute loss
  325. loss_out, mask, masked_td_error, chosen_action_qvals, targets = (
  326. self.loss(rewards, actions, terminated, mask, obs, next_obs,
  327. action_mask, next_action_mask, env_global_state,
  328. next_env_global_state))
  329. # Optimise
  330. self.optimiser.zero_grad()
  331. loss_out.backward()
  332. grad_norm = torch.nn.utils.clip_grad_norm_(
  333. self.params, self.config["grad_norm_clipping"])
  334. self.optimiser.step()
  335. mask_elems = mask.sum().item()
  336. stats = {
  337. "loss": loss_out.item(),
  338. "grad_norm": grad_norm
  339. if isinstance(grad_norm, float) else grad_norm.item(),
  340. "td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
  341. "q_taken_mean": (chosen_action_qvals * mask).sum().item() /
  342. mask_elems,
  343. "target_mean": (targets * mask).sum().item() / mask_elems,
  344. }
  345. return {LEARNER_STATS_KEY: stats}
  346. @override(Policy)
  347. def get_initial_state(self): # initial RNN state
  348. return [
  349. s.expand([self.n_agents, -1]).cpu().numpy()
  350. for s in self.model.get_initial_state()
  351. ]
  352. @override(Policy)
  353. def get_weights(self):
  354. return {
  355. "model": self._cpu_dict(self.model.state_dict()),
  356. "target_model": self._cpu_dict(self.target_model.state_dict()),
  357. "mixer": self._cpu_dict(self.mixer.state_dict())
  358. if self.mixer else None,
  359. "target_mixer": self._cpu_dict(self.target_mixer.state_dict())
  360. if self.mixer else None,
  361. }
  362. @override(Policy)
  363. def set_weights(self, weights):
  364. self.model.load_state_dict(self._device_dict(weights["model"]))
  365. self.target_model.load_state_dict(
  366. self._device_dict(weights["target_model"]))
  367. if weights["mixer"] is not None:
  368. self.mixer.load_state_dict(self._device_dict(weights["mixer"]))
  369. self.target_mixer.load_state_dict(
  370. self._device_dict(weights["target_mixer"]))
  371. @override(Policy)
  372. def get_state(self):
  373. state = self.get_weights()
  374. state["cur_epsilon"] = self.cur_epsilon
  375. return state
  376. @override(Policy)
  377. def set_state(self, state):
  378. self.set_weights(state)
  379. self.set_epsilon(state["cur_epsilon"])
  380. def update_target(self):
  381. self.target_model.load_state_dict(self.model.state_dict())
  382. if self.mixer is not None:
  383. self.target_mixer.load_state_dict(self.mixer.state_dict())
  384. logger.debug("Updated target networks")
  385. def set_epsilon(self, epsilon):
  386. self.cur_epsilon = epsilon
  387. def _get_group_rewards(self, info_batch):
  388. group_rewards = np.array([
  389. info.get(GROUP_REWARDS, [0.0] * self.n_agents)
  390. for info in info_batch
  391. ])
  392. return group_rewards
  393. def _device_dict(self, state_dict):
  394. return {
  395. k: torch.as_tensor(v, device=self.device)
  396. for k, v in state_dict.items()
  397. }
  398. @staticmethod
  399. def _cpu_dict(state_dict):
  400. return {k: v.cpu().detach().numpy() for k, v in state_dict.items()}
  401. def _unpack_observation(self, obs_batch):
  402. """Unpacks the observation, action mask, and state (if present)
  403. from agent grouping.
  404. Returns:
  405. obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size]
  406. mask (np.ndarray): action mask, if any
  407. state (np.ndarray or None): state tensor of shape [B, state_size]
  408. or None if it is not in the batch
  409. """
  410. unpacked = _unpack_obs(
  411. np.array(obs_batch, dtype=np.float32),
  412. self.observation_space.original_space,
  413. tensorlib=np)
  414. if isinstance(unpacked[0], dict):
  415. assert "obs" in unpacked[0]
  416. unpacked_obs = [
  417. np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked
  418. ]
  419. else:
  420. unpacked_obs = unpacked
  421. obs = np.concatenate(
  422. unpacked_obs,
  423. axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
  424. if self.has_action_mask:
  425. action_mask = np.concatenate(
  426. [o["action_mask"] for o in unpacked], axis=1).reshape(
  427. [len(obs_batch), self.n_agents, self.n_actions])
  428. else:
  429. action_mask = np.ones(
  430. [len(obs_batch), self.n_agents, self.n_actions],
  431. dtype=np.float32)
  432. if self.has_env_global_state:
  433. state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1)
  434. else:
  435. state = None
  436. return obs, action_mask, state
  437. def _validate(obs_space, action_space):
  438. if not hasattr(obs_space, "original_space") or \
  439. not isinstance(obs_space.original_space, Tuple):
  440. raise ValueError("Obs space must be a Tuple, got {}. Use ".format(
  441. obs_space) + "MultiAgentEnv.with_agent_groups() to group related "
  442. "agents for QMix.")
  443. if not isinstance(action_space, Tuple):
  444. raise ValueError(
  445. "Action space must be a Tuple, got {}. ".format(action_space) +
  446. "Use MultiAgentEnv.with_agent_groups() to group related "
  447. "agents for QMix.")
  448. if not isinstance(action_space.spaces[0], Discrete):
  449. raise ValueError(
  450. "QMix requires a discrete action space, got {}".format(
  451. action_space.spaces[0]))
  452. if len({str(x) for x in obs_space.original_space.spaces}) > 1:
  453. raise ValueError(
  454. "Implementation limitation: observations of grouped agents "
  455. "must be homogeneous, got {}".format(
  456. obs_space.original_space.spaces))
  457. if len({str(x) for x in action_space.spaces}) > 1:
  458. raise ValueError(
  459. "Implementation limitation: action space of grouped agents "
  460. "must be homogeneous, got {}".format(action_space.spaces))
  461. def _mac(model, obs, h):
  462. """Forward pass of the multi-agent controller.
  463. Args:
  464. model: TorchModelV2 class
  465. obs: Tensor of shape [B, n_agents, obs_size]
  466. h: List of tensors of shape [B, n_agents, h_size]
  467. Returns:
  468. q_vals: Tensor of shape [B, n_agents, n_actions]
  469. h: Tensor of shape [B, n_agents, h_size]
  470. """
  471. B, n_agents = obs.size(0), obs.size(1)
  472. if not isinstance(obs, dict):
  473. obs = {"obs": obs}
  474. obs_agents_as_batches = {k: _drop_agent_dim(v) for k, v in obs.items()}
  475. h_flat = [s.reshape([B * n_agents, -1]) for s in h]
  476. q_flat, h_flat = model(obs_agents_as_batches, h_flat, None)
  477. return q_flat.reshape(
  478. [B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat]
  479. def _unroll_mac(model, obs_tensor):
  480. """Computes the estimated Q values for an entire trajectory batch"""
  481. B = obs_tensor.size(0)
  482. T = obs_tensor.size(1)
  483. n_agents = obs_tensor.size(2)
  484. mac_out = []
  485. h = [s.expand([B, n_agents, -1]) for s in model.get_initial_state()]
  486. for t in range(T):
  487. q, h = _mac(model, obs_tensor[:, t], h)
  488. mac_out.append(q)
  489. mac_out = torch.stack(mac_out, dim=1) # Concat over time
  490. return mac_out
  491. def _drop_agent_dim(T):
  492. shape = list(T.shape)
  493. B, n_agents = shape[0], shape[1]
  494. return T.reshape([B * n_agents] + shape[2:])
  495. def _add_agent_dim(T, n_agents):
  496. shape = list(T.shape)
  497. B = shape[0] // n_agents
  498. assert shape[0] % n_agents == 0
  499. return T.reshape([B, n_agents] + shape[1:])