sharded_moe.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. The file has been adapted from two fairscale files:
  6. (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
  7. (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
  8. Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
  9. We retain the following license from the original files:
  10. """
  11. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  12. #
  13. # This source code is licensed under the BSD license found in the
  14. # LICENSE file in the root directory of this source tree.
  15. from deepspeed.utils.timer import SynchronizedWallClockTimer
  16. from deepspeed.utils import logger
  17. from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple
  18. import torch
  19. from torch import Tensor
  20. from torch.nn import Module
  21. import torch.nn.functional as F
  22. from deepspeed.utils import groups
  23. from .mappings import drop_tokens, gather_tokens
  24. if TYPE_CHECKING:
  25. Base = Module[Tensor]
  26. else:
  27. Base = Module
  28. TOPK_GATE_TIMER = 'topk_gate'
  29. MOE_TIMER = 'moe'
  30. FIRST_ALLTOALL_TIMER = '1st_a2a'
  31. SECOND_ALLTOALL_TIMER = '2nd_a2a'
  32. uniform_map: Dict[torch.device, Callable] = {}
  33. gumbel_map: Dict[torch.device, Callable] = {}
  34. exp_selection_uniform_map: Dict[torch.device, Callable] = {}
  35. try:
  36. # To enable Tutel MoE optimizations:
  37. # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
  38. from tutel import moe as tutel_moe
  39. TUTEL_INSTALLED = True
  40. except:
  41. # Fail silently so we don't spam logs unnecessarily if user isn't using tutel
  42. TUTEL_INSTALLED = False
  43. pass
  44. def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
  45. """
  46. Modified from switch transformer paper. mesh transformers
  47. Multiply values by a random number between 1-epsilon and 1+epsilon.
  48. Makes models more resilient to rounding errors introduced by bfloat16.
  49. This seems particularly important for logits.
  50. Args:
  51. x: a torch.tensor
  52. device: torch.device
  53. epsilon: a floating point value
  54. Returns:
  55. a jittered x.
  56. """
  57. if epsilon == 0:
  58. return x
  59. uniform = uniform_map.get(device)
  60. if uniform is None:
  61. uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),
  62. high=torch.tensor(1.0 + epsilon,
  63. device=device)).rsample # type: ignore
  64. uniform_map[device] = uniform
  65. return x * uniform(x.shape)
  66. def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
  67. gumbel = gumbel_map.get(device)
  68. if gumbel is None:
  69. one = torch.tensor(1.0, device=device)
  70. zero = torch.tensor(0.0, device=device)
  71. gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
  72. gumbel_map[device] = gumbel
  73. return gumbel(shape)
  74. from deepspeed import comm as dist
  75. # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
  76. # See https://arxiv.org/pdf/2006.16668.pdf for details.
  77. # Based on https://github.com/pytorch/pytorch/pull/40762
  78. class _AllToAll(torch.autograd.Function):
  79. @staticmethod
  80. def forward(
  81. ctx: Any,
  82. # TODO: replace with DS process group
  83. group: torch.distributed.ProcessGroup,
  84. input: Tensor) -> Tensor: # type: ignore
  85. ctx.group = group
  86. input = input.contiguous()
  87. output = torch.empty_like(input)
  88. dist.all_to_all_single(output, input, group=group)
  89. return output
  90. @staticmethod
  91. def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
  92. return (None, _AllToAll.apply(ctx.group, *grad_output))
  93. # einsum rewrites are on par or more performant
  94. # switch can be bubbled up in future
  95. USE_EINSUM = True
  96. # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
  97. # See https://arxiv.org/pdf/2006.16668.pdf for details.
  98. def einsum(rule, a, b):
  99. if USE_EINSUM:
  100. return torch.einsum(rule, a, b)
  101. elif rule == 's,se->se':
  102. return a.reshape(a.shape[0], -1) * b
  103. elif rule == 'se,sc->sec':
  104. return a.unsqueeze(2) * b.unsqueeze(1)
  105. elif rule == 'se,se->s':
  106. return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
  107. elif rule == 'sec,sm->ecm':
  108. s = a.shape[0]
  109. e = a.shape[1]
  110. c = a.shape[2]
  111. m = b.shape[1]
  112. return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
  113. elif rule == 'sec,ecm->sm':
  114. return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
  115. elif rule == 'ks,ksm->sm':
  116. k = b.shape[0]
  117. s = b.shape[1]
  118. m = b.shape[2]
  119. # [k, s] -> [s, k] -> [s, 1, k]
  120. a = a.t().unsqueeze(1)
  121. # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
  122. b = b.reshape(k, -1).t().reshape(s, m, k)
  123. # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
  124. return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
  125. else:
  126. return torch.einsum(rule, a, b)
  127. # The following functions are extracted and scripted
  128. # because otherwise during a torch.jit.trace, the non-Tensor
  129. # values used in the calculations get recorded as constants.
  130. # torch.jit.script coerces them into Tensors and preserves
  131. # their dynamic shapes. This enables ONNX export.
  132. # We can't script the entire top1gating function because it
  133. # includes stateful caching logic which is incompatible with ONNX.
  134. @torch.jit.script
  135. def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
  136. # gates has shape of SE
  137. num_tokens = gates.shape[0]
  138. num_experts = gates.shape[1]
  139. # to(torch.int64) works around a bug in torch.onnx.export:
  140. # it should cast k to int64 when converting torch.topk but it doesn't.
  141. capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
  142. if capacity < min_capacity:
  143. capacity = min_capacity.to(torch.int64)
  144. return capacity
  145. @torch.jit.script
  146. def _top_idx(source, k):
  147. return torch.topk(source, k=k, dim=0)[1]
  148. @torch.jit.script
  149. def _one_hot_to_float(x, num_classes):
  150. return F.one_hot(x, num_classes=num_classes).float()
  151. def top1gating(logits: Tensor,
  152. capacity_factor: float,
  153. min_capacity: int,
  154. used_token: Tensor = None,
  155. noisy_gate_policy: Optional[str] = None,
  156. drop_tokens: bool = True,
  157. use_rts: bool = True,
  158. use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  159. """Implements Top1Gating on logits."""
  160. if noisy_gate_policy == 'RSample':
  161. logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
  162. # everything is in fp32 in this function
  163. gates = F.softmax(logits, dim=1)
  164. capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
  165. # Create a mask for 1st's expert per token
  166. # noisy gating
  167. indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
  168. num_experts = int(gates.shape[1])
  169. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  170. # mask only used tokens
  171. if used_token is not None:
  172. mask1 = einsum("s,se->se", used_token, mask1)
  173. # gating decisions
  174. exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
  175. # if we don't want to drop any tokens
  176. if not drop_tokens:
  177. new_capacity = torch.max(exp_counts).to(logits.device)
  178. dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
  179. capacity = new_capacity
  180. # Compute l_aux
  181. me = torch.mean(gates, dim=0)
  182. ce = torch.mean(mask1.float(), dim=0)
  183. l_aux = torch.sum(me * ce) * num_experts
  184. # Random Token Selection
  185. if use_rts:
  186. uniform = exp_selection_uniform_map.get(logits.device)
  187. if uniform is None:
  188. uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),
  189. high=torch.tensor(1.0, device=logits.device)).rsample
  190. exp_selection_uniform_map[logits.device] = uniform
  191. mask1_rand = mask1 * uniform(mask1.shape)
  192. else:
  193. mask1_rand = mask1
  194. assert logits.shape[
  195. 0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
  196. top_idx = _top_idx(mask1_rand, capacity)
  197. new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
  198. mask1 = new_mask1
  199. if use_tutel:
  200. # Tutel doesn't support index values masked with zero
  201. # so we need to replace masked indices with -1
  202. indices_mask = mask1.sum(dim=1) * num_experts - 1
  203. indices1_s = torch.min(indices1_s, indices_mask)
  204. # Compute locations in capacity buffer
  205. if use_tutel:
  206. locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
  207. else:
  208. locations1 = torch.cumsum(mask1, dim=0) - 1
  209. if use_tutel:
  210. gates1_s = (gates * mask1).sum(dim=1)
  211. locations1_s = torch.sum(locations1 * mask1, dim=1)
  212. return l_aux, capacity, num_experts, [
  213. indices1_s,
  214. ], [
  215. locations1_s,
  216. ], [
  217. gates1_s,
  218. ], exp_counts
  219. # Store the capacity location for each token
  220. locations1_s = torch.sum(locations1 * mask1, dim=1)
  221. # Normalize gate probabilities
  222. mask1_float = mask1.float()
  223. gates = gates * mask1_float
  224. locations1_sc = _one_hot_to_float(locations1_s, capacity)
  225. combine_weights = einsum("se,sc->sec", gates, locations1_sc)
  226. dispatch_mask = combine_weights.bool()
  227. return l_aux, combine_weights, dispatch_mask, exp_counts
  228. def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  229. """Implements Top2Gating on logits."""
  230. # everything is in fp32 in this function
  231. gates = F.softmax(logits, dim=1)
  232. capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
  233. # Create a mask for 1st's expert per token
  234. indices1_s = torch.argmax(gates, dim=1)
  235. num_experts = int(gates.shape[1])
  236. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  237. # Create a mask for 2nd's expert per token using Gumbel-max trick
  238. # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
  239. logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
  240. # Replace top-expert with min value
  241. logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
  242. indices2_s = torch.argmax(logits_except1, dim=1)
  243. mask2 = F.one_hot(indices2_s, num_classes=num_experts)
  244. # Compute locations in capacity buffer
  245. locations1 = torch.cumsum(mask1, dim=0) - 1
  246. locations2 = torch.cumsum(mask2, dim=0) - 1
  247. # Update 2nd's location by accounting for locations of 1st
  248. locations2 += torch.sum(mask1, dim=0, keepdim=True)
  249. # gating decisions
  250. exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
  251. # Compute l_aux
  252. me = torch.mean(gates, dim=0)
  253. ce = torch.mean(mask1.float(), dim=0)
  254. l_aux = torch.mean(me * ce) * num_experts * num_experts
  255. # Remove locations outside capacity from mask
  256. mask1 *= torch.lt(locations1, capacity)
  257. mask2 *= torch.lt(locations2, capacity)
  258. # Store the capacity location for each token
  259. locations1_s = torch.sum(locations1 * mask1, dim=1)
  260. locations2_s = torch.sum(locations2 * mask2, dim=1)
  261. # Normalize gate probabilities
  262. mask1_float = mask1.float()
  263. mask2_float = mask2.float()
  264. gates1_s = einsum("se,se->s", gates, mask1_float)
  265. gates2_s = einsum("se,se->s", gates, mask2_float)
  266. denom_s = gates1_s + gates2_s
  267. # Avoid divide-by-zero
  268. denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
  269. gates1_s /= denom_s
  270. gates2_s /= denom_s
  271. # Calculate combine_weights and dispatch_mask
  272. gates1 = einsum("s,se->se", gates1_s, mask1_float)
  273. gates2 = einsum("s,se->se", gates2_s, mask2_float)
  274. locations1_sc = _one_hot_to_float(locations1_s, capacity)
  275. locations2_sc = _one_hot_to_float(locations2_s, capacity)
  276. combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
  277. combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
  278. combine_weights = combine1_sec + combine2_sec
  279. dispatch_mask = combine_weights.bool()
  280. return l_aux, combine_weights, dispatch_mask, exp_counts
  281. class TopKGate(Module):
  282. """Gate module which implements Top2Gating as described in Gshard_.
  283. ::
  284. gate = TopKGate(model_dim, num_experts)
  285. l_aux, combine_weights, dispatch_mask = gate(input)
  286. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  287. Args:
  288. model_dim (int):
  289. size of model embedding dimension
  290. num_experts (ints):
  291. number of experts in model
  292. """
  293. wg: torch.nn.Linear
  294. def __init__(self,
  295. model_dim: int,
  296. num_experts: int,
  297. k: int = 1,
  298. capacity_factor: float = 1.0,
  299. eval_capacity_factor: float = 1.0,
  300. min_capacity: int = 8,
  301. noisy_gate_policy: Optional[str] = None,
  302. drop_tokens: bool = True,
  303. use_rts: bool = True) -> None:
  304. super().__init__()
  305. # Only top-1 and top-2 are supported at the moment.
  306. if k != 1 and k != 2:
  307. raise ValueError('Only top-1 and top-2 gatings are supported.')
  308. self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
  309. self.k = k
  310. self.capacity_factor = capacity_factor
  311. self.eval_capacity_factor = eval_capacity_factor
  312. self.min_capacity = min_capacity
  313. self.noisy_gate_policy = noisy_gate_policy
  314. self.timers = SynchronizedWallClockTimer()
  315. self.wall_clock_breakdown = False
  316. self.gate_time = 0.0
  317. self.drop_tokens = drop_tokens
  318. self.use_rts = use_rts
  319. def forward(self,
  320. input: torch.Tensor,
  321. used_token: torch.Tensor = None,
  322. use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
  323. if self.wall_clock_breakdown:
  324. self.timers(TOPK_GATE_TIMER).start()
  325. if self.wg.weight.dtype != torch.float32:
  326. self.wg = self.wg.float()
  327. input_fp32 = input.float()
  328. # input jittering
  329. if self.noisy_gate_policy == 'Jitter' and self.training:
  330. input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
  331. logits = self.wg(input_fp32)
  332. if self.k == 1:
  333. gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
  334. self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
  335. self.drop_tokens, self.use_rts, use_tutel)
  336. else:
  337. gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
  338. self.min_capacity)
  339. if self.wall_clock_breakdown:
  340. self.timers(TOPK_GATE_TIMER).stop()
  341. self.gate_time = self.timers(TOPK_GATE_TIMER).elapsed(reset=False)
  342. return gate_output
  343. class MOELayer(Base):
  344. """MOELayer module which implements MixtureOfExperts as described in Gshard_.
  345. ::
  346. gate = TopKGate(model_dim, num_experts)
  347. moe = MOELayer(gate, expert)
  348. output = moe(input)
  349. l_aux = moe.l_aux
  350. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  351. Args:
  352. gate (torch.nn.Module):
  353. gate network
  354. expert (torch.nn.Module):
  355. expert network
  356. """
  357. def __init__(self,
  358. gate: Module,
  359. experts: Module,
  360. ep_group_name,
  361. ep_size,
  362. num_local_experts: int,
  363. use_tutel: bool = False) -> None:
  364. super().__init__()
  365. self.gate = gate
  366. self.experts = experts
  367. self.ep_group = None
  368. self.ep_size = ep_size
  369. self.ep_group_name = ep_group_name
  370. self.num_local_experts = num_local_experts
  371. self.time_falltoall = 0.0
  372. self.time_salltoall = 0.0
  373. self.time_moe = 0.0
  374. self.timers = SynchronizedWallClockTimer()
  375. self.wall_clock_breakdown = False
  376. self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1
  377. if self.use_tutel:
  378. logger.info('Using Tutel optimizations.')
  379. elif use_tutel and not TUTEL_INSTALLED:
  380. logger.warning("Tutel optimization requested but not installed. "
  381. "Proceeding without Tutel.")
  382. elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
  383. logger.warning("To enable Tutel optimization, use top-1 instead of top-2 gate. "
  384. "Proceeding without Tutel.")
  385. def _set_ep_group(self, ep_group):
  386. self.ep_group = ep_group
  387. def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
  388. if self.wall_clock_breakdown:
  389. self.timers(MOE_TIMER).start()
  390. # Implement Algorithm 2 from GShard paper.
  391. d_model = input[0].shape[-1]
  392. # Initial implementation -> Reshape into S tokens by dropping sequence dimension.
  393. # Reshape into G groups so that each group can distribute tokens equally
  394. # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
  395. reshaped_input = input[0].reshape(-1, d_model)
  396. if self.use_tutel:
  397. self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
  398. S, M = reshaped_input.size(0), reshaped_input.size(1)
  399. if not hasattr(self, '_tutel_dispatcher'):
  400. self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
  401. self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
  402. dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
  403. else:
  404. self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
  405. dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
  406. if self.wall_clock_breakdown:
  407. self.timers(FIRST_ALLTOALL_TIMER).start()
  408. if groups._get_expert_model_parallel_world_size() == 1:
  409. # If the non-expert is tensor-parallel, it will create
  410. # duplicate tokens on the tensor-parallel ranks.
  411. # Since our experts are not tensor-parallel, these duplicates
  412. # need to be dropped to ensure correctness.
  413. # this also doubles up as a communication optimization as we are
  414. # reducing the all-to-all communication volume.
  415. dispatched_input = drop_tokens(dispatched_input, dim=1)
  416. dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
  417. if self.wall_clock_breakdown:
  418. self.timers(FIRST_ALLTOALL_TIMER).stop()
  419. self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
  420. # Re-shape after all-to-all: ecm -> gecm
  421. dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
  422. expert_output = self.experts(dispatched_input)
  423. if self.wall_clock_breakdown:
  424. self.timers(SECOND_ALLTOALL_TIMER).start()
  425. expert_output = _AllToAll.apply(self.ep_group, expert_output)
  426. if self.wall_clock_breakdown:
  427. self.timers(SECOND_ALLTOALL_TIMER).stop()
  428. self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
  429. # Re-shape back: gecm -> ecm
  430. expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
  431. if groups._get_expert_model_parallel_world_size() == 1:
  432. # the dropped duplicate tokens need to be gathered on each
  433. # tensor parallel rank again for the tensor-parallel
  434. # non-expert of the next layer.
  435. expert_output = gather_tokens(expert_output, dim=1)
  436. if self.use_tutel:
  437. combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
  438. else:
  439. combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
  440. a = combined_output.reshape(input[0].shape)
  441. if self.wall_clock_breakdown:
  442. self.timers(MOE_TIMER).stop()
  443. self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False)
  444. return a