# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import copy from typing import List, Optional import torch from torch import nn class Experts(nn.Module): def __init__(self, expert: nn.Module, num_local_experts: int = 1, expert_group_name: Optional[str] = None) -> None: super(Experts, self).__init__() self.deepspeed_experts = nn.ModuleList([copy.deepcopy(expert) for _ in range(num_local_experts)]) self.num_local_experts = num_local_experts # TODO: revisit allreduce for moe.gate... for expert in self.deepspeed_experts: # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) for param in expert.parameters(): param.allreduce = False param.group_name = expert_group_name def forward(self, inputs: torch.Tensor) -> torch.Tensor: chunks = inputs.chunk(self.num_local_experts, dim=1) expert_outputs: List[torch.Tensor] = [] for chunk, expert in zip(chunks, self.deepspeed_experts): out = expert(chunk) if isinstance(out, tuple): out = out[0] # Ignore the bias term for now expert_outputs += [out] return torch.cat(expert_outputs, dim=1)