experts.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import copy
  5. from typing import List, Optional
  6. import torch
  7. from torch import nn
  8. class Experts(nn.Module):
  9. def __init__(self, expert: nn.Module, num_local_experts: int = 1, expert_group_name: Optional[str] = None) -> None:
  10. super(Experts, self).__init__()
  11. self.deepspeed_experts = nn.ModuleList([copy.deepcopy(expert) for _ in range(num_local_experts)])
  12. self.num_local_experts = num_local_experts
  13. # TODO: revisit allreduce for moe.gate...
  14. for expert in self.deepspeed_experts:
  15. # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
  16. for param in expert.parameters():
  17. param.allreduce = False
  18. param.group_name = expert_group_name
  19. def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  20. chunks = inputs.chunk(self.num_local_experts, dim=1)
  21. expert_outputs: List[torch.Tensor] = []
  22. for chunk, expert in zip(chunks, self.deepspeed_experts):
  23. out = expert(chunk)
  24. if isinstance(out, tuple):
  25. out = out[0] # Ignore the bias term for now
  26. expert_outputs += [out]
  27. return torch.cat(expert_outputs, dim=1)