softmax.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
  5. # https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
  6. import torch
  7. import triton
  8. import triton.language as tl
  9. def next_power_of_2(n):
  10. n -= 1
  11. n |= n >> 1
  12. n |= n >> 2
  13. n |= n >> 4
  14. n |= n >> 8
  15. n |= n >> 16
  16. n += 1
  17. return n
  18. def num_warps(n):
  19. if n < 512:
  20. return 4
  21. if n < 2048:
  22. return 8
  23. return 16
  24. @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])})
  25. @triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])})
  26. @triton.jit
  27. def _forward(X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm,
  28. stride_zattnm, **meta):
  29. TN = meta['TN']
  30. BLOCK = meta['BLOCK']
  31. pidhm = tl.program_id(0)
  32. pidz = tl.program_id(1)
  33. # create index ranges
  34. rxm = pidhm % BLOCK
  35. rbm = pidhm // BLOCK
  36. rxn = tl.arange(0, TN) % BLOCK
  37. rbn = tl.arange(0, TN) // BLOCK
  38. # extract information from LUT
  39. header = LUT + rbm * 2
  40. size = tl.load(header + 0)
  41. offset = tl.load(header + 1)
  42. check = rbn < size
  43. rbmn = tl.where(check, rbn, size - 1)
  44. # block id and column id
  45. blockid = tl.load(LUT + offset + rbmn * 4 + 0)
  46. columnid = tl.load(LUT + offset + rbmn * 4 + 1)
  47. rowid = tl.load(LUT + offset + rbmn * 4 + 2)
  48. headid = tl.load(LUT + offset + rbmn * 4 + 3)
  49. # pointers to X
  50. px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
  51. x = tl.load(px, mask=check, other=-float('inf'))
  52. x = x.to(tl.float32)
  53. # apply scale
  54. if meta['APPLY_SCALE']:
  55. x = x * scale
  56. # apply RPE
  57. if meta['APPLY_RPE']:
  58. prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
  59. rpe = tl.load(prpe, mask=check, other=0)
  60. x = x + rpe
  61. # apply key-padding mask
  62. if meta['APPLY_KP_MASK']:
  63. pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
  64. kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
  65. if meta['KP_MASK_MUL']:
  66. kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
  67. x = x + kp_m
  68. # apply attention mask
  69. if meta['APPLY_ATTN_MASK']:
  70. pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
  71. attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
  72. if meta['ATTN_MASK_MUL']:
  73. attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
  74. x = x + attn_m
  75. # computation
  76. x = tl.softmax(x)
  77. tl.store(px, x, mask=check)
  78. @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
  79. @triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
  80. @triton.jit
  81. def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
  82. pidhm = tl.program_id(0)
  83. pidz = tl.program_id(1)
  84. TN = meta['TN']
  85. BLOCK = meta['BLOCK']
  86. # create index ranges
  87. rxm = pidhm % BLOCK
  88. rbm = pidhm // BLOCK
  89. rxn = tl.arange(0, TN) % BLOCK
  90. rbn = tl.arange(0, TN) // BLOCK
  91. # extract information from look-up table
  92. header = LUT + rbm * 2
  93. size = tl.load(header + 0)
  94. offset = tl.load(header + 1)
  95. # bounds checking on lut
  96. check = rbn < size
  97. rbmn = tl.where(check, rbn, size - 1)
  98. # initialize pointers to block-sparse input
  99. blockid = tl.load(LUT + offset + rbmn * 4)
  100. X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
  101. DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
  102. # compute fused softmax backward
  103. x = tl.load(X, mask=check, other=0)
  104. dx = tl.load(DX, mask=check, other=0)
  105. x = x.to(tl.float32)
  106. dx = dx.to(tl.float32)
  107. y = x * (dx - tl.sum(x * dx, 0)) * scale
  108. tl.store(DX, y, mask=check)
  109. class _sparse_softmax(torch.autograd.Function):
  110. bwd_kernels = dict()
  111. @staticmethod
  112. def make_lut(layout, block, device):
  113. _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
  114. sizes = _empty.clone()
  115. # sizes along rows
  116. for h in range(layout.shape[0]):
  117. sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
  118. # offsets in block format
  119. offsets = torch.zeros_like(sizes)
  120. offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
  121. # block indices
  122. idx = torch.arange(layout.sum())
  123. head = layout.nonzero()[:, 0]
  124. rows = layout.nonzero()[:, 1]
  125. columns = layout.nonzero()[:, 2]
  126. core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
  127. # construct look-up table
  128. offsets = offsets * 4 + 2 * sizes.numel()
  129. header = torch.stack((sizes, offsets), dim=1).view(-1)
  130. lut = torch.cat((header, core)).type(torch.int32).to(device)
  131. return lut, int(sizes.max())
  132. @staticmethod
  133. def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,
  134. num_blocks, maxlut, bench, time):
  135. apply_scale = False if scale == 1.0 else True
  136. # handle None rpe
  137. if rpe is None:
  138. apply_rpe = False
  139. stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
  140. rpe = torch.empty(0, dtype=x.dtype, device=x.device)
  141. else:
  142. apply_rpe = True
  143. stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
  144. # handle None key_padding_mask
  145. if key_padding_mask is None:
  146. apply_kp_mask = False
  147. stride_zkpm = 0
  148. key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
  149. else:
  150. apply_kp_mask = True
  151. stride_zkpm = key_padding_mask.stride(0)
  152. # handle None attention_mask
  153. if attn_mask is None:
  154. apply_attn_mask = False
  155. stride_zattnm = 0
  156. attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
  157. else:
  158. apply_attn_mask = True
  159. stride_zattnm = attn_mask.stride(0)
  160. # run kernel
  161. M = x.shape[0]
  162. meta = {
  163. 'BLOCK': block,
  164. 'APPLY_SCALE': apply_scale,
  165. 'APPLY_RPE': apply_rpe,
  166. 'APPLY_KP_MASK': apply_kp_mask,
  167. 'APPLY_ATTN_MASK': apply_attn_mask,
  168. 'KP_MASK_MUL': kp_mask_mode == 'mul',
  169. 'ATTN_MASK_MUL': attn_mask_mode == 'mul',
  170. }
  171. grid = lambda opt: [spdims[0] * spdims[1] * block, M]
  172. _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
  173. stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)
  174. # save to context
  175. ctx.mark_dirty(x)
  176. ctx.save_for_backward(x, lut)
  177. ctx.spdims = spdims
  178. ctx.block = block
  179. ctx.maxlut = maxlut
  180. ctx.scale = scale
  181. ctx.apply_scale = apply_scale
  182. ctx.apply_rpe = apply_rpe
  183. ctx.apply_kp_mask = apply_kp_mask
  184. ctx.apply_attn_mask = apply_attn_mask
  185. ctx.kp_mask_mode = kp_mask_mode
  186. ctx.attn_mask_mode = attn_mask_mode
  187. return x
  188. @staticmethod
  189. def backward(ctx, dx):
  190. # retrieve from context
  191. x, lut = ctx.saved_tensors
  192. # run kernel
  193. M = x.shape[0]
  194. grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
  195. _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
  196. return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
  197. class Softmax:
  198. """Block-Sparse Softmax class; this class computes softmax on a block sparse matrix. It is also able to apply either/all of the following masks:
  199. - relative position embedding
  200. - key padding mask
  201. - attention mask
  202. For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
  203. """
  204. def sparse_softmax(*args, **kwargs):
  205. return _sparse_softmax.apply(*args, **kwargs)
  206. def make_lut(self, device):
  207. """Generates the sparsity layout used in block-sparse softmax
  208. """
  209. key = (device, )
  210. if key not in self.lut_cache:
  211. self.lut_cache[key] = _sparse_softmax.make_lut(self.layout, self.block, device)
  212. return self.lut_cache[key]
  213. def __init__(self, layout, block, bench=False):
  214. """Initialize the Block-Sparse Softmax class.
  215. Arguments:
  216. layout: required: sparsity layout tensor
  217. block: required: an integer determining the block size.
  218. bench: optional: set if you want to do benchmarking
  219. """
  220. self.num_blocks = layout.sum().item()
  221. self.spdims = layout.shape
  222. self.layout = layout
  223. self.block = block
  224. self.bench = bench
  225. self.lut_cache = dict()
  226. def __call__(self,
  227. x,
  228. scale=1.,
  229. rpe=None,
  230. key_padding_mask=None,
  231. attn_mask=None,
  232. key_padding_mask_mode='add',
  233. attn_mask_mode='add'):
  234. """Applies softmax on a Block-Sparse input tensor.
  235. For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
  236. Arguments:
  237. x: required: a block-sparse tensor that softmax is applied on it; computation will be in place and result will be returned in the same tensor
  238. scale: optional: a float value; x values will be multiplied by this value before normalization. Default value is 1.0.
  239. rpe: optional: a tensor same dimension as x that is used as relative position embedding
  240. key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)
  241. attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
  242. key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied
  243. attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied
  244. Return:
  245. x: a block-sparse tensor contains normalized input x using softmax; and masks applied if given
  246. """
  247. time_y = [None]
  248. if rpe is not None and rpe.dtype != x.dtype:
  249. raise ValueError('relative position embedding must be %s' % x.dtype)
  250. if attn_mask is not None and attn_mask.dtype != x.dtype:
  251. raise ValueError('Attention mask must be %s' % x.dtype)
  252. if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
  253. raise ValueError('Key padding mask must be %s' % x.dtype)
  254. lut, maxlut = self.make_lut(x.device)
  255. x = Softmax.sparse_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,
  256. self.spdims, self.block, lut, self.num_blocks, maxlut, self.bench, time_y)
  257. self.time_y = time_y[0]
  258. return x