matmul.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
  5. # https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
  6. import importlib
  7. import torch
  8. import triton
  9. import triton.language as tl
  10. import triton._C.libtriton as libtriton
  11. from deepspeed.accelerator import get_accelerator
  12. @triton.jit
  13. def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc,
  14. stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta):
  15. TM = meta['TM']
  16. TN = meta['TN']
  17. TK = meta['TK']
  18. TZ = meta['TZ']
  19. BLOCK = meta['BLOCK']
  20. #------------#
  21. #- Prologue -#
  22. #------------#
  23. pid0 = tl.program_id(0)
  24. pid1 = tl.program_id(1)
  25. pidz = tl.program_id(2)
  26. if meta['SDD']:
  27. pid1 = pid1 + SDD_off_width
  28. blockidm = tl.arange(0, TM) // BLOCK
  29. blockidn = tl.arange(0, TN) // BLOCK
  30. offlutm = blockidm * (TN // BLOCK) * 4
  31. offlutn = blockidn * 4
  32. header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
  33. z = tl.load(header + 0)
  34. i = tl.load(header + 1 + offlutm)
  35. j = tl.load(header + 2 + offlutn)
  36. AS1 = SDD_K // TZ
  37. lockid = tl.where(TZ > 1, 1, 0)
  38. offka = pid0 * AS1
  39. offkb = pid0 * AS1
  40. offmc = 0
  41. offnc = 0
  42. offpa = 0
  43. offpb = 0
  44. maxid = TZ
  45. offhc = 0
  46. offha = z
  47. offhb = z
  48. ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)
  49. rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
  50. else:
  51. header = lut + pid0 * 6
  52. offset = tl.load(header + 0)
  53. AS1 = tl.load(header + 1)
  54. column = tl.load(header + 2)
  55. depth = tl.load(header + 3)
  56. lockid = tl.load(header + 4)
  57. maxid = tl.load(header + 5)
  58. pinc = lut + offset
  59. offhc = depth
  60. if meta['DSD']:
  61. # output offset
  62. offnc = pid1 * TN
  63. offmc = column * TM
  64. offpc = 0
  65. # dense input offset
  66. offnb = pid1 * TN
  67. offkb = tl.load(pinc)
  68. offkb = tl.multiple_of(offkb, 8) # compiler hint
  69. offpb = 0
  70. # sparse input offset
  71. offma = 0
  72. offka = 0
  73. offpa = tl.load(pinc + 1)
  74. offpa = tl.multiple_of(offpa, 8) # compiler hint
  75. offpa = offpa * BLOCK * BLOCK
  76. offha = 0
  77. offhb = depth
  78. else:
  79. # output offset
  80. offmc = pid1 * TM
  81. offnc = column * TN
  82. offpc = 0
  83. # dense input offset
  84. offma = pid1 * TM
  85. offka = tl.load(pinc)
  86. offka = tl.multiple_of(offka, 8) # compiler hint
  87. offpa = 0
  88. # sparse input offset
  89. offnb = 0
  90. offkb = 0
  91. offpb = tl.load(pinc + 1)
  92. offpb = tl.multiple_of(offpb, 8) # compiler hint
  93. offpb = offpb * BLOCK * BLOCK
  94. offha = depth
  95. offhb = 0
  96. ram = offma + tl.arange(0, TM)
  97. rbn = offnb + tl.arange(0, TN)
  98. # initialize a, b pointers
  99. rka = offka + tl.arange(0, TK)
  100. rkb = offkb + tl.arange(0, TK)
  101. pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
  102. pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
  103. if meta['DDS']:
  104. checkam = ram[:, None] < DS0
  105. else:
  106. checkam = AS1 > 0
  107. if meta['DSD']:
  108. checkbn = rbn[None, :] < DS0
  109. else:
  110. checkbn = AS1 > 0
  111. a = tl.load(pa, mask=checkam, other=0.)
  112. b = tl.load(pb, mask=checkbn, other=0.)
  113. ## ---------------- ##
  114. ## Inner Loop ##
  115. ## ---------------- ##
  116. acc = tl.zeros((TM, TN), dtype=tl.float32)
  117. for k in range(AS1, 0, -TK):
  118. acc += tl.dot(a, b)
  119. if meta['SDD']:
  120. inc_a = TK * stride_ka
  121. inc_b = TK * stride_kb
  122. else:
  123. pinc += 2
  124. if meta['DSD']:
  125. inc_b = tl.load(pinc)
  126. inc_a = tl.load(pinc + 1)
  127. inc_b = tl.multiple_of(inc_b, 8)
  128. inc_a = tl.multiple_of(inc_a, 8)
  129. inc_b = inc_b * stride_kb
  130. if meta['DDS']:
  131. inc_a = tl.load(pinc)
  132. inc_b = tl.load(pinc + 1)
  133. inc_a = tl.multiple_of(inc_a, 8)
  134. inc_b = tl.multiple_of(inc_b, 8)
  135. inc_a = inc_a * stride_ka
  136. pa += inc_a
  137. pb += inc_b
  138. # pre-fetch
  139. checkak = k > TK
  140. checkbk = k > TK
  141. checka = checkam & checkak
  142. checkb = checkbn & checkbk
  143. a = tl.load(pa, mask=checka)
  144. b = tl.load(pb, mask=checkb)
  145. c = acc.to(C.dtype.element_ty)
  146. if meta['SDD']:
  147. checkc = True
  148. rr_blockidm = tl.arange(0, TM) // BLOCK
  149. rr_blockidn = tl.arange(0, TN) // BLOCK
  150. rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
  151. rr_offlutn = rr_blockidn * 4
  152. off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
  153. bkid = tl.load(header + off_bkid)
  154. offpc = bkid * BLOCK * BLOCK
  155. rcm = tl.arange(0, TM) % BLOCK
  156. rcn = tl.arange(0, TN) % BLOCK
  157. else:
  158. rcm = offmc + tl.arange(0, TM)
  159. rcn = offnc + tl.arange(0, TN)
  160. if meta['DSD']:
  161. checkc = rcn[None, :] < DS0
  162. if meta['DDS']:
  163. checkc = rcm[:, None] < DS0
  164. pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
  165. # write-back directly
  166. if lockid == 0:
  167. tl.store(pc, c, mask=checkc)
  168. # accumulate partial results using spin-locks
  169. else:
  170. plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1
  171. pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
  172. while tl.atomic_cas(plock, 0, 1) == 1:
  173. pass
  174. count = tl.load(pcount)
  175. if count == 0:
  176. tl.store(pc, c, mask=checkc)
  177. else:
  178. d = tl.load(pc, mask=checkc)
  179. tl.store(pc, d + c, mask=checkc)
  180. tl.atomic_xchg(pcount, (count + 1) % maxid)
  181. tl.atomic_xchg(plock, 0)
  182. ##############
  183. # MAIN API #
  184. ##############
  185. class _sparse_matmul(torch.autograd.Function):
  186. sdd_cache = dict()
  187. dsd_cache = dict()
  188. dds_cache = dict()
  189. locks = dict()
  190. # Given an array sizes representing reduction size for each
  191. # column of a block-mode matrix multiplication,
  192. # performs load-balancing to achieve more smaller reductions
  193. # between `seg_size` elements
  194. @staticmethod
  195. def load_balance(sizes, block):
  196. #global triton
  197. #if triton is None:
  198. # triton = importlib.import_module('triton')
  199. # segment size
  200. # heuristics taken from OpenAI blocksparse code
  201. # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
  202. max_size = sizes.max()
  203. min_size = sizes[sizes != 0].min()
  204. #if max_size > min_size * 2.0:
  205. # seg_max = max(triton.cdiv(max_size, 4), min_size*2)
  206. #else:
  207. # seg_max = max_size
  208. seg_max = max_size
  209. seg_min = max(triton.cdiv(seg_max, 4), 4)
  210. # split reduction into segments
  211. div = sizes // seg_max
  212. rem = sizes % seg_max
  213. packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
  214. width = packs.sum()
  215. segments = torch.empty(width, dtype=sizes.dtype)
  216. column = torch.empty_like(segments)
  217. lockid = torch.zeros_like(segments)
  218. maxid = torch.zeros_like(segments)
  219. nlocks = 0
  220. current = 0
  221. col_idx = 0
  222. for i in range(len(sizes)):
  223. d, r = div[i], rem[i]
  224. isempty = sizes[i] < seg_min
  225. last = current + d + (r >= seg_min) + isempty
  226. # column id
  227. column[current:last] = col_idx
  228. # lock id
  229. if d > 1 or (d == 1 and r >= seg_min):
  230. nlocks += 1
  231. lockid[current:last] = nlocks
  232. maxid[current:last] = last - current
  233. # segment size
  234. segments[current:current + d] = seg_max
  235. if r < seg_min and not isempty:
  236. segments[current + d - 1] += r
  237. if r >= seg_min or isempty:
  238. segments[current + d] = r
  239. current = last
  240. col_idx += 1
  241. offsets = torch.zeros_like(segments)
  242. offsets[1:] = torch.cumsum(segments[:-1], dim=0)
  243. return segments, column, lockid, maxid, offsets
  244. @staticmethod
  245. def get_locks(size, dev):
  246. if dev not in _sparse_matmul.locks or \
  247. size > _sparse_matmul.locks[dev].size(0):
  248. _sparse_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
  249. return _sparse_matmul.locks[dev]
  250. ##########################
  251. # SPARSE = DENSE x DENSE #
  252. ##########################
  253. @staticmethod
  254. def make_sdd_lut(layout, block, dtype, device):
  255. #_sparse_matmul._load_utils()
  256. #start_width = 64 // block
  257. #segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width)
  258. start_width = (128 if block > 16 else 32) // block
  259. layout = layout.type(torch.int32)
  260. segmented = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2],
  261. start_width)
  262. luts, widths, packs = [], [], []
  263. for size, nnz in segmented:
  264. """ width = nnz.shape[0] // (size * size)
  265. h = nnz[:, 0]
  266. i = nnz[:, 1]
  267. j = nnz[:, 2]
  268. b = nnz[:, 3]
  269. lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
  270. luts.append(lut.type(torch.int32).to(device))
  271. widths.append(width)
  272. packs.append(size) """
  273. nnz = nnz.reshape(-1, 4)
  274. width = nnz.shape[0] // (size * size)
  275. luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
  276. widths.append(width)
  277. packs.append(size)
  278. # create locks
  279. return luts, None, widths, packs
  280. @staticmethod
  281. def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs, bench, time):
  282. if trans_c:
  283. a, b = b, a
  284. trans_a, trans_b = not trans_b, not trans_a
  285. AS0 = a.size(0)
  286. # Shape check
  287. a_dim = -2 if trans_a else -1
  288. b_dim = -1 if trans_b else -2
  289. a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]
  290. if a_inner != b_inner:
  291. raise ValueError(f"Size of tensor A along the {a_dim} dim ({a_inner}) must match size "
  292. f"of tensor B along the {b_dim} dim ({b_inner})")
  293. if a_inner % 16 != 0:
  294. raise ValueError('Reduction size for SDD must be a multiple of 16')
  295. batch_size = a.size(0)
  296. a_outer = a.size(3 if trans_a else 2)
  297. dtype = a.dtype
  298. is_16_multiple = a_inner % 16 == 0
  299. is_32_multiple = a_inner % 32 == 0
  300. is_64_multiple = a_inner % 64 == 0
  301. if not is_16_multiple:
  302. raise ValueError('Reduction size for SDD must be a multiple of 16')
  303. device = a.device
  304. # create kernel
  305. total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
  306. c = torch.empty((batch_size, total_width, block, block), dtype=dtype, device=a.device)
  307. for lut, width, pack in zip(luts, widths, packs):
  308. F32TK = [8, 16]
  309. F16TK = [16]
  310. F16TK += [32] if is_32_multiple else []
  311. F16TK += [64] if is_64_multiple else []
  312. TK = {torch.float32: F32TK, torch.float16: F16TK}[dtype]
  313. num_lock = 1
  314. meta = {
  315. 'TM': block * pack,
  316. 'TN': block * pack,
  317. 'BLOCK': block,
  318. 'TK': TK[0],
  319. 'TZ': 1,
  320. 'SDD': True,
  321. 'DSD': False,
  322. 'DDS': False
  323. }
  324. # create output
  325. locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock, a.device)
  326. # maximum grid size is 65535
  327. # so operation might be decomposed into multiple
  328. # kernel calls
  329. max_width = 49152
  330. total = 0 if bench else None
  331. for off_width in range(0, width, max_width):
  332. grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]
  333. _kernel[grid](a,
  334. b,
  335. c,
  336. a.stride(0),
  337. a.stride(1),
  338. a.stride(3 if trans_a else 2),
  339. a.stride(2 if trans_a else 3),
  340. b.stride(0),
  341. b.stride(1),
  342. b.stride(3 if trans_b else 2),
  343. b.stride(2 if trans_b else 3),
  344. c.stride(0),
  345. c.stride(0),
  346. c.stride(2),
  347. c.stride(3),
  348. a_outer,
  349. a_outer,
  350. a_inner,
  351. off_width,
  352. lut,
  353. locks,
  354. num_lock,
  355. num_warps=4,
  356. **meta)
  357. # save for backward pass
  358. return c
  359. ##########################
  360. # DENSE = DENSE x SPARSE #
  361. ##########################
  362. # Given a binary layout of 0s and 1s,
  363. # Construct look-up table for efficient execution on GPUs
  364. @staticmethod
  365. def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
  366. # load-balancing
  367. _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
  368. segments = _empty.clone()
  369. column = _empty.clone()
  370. depth = _empty.clone()
  371. lockid = _empty.clone()
  372. maxid = _empty.clone()
  373. offsets = _empty.clone()
  374. current_offset = 0
  375. current_maxid = 0
  376. for z in range(layout.size(0)):
  377. if trans:
  378. sizes = torch.sum(layout[z, :, :], 1)
  379. else:
  380. sizes = torch.sum(layout[z, :, :], 0)
  381. z_segments, z_column, z_lockid, z_maxid, z_offsets = _sparse_matmul.load_balance(sizes, block)
  382. z_depth = z * torch.ones_like(z_segments)
  383. z_lockid[z_lockid > 0] += current_maxid
  384. current_maxid = z_lockid.max()
  385. # concatenate depth
  386. segments = torch.cat((segments, z_segments))
  387. column = torch.cat((column, z_column))
  388. depth = torch.cat((depth, z_depth))
  389. maxid = torch.cat((maxid, z_maxid))
  390. offsets = torch.cat((offsets, current_offset + z_offsets))
  391. lockid = torch.cat((lockid, z_lockid))
  392. current_offset += layout[z, :, :].sum()
  393. segments *= step
  394. # pointer increments
  395. if trans:
  396. nnz = layout.nonzero()
  397. else:
  398. nnz = layout.transpose(1, 2).nonzero()
  399. num_blocks = nnz.size(0)
  400. offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
  401. idx = transform(nnz[:, 2] * block)
  402. xincs = idx.clone()
  403. xincs[1:] -= idx[:-1]
  404. # divide block into multiple steps
  405. div = block // step
  406. xincs = xincs.view(-1, 1).repeat(1, div)
  407. xincs[:, 1:] = step
  408. xincs[:, 0] -= (div - 1) * step
  409. # first increment for each reduction is actually the offset
  410. xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
  411. xincs = xincs.view(-1)
  412. # block-mode input increments
  413. if trans:
  414. widx = torch.arange(num_blocks)
  415. else:
  416. widx = _empty.clone()
  417. current_offset = 0
  418. for z in range(layout.size(0)):
  419. layoutw = layout[z, :, :].clone()
  420. msum = layoutw.sum()
  421. layoutw[layoutw > 0] = 1 + torch.arange(msum)
  422. widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
  423. current_offset += msum
  424. widx = widx
  425. wincs = widx * block * block
  426. wincs[1:] -= widx[:-1] * block * block
  427. wincs = wincs.view(-1, 1).repeat(1, div)
  428. if trans:
  429. wincs[:, 1:] = step
  430. wincs[:, 0] -= (div - 1) * step
  431. else:
  432. wincs[:, 1:] = step * block
  433. wincs[:, 0] -= (div - 1) * step * block
  434. wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
  435. wincs = wincs.view(-1)
  436. # adjust offset and segment size
  437. offsets *= 2 * div
  438. segments *= div
  439. # create header
  440. width = column.size(0)
  441. offsets += 6 * width
  442. header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
  443. incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
  444. incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
  445. # create lut
  446. lut = torch.cat((header, incs))
  447. lut = lut.type(torch.int32).to(device)
  448. # create locks
  449. num_locks = max(1, lockid.max())
  450. return lut, num_locks, width, None
  451. @staticmethod
  452. def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs, bench, time):
  453. global triton
  454. if triton is None:
  455. triton = importlib.import_module('triton')
  456. # shapes / dtypes
  457. AS0 = a.size(0)
  458. AS1 = a.size(1)
  459. AS2 = a.size(3 if trans_a else 2)
  460. AS3 = a.size(2 if trans_a else 3)
  461. BS0 = spdims[0]
  462. BS1 = block * spdims[2 if trans_b else 1]
  463. BS2 = block * spdims[1 if trans_b else 2]
  464. dtype = a.dtype
  465. # kernel
  466. meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1, 'SDD': False, 'DSD': False, 'DDS': True}
  467. # output
  468. CS0 = AS0
  469. CS1 = AS1
  470. CS2 = BS2 if trans_c else AS2
  471. CS3 = AS2 if trans_c else BS2
  472. locks = _sparse_matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
  473. c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
  474. grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0]
  475. _kernel[grid](a,
  476. b,
  477. c,
  478. a.stride(0),
  479. a.stride(1),
  480. a.stride(3 if trans_a else 2),
  481. a.stride(2 if trans_a else 3),
  482. b.stride(0),
  483. b.stride(1),
  484. b.stride(3 if trans_b else 2),
  485. b.stride(2 if trans_b else 3),
  486. c.stride(0),
  487. c.stride(1),
  488. c.stride(3 if trans_c else 2),
  489. c.stride(2 if trans_c else 3),
  490. AS2,
  491. BS2,
  492. 0,
  493. 0,
  494. lut,
  495. locks,
  496. num_locks,
  497. num_warps=4,
  498. **meta)
  499. return c
  500. @staticmethod
  501. def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs, bench, time):
  502. global triton
  503. if triton is None:
  504. triton = importlib.import_module('triton')
  505. # shapes / dtypes
  506. AS0 = spdims[0]
  507. AS1 = block * spdims[2 if trans_a else 1]
  508. AS2 = block * spdims[1 if trans_a else 2]
  509. BS0 = b.size(0)
  510. BS1 = b.size(1)
  511. BS2 = b.size(3 if trans_b else 2)
  512. BS3 = b.size(2 if trans_b else 3)
  513. dtype = a.dtype
  514. # kernel
  515. meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1, 'SDD': False, 'DSD': True, 'DDS': False}
  516. # output
  517. CS0 = BS0
  518. CS1 = BS1
  519. CS2 = BS3 if trans_c else AS1
  520. CS3 = AS1 if trans_c else BS3
  521. locks = _sparse_matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
  522. c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
  523. grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]
  524. _kernel[grid](a,
  525. b,
  526. c,
  527. a.stride(0),
  528. a.stride(1),
  529. a.stride(3 if trans_a else 2),
  530. a.stride(2 if trans_a else 3),
  531. b.stride(0),
  532. b.stride(1),
  533. b.stride(3 if trans_b else 2),
  534. b.stride(2 if trans_b else 3),
  535. c.stride(0),
  536. c.stride(1),
  537. c.stride(2),
  538. c.stride(3),
  539. BS3,
  540. AS1,
  541. 0,
  542. 0,
  543. lut,
  544. locks,
  545. num_locks,
  546. num_warps=4,
  547. **meta)
  548. return c
  549. fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
  550. @staticmethod
  551. def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs,
  552. c_bench, c_time, da_lut, da_num_locks, da_width, da_packs, da_bench, da_time, db_lut, db_num_locks,
  553. db_width, db_packs, db_bench, db_time):
  554. c = _sparse_matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width,
  555. c_packs, c_bench, c_time)
  556. # save for backward
  557. ctx.save_for_backward(a, b)
  558. ctx.da_num_locks = da_num_locks
  559. ctx.da_lut = da_lut
  560. ctx.da_width = da_width
  561. ctx.da_packs = da_packs
  562. ctx.da_bench = da_bench
  563. ctx.da_time = da_time
  564. ctx.db_lut = db_lut
  565. ctx.db_num_locks = db_num_locks
  566. ctx.db_width = db_width
  567. ctx.db_bench = db_bench
  568. ctx.db_packs = db_packs
  569. ctx.db_time = db_time
  570. ctx.mode = mode
  571. ctx.spdims = spdims
  572. ctx.block = block
  573. ctx.trans_a = trans_a
  574. ctx.trans_b = trans_b
  575. return c
  576. @staticmethod
  577. def backward(ctx, dc):
  578. # saved for backward
  579. a, b = ctx.saved_tensors
  580. mode = ctx.mode
  581. # gradients w.r.t. a
  582. if ctx.needs_input_grad[0]:
  583. mode_da = mode[1] + mode[0] + mode[2]
  584. da = _sparse_matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
  585. ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs, ctx.da_bench,
  586. ctx.da_time)
  587. # gradients w.r.t. b
  588. if ctx.needs_input_grad[1]:
  589. mode_db = mode[2] + mode[1] + mode[0]
  590. db = _sparse_matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
  591. ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs, ctx.db_bench,
  592. ctx.db_time)
  593. return da, db, None, None, None,\
  594. None, None, None, None,\
  595. None, None, None, None, None, None,\
  596. None, None, None, None, None, None,\
  597. None, None, None, None, None, None
  598. class MatMul:
  599. """Block-Sparse MatMul class; this class handles three types of matrix-multiplication:
  600. - sparse = dense X dense
  601. - dense = sparse X dense
  602. - dense = dense X sparse
  603. For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
  604. """
  605. def make_lut(self, dtype, device):
  606. """Generates the sparsity layout/s used in block-sparse matmul
  607. """
  608. key = (dtype, device)
  609. if key in self.lut_cache:
  610. return self.lut_cache[key]
  611. # C look-up table
  612. layout, block = self.layout, self.block
  613. step = 16
  614. if self.mode == 'sdd':
  615. c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
  616. elif self.mode == 'dsd':
  617. c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_a,
  618. device)
  619. elif self.mode == 'dds':
  620. c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_b,
  621. device)
  622. # DA look-up table
  623. if self.mode == 'sdd':
  624. da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, True, device)
  625. elif self.mode == 'dsd':
  626. da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
  627. elif self.mode == 'dds':
  628. da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step,
  629. not self.trans_b, device)
  630. # DB look-up table
  631. if self.mode == 'sdd':
  632. db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, False, device)
  633. elif self.mode == 'dsd':
  634. db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_a,
  635. device)
  636. elif self.mode == 'dds':
  637. db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
  638. self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
  639. da_lut, da_num_locks, da_width, da_packs,\
  640. db_lut, db_num_locks, db_width, db_packs)
  641. return self.lut_cache[key]
  642. def __init__(self, layout, block, mode, trans_a=False, trans_b=False, bench=False):
  643. """Initialize the Block-Sparse MatMul class.
  644. Arguments:
  645. layout: required: sparsity layout tensor
  646. block: required: an integer determining the block size.
  647. mode: required: a string determining type of matmul; ('sdd') sparse = dense X dense, ('dsd') dense = sparse X dense, ('dds') dense = dense X sparse
  648. trans_a: optional: a boolean determining if multiplication needs to be applied on transpose of input a; default is false
  649. trans_b: optional: a boolean determining if multiplication needs to be applied on transpose of input b; default is false
  650. bench: optional: set if you want to do benchmarking
  651. """
  652. if mode not in ['sdd', 'dsd', 'dds']:
  653. raise NotImplementedError('Supported modes are: sdd, dsd, dds')
  654. # look-up table cache
  655. self.lut_cache = dict()
  656. # attributes
  657. self.trans_a = trans_a
  658. self.trans_b = trans_b
  659. self.mode = mode
  660. self.block = block
  661. self.layout = layout
  662. layout_dim = layout.ndim
  663. assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
  664. if not mode == 'sdd':
  665. # Dims to be reduced on the 'inside' of the matmul, either -1 or -2
  666. trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b,
  667. -2)
  668. self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner
  669. sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)
  670. # Inner dim of the dense input should be equal to the inner dim of the sparse input
  671. self.dense_inner_size = layout.shape[sparse_inner] * block
  672. # Expected shape for sparse inputs
  673. self.sparse_shape = (layout.sum().item(), block, block)
  674. # Support using the same layout across attention heads etc.
  675. if layout_dim == 2:
  676. layout = layout.unsqueeze(0)
  677. layout = layout.long() # Above code assumes the layout tensor is an integral type
  678. self.spdims = layout.shape
  679. # timings
  680. self.bench = bench
  681. self.time_c = None
  682. self.time_da = None
  683. self.time_db = None
  684. # pad shapes of a tensor to make it
  685. # compatible with kernel calls
  686. @staticmethod
  687. def _pad_shape(x, is_sparse):
  688. max_dim = 3 if is_sparse else 4
  689. for i in range(max_dim - x.dim()):
  690. x = x.unsqueeze(0)
  691. return x
  692. def __call__(self, a, b):
  693. """Applies Block-Sparse MatMul.
  694. For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
  695. Arguments:
  696. a: required: a dense/block-sparse tensor; first input of mat-mul
  697. b: required: a dense/block-sparse tensor; second input of mat-mul
  698. Return:
  699. c: a dense/block-sparse tensor result of a X b
  700. """
  701. c_lut, c_num_locks, c_width, c_packs,\
  702. da_lut, da_num_locks, da_width, da_packs,\
  703. db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
  704. # timings
  705. time_c = [None]
  706. time_da = [None]
  707. time_db = [None]
  708. original_dims = max(a.ndim, b.ndim)
  709. a, b = self._validate_inputs(a, b)
  710. # pad shapes with ones
  711. a = MatMul._pad_shape(a, self.mode == 'dsd')
  712. b = MatMul._pad_shape(b, self.mode == 'dds')
  713. # execute
  714. c = _sparse_matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
  715. c_num_locks, c_width, c_packs, self.bench, time_c, da_lut, da_num_locks, da_width,
  716. da_packs, self.bench, time_da, db_lut, db_num_locks, db_width, db_packs, self.bench,
  717. time_db)
  718. # This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
  719. dims_to_trim = c.ndim - original_dims
  720. for _ in range(dims_to_trim):
  721. c = c.squeeze(0)
  722. self.time_c = time_c[0]
  723. self.time_da = time_da[0]
  724. self.time_db = time_db[0]
  725. return c
  726. def _validate_inputs(self, a, b):
  727. if a.device != b.device:
  728. raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A "
  729. f"and {b.device} for tensor B")
  730. if not get_accelerator().on_accelerator(a):
  731. raise ValueError("Only GPU devices are supported for now")
  732. # When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
  733. if torch.is_autocast_enabled():
  734. a, b = a.half(), b.half()
  735. elif a.dtype != b.dtype:
  736. raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B")
  737. mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
  738. if mode != 'sdd':
  739. # One input is sparse
  740. dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
  741. dense_inner = dense.shape[self.dense_inner_dim]
  742. if dense_inner != self.dense_inner_size:
  743. raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
  744. f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
  745. if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
  746. raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
  747. f"{sparse_name}, got {sparse.shape}")
  748. def add_extra_dims(x):
  749. # Add extra leading singleton dimensions if needed
  750. dims_needed = 4 - x.ndim
  751. if dims_needed > 0:
  752. singletons = [1] * dims_needed
  753. x = x.view(*singletons, *x.shape)
  754. elif dims_needed < 0:
  755. raise ValueError("Tensors with more than 4 dimensions are not currently supported")
  756. return x
  757. # Pad shapes with leading singleton dimensions
  758. a = add_extra_dims(a)
  759. b = add_extra_dims(b)
  760. return a, b