sharded_moe.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. '''
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. '''
  4. # The file has been adapted from two fairscale files:
  5. # (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
  6. # (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
  7. # Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
  8. # We retain the following license from the original files:
  9. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  10. #
  11. # This source code is licensed under the BSD license found in the
  12. # LICENSE file in the root directory of this source tree.
  13. from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
  14. from deepspeed.utils import logger, log_dist
  15. from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union, cast
  16. import time
  17. from time import perf_counter
  18. import torch
  19. from torch import Tensor
  20. import torch.distributed as dist
  21. from torch.nn import Module, ModuleList
  22. import torch.nn.functional as F
  23. if TYPE_CHECKING:
  24. Base = Module[Tensor]
  25. else:
  26. Base = Module
  27. uniform_map: Dict[torch.device, Callable] = {}
  28. gumbel_map: Dict[torch.device, Callable] = {}
  29. exp_selection_uniform_map: Dict[torch.device, Callable] = {}
  30. try:
  31. # To enable Tutel MoE optimizations:
  32. # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
  33. from tutel import moe as tutel_moe
  34. TUTEL_INSTALLED = True
  35. except:
  36. # Fail silently so we don't spam logs unnecessarily if user isn't using tutel
  37. TUTEL_INSTALLED = False
  38. pass
  39. def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
  40. """
  41. Modified from switch transformer paper. mesh transformers
  42. Multiply values by a random number between 1-epsilon and 1+epsilon.
  43. Makes models more resilient to rounding errors introduced by bfloat16.
  44. This seems particularly important for logits.
  45. Args:
  46. x: a torch.tensor
  47. device: torch.device
  48. epsilon: a floating point value
  49. Returns:
  50. a jittered x.
  51. """
  52. if epsilon == 0:
  53. return x
  54. uniform = uniform_map.get(device)
  55. if uniform is None:
  56. uniform = torch.distributions.uniform.Uniform(
  57. low=torch.tensor(1.0 - epsilon,
  58. device=device),
  59. high=torch.tensor(1.0 + epsilon,
  60. device=device)).rsample # type: ignore
  61. uniform_map[device] = uniform
  62. return x * uniform(x.shape)
  63. def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
  64. gumbel = gumbel_map.get(device)
  65. if gumbel is None:
  66. one = torch.tensor(1.0, device=device)
  67. zero = torch.tensor(0.0, device=device)
  68. gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
  69. gumbel_map[device] = gumbel
  70. return gumbel(shape)
  71. # Based on https://github.com/pytorch/pytorch/pull/40762
  72. class _AllToAll(torch.autograd.Function):
  73. @staticmethod
  74. def forward(ctx: Any,
  75. group: dist.ProcessGroup,
  76. input: Tensor) -> Tensor: # type: ignore
  77. ctx.group = group
  78. input = input.contiguous()
  79. output = torch.empty_like(input)
  80. dist.all_to_all_single(output, input, group=group)
  81. return output
  82. @staticmethod
  83. def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
  84. return (None, _AllToAll.apply(ctx.group, *grad_output))
  85. # einsum rewrites are on par or more performant
  86. # switch can be bubbled up in future
  87. USE_EINSUM = True
  88. # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
  89. # See https://arxiv.org/pdf/2006.16668.pdf for details.
  90. def einsum(rule, a, b):
  91. if USE_EINSUM:
  92. return torch.einsum(rule, a, b)
  93. elif rule == 's,se->se':
  94. return a.reshape(a.shape[0], -1) * b
  95. elif rule == 'se,sc->sec':
  96. return a.unsqueeze(2) * b.unsqueeze(1)
  97. elif rule == 'se,se->s':
  98. return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
  99. elif rule == 'sec,sm->ecm':
  100. s = a.shape[0]
  101. e = a.shape[1]
  102. c = a.shape[2]
  103. m = b.shape[1]
  104. return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
  105. elif rule == 'sec,ecm->sm':
  106. return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
  107. elif rule == 'ks,ksm->sm':
  108. k = b.shape[0]
  109. s = b.shape[1]
  110. m = b.shape[2]
  111. # [k, s] -> [s, k] -> [s, 1, k]
  112. a = a.t().unsqueeze(1)
  113. # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
  114. b = b.reshape(k, -1).t().reshape(s, m, k)
  115. # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
  116. return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
  117. else:
  118. return torch.einsum(rule, a, b)
  119. # The following functions are extracted and scripted
  120. # because otherwise during a torch.jit.trace, the non-Tensor
  121. # values used in the calculations get recorded as constants.
  122. # torch.jit.script coerces them into Tensors and preserves
  123. # their dynamic shapes. This enables ONNX export.
  124. # We can't script the entire top1gating function because it
  125. # includes stateful caching logic which is incompatible with ONNX.
  126. @torch.jit.script
  127. def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
  128. # gates has shape of SE
  129. num_tokens = gates.shape[0]
  130. num_experts = gates.shape[1]
  131. # to(torch.int64) works around a bug in torch.onnx.export:
  132. # it should cast k to int64 when converting torch.topk but it doesn't.
  133. capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
  134. if capacity < min_capacity:
  135. capacity = min_capacity.to(torch.int64)
  136. return capacity
  137. @torch.jit.script
  138. def _top_idx(source, k):
  139. return torch.topk(source, k=k, dim=0)[1]
  140. @torch.jit.script
  141. def _one_hot_to_float(x, num_classes):
  142. return F.one_hot(x, num_classes=num_classes).float()
  143. def top1gating(logits: Tensor,
  144. capacity_factor: float,
  145. min_capacity: int,
  146. used_token: Tensor = None,
  147. noisy_gate_policy: Optional[str] = None,
  148. drop_tokens: bool = True,
  149. use_rts: bool = True,
  150. use_tutel: bool = False) -> Tuple[Tensor,
  151. Tensor,
  152. Tensor,
  153. Tensor]:
  154. """Implements Top1Gating on logits."""
  155. if noisy_gate_policy == 'RSample':
  156. logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
  157. # everything is in fp32 in this function
  158. gates = F.softmax(logits, dim=1)
  159. capacity = _capacity(gates,
  160. torch.tensor(capacity_factor),
  161. torch.tensor(min_capacity))
  162. # Create a mask for 1st's expert per token
  163. # noisy gating
  164. indices1_s = torch.argmax(
  165. logits_w_noise if noisy_gate_policy == 'RSample' else gates,
  166. dim=1)
  167. num_experts = int(gates.shape[1])
  168. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  169. # mask only used tokens
  170. if used_token is not None:
  171. mask1 = einsum("s,se->se", used_token, mask1)
  172. # gating decisions
  173. exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
  174. # if we don't want to drop any tokens
  175. if not drop_tokens:
  176. new_capacity = torch.max(exp_counts).to(logits.device)
  177. dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
  178. capacity = new_capacity
  179. # Compute l_aux
  180. me = torch.mean(gates, dim=0)
  181. ce = torch.mean(mask1.float(), dim=0)
  182. l_aux = torch.sum(me * ce) * num_experts
  183. # Random Token Selection
  184. if use_rts:
  185. uniform = exp_selection_uniform_map.get(logits.device)
  186. if uniform is None:
  187. uniform = torch.distributions.uniform.Uniform(
  188. low=torch.tensor(0.0,
  189. device=logits.device),
  190. high=torch.tensor(1.0,
  191. device=logits.device)).rsample
  192. exp_selection_uniform_map[logits.device] = uniform
  193. mask1_rand = mask1 * uniform(mask1.shape)
  194. assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
  195. top_idx = _top_idx(mask1_rand, capacity)
  196. new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
  197. mask1 = new_mask1
  198. if use_tutel:
  199. # Tutel doesn't support index values masked with zero
  200. # so we need to replace masked indices with -1
  201. indices_mask = mask1.sum(dim=1) * num_experts - 1
  202. indices1_s = torch.min(indices1_s, indices_mask)
  203. # Compute locations in capacity buffer
  204. if use_tutel:
  205. locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
  206. else:
  207. locations1 = torch.cumsum(mask1, dim=0) - 1
  208. if use_tutel:
  209. gates1_s = (gates * mask1).sum(dim=1)
  210. locations1_s = torch.sum(locations1 * mask1, dim=1)
  211. return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,], exp_counts
  212. # Store the capacity location for each token
  213. locations1_s = torch.sum(locations1 * mask1, dim=1)
  214. # Normalize gate probabilities
  215. mask1_float = mask1.float()
  216. gates = gates * mask1_float
  217. locations1_sc = _one_hot_to_float(locations1_s, capacity)
  218. combine_weights = einsum("se,sc->sec", gates, locations1_sc)
  219. dispatch_mask = combine_weights.bool()
  220. return l_aux, combine_weights, dispatch_mask, exp_counts
  221. def top2gating(logits: Tensor,
  222. capacity_factor: float) -> Tuple[Tensor,
  223. Tensor,
  224. Tensor,
  225. Tensor]:
  226. """Implements Top2Gating on logits."""
  227. # everything is in fp32 in this function
  228. gates = F.softmax(logits, dim=1)
  229. capacity = _capacity(gates,
  230. torch.tensor(capacity_factor * 2),
  231. torch.tensor(min_capacity))
  232. # Create a mask for 1st's expert per token
  233. indices1_s = torch.argmax(gates, dim=1)
  234. num_experts = int(gates.shape[1])
  235. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  236. # Create a mask for 2nd's expert per token using Gumbel-max trick
  237. # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
  238. logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
  239. # Replace top-expert with min value
  240. logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
  241. indices2_s = torch.argmax(logits_except1, dim=1)
  242. mask2 = F.one_hot(indices2_s, num_classes=num_experts)
  243. # Compute locations in capacity buffer
  244. locations1 = torch.cumsum(mask1, dim=0) - 1
  245. locations2 = torch.cumsum(mask2, dim=0) - 1
  246. # Update 2nd's location by accounting for locations of 1st
  247. locations2 += torch.sum(mask1, dim=0, keepdim=True)
  248. # gating decisions
  249. exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
  250. # Compute l_aux
  251. me = torch.mean(gates, dim=0)
  252. ce = torch.mean(mask1.float(), dim=0)
  253. l_aux = torch.mean(me * ce) * num_experts * num_experts
  254. # Remove locations outside capacity from mask
  255. mask1 *= torch.lt(locations1, capacity)
  256. mask2 *= torch.lt(locations2, capacity)
  257. # Store the capacity location for each token
  258. locations1_s = torch.sum(locations1 * mask1, dim=1)
  259. locations2_s = torch.sum(locations2 * mask2, dim=1)
  260. # Normalize gate probabilities
  261. mask1_float = mask1.float()
  262. mask2_float = mask2.float()
  263. gates1_s = einsum("se,se->s", gates, mask1_float)
  264. gates2_s = einsum("se,se->s", gates, mask2_float)
  265. denom_s = gates1_s + gates2_s
  266. # Avoid divide-by-zero
  267. denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
  268. gates1_s /= denom_s
  269. gates2_s /= denom_s
  270. # Calculate combine_weights and dispatch_mask
  271. gates1 = einsum("s,se->se", gates1_s, mask1_float)
  272. gates2 = einsum("s,se->se", gates2_s, mask2_float)
  273. locations1_sc = _one_hot_to_float(locations1_s, capacity)
  274. locations2_sc = _one_hot_to_float(locations2_s, capacity)
  275. combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
  276. combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
  277. combine_weights = combine1_sec + combine2_sec
  278. dispatch_mask = combine_weights.bool()
  279. return l_aux, combine_weights, dispatch_mask, exp_counts
  280. class TopKGate(Module):
  281. """Gate module which implements Top2Gating as described in Gshard_.
  282. ::
  283. gate = TopKGate(model_dim, num_experts)
  284. l_aux, combine_weights, dispatch_mask = gate(input)
  285. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  286. Args:
  287. model_dim (int):
  288. size of model embedding dimension
  289. num_experts (ints):
  290. number of experts in model
  291. """
  292. wg: torch.nn.Linear
  293. def __init__(self,
  294. model_dim: int,
  295. num_experts: int,
  296. k: int = 1,
  297. capacity_factor: float = 1.0,
  298. eval_capacity_factor: float = 1.0,
  299. min_capacity: int = 4,
  300. noisy_gate_policy: Optional[str] = None,
  301. drop_tokens: bool = True,
  302. use_rts: bool = True) -> None:
  303. super().__init__()
  304. # Only top-1 and top-2 are supported at the moment.
  305. if k != 1 and k != 2:
  306. raise ValueError('Only top-1 and top-2 gatings are supported.')
  307. self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
  308. self.k = k
  309. self.capacity_factor = capacity_factor
  310. self.eval_capacity_factor = eval_capacity_factor
  311. self.min_capacity = min_capacity
  312. self.noisy_gate_policy = noisy_gate_policy
  313. self.timers = SynchronizedWallClockTimer()
  314. self.wall_clock_breakdown = False
  315. self.gate_time = 0.0
  316. self.drop_tokens = drop_tokens
  317. self.use_rts = use_rts
  318. def forward(
  319. self,
  320. input: torch.Tensor,
  321. used_token: torch.Tensor = None,
  322. use_tutel: bool = False) -> Tuple[Tensor,
  323. Tensor,
  324. Tensor]: # type: ignore
  325. if self.wall_clock_breakdown:
  326. self.timers('TopKGate').start()
  327. if self.wg.weight.dtype != torch.float32:
  328. self.wg = self.wg.float()
  329. input_fp32 = input.float()
  330. # input jittering
  331. if self.noisy_gate_policy == 'Jitter' and self.training:
  332. input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
  333. logits = self.wg(input_fp32)
  334. if self.k == 1:
  335. gate_output = top1gating(
  336. logits,
  337. self.capacity_factor if self.training else self.eval_capacity_factor,
  338. self.min_capacity,
  339. used_token,
  340. self.noisy_gate_policy if self.training else None,
  341. self.drop_tokens,
  342. self.use_rts,
  343. use_tutel)
  344. else:
  345. gate_output = top2gating(
  346. logits,
  347. self.capacity_factor if self.training else self.eval_capacity_factor)
  348. if self.wall_clock_breakdown:
  349. self.timers('TopKGate').stop()
  350. self.gate_time = self.timers('TopKGate').elapsed(reset=False) * 1000
  351. return gate_output
  352. class MOELayer(Base):
  353. """MOELayer module which implements MixtureOfExperts as described in Gshard_.
  354. ::
  355. gate = TopKGate(model_dim, num_experts)
  356. moe = MOELayer(gate, expert)
  357. output = moe(input)
  358. l_aux = moe.l_aux
  359. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  360. Args:
  361. gate (torch.nn.Module):
  362. gate network
  363. expert (torch.nn.Module):
  364. expert network
  365. """
  366. def __init__(self,
  367. gate: Module,
  368. experts: Module,
  369. num_local_experts: int,
  370. group: Optional[Any] = None,
  371. use_tutel: bool = False) -> None:
  372. super().__init__()
  373. self.gate = gate
  374. self.experts = experts
  375. self.group = group
  376. self.world_size = dist.get_world_size(group)
  377. self.num_local_experts = num_local_experts
  378. self.time_falltoall = 0.0
  379. self.time_salltoall = 0.0
  380. self.time_moe = 0.0
  381. self.timers = SynchronizedWallClockTimer()
  382. self.wall_clock_breakdown = False
  383. self.use_tutel = use_tutel and TUTEL_INSTALLED
  384. if self.use_tutel:
  385. logger.info('Using Tutel optimizations.')
  386. elif use_tutel and not TUTEL_INSTALLED:
  387. logger.warning("Tutel optimization requested but not installed. "
  388. "Proceeding without Tutel.")
  389. def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
  390. if self.wall_clock_breakdown:
  391. self.timers('moe').start()
  392. # Implement Algorithm 2 from GShard paper.
  393. d_model = input[0].shape[-1]
  394. # Initial implementation -> Reshape into S tokens by dropping sequence dimension.
  395. # Reshape into G groups so that each group can distribute tokens equally
  396. # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
  397. reshaped_input = input[0].reshape(-1, d_model)
  398. if self.use_tutel:
  399. self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
  400. S, M = reshaped_input.size(0), reshaped_input.size(1)
  401. if not hasattr(self, '_tutel_dispatcher'):
  402. self._tutel_dispatcher = tutel_moe.fast_dispatcher(
  403. E,
  404. C,
  405. M,
  406. dispatch_dtype=reshaped_input.dtype)
  407. self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
  408. dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
  409. else:
  410. self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
  411. dispatched_input = einsum("sec,sm->ecm",
  412. dispatch_mask.type_as(input[0]),
  413. reshaped_input)
  414. if self.wall_clock_breakdown:
  415. self.timers('falltoall').start()
  416. dispatched_input = _AllToAll.apply(self.group, dispatched_input)
  417. if self.wall_clock_breakdown:
  418. self.timers('falltoall').stop()
  419. self.time_falltoall = self.timers('falltoall').elapsed(reset=False) * 1000
  420. # Re-shape after all-to-all: ecm -> gecm
  421. dispatched_input = dispatched_input.reshape(self.world_size,
  422. self.num_local_experts,
  423. -1,
  424. d_model)
  425. expert_output = self.experts(dispatched_input)
  426. if self.wall_clock_breakdown:
  427. self.timers('salltoall').start()
  428. expert_output = _AllToAll.apply(self.group, expert_output)
  429. if self.wall_clock_breakdown:
  430. self.timers('salltoall').stop()
  431. self.time_salltoall = self.timers('salltoall').elapsed(reset=False) * 1000
  432. # Re-shape back: gecm -> ecm
  433. expert_output = expert_output.reshape(self.world_size * self.num_local_experts,
  434. -1,
  435. d_model)
  436. if self.use_tutel:
  437. combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
  438. else:
  439. combined_output = einsum("sec,ecm->sm",
  440. combine_weights.type_as(input[0]),
  441. expert_output)
  442. a = combined_output.reshape(input[0].shape)
  443. if self.wall_clock_breakdown:
  444. self.timers('moe').stop()
  445. self.time_moe = self.timers('moe').elapsed(reset=False) * 1000
  446. return a