sharded_moe.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. '''
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. '''
  4. # The file has been adapted from two fairscale files:
  5. # (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
  6. # (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
  7. # Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
  8. # We retain the following license from the original files:
  9. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  10. #
  11. # This source code is licensed under the BSD license found in the
  12. # LICENSE file in the root directory of this source tree.
  13. from deepspeed.utils.timer import SynchronizedWallClockTimer
  14. from deepspeed.utils import logger
  15. from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple
  16. import torch
  17. from torch import Tensor
  18. from torch.nn import Module
  19. import torch.nn.functional as F
  20. from deepspeed.utils import groups
  21. from .mappings import drop_tokens, gather_tokens
  22. if TYPE_CHECKING:
  23. Base = Module[Tensor]
  24. else:
  25. Base = Module
  26. uniform_map: Dict[torch.device, Callable] = {}
  27. gumbel_map: Dict[torch.device, Callable] = {}
  28. exp_selection_uniform_map: Dict[torch.device, Callable] = {}
  29. try:
  30. # To enable Tutel MoE optimizations:
  31. # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
  32. from tutel import moe as tutel_moe
  33. TUTEL_INSTALLED = True
  34. except:
  35. # Fail silently so we don't spam logs unnecessarily if user isn't using tutel
  36. TUTEL_INSTALLED = False
  37. pass
  38. def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
  39. """
  40. Modified from switch transformer paper. mesh transformers
  41. Multiply values by a random number between 1-epsilon and 1+epsilon.
  42. Makes models more resilient to rounding errors introduced by bfloat16.
  43. This seems particularly important for logits.
  44. Args:
  45. x: a torch.tensor
  46. device: torch.device
  47. epsilon: a floating point value
  48. Returns:
  49. a jittered x.
  50. """
  51. if epsilon == 0:
  52. return x
  53. uniform = uniform_map.get(device)
  54. if uniform is None:
  55. uniform = torch.distributions.uniform.Uniform(
  56. low=torch.tensor(1.0 - epsilon,
  57. device=device),
  58. high=torch.tensor(1.0 + epsilon,
  59. device=device)).rsample # type: ignore
  60. uniform_map[device] = uniform
  61. return x * uniform(x.shape)
  62. def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
  63. gumbel = gumbel_map.get(device)
  64. if gumbel is None:
  65. one = torch.tensor(1.0, device=device)
  66. zero = torch.tensor(0.0, device=device)
  67. gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
  68. gumbel_map[device] = gumbel
  69. return gumbel(shape)
  70. from deepspeed import comm as dist
  71. # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
  72. # See https://arxiv.org/pdf/2006.16668.pdf for details.
  73. # Based on https://github.com/pytorch/pytorch/pull/40762
  74. class _AllToAll(torch.autograd.Function):
  75. @staticmethod
  76. def forward(
  77. ctx: Any,
  78. # TODO: replace with DS process group
  79. group: torch.distributed.ProcessGroup,
  80. input: Tensor) -> Tensor: # type: ignore
  81. ctx.group = group
  82. input = input.contiguous()
  83. output = torch.empty_like(input)
  84. dist.all_to_all_single(output, input, group=group)
  85. return output
  86. @staticmethod
  87. def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
  88. return (None, _AllToAll.apply(ctx.group, *grad_output))
  89. # einsum rewrites are on par or more performant
  90. # switch can be bubbled up in future
  91. USE_EINSUM = True
  92. # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
  93. # See https://arxiv.org/pdf/2006.16668.pdf for details.
  94. def einsum(rule, a, b):
  95. if USE_EINSUM:
  96. return torch.einsum(rule, a, b)
  97. elif rule == 's,se->se':
  98. return a.reshape(a.shape[0], -1) * b
  99. elif rule == 'se,sc->sec':
  100. return a.unsqueeze(2) * b.unsqueeze(1)
  101. elif rule == 'se,se->s':
  102. return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
  103. elif rule == 'sec,sm->ecm':
  104. s = a.shape[0]
  105. e = a.shape[1]
  106. c = a.shape[2]
  107. m = b.shape[1]
  108. return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
  109. elif rule == 'sec,ecm->sm':
  110. return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
  111. elif rule == 'ks,ksm->sm':
  112. k = b.shape[0]
  113. s = b.shape[1]
  114. m = b.shape[2]
  115. # [k, s] -> [s, k] -> [s, 1, k]
  116. a = a.t().unsqueeze(1)
  117. # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
  118. b = b.reshape(k, -1).t().reshape(s, m, k)
  119. # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
  120. return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
  121. else:
  122. return torch.einsum(rule, a, b)
  123. # The following functions are extracted and scripted
  124. # because otherwise during a torch.jit.trace, the non-Tensor
  125. # values used in the calculations get recorded as constants.
  126. # torch.jit.script coerces them into Tensors and preserves
  127. # their dynamic shapes. This enables ONNX export.
  128. # We can't script the entire top1gating function because it
  129. # includes stateful caching logic which is incompatible with ONNX.
  130. @torch.jit.script
  131. def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
  132. # gates has shape of SE
  133. num_tokens = gates.shape[0]
  134. num_experts = gates.shape[1]
  135. # to(torch.int64) works around a bug in torch.onnx.export:
  136. # it should cast k to int64 when converting torch.topk but it doesn't.
  137. capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
  138. if capacity < min_capacity:
  139. capacity = min_capacity.to(torch.int64)
  140. return capacity
  141. @torch.jit.script
  142. def _top_idx(source, k):
  143. return torch.topk(source, k=k, dim=0)[1]
  144. @torch.jit.script
  145. def _one_hot_to_float(x, num_classes):
  146. return F.one_hot(x, num_classes=num_classes).float()
  147. def top1gating(logits: Tensor,
  148. capacity_factor: float,
  149. min_capacity: int,
  150. used_token: Tensor = None,
  151. noisy_gate_policy: Optional[str] = None,
  152. drop_tokens: bool = True,
  153. use_rts: bool = True,
  154. use_tutel: bool = False) -> Tuple[Tensor,
  155. Tensor,
  156. Tensor,
  157. Tensor]:
  158. """Implements Top1Gating on logits."""
  159. if noisy_gate_policy == 'RSample':
  160. logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
  161. # everything is in fp32 in this function
  162. gates = F.softmax(logits, dim=1)
  163. capacity = _capacity(gates,
  164. torch.tensor(capacity_factor),
  165. torch.tensor(min_capacity))
  166. # Create a mask for 1st's expert per token
  167. # noisy gating
  168. indices1_s = torch.argmax(
  169. logits_w_noise if noisy_gate_policy == 'RSample' else gates,
  170. dim=1)
  171. num_experts = int(gates.shape[1])
  172. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  173. # mask only used tokens
  174. if used_token is not None:
  175. mask1 = einsum("s,se->se", used_token, mask1)
  176. # gating decisions
  177. exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
  178. # if we don't want to drop any tokens
  179. if not drop_tokens:
  180. new_capacity = torch.max(exp_counts).to(logits.device)
  181. dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
  182. capacity = new_capacity
  183. # Compute l_aux
  184. me = torch.mean(gates, dim=0)
  185. ce = torch.mean(mask1.float(), dim=0)
  186. l_aux = torch.sum(me * ce) * num_experts
  187. # Random Token Selection
  188. if use_rts:
  189. uniform = exp_selection_uniform_map.get(logits.device)
  190. if uniform is None:
  191. uniform = torch.distributions.uniform.Uniform(
  192. low=torch.tensor(0.0,
  193. device=logits.device),
  194. high=torch.tensor(1.0,
  195. device=logits.device)).rsample
  196. exp_selection_uniform_map[logits.device] = uniform
  197. mask1_rand = mask1 * uniform(mask1.shape)
  198. else:
  199. mask1_rand = mask1
  200. assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
  201. top_idx = _top_idx(mask1_rand, capacity)
  202. new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
  203. mask1 = new_mask1
  204. if use_tutel:
  205. # Tutel doesn't support index values masked with zero
  206. # so we need to replace masked indices with -1
  207. indices_mask = mask1.sum(dim=1) * num_experts - 1
  208. indices1_s = torch.min(indices1_s, indices_mask)
  209. # Compute locations in capacity buffer
  210. if use_tutel:
  211. locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
  212. else:
  213. locations1 = torch.cumsum(mask1, dim=0) - 1
  214. if use_tutel:
  215. gates1_s = (gates * mask1).sum(dim=1)
  216. locations1_s = torch.sum(locations1 * mask1, dim=1)
  217. return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,], exp_counts
  218. # Store the capacity location for each token
  219. locations1_s = torch.sum(locations1 * mask1, dim=1)
  220. # Normalize gate probabilities
  221. mask1_float = mask1.float()
  222. gates = gates * mask1_float
  223. locations1_sc = _one_hot_to_float(locations1_s, capacity)
  224. combine_weights = einsum("se,sc->sec", gates, locations1_sc)
  225. dispatch_mask = combine_weights.bool()
  226. return l_aux, combine_weights, dispatch_mask, exp_counts
  227. def top2gating(logits: Tensor,
  228. capacity_factor: float,
  229. min_capacity: int) -> Tuple[Tensor,
  230. Tensor,
  231. Tensor,
  232. Tensor]:
  233. """Implements Top2Gating on logits."""
  234. # everything is in fp32 in this function
  235. gates = F.softmax(logits, dim=1)
  236. capacity = _capacity(gates,
  237. torch.tensor(capacity_factor * 2),
  238. torch.tensor(min_capacity))
  239. # Create a mask for 1st's expert per token
  240. indices1_s = torch.argmax(gates, dim=1)
  241. num_experts = int(gates.shape[1])
  242. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  243. # Create a mask for 2nd's expert per token using Gumbel-max trick
  244. # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
  245. logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
  246. # Replace top-expert with min value
  247. logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
  248. indices2_s = torch.argmax(logits_except1, dim=1)
  249. mask2 = F.one_hot(indices2_s, num_classes=num_experts)
  250. # Compute locations in capacity buffer
  251. locations1 = torch.cumsum(mask1, dim=0) - 1
  252. locations2 = torch.cumsum(mask2, dim=0) - 1
  253. # Update 2nd's location by accounting for locations of 1st
  254. locations2 += torch.sum(mask1, dim=0, keepdim=True)
  255. # gating decisions
  256. exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
  257. # Compute l_aux
  258. me = torch.mean(gates, dim=0)
  259. ce = torch.mean(mask1.float(), dim=0)
  260. l_aux = torch.mean(me * ce) * num_experts * num_experts
  261. # Remove locations outside capacity from mask
  262. mask1 *= torch.lt(locations1, capacity)
  263. mask2 *= torch.lt(locations2, capacity)
  264. # Store the capacity location for each token
  265. locations1_s = torch.sum(locations1 * mask1, dim=1)
  266. locations2_s = torch.sum(locations2 * mask2, dim=1)
  267. # Normalize gate probabilities
  268. mask1_float = mask1.float()
  269. mask2_float = mask2.float()
  270. gates1_s = einsum("se,se->s", gates, mask1_float)
  271. gates2_s = einsum("se,se->s", gates, mask2_float)
  272. denom_s = gates1_s + gates2_s
  273. # Avoid divide-by-zero
  274. denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
  275. gates1_s /= denom_s
  276. gates2_s /= denom_s
  277. # Calculate combine_weights and dispatch_mask
  278. gates1 = einsum("s,se->se", gates1_s, mask1_float)
  279. gates2 = einsum("s,se->se", gates2_s, mask2_float)
  280. locations1_sc = _one_hot_to_float(locations1_s, capacity)
  281. locations2_sc = _one_hot_to_float(locations2_s, capacity)
  282. combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
  283. combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
  284. combine_weights = combine1_sec + combine2_sec
  285. dispatch_mask = combine_weights.bool()
  286. return l_aux, combine_weights, dispatch_mask, exp_counts
  287. class TopKGate(Module):
  288. """Gate module which implements Top2Gating as described in Gshard_.
  289. ::
  290. gate = TopKGate(model_dim, num_experts)
  291. l_aux, combine_weights, dispatch_mask = gate(input)
  292. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  293. Args:
  294. model_dim (int):
  295. size of model embedding dimension
  296. num_experts (ints):
  297. number of experts in model
  298. """
  299. wg: torch.nn.Linear
  300. def __init__(self,
  301. model_dim: int,
  302. num_experts: int,
  303. k: int = 1,
  304. capacity_factor: float = 1.0,
  305. eval_capacity_factor: float = 1.0,
  306. min_capacity: int = 8,
  307. noisy_gate_policy: Optional[str] = None,
  308. drop_tokens: bool = True,
  309. use_rts: bool = True) -> None:
  310. super().__init__()
  311. # Only top-1 and top-2 are supported at the moment.
  312. if k != 1 and k != 2:
  313. raise ValueError('Only top-1 and top-2 gatings are supported.')
  314. self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
  315. self.k = k
  316. self.capacity_factor = capacity_factor
  317. self.eval_capacity_factor = eval_capacity_factor
  318. self.min_capacity = min_capacity
  319. self.noisy_gate_policy = noisy_gate_policy
  320. self.timers = SynchronizedWallClockTimer()
  321. self.wall_clock_breakdown = False
  322. self.gate_time = 0.0
  323. self.drop_tokens = drop_tokens
  324. self.use_rts = use_rts
  325. def forward(
  326. self,
  327. input: torch.Tensor,
  328. used_token: torch.Tensor = None,
  329. use_tutel: bool = False) -> Tuple[Tensor,
  330. Tensor,
  331. Tensor]: # type: ignore
  332. if self.wall_clock_breakdown:
  333. self.timers('TopKGate').start()
  334. if self.wg.weight.dtype != torch.float32:
  335. self.wg = self.wg.float()
  336. input_fp32 = input.float()
  337. # input jittering
  338. if self.noisy_gate_policy == 'Jitter' and self.training:
  339. input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
  340. logits = self.wg(input_fp32)
  341. if self.k == 1:
  342. gate_output = top1gating(
  343. logits,
  344. self.capacity_factor if self.training else self.eval_capacity_factor,
  345. self.min_capacity,
  346. used_token,
  347. self.noisy_gate_policy if self.training else None,
  348. self.drop_tokens,
  349. self.use_rts,
  350. use_tutel)
  351. else:
  352. gate_output = top2gating(
  353. logits,
  354. self.capacity_factor if self.training else self.eval_capacity_factor,
  355. self.min_capacity)
  356. if self.wall_clock_breakdown:
  357. self.timers('TopKGate').stop()
  358. self.gate_time = self.timers('TopKGate').elapsed(reset=False)
  359. return gate_output
  360. class MOELayer(Base):
  361. """MOELayer module which implements MixtureOfExperts as described in Gshard_.
  362. ::
  363. gate = TopKGate(model_dim, num_experts)
  364. moe = MOELayer(gate, expert)
  365. output = moe(input)
  366. l_aux = moe.l_aux
  367. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  368. Args:
  369. gate (torch.nn.Module):
  370. gate network
  371. expert (torch.nn.Module):
  372. expert network
  373. """
  374. def __init__(self,
  375. gate: Module,
  376. experts: Module,
  377. ep_group_name,
  378. ep_size,
  379. num_local_experts: int,
  380. use_tutel: bool = False) -> None:
  381. super().__init__()
  382. self.gate = gate
  383. self.experts = experts
  384. self.ep_group = None
  385. self.ep_size = ep_size
  386. self.ep_group_name = ep_group_name
  387. self.num_local_experts = num_local_experts
  388. self.time_falltoall = 0.0
  389. self.time_salltoall = 0.0
  390. self.time_moe = 0.0
  391. self.timers = SynchronizedWallClockTimer()
  392. self.wall_clock_breakdown = False
  393. self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1
  394. if self.use_tutel:
  395. logger.info('Using Tutel optimizations.')
  396. elif use_tutel and not TUTEL_INSTALLED:
  397. logger.warning("Tutel optimization requested but not installed. "
  398. "Proceeding without Tutel.")
  399. elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
  400. logger.warning(
  401. "To enable Tutel optimization, use top-1 instead of top-2 gate. "
  402. "Proceeding without Tutel.")
  403. def _set_ep_group(self, ep_group):
  404. self.ep_group = ep_group
  405. def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
  406. if self.wall_clock_breakdown:
  407. self.timers('moe').start()
  408. # Implement Algorithm 2 from GShard paper.
  409. d_model = input[0].shape[-1]
  410. # Initial implementation -> Reshape into S tokens by dropping sequence dimension.
  411. # Reshape into G groups so that each group can distribute tokens equally
  412. # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
  413. reshaped_input = input[0].reshape(-1, d_model)
  414. if self.use_tutel:
  415. self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
  416. S, M = reshaped_input.size(0), reshaped_input.size(1)
  417. if not hasattr(self, '_tutel_dispatcher'):
  418. self._tutel_dispatcher = tutel_moe.fast_dispatcher(
  419. E,
  420. C,
  421. M,
  422. dispatch_dtype=reshaped_input.dtype)
  423. self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
  424. dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
  425. else:
  426. self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
  427. dispatched_input = einsum("sec,sm->ecm",
  428. dispatch_mask.type_as(input[0]),
  429. reshaped_input)
  430. if self.wall_clock_breakdown:
  431. self.timers('falltoall').start()
  432. if groups._get_expert_model_parallel_world_size() == 1:
  433. # If the non-expert is tensor-parallel, it will create
  434. # duplicate tokens on the tensor-parallel ranks.
  435. # Since our experts are not tensor-parallel, these duplicates
  436. # need to be dropped to ensure correctness.
  437. # this also doubles up as a communication optimization as we are
  438. # reducing the all-to-all communication volume.
  439. dispatched_input = drop_tokens(dispatched_input, dim=1)
  440. dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
  441. if self.wall_clock_breakdown:
  442. self.timers('falltoall').stop()
  443. self.time_falltoall = self.timers('falltoall').elapsed(reset=False)
  444. # Re-shape after all-to-all: ecm -> gecm
  445. dispatched_input = dispatched_input.reshape(self.ep_size,
  446. self.num_local_experts,
  447. -1,
  448. d_model)
  449. expert_output = self.experts(dispatched_input)
  450. if self.wall_clock_breakdown:
  451. self.timers('salltoall').start()
  452. expert_output = _AllToAll.apply(self.ep_group, expert_output)
  453. if self.wall_clock_breakdown:
  454. self.timers('salltoall').stop()
  455. self.time_salltoall = self.timers('salltoall').elapsed(reset=False)
  456. # Re-shape back: gecm -> ecm
  457. expert_output = expert_output.reshape(self.ep_size * self.num_local_experts,
  458. -1,
  459. d_model)
  460. if groups._get_expert_model_parallel_world_size() == 1:
  461. # the dropped duplicate tokens need to be gathered on each
  462. # tensor parallel rank again for the tensor-parallel
  463. # non-expert of the next layer.
  464. expert_output = gather_tokens(expert_output, dim=1)
  465. if self.use_tutel:
  466. combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
  467. else:
  468. combined_output = einsum("sec,ecm->sm",
  469. combine_weights.type_as(input[0]),
  470. expert_output)
  471. a = combined_output.reshape(input[0].shape)
  472. if self.wall_clock_breakdown:
  473. self.timers('moe').stop()
  474. self.time_moe = self.timers('moe').elapsed(reset=False)
  475. return a