sharded_moe.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. '''
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. '''
  4. # The file has been adapted from two fairscale files:
  5. # (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
  6. # (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
  7. # Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
  8. # We retain the following license from the original files:
  9. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  10. #
  11. # This source code is licensed under the BSD license found in the
  12. # LICENSE file in the root directory of this source tree.
  13. from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
  14. from deepspeed.utils import logger, log_dist
  15. from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union, cast
  16. import time
  17. from time import perf_counter
  18. import torch
  19. from torch import Tensor
  20. import torch.distributed as dist
  21. from torch.nn import Module, ModuleList
  22. if TYPE_CHECKING:
  23. Base = Module[Tensor]
  24. else:
  25. Base = Module
  26. uniform_map: Dict[torch.device, Callable] = {}
  27. gumbel_map: Dict[torch.device, Callable] = {}
  28. exp_selection_uniform_map: Dict[torch.device, Callable] = {}
  29. try:
  30. # To enable Tutel MoE optimizations:
  31. # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
  32. from tutel import moe as tutel_moe
  33. TUTEL_INSTALLED = True
  34. except ImportError:
  35. # Fail silently so we don't spam logs unnecessarily if user isn't using tutel
  36. TUTEL_INSTALLED = False
  37. pass
  38. def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
  39. """
  40. Modified from switch transformer paper. mesh transformers
  41. Multiply values by a random number between 1-epsilon and 1+epsilon.
  42. Makes models more resilient to rounding errors introduced by bfloat16.
  43. This seems particularly important for logits.
  44. Args:
  45. x: a torch.tensor
  46. device: torch.device
  47. epsilon: a floating point value
  48. Returns:
  49. a jittered x.
  50. """
  51. if epsilon == 0:
  52. return x
  53. uniform = uniform_map.get(device)
  54. if uniform is None:
  55. uniform = torch.distributions.uniform.Uniform(
  56. low=torch.tensor(1.0 - epsilon,
  57. device=device),
  58. high=torch.tensor(1.0 + epsilon,
  59. device=device)).rsample # type: ignore
  60. uniform_map[device] = uniform
  61. return x * uniform(x.shape)
  62. def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
  63. gumbel = gumbel_map.get(device)
  64. if gumbel is None:
  65. one = torch.tensor(1.0, device=device)
  66. zero = torch.tensor(0.0, device=device)
  67. gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
  68. gumbel_map[device] = gumbel
  69. return gumbel(shape)
  70. import torch.distributed as dist
  71. # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
  72. # See https://arxiv.org/pdf/2006.16668.pdf for details.
  73. # Based on https://github.com/pytorch/pytorch/pull/40762
  74. class _AllToAll(torch.autograd.Function):
  75. @staticmethod
  76. def forward(ctx: Any,
  77. group: dist.ProcessGroup,
  78. input: Tensor) -> Tensor: # type: ignore
  79. ctx.group = group
  80. input = input.contiguous()
  81. output = torch.empty_like(input)
  82. dist.all_to_all_single(output, input, group=group)
  83. return output
  84. @staticmethod
  85. def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
  86. return (None, _AllToAll.apply(ctx.group, *grad_output))
  87. from torch import nn
  88. import torch.nn.functional as F
  89. import math
  90. # einsum rewrites are on par or more performant
  91. # switch can be bubbled up in future
  92. USE_EINSUM = True
  93. def einsum(rule, a, b):
  94. if USE_EINSUM:
  95. return torch.einsum(rule, a, b)
  96. elif rule == 's,se->se':
  97. return a.reshape(a.shape[0], -1) * b
  98. elif rule == 'se,sc->sec':
  99. return a.unsqueeze(2) * b.unsqueeze(1)
  100. elif rule == 'se,se->s':
  101. return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
  102. elif rule == 'sec,sm->ecm':
  103. s = a.shape[0]
  104. e = a.shape[1]
  105. c = a.shape[2]
  106. m = b.shape[1]
  107. return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
  108. elif rule == 'sec,ecm->sm':
  109. return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
  110. elif rule == 'ks,ksm->sm':
  111. k = b.shape[0]
  112. s = b.shape[1]
  113. m = b.shape[2]
  114. # [k, s] -> [s, k] -> [s, 1, k]
  115. a = a.t().unsqueeze(1)
  116. # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
  117. b = b.reshape(k, -1).t().reshape(s, m, k)
  118. # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
  119. return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
  120. else:
  121. return torch.einsum(rule, a, b)
  122. def top1gating(logits: torch.Tensor,
  123. capacity_factor: float,
  124. min_capacity: int,
  125. used_token: torch.Tensor = None,
  126. noisy_gate_policy: Optional[str] = None,
  127. drop_tokens: bool = True,
  128. use_rts: bool = True,
  129. use_tutel: bool = False) -> Tuple[Tensor,
  130. Tensor,
  131. Tensor]:
  132. """Implements Top1Gating on logits."""
  133. if noisy_gate_policy == 'RSample':
  134. logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
  135. # everything is in fp32 in this function
  136. gates = F.softmax(logits, dim=1)
  137. # gates has shape of SE
  138. num_tokens = int(gates.shape[0])
  139. num_experts = int(gates.shape[1])
  140. # round-up
  141. capacity = math.ceil((num_tokens / num_experts) * capacity_factor)
  142. if capacity < min_capacity:
  143. capacity = min_capacity
  144. # Create a mask for 1st's expert per token
  145. # noisy gating
  146. indices1_s = torch.argmax(
  147. logits_w_noise if noisy_gate_policy == 'RSample' else gates,
  148. dim=1)
  149. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  150. # mask only used tokens
  151. if used_token is not None:
  152. mask1 = einsum("s,se->se", used_token, mask1)
  153. # gating decisions
  154. exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
  155. # if we don't want to drop any tokens
  156. if not drop_tokens:
  157. new_capacity = torch.max(exp_counts).to(logits.device)
  158. dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
  159. capacity = new_capacity
  160. # Compute l_aux
  161. me = torch.mean(gates, dim=0)
  162. ce = torch.mean(mask1.float(), dim=0)
  163. l_aux = torch.sum(me * ce) * num_experts
  164. # Random Token Selection
  165. if use_rts:
  166. uniform = exp_selection_uniform_map.get(logits.device)
  167. if uniform is None:
  168. uniform = torch.distributions.uniform.Uniform(
  169. low=torch.tensor(0.0,
  170. device=logits.device),
  171. high=torch.tensor(1.0,
  172. device=logits.device)).rsample
  173. exp_selection_uniform_map[logits.device] = uniform
  174. mask1_rand = mask1 * uniform(mask1.shape)
  175. assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
  176. _, top_idx = torch.topk(mask1_rand, k=capacity, dim=0)
  177. new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
  178. mask1 = new_mask1
  179. if use_tutel:
  180. # Tutel doesn't support index values masked with zero
  181. # so we need to replace masked indices with -1
  182. indices_mask = mask1.sum(dim=1) * num_experts - 1
  183. indices1_s = torch.min(indices1_s, indices_mask)
  184. # Compute locations in capacity buffer
  185. if use_tutel:
  186. locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
  187. else:
  188. locations1 = torch.cumsum(mask1, dim=0) - 1
  189. if use_tutel:
  190. gates1_s = (gates * mask1).sum(dim=1)
  191. locations1_s = torch.sum(locations1 * mask1, dim=1)
  192. return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,], exp_counts
  193. # Store the capacity location for each token
  194. locations1_s = torch.sum(locations1 * mask1, dim=1)
  195. # Normalize gate probabilities
  196. mask1_float = mask1.float()
  197. gates = gates * mask1_float
  198. locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float()
  199. combine_weights = einsum("se,sc->sec", gates, locations1_sc)
  200. dispatch_mask = combine_weights.bool()
  201. return l_aux, combine_weights, dispatch_mask, exp_counts
  202. def top2gating(logits: torch.Tensor,
  203. capacity_factor: float) -> Tuple[Tensor,
  204. Tensor,
  205. Tensor]:
  206. """Implements Top2Gating on logits."""
  207. # everything is in fp32 in this function
  208. # logits_fp32 = logits.to(torch.float32)
  209. gates = F.softmax(logits, dim=1)
  210. # gates has shape of SE
  211. num_tokens = int(gates.shape[0])
  212. num_experts = int(gates.shape[1])
  213. # capacity = (2 * num_tokens // num_experts) * capacity_factor
  214. # round-up
  215. capacity = math.ceil((2 * num_tokens / num_experts) * capacity_factor)
  216. # Create a mask for 1st's expert per token
  217. indices1_s = torch.argmax(gates, dim=1)
  218. mask1 = F.one_hot(indices1_s, num_classes=num_experts)
  219. # Create a mask for 2nd's expert per token using Gumbel-max trick
  220. # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
  221. logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
  222. # Replace top-expert with min value
  223. logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
  224. indices2_s = torch.argmax(logits_except1, dim=1)
  225. mask2 = F.one_hot(indices2_s, num_classes=num_experts)
  226. # Compute locations in capacity buffer
  227. locations1 = torch.cumsum(mask1, dim=0) - 1
  228. locations2 = torch.cumsum(mask2, dim=0) - 1
  229. # Update 2nd's location by accounting for locations of 1st
  230. locations2 += torch.sum(mask1, dim=0, keepdim=True)
  231. # gating decisions
  232. exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
  233. # Compute l_aux
  234. me = torch.mean(gates, dim=0)
  235. ce = torch.mean(mask1.float(), dim=0)
  236. l_aux = torch.mean(me * ce) * num_experts * num_experts
  237. # Remove locations outside capacity from mask
  238. mask1 *= torch.lt(locations1, capacity)
  239. mask2 *= torch.lt(locations2, capacity)
  240. # Store the capacity location for each token
  241. locations1_s = torch.sum(locations1 * mask1, dim=1)
  242. locations2_s = torch.sum(locations2 * mask2, dim=1)
  243. # Normalize gate probabilities
  244. mask1_float = mask1.float()
  245. mask2_float = mask2.float()
  246. gates1_s = einsum("se,se->s", gates, mask1_float)
  247. gates2_s = einsum("se,se->s", gates, mask2_float)
  248. denom_s = gates1_s + gates2_s
  249. # Avoid divide-by-zero
  250. denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
  251. gates1_s /= denom_s
  252. gates2_s /= denom_s
  253. # Calculate combine_weights and dispatch_mask
  254. gates1 = einsum("s,se->se", gates1_s, mask1_float)
  255. gates2 = einsum("s,se->se", gates2_s, mask2_float)
  256. locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float()
  257. locations2_sc = F.one_hot(locations2_s, num_classes=capacity).float()
  258. combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
  259. combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
  260. combine_weights = combine1_sec + combine2_sec
  261. dispatch_mask = combine_weights.bool()
  262. return l_aux, combine_weights, dispatch_mask, exp_counts
  263. class TopKGate(torch.nn.Module):
  264. """Gate module which implements Top2Gating as described in Gshard_.
  265. ::
  266. gate = TopKGate(model_dim, num_experts)
  267. l_aux, combine_weights, dispatch_mask = gate(input)
  268. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  269. Args:
  270. model_dim (int):
  271. size of model embedding dimension
  272. num_experts (ints):
  273. number of experts in model
  274. """
  275. wg: torch.nn.Linear
  276. def __init__(self,
  277. model_dim: int,
  278. num_experts: int,
  279. k: int = 1,
  280. capacity_factor: float = 1.0,
  281. eval_capacity_factor: float = 1.0,
  282. min_capacity: int = 4,
  283. noisy_gate_policy: Optional[str] = None,
  284. drop_tokens: bool = True,
  285. use_rts: bool = True) -> None:
  286. super().__init__()
  287. # Only top-1 and top-2 are supported at the moment.
  288. if k != 1 and k != 2:
  289. raise ValueError('Only top-1 and top-2 gatings are supported.')
  290. self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
  291. self.k = k
  292. self.capacity_factor = capacity_factor
  293. self.eval_capacity_factor = eval_capacity_factor
  294. self.min_capacity = min_capacity
  295. self.noisy_gate_policy = noisy_gate_policy
  296. self.timers = SynchronizedWallClockTimer()
  297. self.wall_clock_breakdown = False
  298. self.gate_time = 0.0
  299. self.drop_tokens = drop_tokens
  300. self.use_rts = use_rts
  301. def forward(
  302. self,
  303. input: torch.Tensor,
  304. used_token: torch.Tensor = None,
  305. use_tutel: bool = False) -> Tuple[Tensor,
  306. Tensor,
  307. Tensor]: # type: ignore
  308. if self.wall_clock_breakdown:
  309. self.timers('TopKGate').start()
  310. if self.wg.weight.dtype != torch.float32:
  311. self.wg = self.wg.float()
  312. input_fp32 = input.float()
  313. # input jittering
  314. if self.noisy_gate_policy == 'Jitter' and self.training:
  315. input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
  316. logits = self.wg(input_fp32)
  317. if self.k == 1:
  318. gate_output = top1gating(
  319. logits,
  320. self.capacity_factor if self.training else self.eval_capacity_factor,
  321. self.min_capacity,
  322. used_token,
  323. self.noisy_gate_policy if self.training else None,
  324. self.drop_tokens,
  325. self.use_rts,
  326. use_tutel)
  327. else:
  328. gate_output = top2gating(
  329. logits,
  330. self.capacity_factor if self.training else self.eval_capacity_factor)
  331. if self.wall_clock_breakdown:
  332. self.timers('TopKGate').stop()
  333. self.gate_time = self.timers('TopKGate').elapsed(reset=False) * 1000
  334. return gate_output
  335. class MOELayer(Base):
  336. """MOELayer module which implements MixtureOfExperts as described in Gshard_.
  337. ::
  338. gate = TopKGate(model_dim, num_experts)
  339. moe = MOELayer(gate, expert)
  340. output = moe(input)
  341. l_aux = moe.l_aux
  342. .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
  343. Args:
  344. gate (torch.nn.Module):
  345. gate network
  346. expert (torch.nn.Module):
  347. expert network
  348. """
  349. def __init__(self,
  350. gate: Module,
  351. experts: Module,
  352. num_local_experts: int,
  353. group: Optional[Any] = None,
  354. use_tutel: bool = False) -> None:
  355. super().__init__()
  356. self.gate = gate
  357. self.experts = experts
  358. self.group = group
  359. self.world_size = dist.get_world_size(group)
  360. self.num_local_experts = num_local_experts
  361. self.time_falltoall = 0.0
  362. self.time_salltoall = 0.0
  363. self.time_moe = 0.0
  364. self.timers = SynchronizedWallClockTimer()
  365. self.wall_clock_breakdown = False
  366. self.use_tutel = use_tutel and TUTEL_INSTALLED
  367. if self.use_tutel:
  368. logger.info('Using Tutel optimizations.')
  369. elif use_tutel and not TUTEL_INSTALLED:
  370. logger.warning("Tutel optimization requested but not installed. "
  371. "Proceeding without Tutel.")
  372. def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
  373. if self.wall_clock_breakdown:
  374. self.timers('moe').start()
  375. # Implement Algorithm 2 from GShard paper.
  376. d_model = input[0].shape[-1]
  377. # Initial implementation -> Reshape into S tokens by dropping sequence dimension.
  378. # Reshape into G groups so that each group can distribute tokens equally
  379. # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
  380. reshaped_input = input[0].reshape(-1, d_model)
  381. if self.use_tutel:
  382. self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
  383. S, M = reshaped_input.size(0), reshaped_input.size(1)
  384. if not hasattr(self, '_tutel_dispatcher'):
  385. self._tutel_dispatcher = tutel_moe.fast_dispatcher(
  386. E,
  387. C,
  388. M,
  389. dispatch_dtype=reshaped_input.dtype)
  390. self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
  391. dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
  392. else:
  393. self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
  394. dispatched_input = einsum("sec,sm->ecm",
  395. dispatch_mask.type_as(input[0]),
  396. reshaped_input)
  397. if self.wall_clock_breakdown:
  398. self.timers('falltoall').start()
  399. dispatched_input = _AllToAll.apply(self.group, dispatched_input)
  400. if self.wall_clock_breakdown:
  401. self.timers('falltoall').stop()
  402. self.time_falltoall = self.timers('falltoall').elapsed(reset=False) * 1000
  403. # Re-shape after all-to-all: ecm -> gecm
  404. dispatched_input = dispatched_input.reshape(self.world_size,
  405. self.num_local_experts,
  406. -1,
  407. d_model)
  408. expert_output = self.experts(dispatched_input)
  409. if self.wall_clock_breakdown:
  410. self.timers('salltoall').start()
  411. expert_output = _AllToAll.apply(self.group, expert_output)
  412. if self.wall_clock_breakdown:
  413. self.timers('salltoall').stop()
  414. self.time_salltoall = self.timers('salltoall').elapsed(reset=False) * 1000
  415. # Re-shape back: gecm -> ecm
  416. expert_output = expert_output.reshape(self.world_size * self.num_local_experts,
  417. -1,
  418. d_model)
  419. if self.use_tutel:
  420. combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
  421. else:
  422. combined_output = einsum("sec,ecm->sm",
  423. combine_weights.type_as(input[0]),
  424. expert_output)
  425. a = combined_output.reshape(input[0].shape)
  426. if self.wall_clock_breakdown:
  427. self.timers('moe').stop()
  428. self.time_moe = self.timers('moe').elapsed(reset=False) * 1000
  429. return a