softmax.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. # DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
  2. # https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
  3. import warnings
  4. import importlib
  5. import torch
  6. import math
  7. import triton
  8. import triton.language as tl
  9. import triton._C.libtriton as libtriton
  10. def next_power_of_2(n):
  11. n -= 1
  12. n |= n >> 1
  13. n |= n >> 2
  14. n |= n >> 4
  15. n |= n >> 8
  16. n |= n >> 16
  17. n += 1
  18. return n
  19. def num_warps(n):
  20. if n < 512:
  21. return 4
  22. if n < 2048:
  23. return 8
  24. return 16
  25. @triton.heuristics({
  26. 'num_warps': lambda *args,
  27. **meta: num_warps(args[6] * meta['BLOCK'])
  28. })
  29. @triton.heuristics({
  30. 'TN': lambda *args,
  31. **meta: next_power_of_2(args[6] * meta['BLOCK'])
  32. })
  33. @triton.jit
  34. def _forward(X,
  35. scale,
  36. LUT,
  37. RPE,
  38. KP_M,
  39. ATTN_M,
  40. sizemax,
  41. stride_zx,
  42. stride_zrpe,
  43. stride_hrpe,
  44. stride_srpe,
  45. stride_zkpm,
  46. stride_zattnm,
  47. **meta):
  48. TN = meta['TN']
  49. BLOCK = meta['BLOCK']
  50. pidhm = tl.program_id(0)
  51. pidz = tl.program_id(1)
  52. # create index ranges
  53. rxm = pidhm % BLOCK
  54. rbm = pidhm // BLOCK
  55. rxn = tl.arange(0, TN) % BLOCK
  56. rbn = tl.arange(0, TN) // BLOCK
  57. # extract information from LUT
  58. header = LUT + rbm * 2
  59. size = tl.load(header + 0)
  60. offset = tl.load(header + 1)
  61. check = rbn < size
  62. rbmn = tl.where(check, rbn, size - 1)
  63. # block id and column id
  64. blockid = tl.load(LUT + offset + rbmn * 4 + 0)
  65. columnid = tl.load(LUT + offset + rbmn * 4 + 1)
  66. rowid = tl.load(LUT + offset + rbmn * 4 + 2)
  67. headid = tl.load(LUT + offset + rbmn * 4 + 3)
  68. # pointers to X
  69. px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
  70. x = tl.load(px, mask=check, other=-float('inf'))
  71. x = x.to(tl.float32)
  72. # apply scale
  73. if meta['APPLY_SCALE']:
  74. x = x * scale
  75. # apply RPE
  76. if meta['APPLY_RPE']:
  77. prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
  78. rpe = tl.load(prpe, mask=check, other=0)
  79. x = x + rpe
  80. # apply key-padding mask
  81. if meta['APPLY_KP_MASK']:
  82. pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
  83. kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
  84. if meta['KP_MASK_MUL']:
  85. kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
  86. x = x + kp_m
  87. # apply attention mask
  88. if meta['APPLY_ATTN_MASK']:
  89. pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
  90. attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
  91. if meta['ATTN_MASK_MUL']:
  92. attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
  93. x = x + attn_m
  94. # computation
  95. x = tl.softmax(x)
  96. tl.store(px, x, mask=check)
  97. @triton.heuristics({
  98. 'num_warps': lambda *args,
  99. **meta: num_warps(args[4] * meta['BLOCK'])
  100. })
  101. @triton.heuristics({
  102. 'TN': lambda *args,
  103. **meta: next_power_of_2(args[4]) * meta['BLOCK']
  104. })
  105. @triton.jit
  106. def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
  107. pidhm = tl.program_id(0)
  108. pidz = tl.program_id(1)
  109. TN = meta['TN']
  110. BLOCK = meta['BLOCK']
  111. # create index ranges
  112. rxm = pidhm % BLOCK
  113. rbm = pidhm // BLOCK
  114. rxn = tl.arange(0, TN) % BLOCK
  115. rbn = tl.arange(0, TN) // BLOCK
  116. # extract information from look-up table
  117. header = LUT + rbm * 2
  118. size = tl.load(header + 0)
  119. offset = tl.load(header + 1)
  120. # bounds checking on lut
  121. check = rbn < size
  122. rbmn = tl.where(check, rbn, size - 1)
  123. # initialize pointers to block-sparse input
  124. blockid = tl.load(LUT + offset + rbmn * 4)
  125. X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
  126. DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
  127. # compute fused softmax backward
  128. x = tl.load(X, mask=check, other=0)
  129. dx = tl.load(DX, mask=check, other=0)
  130. x = x.to(tl.float32)
  131. dx = dx.to(tl.float32)
  132. y = x * (dx - tl.sum(x * dx, 0)) * scale
  133. tl.store(DX, y, mask=check)
  134. class _sparse_softmax(torch.autograd.Function):
  135. bwd_kernels = dict()
  136. @staticmethod
  137. def make_lut(layout, block, device):
  138. _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
  139. sizes = _empty.clone()
  140. # sizes along rows
  141. for h in range(layout.shape[0]):
  142. sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
  143. # offsets in block format
  144. offsets = torch.zeros_like(sizes)
  145. offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
  146. # block indices
  147. idx = torch.arange(layout.sum())
  148. head = layout.nonzero()[:, 0]
  149. rows = layout.nonzero()[:, 1]
  150. columns = layout.nonzero()[:, 2]
  151. core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
  152. # construct look-up table
  153. offsets = offsets * 4 + 2 * sizes.numel()
  154. header = torch.stack((sizes, offsets), dim=1).view(-1)
  155. lut = torch.cat((header, core)).type(torch.int32).to(device)
  156. return lut, int(sizes.max())
  157. @staticmethod
  158. def forward(ctx,
  159. x,
  160. scale,
  161. rpe,
  162. key_padding_mask,
  163. attn_mask,
  164. kp_mask_mode,
  165. attn_mask_mode,
  166. spdims,
  167. block,
  168. lut,
  169. num_blocks,
  170. maxlut,
  171. bench,
  172. time):
  173. apply_scale = False if scale == 1.0 else True
  174. # handle None rpe
  175. if rpe is None:
  176. apply_rpe = False
  177. stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
  178. rpe = torch.empty(0, dtype=x.dtype, device=x.device)
  179. else:
  180. apply_rpe = True
  181. stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
  182. # handle None key_padding_mask
  183. if key_padding_mask is None:
  184. apply_kp_mask = False
  185. stride_zkpm = 0
  186. key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
  187. else:
  188. apply_kp_mask = True
  189. stride_zkpm = key_padding_mask.stride(0)
  190. # handle None attention_mask
  191. if attn_mask is None:
  192. apply_attn_mask = False
  193. stride_zattnm = 0
  194. attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
  195. else:
  196. apply_attn_mask = True
  197. stride_zattnm = attn_mask.stride(0)
  198. # run kernel
  199. M = x.shape[0]
  200. meta = {
  201. 'BLOCK': block,
  202. 'APPLY_SCALE': apply_scale,
  203. 'APPLY_RPE': apply_rpe,
  204. 'APPLY_KP_MASK': apply_kp_mask,
  205. 'APPLY_ATTN_MASK': apply_attn_mask,
  206. 'KP_MASK_MUL': kp_mask_mode == 'mul',
  207. 'ATTN_MASK_MUL': attn_mask_mode == 'mul',
  208. }
  209. grid = lambda opt: [spdims[0] * spdims[1] * block, M]
  210. _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
  211. stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)
  212. # save to context
  213. ctx.mark_dirty(x)
  214. ctx.save_for_backward(x, lut)
  215. ctx.spdims = spdims
  216. ctx.block = block
  217. ctx.maxlut = maxlut
  218. ctx.scale = scale
  219. ctx.apply_scale = apply_scale
  220. ctx.apply_rpe = apply_rpe
  221. ctx.apply_kp_mask = apply_kp_mask
  222. ctx.apply_attn_mask = apply_attn_mask
  223. ctx.kp_mask_mode = kp_mask_mode
  224. ctx.attn_mask_mode = attn_mask_mode
  225. return x
  226. @staticmethod
  227. def backward(ctx, dx):
  228. # retrieve from context
  229. x, lut = ctx.saved_tensors
  230. # run kernel
  231. M = x.shape[0]
  232. grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
  233. _backward[grid](x,
  234. ctx.scale,
  235. dx,
  236. lut,
  237. ctx.maxlut,
  238. x.stride(0),
  239. dx.stride(0),
  240. BLOCK=ctx.block)
  241. return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
  242. class Softmax:
  243. """Block-Sparse Softmax class; this class computes softmax on a block sparse matrix. It is also able to apply either/all of the following masks:
  244. - relative position embedding
  245. - key padding mask
  246. - attention mask
  247. For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
  248. """
  249. def sparse_softmax(*args, **kwargs):
  250. return _sparse_softmax.apply(*args, **kwargs)
  251. def make_lut(self, device):
  252. """Generates the sparsity layout used in block-sparse softmax
  253. """
  254. key = (device, )
  255. if key not in self.lut_cache:
  256. self.lut_cache[key] = _sparse_softmax.make_lut(self.layout,
  257. self.block,
  258. device)
  259. return self.lut_cache[key]
  260. def __init__(self, layout, block, bench=False):
  261. """Initialize the Block-Sparse Softmax class.
  262. Arguments:
  263. layout: required: sparsity layout tensor
  264. block: required: an integer determining the block size.
  265. bench: optional: set if you want to do benchmarking
  266. """
  267. self.num_blocks = layout.sum().item()
  268. self.spdims = layout.shape
  269. self.layout = layout
  270. self.block = block
  271. self.bench = bench
  272. self.lut_cache = dict()
  273. def __call__(self,
  274. x,
  275. scale=1.,
  276. rpe=None,
  277. key_padding_mask=None,
  278. attn_mask=None,
  279. key_padding_mask_mode='add',
  280. attn_mask_mode='add'):
  281. """Applies softmax on a Block-Sparse input tensor.
  282. For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
  283. Arguments:
  284. x: required: a block-sparse tensor that softmax is applied on it; computation will be in place and result will be returned in the same tensor
  285. scale: optional: a float value; x values will be multiplied by this value before normalization. Default value is 1.0.
  286. rpe: optional: a tensor same dimension as x that is used as relative position embedding
  287. key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)
  288. attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
  289. key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied
  290. attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied
  291. Return:
  292. x: a block-sparse tensor contains normalized input x using softmax; and masks applied if given
  293. """
  294. time_y = [None]
  295. if rpe is not None and rpe.dtype != x.dtype:
  296. raise ValueError('relative position embedding must be %s' % x.dtype)
  297. if attn_mask is not None and attn_mask.dtype != x.dtype:
  298. raise ValueError('Attention mask must be %s' % x.dtype)
  299. if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
  300. raise ValueError('Key padding mask must be %s' % x.dtype)
  301. lut, maxlut = self.make_lut(x.device)
  302. x = Softmax.sparse_softmax(x,
  303. scale,
  304. rpe,
  305. key_padding_mask,
  306. attn_mask,
  307. key_padding_mask_mode,
  308. attn_mask_mode,
  309. self.spdims,
  310. self.block,
  311. lut,
  312. self.num_blocks,
  313. maxlut,
  314. self.bench,
  315. time_y)
  316. self.time_y = time_y[0]
  317. return x