1234567891011121314151617181920212223242526272829303132333435363738 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import copy
- from typing import List, Optional
- import torch
- from torch import nn
- class Experts(nn.Module):
- def __init__(self, expert: nn.Module, num_local_experts: int = 1, expert_group_name: Optional[str] = None) -> None:
- super(Experts, self).__init__()
- self.deepspeed_experts = nn.ModuleList([copy.deepcopy(expert) for _ in range(num_local_experts)])
- self.num_local_experts = num_local_experts
- # TODO: revisit allreduce for moe.gate...
- for expert in self.deepspeed_experts:
- # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
- for param in expert.parameters():
- param.allreduce = False
- param.group_name = expert_group_name
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
- chunks = inputs.chunk(self.num_local_experts, dim=1)
- expert_outputs: List[torch.Tensor] = []
- for chunk, expert in zip(chunks, self.deepspeed_experts):
- out = expert(chunk)
- if isinstance(out, tuple):
- out = out[0] # Ignore the bias term for now
- expert_outputs += [out]
- return torch.cat(expert_outputs, dim=1)
|