''' Copyright 2020 The Microsoft DeepSpeed Team ''' import torch import copy class Experts(torch.nn.Module): def __init__(self, expert, num_local_experts=1): super(Experts, self).__init__() self.deepspeed_experts = torch.nn.ModuleList( [copy.deepcopy(expert) for i in range(num_local_experts)]) self.num_local_experts = num_local_experts # TODO: revisit allreduce for moe.gate... for expert in self.deepspeed_experts: # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) for name, param in expert.named_parameters(): param.allreduce = False def forward(self, inputs): chunks = inputs.chunk(self.num_local_experts, dim=1) expert_outputs = [] for chunk, expert in zip(chunks, self.deepspeed_experts): out = expert(chunk) if type(out) is tuple: out = out[0] # Ignore the bias term for now expert_outputs += [out] expert_output = torch.cat(expert_outputs, dim=1) return expert_output