matmul.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. # DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
  2. # https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
  3. import importlib
  4. import warnings
  5. import torch
  6. import math
  7. import triton
  8. import triton.language as tl
  9. import triton._C.libtriton as libtriton
  10. @triton.jit
  11. def _kernel(A,
  12. B,
  13. C,
  14. stride_za,
  15. stride_ha,
  16. stride_ma,
  17. stride_ka,
  18. stride_zb,
  19. stride_hb,
  20. stride_kb,
  21. stride_nb,
  22. stride_zc,
  23. stride_hc,
  24. stride_mc,
  25. stride_nc,
  26. DS0,
  27. DS1,
  28. SDD_K,
  29. SDD_off_width,
  30. lut,
  31. locks,
  32. nlocks,
  33. **meta):
  34. TM = meta['TM']
  35. TN = meta['TN']
  36. TK = meta['TK']
  37. TZ = meta['TZ']
  38. BLOCK = meta['BLOCK']
  39. #------------#
  40. #- Prologue -#
  41. #------------#
  42. pid0 = tl.program_id(0)
  43. pid1 = tl.program_id(1)
  44. pidz = tl.program_id(2)
  45. if meta['SDD']:
  46. pid1 = pid1 + SDD_off_width
  47. blockidm = tl.arange(0, TM) // BLOCK
  48. blockidn = tl.arange(0, TN) // BLOCK
  49. offlutm = blockidm * (TN // BLOCK) * 4
  50. offlutn = blockidn * 4
  51. header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
  52. z = tl.load(header + 0)
  53. i = tl.load(header + 1 + offlutm)
  54. j = tl.load(header + 2 + offlutn)
  55. AS1 = SDD_K // TZ
  56. lockid = tl.where(TZ > 1, 1, 0)
  57. offka = pid0 * AS1
  58. offkb = pid0 * AS1
  59. offmc = 0
  60. offnc = 0
  61. offpa = 0
  62. offpb = 0
  63. maxid = TZ
  64. offhc = 0
  65. offha = z
  66. offhb = z
  67. ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)
  68. rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
  69. else:
  70. header = lut + pid0 * 6
  71. offset = tl.load(header + 0)
  72. AS1 = tl.load(header + 1)
  73. column = tl.load(header + 2)
  74. depth = tl.load(header + 3)
  75. lockid = tl.load(header + 4)
  76. maxid = tl.load(header + 5)
  77. pinc = lut + offset
  78. offhc = depth
  79. if meta['DSD']:
  80. # output offset
  81. offnc = pid1 * TN
  82. offmc = column * TM
  83. offpc = 0
  84. # dense input offset
  85. offnb = pid1 * TN
  86. offkb = tl.load(pinc)
  87. offkb = tl.multiple_of(offkb, 8) # compiler hint
  88. offpb = 0
  89. # sparse input offset
  90. offma = 0
  91. offka = 0
  92. offpa = tl.load(pinc + 1)
  93. offpa = tl.multiple_of(offpa, 8) # compiler hint
  94. offpa = offpa * BLOCK * BLOCK
  95. offha = 0
  96. offhb = depth
  97. else:
  98. # output offset
  99. offmc = pid1 * TM
  100. offnc = column * TN
  101. offpc = 0
  102. # dense input offset
  103. offma = pid1 * TM
  104. offka = tl.load(pinc)
  105. offka = tl.multiple_of(offka, 8) # compiler hint
  106. offpa = 0
  107. # sparse input offset
  108. offnb = 0
  109. offkb = 0
  110. offpb = tl.load(pinc + 1)
  111. offpb = tl.multiple_of(offpb, 8) # compiler hint
  112. offpb = offpb * BLOCK * BLOCK
  113. offha = depth
  114. offhb = 0
  115. ram = offma + tl.arange(0, TM)
  116. rbn = offnb + tl.arange(0, TN)
  117. # initialize a, b pointers
  118. rka = offka + tl.arange(0, TK)
  119. rkb = offkb + tl.arange(0, TK)
  120. pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
  121. pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
  122. if meta['DDS']:
  123. checkam = ram[:, None] < DS0
  124. else:
  125. checkam = AS1 > 0
  126. if meta['DSD']:
  127. checkbn = rbn[None, :] < DS0
  128. else:
  129. checkbn = AS1 > 0
  130. a = tl.load(pa, mask=checkam, other=0.)
  131. b = tl.load(pb, mask=checkbn, other=0.)
  132. ## ---------------- ##
  133. ## Inner Loop ##
  134. ## ---------------- ##
  135. acc = tl.zeros((TM, TN), dtype=tl.float32)
  136. for k in range(AS1, 0, -TK):
  137. acc += tl.dot(a, b)
  138. if meta['SDD']:
  139. inc_a = TK * stride_ka
  140. inc_b = TK * stride_kb
  141. else:
  142. pinc += 2
  143. if meta['DSD']:
  144. inc_b = tl.load(pinc)
  145. inc_a = tl.load(pinc + 1)
  146. inc_b = tl.multiple_of(inc_b, 8)
  147. inc_a = tl.multiple_of(inc_a, 8)
  148. inc_b = inc_b * stride_kb
  149. if meta['DDS']:
  150. inc_a = tl.load(pinc)
  151. inc_b = tl.load(pinc + 1)
  152. inc_a = tl.multiple_of(inc_a, 8)
  153. inc_b = tl.multiple_of(inc_b, 8)
  154. inc_a = inc_a * stride_ka
  155. pa += inc_a
  156. pb += inc_b
  157. # pre-fetch
  158. checkak = k > TK
  159. checkbk = k > TK
  160. checka = checkam & checkak
  161. checkb = checkbn & checkbk
  162. a = tl.load(pa, mask=checka)
  163. b = tl.load(pb, mask=checkb)
  164. c = acc.to(C.dtype.element_ty)
  165. if meta['SDD']:
  166. checkc = True
  167. rr_blockidm = tl.arange(0, TM) // BLOCK
  168. rr_blockidn = tl.arange(0, TN) // BLOCK
  169. rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
  170. rr_offlutn = rr_blockidn * 4
  171. off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
  172. bkid = tl.load(header + off_bkid)
  173. offpc = bkid * BLOCK * BLOCK
  174. rcm = tl.arange(0, TM) % BLOCK
  175. rcn = tl.arange(0, TN) % BLOCK
  176. else:
  177. rcm = offmc + tl.arange(0, TM)
  178. rcn = offnc + tl.arange(0, TN)
  179. if meta['DSD']:
  180. checkc = rcn[None, :] < DS0
  181. if meta['DDS']:
  182. checkc = rcm[:, None] < DS0
  183. pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
  184. # write-back directly
  185. if lockid == 0:
  186. tl.store(pc, c, mask=checkc)
  187. # accumulate partial results using spin-locks
  188. else:
  189. plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(
  190. 1) * nlocks + lockid - 1
  191. pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
  192. while tl.atomic_cas(plock, 0, 1) == 1:
  193. pass
  194. count = tl.load(pcount)
  195. if count == 0:
  196. tl.store(pc, c, mask=checkc)
  197. else:
  198. d = tl.load(pc, mask=checkc)
  199. tl.store(pc, d + c, mask=checkc)
  200. tl.atomic_xchg(pcount, (count + 1) % maxid)
  201. tl.atomic_xchg(plock, 0)
  202. ##############
  203. # MAIN API #
  204. ##############
  205. class _sparse_matmul(torch.autograd.Function):
  206. sdd_cache = dict()
  207. dsd_cache = dict()
  208. dds_cache = dict()
  209. locks = dict()
  210. # Given an array sizes representing reduction size for each
  211. # column of a block-mode matrix multiplication,
  212. # performs load-balancing to achieve more smaller reductions
  213. # between `seg_size` elements
  214. @staticmethod
  215. def load_balance(sizes, block):
  216. #global triton
  217. #if triton is None:
  218. # triton = importlib.import_module('triton')
  219. # segment size
  220. # heuristics taken from OpenAI blocksparse code
  221. # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
  222. max_size = sizes.max()
  223. min_size = sizes[sizes != 0].min()
  224. #if max_size > min_size * 2.0:
  225. # seg_max = max(triton.cdiv(max_size, 4), min_size*2)
  226. #else:
  227. # seg_max = max_size
  228. seg_max = max_size
  229. seg_min = max(triton.cdiv(seg_max, 4), 4)
  230. # split reduction into segments
  231. div = sizes // seg_max
  232. rem = sizes % seg_max
  233. packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
  234. width = packs.sum()
  235. segments = torch.empty(width, dtype=sizes.dtype)
  236. column = torch.empty_like(segments)
  237. lockid = torch.zeros_like(segments)
  238. maxid = torch.zeros_like(segments)
  239. nlocks = 0
  240. current = 0
  241. col_idx = 0
  242. for i in range(len(sizes)):
  243. d, r = div[i], rem[i]
  244. isempty = sizes[i] < seg_min
  245. last = current + d + (r >= seg_min) + isempty
  246. # column id
  247. column[current:last] = col_idx
  248. # lock id
  249. if d > 1 or (d == 1 and r >= seg_min):
  250. nlocks += 1
  251. lockid[current:last] = nlocks
  252. maxid[current:last] = last - current
  253. # segment size
  254. segments[current:current + d] = seg_max
  255. if r < seg_min and not isempty:
  256. segments[current + d - 1] += r
  257. if r >= seg_min or isempty:
  258. segments[current + d] = r
  259. current = last
  260. col_idx += 1
  261. offsets = torch.zeros_like(segments)
  262. offsets[1:] = torch.cumsum(segments[:-1], dim=0)
  263. return segments, column, lockid, maxid, offsets
  264. @staticmethod
  265. def get_locks(size, dev):
  266. if dev not in _sparse_matmul.locks or \
  267. size > _sparse_matmul.locks[dev].size(0):
  268. _sparse_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
  269. return _sparse_matmul.locks[dev]
  270. ##########################
  271. # SPARSE = DENSE x DENSE #
  272. ##########################
  273. @staticmethod
  274. def make_sdd_lut(layout, block, dtype, device):
  275. #_sparse_matmul._load_utils()
  276. #start_width = 64 // block
  277. #segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width)
  278. start_width = (128 if block > 16 else 32) // block
  279. layout = layout.type(torch.int32)
  280. segmented = libtriton.superblock(layout.data_ptr(),
  281. layout.shape[0],
  282. layout.shape[1],
  283. layout.shape[2],
  284. start_width)
  285. luts, widths, packs = [], [], []
  286. for size, nnz in segmented:
  287. """ width = nnz.shape[0] // (size * size)
  288. h = nnz[:, 0]
  289. i = nnz[:, 1]
  290. j = nnz[:, 2]
  291. b = nnz[:, 3]
  292. lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
  293. luts.append(lut.type(torch.int32).to(device))
  294. widths.append(width)
  295. packs.append(size) """
  296. nnz = nnz.reshape(-1, 4)
  297. width = nnz.shape[0] // (size * size)
  298. luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
  299. widths.append(width)
  300. packs.append(size)
  301. # create locks
  302. return luts, None, widths, packs
  303. @staticmethod
  304. def _sdd_matmul(a,
  305. b,
  306. trans_a,
  307. trans_b,
  308. trans_c,
  309. spdims,
  310. block,
  311. luts,
  312. num_locks,
  313. widths,
  314. packs,
  315. bench,
  316. time):
  317. if trans_c:
  318. a, b = b, a
  319. trans_a, trans_b = not trans_b, not trans_a
  320. AS0 = a.size(0)
  321. # Shape check
  322. a_dim = -2 if trans_a else -1
  323. b_dim = -1 if trans_b else -2
  324. a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]
  325. if a_inner != b_inner:
  326. raise ValueError(
  327. f"Size of tensor A along the {_dim_to_name(a_dim)} dim ({a_inner}) must match size "
  328. f"of tensor B along the {_dim_to_name(b_dim)} dim ({b_inner})")
  329. if a_inner % 16 != 0:
  330. raise ValueError('Reduction size for SDD must be a multiple of 16')
  331. batch_size = a.size(0)
  332. a_outer = a.size(3 if trans_a else 2)
  333. dtype = a.dtype
  334. is_16_multiple = a_inner % 16 == 0
  335. is_32_multiple = a_inner % 32 == 0
  336. is_64_multiple = a_inner % 64 == 0
  337. if not is_16_multiple:
  338. raise ValueError('Reduction size for SDD must be a multiple of 16')
  339. device = a.device
  340. # create kernel
  341. total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
  342. c = torch.empty((batch_size,
  343. total_width,
  344. block,
  345. block),
  346. dtype=dtype,
  347. device=a.device)
  348. for lut, width, pack in zip(luts, widths, packs):
  349. F32TK = [8, 16]
  350. F16TK = [16]
  351. F16TK += [32] if is_32_multiple else []
  352. F16TK += [64] if is_64_multiple else []
  353. TK = {torch.float32: F32TK, torch.float16: F16TK}[dtype]
  354. num_lock = 1
  355. meta = {
  356. 'TM': block * pack,
  357. 'TN': block * pack,
  358. 'BLOCK': block,
  359. 'TK': TK[0],
  360. 'TZ': 1,
  361. 'SDD': True,
  362. 'DSD': False,
  363. 'DDS': False
  364. }
  365. # create output
  366. locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock, a.device)
  367. # maximum grid size is 65535
  368. # so operation might be decomposed into multiple
  369. # kernel calls
  370. max_width = 49152
  371. total = 0 if bench else None
  372. for off_width in range(0, width, max_width):
  373. grid = lambda meta: [
  374. meta['TZ'],
  375. min(max_width,
  376. width - off_width),
  377. batch_size
  378. ]
  379. _kernel[grid](a,
  380. b,
  381. c,
  382. a.stride(0),
  383. a.stride(1),
  384. a.stride(3 if trans_a else 2),
  385. a.stride(2 if trans_a else 3),
  386. b.stride(0),
  387. b.stride(1),
  388. b.stride(3 if trans_b else 2),
  389. b.stride(2 if trans_b else 3),
  390. c.stride(0),
  391. c.stride(0),
  392. c.stride(2),
  393. c.stride(3),
  394. a_outer,
  395. a_outer,
  396. a_inner,
  397. off_width,
  398. lut,
  399. locks,
  400. num_lock,
  401. num_warps=4,
  402. **meta)
  403. # save for backward pass
  404. return c
  405. ##########################
  406. # DENSE = DENSE x SPARSE #
  407. ##########################
  408. # Given a binary layout of 0s and 1s,
  409. # Construct look-up table for efficient execution on GPUs
  410. @staticmethod
  411. def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
  412. # load-balancing
  413. _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
  414. segments = _empty.clone()
  415. column = _empty.clone()
  416. depth = _empty.clone()
  417. lockid = _empty.clone()
  418. maxid = _empty.clone()
  419. offsets = _empty.clone()
  420. current_offset = 0
  421. current_maxid = 0
  422. for z in range(layout.size(0)):
  423. if trans:
  424. sizes = torch.sum(layout[z, :, :], 1)
  425. else:
  426. sizes = torch.sum(layout[z, :, :], 0)
  427. z_segments, z_column, z_lockid, z_maxid, z_offsets = _sparse_matmul.load_balance(sizes, block)
  428. z_depth = z * torch.ones_like(z_segments)
  429. z_lockid[z_lockid > 0] += current_maxid
  430. current_maxid = z_lockid.max()
  431. # concatenate depth
  432. segments = torch.cat((segments, z_segments))
  433. column = torch.cat((column, z_column))
  434. depth = torch.cat((depth, z_depth))
  435. maxid = torch.cat((maxid, z_maxid))
  436. offsets = torch.cat((offsets, current_offset + z_offsets))
  437. lockid = torch.cat((lockid, z_lockid))
  438. current_offset += layout[z, :, :].sum()
  439. segments *= step
  440. # pointer increments
  441. if trans:
  442. nnz = layout.nonzero()
  443. else:
  444. nnz = layout.transpose(1, 2).nonzero()
  445. num_blocks = nnz.size(0)
  446. offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
  447. idx = transform(nnz[:, 2] * block)
  448. xincs = idx.clone()
  449. xincs[1:] -= idx[:-1]
  450. # divide block into multiple steps
  451. div = block // step
  452. xincs = xincs.view(-1, 1).repeat(1, div)
  453. xincs[:, 1:] = step
  454. xincs[:, 0] -= (div - 1) * step
  455. # first increment for each reduction is actually the offset
  456. xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
  457. xincs = xincs.view(-1)
  458. # block-mode input increments
  459. if trans:
  460. widx = torch.arange(num_blocks)
  461. else:
  462. widx = _empty.clone()
  463. current_offset = 0
  464. for z in range(layout.size(0)):
  465. layoutw = layout[z, :, :].clone()
  466. msum = layoutw.sum()
  467. layoutw[layoutw > 0] = 1 + torch.arange(msum)
  468. widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
  469. current_offset += msum
  470. widx = widx
  471. wincs = widx * block * block
  472. wincs[1:] -= widx[:-1] * block * block
  473. wincs = wincs.view(-1, 1).repeat(1, div)
  474. if trans:
  475. wincs[:, 1:] = step
  476. wincs[:, 0] -= (div - 1) * step
  477. else:
  478. wincs[:, 1:] = step * block
  479. wincs[:, 0] -= (div - 1) * step * block
  480. wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
  481. wincs = wincs.view(-1)
  482. # adjust offset and segment size
  483. offsets *= 2 * div
  484. segments *= div
  485. # create header
  486. width = column.size(0)
  487. offsets += 6 * width
  488. header = torch.stack((offsets,
  489. segments,
  490. column,
  491. depth,
  492. lockid,
  493. maxid),
  494. dim=1).view(-1).contiguous()
  495. incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
  496. incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
  497. # create lut
  498. lut = torch.cat((header, incs))
  499. lut = lut.type(torch.int32).to(device)
  500. # create locks
  501. num_locks = max(1, lockid.max())
  502. return lut, num_locks, width, None
  503. @staticmethod
  504. def _dds_matmul(a,
  505. b,
  506. trans_a,
  507. trans_b,
  508. trans_c,
  509. spdims,
  510. block,
  511. lut,
  512. num_locks,
  513. width,
  514. packs,
  515. bench,
  516. time):
  517. global triton
  518. if triton is None:
  519. triton = importlib.import_module('triton')
  520. # shapes / dtypes
  521. AS0 = a.size(0)
  522. AS1 = a.size(1)
  523. AS2 = a.size(3 if trans_a else 2)
  524. AS3 = a.size(2 if trans_a else 3)
  525. BS0 = spdims[0]
  526. BS1 = block * spdims[2 if trans_b else 1]
  527. BS2 = block * spdims[1 if trans_b else 2]
  528. dtype = a.dtype
  529. # kernel
  530. meta = {
  531. 'TN': block,
  532. 'TM': 128,
  533. 'TK': 16,
  534. 'BLOCK': block,
  535. 'TZ': 1,
  536. 'SDD': False,
  537. 'DSD': False,
  538. 'DDS': True
  539. }
  540. # output
  541. CS0 = AS0
  542. CS1 = AS1
  543. CS2 = BS2 if trans_c else AS2
  544. CS3 = AS2 if trans_c else BS2
  545. locks = _sparse_matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
  546. c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
  547. grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0]
  548. _kernel[grid](a,
  549. b,
  550. c,
  551. a.stride(0),
  552. a.stride(1),
  553. a.stride(3 if trans_a else 2),
  554. a.stride(2 if trans_a else 3),
  555. b.stride(0),
  556. b.stride(1),
  557. b.stride(3 if trans_b else 2),
  558. b.stride(2 if trans_b else 3),
  559. c.stride(0),
  560. c.stride(1),
  561. c.stride(3 if trans_c else 2),
  562. c.stride(2 if trans_c else 3),
  563. AS2,
  564. BS2,
  565. 0,
  566. 0,
  567. lut,
  568. locks,
  569. num_locks,
  570. num_warps=4,
  571. **meta)
  572. return c
  573. @staticmethod
  574. def _dsd_matmul(a,
  575. b,
  576. trans_a,
  577. trans_b,
  578. trans_c,
  579. spdims,
  580. block,
  581. lut,
  582. num_locks,
  583. width,
  584. packs,
  585. bench,
  586. time):
  587. global triton
  588. if triton is None:
  589. triton = importlib.import_module('triton')
  590. # shapes / dtypes
  591. AS0 = spdims[0]
  592. AS1 = block * spdims[2 if trans_a else 1]
  593. AS2 = block * spdims[1 if trans_a else 2]
  594. BS0 = b.size(0)
  595. BS1 = b.size(1)
  596. BS2 = b.size(3 if trans_b else 2)
  597. BS3 = b.size(2 if trans_b else 3)
  598. dtype = a.dtype
  599. # kernel
  600. meta = {
  601. 'TM': block,
  602. 'TN': 128,
  603. 'TK': 16,
  604. 'BLOCK': block,
  605. 'TZ': 1,
  606. 'SDD': False,
  607. 'DSD': True,
  608. 'DDS': False
  609. }
  610. # output
  611. CS0 = BS0
  612. CS1 = BS1
  613. CS2 = BS3 if trans_c else AS1
  614. CS3 = AS1 if trans_c else BS3
  615. locks = _sparse_matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
  616. c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
  617. grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]
  618. _kernel[grid](a,
  619. b,
  620. c,
  621. a.stride(0),
  622. a.stride(1),
  623. a.stride(3 if trans_a else 2),
  624. a.stride(2 if trans_a else 3),
  625. b.stride(0),
  626. b.stride(1),
  627. b.stride(3 if trans_b else 2),
  628. b.stride(2 if trans_b else 3),
  629. c.stride(0),
  630. c.stride(1),
  631. c.stride(2),
  632. c.stride(3),
  633. BS3,
  634. AS1,
  635. 0,
  636. 0,
  637. lut,
  638. locks,
  639. num_locks,
  640. num_warps=4,
  641. **meta)
  642. return c
  643. fn = {
  644. 'sdd': _sdd_matmul.__get__(object),
  645. 'dsd': _dsd_matmul.__get__(object),
  646. 'dds': _dds_matmul.__get__(object)
  647. }
  648. @staticmethod
  649. def forward(ctx,
  650. a,
  651. b,
  652. trans_a,
  653. trans_b,
  654. trans_c,
  655. mode,
  656. spdims,
  657. block,
  658. c_lut,
  659. c_num_locks,
  660. c_width,
  661. c_packs,
  662. c_bench,
  663. c_time,
  664. da_lut,
  665. da_num_locks,
  666. da_width,
  667. da_packs,
  668. da_bench,
  669. da_time,
  670. db_lut,
  671. db_num_locks,
  672. db_width,
  673. db_packs,
  674. db_bench,
  675. db_time):
  676. c = _sparse_matmul.fn[mode](a,
  677. b,
  678. trans_a,
  679. trans_b,
  680. trans_c,
  681. spdims,
  682. block,
  683. c_lut,
  684. c_num_locks,
  685. c_width,
  686. c_packs,
  687. c_bench,
  688. c_time)
  689. # save for backward
  690. ctx.save_for_backward(a, b)
  691. ctx.da_num_locks = da_num_locks
  692. ctx.da_lut = da_lut
  693. ctx.da_width = da_width
  694. ctx.da_packs = da_packs
  695. ctx.da_bench = da_bench
  696. ctx.da_time = da_time
  697. ctx.db_lut = db_lut
  698. ctx.db_num_locks = db_num_locks
  699. ctx.db_width = db_width
  700. ctx.db_bench = db_bench
  701. ctx.db_packs = db_packs
  702. ctx.db_time = db_time
  703. ctx.mode = mode
  704. ctx.spdims = spdims
  705. ctx.block = block
  706. ctx.trans_a = trans_a
  707. ctx.trans_b = trans_b
  708. return c
  709. @staticmethod
  710. def backward(ctx, dc):
  711. # saved for backward
  712. a, b = ctx.saved_tensors
  713. mode = ctx.mode
  714. # gradients w.r.t. a
  715. if ctx.needs_input_grad[0]:
  716. mode_da = mode[1] + mode[0] + mode[2]
  717. da = _sparse_matmul.fn[mode_da](dc,
  718. b,
  719. False,
  720. not ctx.trans_b,
  721. ctx.trans_a,
  722. ctx.spdims,
  723. ctx.block,
  724. ctx.da_lut,
  725. ctx.da_num_locks,
  726. ctx.da_width,
  727. ctx.da_packs,
  728. ctx.da_bench,
  729. ctx.da_time)
  730. # gradients w.r.t. b
  731. if ctx.needs_input_grad[1]:
  732. mode_db = mode[2] + mode[1] + mode[0]
  733. db = _sparse_matmul.fn[mode_db](a,
  734. dc,
  735. not ctx.trans_a,
  736. False,
  737. ctx.trans_b,
  738. ctx.spdims,
  739. ctx.block,
  740. ctx.db_lut,
  741. ctx.db_num_locks,
  742. ctx.db_width,
  743. ctx.db_packs,
  744. ctx.db_bench,
  745. ctx.db_time)
  746. return da, db, None, None, None,\
  747. None, None, None, None,\
  748. None, None, None, None, None, None,\
  749. None, None, None, None, None, None,\
  750. None, None, None, None, None, None
  751. class MatMul:
  752. """Block-Sparse MatMul class; this class handles three types of matrix-multiplication:
  753. - sparse = dense X dense
  754. - dense = sparse X dense
  755. - dense = dense X sparse
  756. For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
  757. """
  758. def make_lut(self, dtype, device):
  759. """Generates the sparsity layout/s used in block-sparse matmul
  760. """
  761. key = (dtype, device)
  762. if key in self.lut_cache:
  763. return self.lut_cache[key]
  764. # C look-up table
  765. layout, block = self.layout, self.block
  766. step = 16
  767. if self.mode == 'sdd':
  768. c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
  769. elif self.mode == 'dsd':
  770. c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
  771. elif self.mode == 'dds':
  772. c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
  773. # DA look-up table
  774. if self.mode == 'sdd':
  775. da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, True, device)
  776. elif self.mode == 'dsd':
  777. da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
  778. elif self.mode == 'dds':
  779. da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
  780. # DB look-up table
  781. if self.mode == 'sdd':
  782. db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, False, device)
  783. elif self.mode == 'dsd':
  784. db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
  785. elif self.mode == 'dds':
  786. db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
  787. self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
  788. da_lut, da_num_locks, da_width, da_packs,\
  789. db_lut, db_num_locks, db_width, db_packs)
  790. return self.lut_cache[key]
  791. def __init__(self, layout, block, mode, trans_a=False, trans_b=False, bench=False):
  792. """Initialize the Block-Sparse MatMul class.
  793. Arguments:
  794. layout: required: sparsity layout tensor
  795. block: required: an integer determining the block size.
  796. mode: required: a string determining type of matmul; ('sdd') sparse = dense X dense, ('dsd') dense = sparse X dense, ('dds') dense = dense X sparse
  797. trans_a: optional: a boolean determining if multiplication needs to be applied on transpose of input a; default is false
  798. trans_b: optional: a boolean determining if multiplication needs to be applied on transpose of input b; default is false
  799. bench: optional: set if you want to do benchmarking
  800. """
  801. if mode not in ['sdd', 'dsd', 'dds']:
  802. raise NotImplementedError('Supported modes are: sdd, dsd, dds')
  803. # look-up table cache
  804. self.lut_cache = dict()
  805. # attributes
  806. self.trans_a = trans_a
  807. self.trans_b = trans_b
  808. self.mode = mode
  809. self.block = block
  810. self.layout = layout
  811. layout_dim = layout.ndim
  812. assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
  813. if not mode == 'sdd':
  814. # Dims to be reduced on the 'inside' of the matmul, either -1 or -2
  815. trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2)
  816. self.dense_inner_dim = -(
  817. (sparse_inner % 2) + 1) if not trans_dense else sparse_inner
  818. sparse_inner = sparse_inner if not trans_sparse else -(
  819. (sparse_inner % 2) + 1)
  820. # Inner dim of the dense input should be equal to the inner dim of the sparse input
  821. self.dense_inner_size = layout.shape[sparse_inner] * block
  822. # Expected shape for sparse inputs
  823. self.sparse_shape = (layout.sum().item(), block, block)
  824. # Support using the same layout across attention heads etc.
  825. if layout_dim == 2:
  826. layout = layout.unsqueeze(0)
  827. layout = layout.long(
  828. ) # Above code assumes the layout tensor is an integral type
  829. self.spdims = layout.shape
  830. # timings
  831. self.bench = bench
  832. self.time_c = None
  833. self.time_da = None
  834. self.time_db = None
  835. # pad shapes of a tensor to make it
  836. # compatible with kernel calls
  837. @staticmethod
  838. def _pad_shape(x, is_sparse):
  839. max_dim = 3 if is_sparse else 4
  840. for i in range(max_dim - x.dim()):
  841. x = x.unsqueeze(0)
  842. return x
  843. def __call__(self, a, b):
  844. """Applies Block-Sparse MatMul.
  845. For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
  846. Arguments:
  847. a: required: a dense/block-sparse tensor; first input of mat-mul
  848. b: required: a dense/block-sparse tensor; second input of mat-mul
  849. Return:
  850. c: a dense/block-sparse tensor result of a X b
  851. """
  852. c_lut, c_num_locks, c_width, c_packs,\
  853. da_lut, da_num_locks, da_width, da_packs,\
  854. db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
  855. # timings
  856. time_c = [None]
  857. time_da = [None]
  858. time_db = [None]
  859. original_dims = max(a.ndim, b.ndim)
  860. a, b = self._validate_inputs(a, b)
  861. # pad shapes with ones
  862. a = MatMul._pad_shape(a, self.mode == 'dsd')
  863. b = MatMul._pad_shape(b, self.mode == 'dds')
  864. # execute
  865. c = _sparse_matmul.apply(a,
  866. b,
  867. self.trans_a,
  868. self.trans_b,
  869. False,
  870. self.mode,
  871. self.spdims,
  872. self.block,
  873. c_lut,
  874. c_num_locks,
  875. c_width,
  876. c_packs,
  877. self.bench,
  878. time_c,
  879. da_lut,
  880. da_num_locks,
  881. da_width,
  882. da_packs,
  883. self.bench,
  884. time_da,
  885. db_lut,
  886. db_num_locks,
  887. db_width,
  888. db_packs,
  889. self.bench,
  890. time_db)
  891. # This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
  892. dims_to_trim = c.ndim - original_dims
  893. for _ in range(dims_to_trim):
  894. c = c.squeeze(0)
  895. self.time_c = time_c[0]
  896. self.time_da = time_da[0]
  897. self.time_db = time_db[0]
  898. return c
  899. def _validate_inputs(self, a, b):
  900. if a.device != b.device:
  901. raise ValueError(
  902. f"Inputs must be on the same device; got {a.device} for tensor A "
  903. f"and {b.device} for tensor B")
  904. if not a.is_cuda:
  905. raise ValueError("Only GPU devices are supported for now")
  906. # When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
  907. if torch.is_autocast_enabled():
  908. a, b = a.half(), b.half()
  909. elif a.dtype != b.dtype:
  910. raise ValueError(
  911. f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B"
  912. )
  913. mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
  914. if mode != 'sdd':
  915. # One input is sparse
  916. dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
  917. dense_inner = dense.shape[self.dense_inner_dim]
  918. if dense_inner != self.dense_inner_size:
  919. raise ValueError(
  920. f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
  921. f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
  922. if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
  923. raise ValueError(
  924. f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
  925. f"{sparse_name}, got {sparse.shape}")
  926. def add_extra_dims(x):
  927. # Add extra leading singleton dimensions if needed
  928. dims_needed = 4 - x.ndim
  929. if dims_needed > 0:
  930. singletons = [1] * dims_needed
  931. x = x.view(*singletons, *x.shape)
  932. elif dims_needed < 0:
  933. raise ValueError(
  934. "Tensors with more than 4 dimensions are not currently supported")
  935. return x
  936. # Pad shapes with leading singleton dimensions
  937. a = add_extra_dims(a)
  938. b = add_extra_dims(b)
  939. return a, b