experts.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import copy
  6. class Experts(torch.nn.Module):
  7. def __init__(self, expert, num_local_experts=1, expert_group_name=None):
  8. super(Experts, self).__init__()
  9. self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
  10. self.num_local_experts = num_local_experts
  11. # TODO: revisit allreduce for moe.gate...
  12. for expert in self.deepspeed_experts:
  13. # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
  14. for name, param in expert.named_parameters():
  15. param.allreduce = False
  16. param.group_name = expert_group_name
  17. def forward(self, inputs):
  18. chunks = inputs.chunk(self.num_local_experts, dim=1)
  19. expert_outputs = []
  20. for chunk, expert in zip(chunks, self.deepspeed_experts):
  21. out = expert(chunk)
  22. if type(out) is tuple:
  23. out = out[0] # Ignore the bias term for now
  24. expert_outputs += [out]
  25. expert_output = torch.cat(expert_outputs, dim=1)
  26. return expert_output