mixers.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import numpy as np
  2. from ray.rllib.utils.framework import try_import_torch
  3. torch, nn = try_import_torch()
  4. class VDNMixer(nn.Module):
  5. def __init__(self):
  6. super(VDNMixer, self).__init__()
  7. def forward(self, agent_qs, batch):
  8. return torch.sum(agent_qs, dim=2, keepdim=True)
  9. class QMixer(nn.Module):
  10. def __init__(self, n_agents, state_shape, mixing_embed_dim):
  11. super(QMixer, self).__init__()
  12. self.n_agents = n_agents
  13. self.embed_dim = mixing_embed_dim
  14. self.state_dim = int(np.prod(state_shape))
  15. self.hyper_w_1 = nn.Linear(self.state_dim,
  16. self.embed_dim * self.n_agents)
  17. self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)
  18. # State dependent bias for hidden layer
  19. self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)
  20. # V(s) instead of a bias for the last layers
  21. self.V = nn.Sequential(
  22. nn.Linear(self.state_dim, self.embed_dim), nn.ReLU(),
  23. nn.Linear(self.embed_dim, 1))
  24. def forward(self, agent_qs, states):
  25. """Forward pass for the mixer.
  26. Args:
  27. agent_qs: Tensor of shape [B, T, n_agents, n_actions]
  28. states: Tensor of shape [B, T, state_dim]
  29. """
  30. bs = agent_qs.size(0)
  31. states = states.reshape(-1, self.state_dim)
  32. agent_qs = agent_qs.view(-1, 1, self.n_agents)
  33. # First layer
  34. w1 = torch.abs(self.hyper_w_1(states))
  35. b1 = self.hyper_b_1(states)
  36. w1 = w1.view(-1, self.n_agents, self.embed_dim)
  37. b1 = b1.view(-1, 1, self.embed_dim)
  38. hidden = nn.functional.elu(torch.bmm(agent_qs, w1) + b1)
  39. # Second layer
  40. w_final = torch.abs(self.hyper_w_final(states))
  41. w_final = w_final.view(-1, self.embed_dim, 1)
  42. # State-dependent bias
  43. v = self.V(states).view(-1, 1, 1)
  44. # Compute final output
  45. y = torch.bmm(hidden, w_final) + v
  46. # Reshape and return
  47. q_tot = y.view(bs, -1, 1)
  48. return q_tot