123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546 |
- '''
- Copyright 2021 The Microsoft DeepSpeed Team
- '''
- # The file has been adapted from two fairscale files:
- # (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
- # (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
- # Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
- # We retain the following license from the original files:
- # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
- from deepspeed.utils import logger, log_dist
- from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union, cast
- import time
- from time import perf_counter
- import torch
- from torch import Tensor
- import torch.distributed as dist
- from torch.nn import Module, ModuleList
- import torch.nn.functional as F
- if TYPE_CHECKING:
- Base = Module[Tensor]
- else:
- Base = Module
- uniform_map: Dict[torch.device, Callable] = {}
- gumbel_map: Dict[torch.device, Callable] = {}
- exp_selection_uniform_map: Dict[torch.device, Callable] = {}
- try:
- # To enable Tutel MoE optimizations:
- # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
- from tutel import moe as tutel_moe
- TUTEL_INSTALLED = True
- except:
- # Fail silently so we don't spam logs unnecessarily if user isn't using tutel
- TUTEL_INSTALLED = False
- pass
- def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
- """
- Modified from switch transformer paper. mesh transformers
- Multiply values by a random number between 1-epsilon and 1+epsilon.
- Makes models more resilient to rounding errors introduced by bfloat16.
- This seems particularly important for logits.
- Args:
- x: a torch.tensor
- device: torch.device
- epsilon: a floating point value
- Returns:
- a jittered x.
- """
- if epsilon == 0:
- return x
- uniform = uniform_map.get(device)
- if uniform is None:
- uniform = torch.distributions.uniform.Uniform(
- low=torch.tensor(1.0 - epsilon,
- device=device),
- high=torch.tensor(1.0 + epsilon,
- device=device)).rsample # type: ignore
- uniform_map[device] = uniform
- return x * uniform(x.shape)
- def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
- gumbel = gumbel_map.get(device)
- if gumbel is None:
- one = torch.tensor(1.0, device=device)
- zero = torch.tensor(0.0, device=device)
- gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
- gumbel_map[device] = gumbel
- return gumbel(shape)
- # Based on https://github.com/pytorch/pytorch/pull/40762
- class _AllToAll(torch.autograd.Function):
- @staticmethod
- def forward(ctx: Any,
- group: dist.ProcessGroup,
- input: Tensor) -> Tensor: # type: ignore
- ctx.group = group
- input = input.contiguous()
- output = torch.empty_like(input)
- dist.all_to_all_single(output, input, group=group)
- return output
- @staticmethod
- def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
- return (None, _AllToAll.apply(ctx.group, *grad_output))
- # einsum rewrites are on par or more performant
- # switch can be bubbled up in future
- USE_EINSUM = True
- # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
- # See https://arxiv.org/pdf/2006.16668.pdf for details.
- def einsum(rule, a, b):
- if USE_EINSUM:
- return torch.einsum(rule, a, b)
- elif rule == 's,se->se':
- return a.reshape(a.shape[0], -1) * b
- elif rule == 'se,sc->sec':
- return a.unsqueeze(2) * b.unsqueeze(1)
- elif rule == 'se,se->s':
- return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
- elif rule == 'sec,sm->ecm':
- s = a.shape[0]
- e = a.shape[1]
- c = a.shape[2]
- m = b.shape[1]
- return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
- elif rule == 'sec,ecm->sm':
- return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
- elif rule == 'ks,ksm->sm':
- k = b.shape[0]
- s = b.shape[1]
- m = b.shape[2]
- # [k, s] -> [s, k] -> [s, 1, k]
- a = a.t().unsqueeze(1)
- # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
- b = b.reshape(k, -1).t().reshape(s, m, k)
- # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
- return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
- else:
- return torch.einsum(rule, a, b)
- # The following functions are extracted and scripted
- # because otherwise during a torch.jit.trace, the non-Tensor
- # values used in the calculations get recorded as constants.
- # torch.jit.script coerces them into Tensors and preserves
- # their dynamic shapes. This enables ONNX export.
- # We can't script the entire top1gating function because it
- # includes stateful caching logic which is incompatible with ONNX.
- @torch.jit.script
- def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
- # gates has shape of SE
- num_tokens = gates.shape[0]
- num_experts = gates.shape[1]
- # to(torch.int64) works around a bug in torch.onnx.export:
- # it should cast k to int64 when converting torch.topk but it doesn't.
- capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
- if capacity < min_capacity:
- capacity = min_capacity.to(torch.int64)
- return capacity
- @torch.jit.script
- def _top_idx(source, k):
- return torch.topk(source, k=k, dim=0)[1]
- @torch.jit.script
- def _one_hot_to_float(x, num_classes):
- return F.one_hot(x, num_classes=num_classes).float()
- def top1gating(logits: Tensor,
- capacity_factor: float,
- min_capacity: int,
- used_token: Tensor = None,
- noisy_gate_policy: Optional[str] = None,
- drop_tokens: bool = True,
- use_rts: bool = True,
- use_tutel: bool = False) -> Tuple[Tensor,
- Tensor,
- Tensor,
- Tensor]:
- """Implements Top1Gating on logits."""
- if noisy_gate_policy == 'RSample':
- logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
- # everything is in fp32 in this function
- gates = F.softmax(logits, dim=1)
- capacity = _capacity(gates,
- torch.tensor(capacity_factor),
- torch.tensor(min_capacity))
- # Create a mask for 1st's expert per token
- # noisy gating
- indices1_s = torch.argmax(
- logits_w_noise if noisy_gate_policy == 'RSample' else gates,
- dim=1)
- num_experts = int(gates.shape[1])
- mask1 = F.one_hot(indices1_s, num_classes=num_experts)
- # mask only used tokens
- if used_token is not None:
- mask1 = einsum("s,se->se", used_token, mask1)
- # gating decisions
- exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
- # if we don't want to drop any tokens
- if not drop_tokens:
- new_capacity = torch.max(exp_counts).to(logits.device)
- dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
- capacity = new_capacity
- # Compute l_aux
- me = torch.mean(gates, dim=0)
- ce = torch.mean(mask1.float(), dim=0)
- l_aux = torch.sum(me * ce) * num_experts
- # Random Token Selection
- if use_rts:
- uniform = exp_selection_uniform_map.get(logits.device)
- if uniform is None:
- uniform = torch.distributions.uniform.Uniform(
- low=torch.tensor(0.0,
- device=logits.device),
- high=torch.tensor(1.0,
- device=logits.device)).rsample
- exp_selection_uniform_map[logits.device] = uniform
- mask1_rand = mask1 * uniform(mask1.shape)
- assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
- top_idx = _top_idx(mask1_rand, capacity)
- new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
- mask1 = new_mask1
- if use_tutel:
- # Tutel doesn't support index values masked with zero
- # so we need to replace masked indices with -1
- indices_mask = mask1.sum(dim=1) * num_experts - 1
- indices1_s = torch.min(indices1_s, indices_mask)
- # Compute locations in capacity buffer
- if use_tutel:
- locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
- else:
- locations1 = torch.cumsum(mask1, dim=0) - 1
- if use_tutel:
- gates1_s = (gates * mask1).sum(dim=1)
- locations1_s = torch.sum(locations1 * mask1, dim=1)
- return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,], exp_counts
- # Store the capacity location for each token
- locations1_s = torch.sum(locations1 * mask1, dim=1)
- # Normalize gate probabilities
- mask1_float = mask1.float()
- gates = gates * mask1_float
- locations1_sc = _one_hot_to_float(locations1_s, capacity)
- combine_weights = einsum("se,sc->sec", gates, locations1_sc)
- dispatch_mask = combine_weights.bool()
- return l_aux, combine_weights, dispatch_mask, exp_counts
- def top2gating(logits: Tensor,
- capacity_factor: float) -> Tuple[Tensor,
- Tensor,
- Tensor,
- Tensor]:
- """Implements Top2Gating on logits."""
- # everything is in fp32 in this function
- gates = F.softmax(logits, dim=1)
- capacity = _capacity(gates,
- torch.tensor(capacity_factor * 2),
- torch.tensor(min_capacity))
- # Create a mask for 1st's expert per token
- indices1_s = torch.argmax(gates, dim=1)
- num_experts = int(gates.shape[1])
- mask1 = F.one_hot(indices1_s, num_classes=num_experts)
- # Create a mask for 2nd's expert per token using Gumbel-max trick
- # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
- logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
- # Replace top-expert with min value
- logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
- indices2_s = torch.argmax(logits_except1, dim=1)
- mask2 = F.one_hot(indices2_s, num_classes=num_experts)
- # Compute locations in capacity buffer
- locations1 = torch.cumsum(mask1, dim=0) - 1
- locations2 = torch.cumsum(mask2, dim=0) - 1
- # Update 2nd's location by accounting for locations of 1st
- locations2 += torch.sum(mask1, dim=0, keepdim=True)
- # gating decisions
- exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
- # Compute l_aux
- me = torch.mean(gates, dim=0)
- ce = torch.mean(mask1.float(), dim=0)
- l_aux = torch.mean(me * ce) * num_experts * num_experts
- # Remove locations outside capacity from mask
- mask1 *= torch.lt(locations1, capacity)
- mask2 *= torch.lt(locations2, capacity)
- # Store the capacity location for each token
- locations1_s = torch.sum(locations1 * mask1, dim=1)
- locations2_s = torch.sum(locations2 * mask2, dim=1)
- # Normalize gate probabilities
- mask1_float = mask1.float()
- mask2_float = mask2.float()
- gates1_s = einsum("se,se->s", gates, mask1_float)
- gates2_s = einsum("se,se->s", gates, mask2_float)
- denom_s = gates1_s + gates2_s
- # Avoid divide-by-zero
- denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
- gates1_s /= denom_s
- gates2_s /= denom_s
- # Calculate combine_weights and dispatch_mask
- gates1 = einsum("s,se->se", gates1_s, mask1_float)
- gates2 = einsum("s,se->se", gates2_s, mask2_float)
- locations1_sc = _one_hot_to_float(locations1_s, capacity)
- locations2_sc = _one_hot_to_float(locations2_s, capacity)
- combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
- combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
- combine_weights = combine1_sec + combine2_sec
- dispatch_mask = combine_weights.bool()
- return l_aux, combine_weights, dispatch_mask, exp_counts
- class TopKGate(Module):
- """Gate module which implements Top2Gating as described in Gshard_.
- ::
- gate = TopKGate(model_dim, num_experts)
- l_aux, combine_weights, dispatch_mask = gate(input)
- .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
- Args:
- model_dim (int):
- size of model embedding dimension
- num_experts (ints):
- number of experts in model
- """
- wg: torch.nn.Linear
- def __init__(self,
- model_dim: int,
- num_experts: int,
- k: int = 1,
- capacity_factor: float = 1.0,
- eval_capacity_factor: float = 1.0,
- min_capacity: int = 4,
- noisy_gate_policy: Optional[str] = None,
- drop_tokens: bool = True,
- use_rts: bool = True) -> None:
- super().__init__()
- # Only top-1 and top-2 are supported at the moment.
- if k != 1 and k != 2:
- raise ValueError('Only top-1 and top-2 gatings are supported.')
- self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
- self.k = k
- self.capacity_factor = capacity_factor
- self.eval_capacity_factor = eval_capacity_factor
- self.min_capacity = min_capacity
- self.noisy_gate_policy = noisy_gate_policy
- self.timers = SynchronizedWallClockTimer()
- self.wall_clock_breakdown = False
- self.gate_time = 0.0
- self.drop_tokens = drop_tokens
- self.use_rts = use_rts
- def forward(
- self,
- input: torch.Tensor,
- used_token: torch.Tensor = None,
- use_tutel: bool = False) -> Tuple[Tensor,
- Tensor,
- Tensor]: # type: ignore
- if self.wall_clock_breakdown:
- self.timers('TopKGate').start()
- if self.wg.weight.dtype != torch.float32:
- self.wg = self.wg.float()
- input_fp32 = input.float()
- # input jittering
- if self.noisy_gate_policy == 'Jitter' and self.training:
- input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
- logits = self.wg(input_fp32)
- if self.k == 1:
- gate_output = top1gating(
- logits,
- self.capacity_factor if self.training else self.eval_capacity_factor,
- self.min_capacity,
- used_token,
- self.noisy_gate_policy if self.training else None,
- self.drop_tokens,
- self.use_rts,
- use_tutel)
- else:
- gate_output = top2gating(
- logits,
- self.capacity_factor if self.training else self.eval_capacity_factor)
- if self.wall_clock_breakdown:
- self.timers('TopKGate').stop()
- self.gate_time = self.timers('TopKGate').elapsed(reset=False) * 1000
- return gate_output
- class MOELayer(Base):
- """MOELayer module which implements MixtureOfExperts as described in Gshard_.
- ::
- gate = TopKGate(model_dim, num_experts)
- moe = MOELayer(gate, expert)
- output = moe(input)
- l_aux = moe.l_aux
- .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
- Args:
- gate (torch.nn.Module):
- gate network
- expert (torch.nn.Module):
- expert network
- """
- def __init__(self,
- gate: Module,
- experts: Module,
- num_local_experts: int,
- group: Optional[Any] = None,
- use_tutel: bool = False) -> None:
- super().__init__()
- self.gate = gate
- self.experts = experts
- self.group = group
- self.world_size = dist.get_world_size(group)
- self.num_local_experts = num_local_experts
- self.time_falltoall = 0.0
- self.time_salltoall = 0.0
- self.time_moe = 0.0
- self.timers = SynchronizedWallClockTimer()
- self.wall_clock_breakdown = False
- self.use_tutel = use_tutel and TUTEL_INSTALLED
- if self.use_tutel:
- logger.info('Using Tutel optimizations.')
- elif use_tutel and not TUTEL_INSTALLED:
- logger.warning("Tutel optimization requested but not installed. "
- "Proceeding without Tutel.")
- def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
- if self.wall_clock_breakdown:
- self.timers('moe').start()
- # Implement Algorithm 2 from GShard paper.
- d_model = input[0].shape[-1]
- # Initial implementation -> Reshape into S tokens by dropping sequence dimension.
- # Reshape into G groups so that each group can distribute tokens equally
- # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
- reshaped_input = input[0].reshape(-1, d_model)
- if self.use_tutel:
- self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
- S, M = reshaped_input.size(0), reshaped_input.size(1)
- if not hasattr(self, '_tutel_dispatcher'):
- self._tutel_dispatcher = tutel_moe.fast_dispatcher(
- E,
- C,
- M,
- dispatch_dtype=reshaped_input.dtype)
- self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
- dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
- else:
- self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
- dispatched_input = einsum("sec,sm->ecm",
- dispatch_mask.type_as(input[0]),
- reshaped_input)
- if self.wall_clock_breakdown:
- self.timers('falltoall').start()
- dispatched_input = _AllToAll.apply(self.group, dispatched_input)
- if self.wall_clock_breakdown:
- self.timers('falltoall').stop()
- self.time_falltoall = self.timers('falltoall').elapsed(reset=False) * 1000
- # Re-shape after all-to-all: ecm -> gecm
- dispatched_input = dispatched_input.reshape(self.world_size,
- self.num_local_experts,
- -1,
- d_model)
- expert_output = self.experts(dispatched_input)
- if self.wall_clock_breakdown:
- self.timers('salltoall').start()
- expert_output = _AllToAll.apply(self.group, expert_output)
- if self.wall_clock_breakdown:
- self.timers('salltoall').stop()
- self.time_salltoall = self.timers('salltoall').elapsed(reset=False) * 1000
- # Re-shape back: gecm -> ecm
- expert_output = expert_output.reshape(self.world_size * self.num_local_experts,
- -1,
- d_model)
- if self.use_tutel:
- combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
- else:
- combined_output = einsum("sec,ecm->sm",
- combine_weights.type_as(input[0]),
- expert_output)
- a = combined_output.reshape(input[0].shape)
- if self.wall_clock_breakdown:
- self.timers('moe').stop()
- self.time_moe = self.timers('moe').elapsed(reset=False) * 1000
- return a
|