sharded_moe.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. The file has been adapted from two fairscale files:
  6. (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
  7. (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
  8. Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
  9. We retain the following license from the original files:
  10. """
  11. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  12. #
  13. # This source code is licensed under the BSD license found in the
  14. # LICENSE file in the root directory of this source tree.
  15. from deepspeed.utils.timer import SynchronizedWallClockTimer
  16. from deepspeed.utils import logger
  17. from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size
  18. from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union
  19. import torch
  20. from torch import Tensor
  21. from torch.nn import Module
  22. import torch.nn.functional as F
  23. from deepspeed.utils import groups
  24. from .mappings import drop_tokens, gather_tokens
  25. if TYPE_CHECKING:
  26. Base = Module[Tensor]
  27. else:
  28. Base = Module
  29. TOPK_GATE_TIMER = 'topk_gate'
  30. MOE_TIMER = 'moe'
  31. FIRST_ALLTOALL_TIMER = '1st_a2a'
  32. SECOND_ALLTOALL_TIMER = '2nd_a2a'
  33. uniform_map: Dict[torch.device, Callable] = {}
  34. gumbel_map: Dict[torch.device, Callable] = {}
  35. exp_selection_uniform_map: Dict[torch.device, Callable] = {}
  36. try:
  37. # To enable Tutel MoE optimizations:
  38. # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
  39. from tutel import moe as tutel_moe
  40. TUTEL_INSTALLED = True
  41. except:
  42. # Fail silently so we don't spam logs unnecessarily if user isn't using tutel
  43. TUTEL_INSTALLED = False
  44. pass
  45. def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
  46. """
  47. Modified from switch transformer paper. mesh transformers
  48. Multiply values by a random number between 1-epsilon and 1+epsilon.
  49. Makes models more resilient to rounding errors introduced by bfloat16.
  50. This seems particularly important for logits.
  51. Args:
  52. x: a torch.tensor
  53. device: torch.device
  54. epsilon: a floating point value
  55. Returns:
  56. a jittered x.
  57. """
  58. if epsilon == 0:
  59. return x
  60. uniform = uniform_map.get(device)
  61. if uniform is None:
  62. uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),
  63. high=torch.tensor(1.0 + epsilon,
  64. device=device)).rsample # type: ignore
  65. uniform_map[device] = uniform
  66. return x * uniform(x.shape)
  67. def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
  68. gumbel = gumbel_map.get(device)
  69. if gumbel is None:
  70. one = torch.tensor(1.0, device=device)
  71. zero = torch.tensor(0.0, device=device)
  72. gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
  73. gumbel_map[device] = gumbel
  74. return gumbel(shape)
  75. from deepspeed import comm as dist
  76. # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
  77. # See https://arxiv.org/pdf/2006.16668.pdf for details.
  78. # Based on https://github.com/pytorch/pytorch/pull/40762
  79. class _AllToAll(torch.autograd.Function):
  80. @staticmethod
  81. def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
  82. ctx.group = group
  83. input = input.contiguous()
  84. output = torch.empty_like(input)
  85. dist.all_to_all_single(output, input, group=group)
  86. return output
  87. @staticmethod
  88. def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
  89. return (None, _AllToAll.apply(ctx.group, *grad_output))
  90. # einsum rewrites are on par or more performant
  91. # switch can be bubbled up in future
  92. USE_EINSUM = True
  93. # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
  94. # See https://arxiv.org/pdf/2006.16668.pdf for details.
  95. def einsum(rule, a, b):
  96. if USE_EINSUM:
  97. return torch.einsum(rule, a, b)
  98. elif rule == 's,se->se':
  99. return a.reshape(a.shape[0], -1) * b
  100. elif rule == 'se,sc->sec':
  101. return a.unsqueeze(2) * b.unsqueeze(1)
  102. elif rule == 'se,se->s':
  103. return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
  104. elif rule == 'sec,sm->ecm':
  105. s = a.shape[0]
  106. e = a.shape[1]
  107. c = a.shape[2]
  108. m = b.shape[1]
  109. return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
  110. elif rule == 'sec,ecm->sm':
  111. return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
  112. elif rule == 'ks,ksm->sm':
  113. k = b.shape[0]
  114. s = b.shape[1]
  115. m = b.shape[2]
  116. # [k, s] -> [s, k] -> [s, 1, k]
  117. a = a.t().unsqueeze(1)
  118. # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
  119. b = b.reshape(k, -1).t().reshape(s, m, k)
  120. # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
  121. return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
  122. else:
  123. return torch.einsum(rule, a, b)
  124. # The following functions are extracted and scripted
  125. # because otherwise during a torch.jit.trace, the non-Tensor
  126. # values used in the calculations get recorded as constants.
  127. # torch.jit.script coerces them into Tensors and preserves
  128. # their dynamic shapes. This enables ONNX export.
  129. # We can't script the entire top1gating function because it
  130. # includes stateful caching logic which is incompatible with ONNX.
  131. @torch.jit.script
  132. def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
  133. # gates has shape of SE
  134. num_tokens = gates.shape[0]
  135. num_experts = gates.shape[1]
  136. # to(torch.int64) works around a bug in torch.onnx.export:
  137. # it should cast k to int64 when converting torch.topk but it doesn't.
  138. capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
  139. if capacity < min_capacity:
  140. capacity = min_capacity.to(torch.int64)
  141. return capacity
  142. @torch.jit.script
  143. def _top_idx(source, k):
  144. return torch.topk(source, k=k, dim=0)[1]
  145. @torch.jit.script
  146. def _one_hot_to_float(x, num_classes):
  147. return F.one_hot(x, num_classes=num_classes).float()
  148. def top1gating(logits: Tensor,
  149. capacity_factor: float,
  150. min_capacity: int,
  151. used_token: Tensor = None,
  152. noisy_gate_policy: Optional[str] = None,
  153. drop_tokens: bool = True,
  154. use_rts: bool = True,
  155. ep_group: Union[torch.distributed.ProcessGroup, None] = None,
  156. use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  157. """Implements Top1Gating on logits."""
  158. if noisy_gate_policy == 'RSample':
  159. logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
  160. # everything is in fp32 in this function
  161. gates = F.softmax(logits, dim=1)
  162. capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
  163. # Create a mask for 1st's expert per token
  164. # noisy gating
  165. indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
  166. num_experts = int(gates.shape[1])
  167. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  168. # mask only used tokens
  169. if used_token is not None:
  170. mask1 = einsum("s,se->se", used_token, mask1)
  171. # gating decisions
  172. exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
  173. # if we don't want to drop any tokens
  174. if not drop_tokens:
  175. new_capacity = torch.max(exp_counts).to(logits.device)
  176. # Communicate across expert processes to pick the maximum capacity.
  177. if ep_group is not None:
  178. dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
  179. if groups._get_expert_model_parallel_world_size() == 1:
  180. # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
  181. # This is since we are going to activate drop_tokens() to drop duplicate tokens.
  182. tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
  183. new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
  184. # Make sure the capacity value does not exceed the number of tokens.
  185. capacity = min(new_capacity, torch.tensor(mask1.size(0)).to(new_capacity.device))
  186. # Compute l_aux
  187. me = torch.mean(gates, dim=0)
  188. ce = torch.mean(mask1.float(), dim=0)
  189. l_aux = torch.sum(me * ce) * num_experts
  190. # Random Token Selection
  191. if use_rts:
  192. uniform = exp_selection_uniform_map.get(logits.device)
  193. if uniform is None:
  194. uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),
  195. high=torch.tensor(1.0, device=logits.device)).rsample
  196. exp_selection_uniform_map[logits.device] = uniform
  197. mask1_rand = mask1 * uniform(mask1.shape)
  198. else:
  199. mask1_rand = mask1
  200. assert logits.shape[
  201. 0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
  202. top_idx = _top_idx(mask1_rand, capacity)
  203. new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
  204. mask1 = new_mask1
  205. if use_tutel:
  206. # Tutel doesn't support index values masked with zero
  207. # so we need to replace masked indices with -1
  208. indices_mask = mask1.sum(dim=1) * num_experts - 1
  209. indices1_s = torch.min(indices1_s, indices_mask)
  210. # Compute locations in capacity buffer
  211. if use_tutel:
  212. locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
  213. else:
  214. locations1 = torch.cumsum(mask1, dim=0) - 1
  215. if use_tutel:
  216. gates1_s = (gates * mask1).sum(dim=1)
  217. locations1_s = torch.sum(locations1 * mask1, dim=1)
  218. return l_aux, capacity, num_experts, [
  219. indices1_s,
  220. ], [
  221. locations1_s,
  222. ], [
  223. gates1_s,
  224. ], exp_counts
  225. # Store the capacity location for each token
  226. locations1_s = torch.sum(locations1 * mask1, dim=1)
  227. # Normalize gate probabilities
  228. mask1_float = mask1.float()
  229. gates = gates * mask1_float
  230. locations1_sc = _one_hot_to_float(locations1_s, capacity)
  231. combine_weights = einsum("se,sc->sec", gates, locations1_sc)
  232. dispatch_mask = combine_weights.bool()
  233. return l_aux, combine_weights, dispatch_mask, exp_counts
  234. def top2gating(logits: Tensor,
  235. capacity_factor: float,
  236. min_capacity: int,
  237. drop_tokens: bool = True,
  238. ep_group: Union[torch.distributed.ProcessGroup, None] = None,
  239. top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  240. """Implements Top2Gating on logits."""
  241. # everything is in fp32 in this function
  242. gates = F.softmax(logits, dim=1)
  243. # Create a mask for 1st's expert per token
  244. indices1_s = torch.argmax(gates, dim=1)
  245. num_experts = int(gates.shape[1])
  246. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  247. if top2_2nd_expert_sampling:
  248. # Create a mask for 2nd's expert per token using Gumbel-max trick
  249. # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
  250. logits += gumbel_rsample(logits.shape, device=logits.device)
  251. # Replace top-expert with min value
  252. logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
  253. indices2_s = torch.argmax(logits_except1, dim=1)
  254. mask2 = F.one_hot(indices2_s, num_classes=num_experts)
  255. # Compute locations in capacity buffer
  256. locations1 = torch.cumsum(mask1, dim=0) - 1
  257. locations2 = torch.cumsum(mask2, dim=0) - 1
  258. # Update 2nd's location by accounting for locations of 1st
  259. locations2 += torch.sum(mask1, dim=0, keepdim=True)
  260. # Compute l_aux
  261. me = torch.mean(gates, dim=0)
  262. ce = torch.mean(mask1.float(), dim=0)
  263. l_aux = torch.mean(me * ce) * num_experts * num_experts
  264. # gating decisions
  265. exp_counts = torch.sum(mask1 + mask2, dim=0)
  266. if drop_tokens:
  267. # Calculate configured capacity and remove locations outside capacity from mask
  268. capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
  269. mask1 *= torch.lt(locations1, capacity)
  270. mask2 *= torch.lt(locations2, capacity)
  271. else:
  272. # Do not drop tokens - set capacity according to current expert assignments
  273. new_capacity = torch.max(exp_counts)
  274. if ep_group is not None:
  275. dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
  276. if groups._get_expert_model_parallel_world_size() == 1:
  277. # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
  278. # This is since we are going to activate drop_tokens() to drop duplicate tokens.
  279. tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
  280. new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
  281. capacity = new_capacity
  282. # Store the capacity location for each token
  283. locations1_s = torch.sum(locations1 * mask1, dim=1)
  284. locations2_s = torch.sum(locations2 * mask2, dim=1)
  285. # Normalize gate probabilities
  286. mask1_float = mask1.float()
  287. mask2_float = mask2.float()
  288. gates1_s = einsum("se,se->s", gates, mask1_float)
  289. gates2_s = einsum("se,se->s", gates, mask2_float)
  290. denom_s = gates1_s + gates2_s
  291. # Avoid divide-by-zero
  292. denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
  293. gates1_s /= denom_s
  294. gates2_s /= denom_s
  295. # Calculate combine_weights and dispatch_mask
  296. gates1 = einsum("s,se->se", gates1_s, mask1_float)
  297. gates2 = einsum("s,se->se", gates2_s, mask2_float)
  298. locations1_sc = _one_hot_to_float(locations1_s, capacity)
  299. locations2_sc = _one_hot_to_float(locations2_s, capacity)
  300. combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
  301. combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
  302. combine_weights = combine1_sec + combine2_sec
  303. dispatch_mask = combine_weights.bool()
  304. return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu')
  305. class TopKGate(Module):
  306. """Gate module which implements Top2Gating as described in Gshard_.
  307. ::
  308. gate = TopKGate(model_dim, num_experts)
  309. l_aux, combine_weights, dispatch_mask = gate(input)
  310. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  311. Args:
  312. model_dim (int):
  313. size of model embedding dimension
  314. num_experts (int):
  315. number of experts in model
  316. """
  317. wg: torch.nn.Linear
  318. def __init__(self,
  319. model_dim: int,
  320. num_experts: int,
  321. k: int = 1,
  322. capacity_factor: float = 1.0,
  323. eval_capacity_factor: float = 1.0,
  324. min_capacity: int = 8,
  325. noisy_gate_policy: Optional[str] = None,
  326. drop_tokens: bool = True,
  327. use_rts: bool = True,
  328. ep_group: Union[torch.distributed.ProcessGroup, None] = None,
  329. top2_2nd_expert_sampling: bool = True) -> None:
  330. super().__init__()
  331. # Only top-1 and top-2 are supported at the moment.
  332. if k != 1 and k != 2:
  333. raise ValueError('Only top-1 and top-2 gatings are supported.')
  334. self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
  335. self.ep_group = ep_group
  336. self.k = k
  337. self.capacity_factor = capacity_factor
  338. self.eval_capacity_factor = eval_capacity_factor
  339. self.min_capacity = min_capacity
  340. self.noisy_gate_policy = noisy_gate_policy
  341. self.timers = SynchronizedWallClockTimer()
  342. self.wall_clock_breakdown = False
  343. self.gate_time = 0.0
  344. self.drop_tokens = drop_tokens
  345. self.use_rts = use_rts
  346. self.top2_2nd_expert_sampling = top2_2nd_expert_sampling
  347. def _set_ep_group(self, ep_group):
  348. assert self.ep_group is None, f'Attempting to override an existing ep_group'
  349. self.ep_group = ep_group
  350. def forward(self,
  351. input: torch.Tensor,
  352. used_token: torch.Tensor = None,
  353. use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
  354. if self.wall_clock_breakdown:
  355. self.timers(TOPK_GATE_TIMER).start()
  356. input_fp32 = input.float()
  357. # input jittering
  358. if self.noisy_gate_policy == 'Jitter' and self.training:
  359. input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
  360. logits = torch.nn.functional.linear(input_fp32, weight=self.wg.weight.float(), bias=None)
  361. if self.k == 1:
  362. gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
  363. self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
  364. self.drop_tokens, self.use_rts, self.ep_group, use_tutel)
  365. else:
  366. gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
  367. self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
  368. if self.wall_clock_breakdown:
  369. self.timers(TOPK_GATE_TIMER).stop()
  370. self.gate_time = self.timers(TOPK_GATE_TIMER).elapsed(reset=False)
  371. return gate_output
  372. class MOELayer(Base):
  373. """MOELayer module which implements MixtureOfExperts as described in Gshard_.
  374. ::
  375. gate = TopKGate(model_dim, num_experts)
  376. moe = MOELayer(gate, expert)
  377. output = moe(input)
  378. l_aux = moe.l_aux
  379. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  380. Args:
  381. gate (torch.nn.Module):
  382. gate network
  383. expert (torch.nn.Module):
  384. expert network
  385. """
  386. def __init__(self,
  387. gate: Module,
  388. experts: Module,
  389. ep_group_name,
  390. ep_size,
  391. num_local_experts: int,
  392. use_tutel: bool = False) -> None:
  393. super().__init__()
  394. self.gate = gate
  395. self.experts = experts
  396. self.ep_group = None
  397. self.ep_size = ep_size
  398. self.ep_group_name = ep_group_name
  399. self.num_local_experts = num_local_experts
  400. self.time_falltoall = 0.0
  401. self.time_salltoall = 0.0
  402. self.time_moe = 0.0
  403. self.timers = SynchronizedWallClockTimer()
  404. self.wall_clock_breakdown = False
  405. self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1
  406. if self.use_tutel:
  407. logger.info('Using Tutel optimizations.')
  408. elif use_tutel and not TUTEL_INSTALLED:
  409. logger.warning("Tutel optimization requested but not installed. "
  410. "Proceeding without Tutel.")
  411. elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
  412. logger.warning("To enable Tutel optimization, use top-1 instead of top-2 gate. "
  413. "Proceeding without Tutel.")
  414. def _set_ep_group(self, ep_group):
  415. self.ep_group = ep_group
  416. self.gate._set_ep_group(ep_group)
  417. def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
  418. if self.wall_clock_breakdown:
  419. self.timers(MOE_TIMER).start()
  420. # Implement Algorithm 2 from GShard paper.
  421. d_model = input[0].shape[-1]
  422. # Initial implementation -> Reshape into S tokens by dropping sequence dimension.
  423. # Reshape into G groups so that each group can distribute tokens equally
  424. # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
  425. reshaped_input = input[0].reshape(-1, d_model)
  426. if self.use_tutel:
  427. self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
  428. S, M = reshaped_input.size(0), reshaped_input.size(1)
  429. if not hasattr(self, '_tutel_dispatcher'):
  430. self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
  431. self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
  432. dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
  433. else:
  434. self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
  435. dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
  436. if self.wall_clock_breakdown:
  437. self.timers(FIRST_ALLTOALL_TIMER).start()
  438. if groups._get_expert_model_parallel_world_size() == 1:
  439. # If the non-expert is tensor-parallel, it will create
  440. # duplicate tokens on the tensor-parallel ranks.
  441. # Since our experts are not tensor-parallel, these duplicates
  442. # need to be dropped to ensure correctness.
  443. # this also doubles up as a communication optimization as we are
  444. # reducing the all-to-all communication volume.
  445. dispatched_input = drop_tokens(dispatched_input, dim=1)
  446. dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
  447. if self.wall_clock_breakdown:
  448. self.timers(FIRST_ALLTOALL_TIMER).stop()
  449. self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
  450. # Re-shape after all-to-all: ecm -> gecm
  451. dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
  452. expert_output = self.experts(dispatched_input)
  453. if self.wall_clock_breakdown:
  454. self.timers(SECOND_ALLTOALL_TIMER).start()
  455. expert_output = _AllToAll.apply(self.ep_group, expert_output)
  456. if self.wall_clock_breakdown:
  457. self.timers(SECOND_ALLTOALL_TIMER).stop()
  458. self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
  459. # Re-shape back: gecm -> ecm
  460. expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
  461. if groups._get_expert_model_parallel_world_size() == 1:
  462. # the dropped duplicate tokens need to be gathered on each
  463. # tensor parallel rank again for the tensor-parallel
  464. # non-expert of the next layer.
  465. expert_output = gather_tokens(expert_output, dim=1)
  466. if self.use_tutel:
  467. combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
  468. else:
  469. combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
  470. a = combined_output.reshape(input[0].shape)
  471. if self.wall_clock_breakdown:
  472. self.timers(MOE_TIMER).stop()
  473. self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False)
  474. return a