123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996 |
- '''Copyright The Microsoft DeepSpeed Team'''
- # DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
- # https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
- import importlib
- import torch
- import triton
- import triton.language as tl
- import triton._C.libtriton as libtriton
- from deepspeed.accelerator import get_accelerator
- @triton.jit
- def _kernel(A,
- B,
- C,
- stride_za,
- stride_ha,
- stride_ma,
- stride_ka,
- stride_zb,
- stride_hb,
- stride_kb,
- stride_nb,
- stride_zc,
- stride_hc,
- stride_mc,
- stride_nc,
- DS0,
- DS1,
- SDD_K,
- SDD_off_width,
- lut,
- locks,
- nlocks,
- **meta):
- TM = meta['TM']
- TN = meta['TN']
- TK = meta['TK']
- TZ = meta['TZ']
- BLOCK = meta['BLOCK']
- #------------#
- #- Prologue -#
- #------------#
- pid0 = tl.program_id(0)
- pid1 = tl.program_id(1)
- pidz = tl.program_id(2)
- if meta['SDD']:
- pid1 = pid1 + SDD_off_width
- blockidm = tl.arange(0, TM) // BLOCK
- blockidn = tl.arange(0, TN) // BLOCK
- offlutm = blockidm * (TN // BLOCK) * 4
- offlutn = blockidn * 4
- header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
- z = tl.load(header + 0)
- i = tl.load(header + 1 + offlutm)
- j = tl.load(header + 2 + offlutn)
- AS1 = SDD_K // TZ
- lockid = tl.where(TZ > 1, 1, 0)
- offka = pid0 * AS1
- offkb = pid0 * AS1
- offmc = 0
- offnc = 0
- offpa = 0
- offpb = 0
- maxid = TZ
- offhc = 0
- offha = z
- offhb = z
- ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)
- rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
- else:
- header = lut + pid0 * 6
- offset = tl.load(header + 0)
- AS1 = tl.load(header + 1)
- column = tl.load(header + 2)
- depth = tl.load(header + 3)
- lockid = tl.load(header + 4)
- maxid = tl.load(header + 5)
- pinc = lut + offset
- offhc = depth
- if meta['DSD']:
- # output offset
- offnc = pid1 * TN
- offmc = column * TM
- offpc = 0
- # dense input offset
- offnb = pid1 * TN
- offkb = tl.load(pinc)
- offkb = tl.multiple_of(offkb, 8) # compiler hint
- offpb = 0
- # sparse input offset
- offma = 0
- offka = 0
- offpa = tl.load(pinc + 1)
- offpa = tl.multiple_of(offpa, 8) # compiler hint
- offpa = offpa * BLOCK * BLOCK
- offha = 0
- offhb = depth
- else:
- # output offset
- offmc = pid1 * TM
- offnc = column * TN
- offpc = 0
- # dense input offset
- offma = pid1 * TM
- offka = tl.load(pinc)
- offka = tl.multiple_of(offka, 8) # compiler hint
- offpa = 0
- # sparse input offset
- offnb = 0
- offkb = 0
- offpb = tl.load(pinc + 1)
- offpb = tl.multiple_of(offpb, 8) # compiler hint
- offpb = offpb * BLOCK * BLOCK
- offha = depth
- offhb = 0
- ram = offma + tl.arange(0, TM)
- rbn = offnb + tl.arange(0, TN)
- # initialize a, b pointers
- rka = offka + tl.arange(0, TK)
- rkb = offkb + tl.arange(0, TK)
- pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
- pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
- if meta['DDS']:
- checkam = ram[:, None] < DS0
- else:
- checkam = AS1 > 0
- if meta['DSD']:
- checkbn = rbn[None, :] < DS0
- else:
- checkbn = AS1 > 0
- a = tl.load(pa, mask=checkam, other=0.)
- b = tl.load(pb, mask=checkbn, other=0.)
- ## ---------------- ##
- ## Inner Loop ##
- ## ---------------- ##
- acc = tl.zeros((TM, TN), dtype=tl.float32)
- for k in range(AS1, 0, -TK):
- acc += tl.dot(a, b)
- if meta['SDD']:
- inc_a = TK * stride_ka
- inc_b = TK * stride_kb
- else:
- pinc += 2
- if meta['DSD']:
- inc_b = tl.load(pinc)
- inc_a = tl.load(pinc + 1)
- inc_b = tl.multiple_of(inc_b, 8)
- inc_a = tl.multiple_of(inc_a, 8)
- inc_b = inc_b * stride_kb
- if meta['DDS']:
- inc_a = tl.load(pinc)
- inc_b = tl.load(pinc + 1)
- inc_a = tl.multiple_of(inc_a, 8)
- inc_b = tl.multiple_of(inc_b, 8)
- inc_a = inc_a * stride_ka
- pa += inc_a
- pb += inc_b
- # pre-fetch
- checkak = k > TK
- checkbk = k > TK
- checka = checkam & checkak
- checkb = checkbn & checkbk
- a = tl.load(pa, mask=checka)
- b = tl.load(pb, mask=checkb)
- c = acc.to(C.dtype.element_ty)
- if meta['SDD']:
- checkc = True
- rr_blockidm = tl.arange(0, TM) // BLOCK
- rr_blockidn = tl.arange(0, TN) // BLOCK
- rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
- rr_offlutn = rr_blockidn * 4
- off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
- bkid = tl.load(header + off_bkid)
- offpc = bkid * BLOCK * BLOCK
- rcm = tl.arange(0, TM) % BLOCK
- rcn = tl.arange(0, TN) % BLOCK
- else:
- rcm = offmc + tl.arange(0, TM)
- rcn = offnc + tl.arange(0, TN)
- if meta['DSD']:
- checkc = rcn[None, :] < DS0
- if meta['DDS']:
- checkc = rcm[:, None] < DS0
- pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
- # write-back directly
- if lockid == 0:
- tl.store(pc, c, mask=checkc)
- # accumulate partial results using spin-locks
- else:
- plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(
- 1) * nlocks + lockid - 1
- pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
- while tl.atomic_cas(plock, 0, 1) == 1:
- pass
- count = tl.load(pcount)
- if count == 0:
- tl.store(pc, c, mask=checkc)
- else:
- d = tl.load(pc, mask=checkc)
- tl.store(pc, d + c, mask=checkc)
- tl.atomic_xchg(pcount, (count + 1) % maxid)
- tl.atomic_xchg(plock, 0)
- ##############
- # MAIN API #
- ##############
- class _sparse_matmul(torch.autograd.Function):
- sdd_cache = dict()
- dsd_cache = dict()
- dds_cache = dict()
- locks = dict()
- # Given an array sizes representing reduction size for each
- # column of a block-mode matrix multiplication,
- # performs load-balancing to achieve more smaller reductions
- # between `seg_size` elements
- @staticmethod
- def load_balance(sizes, block):
- #global triton
- #if triton is None:
- # triton = importlib.import_module('triton')
- # segment size
- # heuristics taken from OpenAI blocksparse code
- # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
- max_size = sizes.max()
- min_size = sizes[sizes != 0].min()
- #if max_size > min_size * 2.0:
- # seg_max = max(triton.cdiv(max_size, 4), min_size*2)
- #else:
- # seg_max = max_size
- seg_max = max_size
- seg_min = max(triton.cdiv(seg_max, 4), 4)
- # split reduction into segments
- div = sizes // seg_max
- rem = sizes % seg_max
- packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
- width = packs.sum()
- segments = torch.empty(width, dtype=sizes.dtype)
- column = torch.empty_like(segments)
- lockid = torch.zeros_like(segments)
- maxid = torch.zeros_like(segments)
- nlocks = 0
- current = 0
- col_idx = 0
- for i in range(len(sizes)):
- d, r = div[i], rem[i]
- isempty = sizes[i] < seg_min
- last = current + d + (r >= seg_min) + isempty
- # column id
- column[current:last] = col_idx
- # lock id
- if d > 1 or (d == 1 and r >= seg_min):
- nlocks += 1
- lockid[current:last] = nlocks
- maxid[current:last] = last - current
- # segment size
- segments[current:current + d] = seg_max
- if r < seg_min and not isempty:
- segments[current + d - 1] += r
- if r >= seg_min or isempty:
- segments[current + d] = r
- current = last
- col_idx += 1
- offsets = torch.zeros_like(segments)
- offsets[1:] = torch.cumsum(segments[:-1], dim=0)
- return segments, column, lockid, maxid, offsets
- @staticmethod
- def get_locks(size, dev):
- if dev not in _sparse_matmul.locks or \
- size > _sparse_matmul.locks[dev].size(0):
- _sparse_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
- return _sparse_matmul.locks[dev]
- ##########################
- # SPARSE = DENSE x DENSE #
- ##########################
- @staticmethod
- def make_sdd_lut(layout, block, dtype, device):
- #_sparse_matmul._load_utils()
- #start_width = 64 // block
- #segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width)
- start_width = (128 if block > 16 else 32) // block
- layout = layout.type(torch.int32)
- segmented = libtriton.superblock(layout.data_ptr(),
- layout.shape[0],
- layout.shape[1],
- layout.shape[2],
- start_width)
- luts, widths, packs = [], [], []
- for size, nnz in segmented:
- """ width = nnz.shape[0] // (size * size)
- h = nnz[:, 0]
- i = nnz[:, 1]
- j = nnz[:, 2]
- b = nnz[:, 3]
- lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
- luts.append(lut.type(torch.int32).to(device))
- widths.append(width)
- packs.append(size) """
- nnz = nnz.reshape(-1, 4)
- width = nnz.shape[0] // (size * size)
- luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
- widths.append(width)
- packs.append(size)
- # create locks
- return luts, None, widths, packs
- @staticmethod
- def _sdd_matmul(a,
- b,
- trans_a,
- trans_b,
- trans_c,
- spdims,
- block,
- luts,
- num_locks,
- widths,
- packs,
- bench,
- time):
- if trans_c:
- a, b = b, a
- trans_a, trans_b = not trans_b, not trans_a
- AS0 = a.size(0)
- # Shape check
- a_dim = -2 if trans_a else -1
- b_dim = -1 if trans_b else -2
- a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]
- if a_inner != b_inner:
- raise ValueError(
- f"Size of tensor A along the {a_dim} dim ({a_inner}) must match size "
- f"of tensor B along the {b_dim} dim ({b_inner})")
- if a_inner % 16 != 0:
- raise ValueError('Reduction size for SDD must be a multiple of 16')
- batch_size = a.size(0)
- a_outer = a.size(3 if trans_a else 2)
- dtype = a.dtype
- is_16_multiple = a_inner % 16 == 0
- is_32_multiple = a_inner % 32 == 0
- is_64_multiple = a_inner % 64 == 0
- if not is_16_multiple:
- raise ValueError('Reduction size for SDD must be a multiple of 16')
- device = a.device
- # create kernel
- total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
- c = torch.empty((batch_size,
- total_width,
- block,
- block),
- dtype=dtype,
- device=a.device)
- for lut, width, pack in zip(luts, widths, packs):
- F32TK = [8, 16]
- F16TK = [16]
- F16TK += [32] if is_32_multiple else []
- F16TK += [64] if is_64_multiple else []
- TK = {torch.float32: F32TK, torch.float16: F16TK}[dtype]
- num_lock = 1
- meta = {
- 'TM': block * pack,
- 'TN': block * pack,
- 'BLOCK': block,
- 'TK': TK[0],
- 'TZ': 1,
- 'SDD': True,
- 'DSD': False,
- 'DDS': False
- }
- # create output
- locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock, a.device)
- # maximum grid size is 65535
- # so operation might be decomposed into multiple
- # kernel calls
- max_width = 49152
- total = 0 if bench else None
- for off_width in range(0, width, max_width):
- grid = lambda meta: [
- meta['TZ'],
- min(max_width,
- width - off_width),
- batch_size
- ]
- _kernel[grid](a,
- b,
- c,
- a.stride(0),
- a.stride(1),
- a.stride(3 if trans_a else 2),
- a.stride(2 if trans_a else 3),
- b.stride(0),
- b.stride(1),
- b.stride(3 if trans_b else 2),
- b.stride(2 if trans_b else 3),
- c.stride(0),
- c.stride(0),
- c.stride(2),
- c.stride(3),
- a_outer,
- a_outer,
- a_inner,
- off_width,
- lut,
- locks,
- num_lock,
- num_warps=4,
- **meta)
- # save for backward pass
- return c
- ##########################
- # DENSE = DENSE x SPARSE #
- ##########################
- # Given a binary layout of 0s and 1s,
- # Construct look-up table for efficient execution on GPUs
- @staticmethod
- def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
- # load-balancing
- _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
- segments = _empty.clone()
- column = _empty.clone()
- depth = _empty.clone()
- lockid = _empty.clone()
- maxid = _empty.clone()
- offsets = _empty.clone()
- current_offset = 0
- current_maxid = 0
- for z in range(layout.size(0)):
- if trans:
- sizes = torch.sum(layout[z, :, :], 1)
- else:
- sizes = torch.sum(layout[z, :, :], 0)
- z_segments, z_column, z_lockid, z_maxid, z_offsets = _sparse_matmul.load_balance(sizes, block)
- z_depth = z * torch.ones_like(z_segments)
- z_lockid[z_lockid > 0] += current_maxid
- current_maxid = z_lockid.max()
- # concatenate depth
- segments = torch.cat((segments, z_segments))
- column = torch.cat((column, z_column))
- depth = torch.cat((depth, z_depth))
- maxid = torch.cat((maxid, z_maxid))
- offsets = torch.cat((offsets, current_offset + z_offsets))
- lockid = torch.cat((lockid, z_lockid))
- current_offset += layout[z, :, :].sum()
- segments *= step
- # pointer increments
- if trans:
- nnz = layout.nonzero()
- else:
- nnz = layout.transpose(1, 2).nonzero()
- num_blocks = nnz.size(0)
- offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
- idx = transform(nnz[:, 2] * block)
- xincs = idx.clone()
- xincs[1:] -= idx[:-1]
- # divide block into multiple steps
- div = block // step
- xincs = xincs.view(-1, 1).repeat(1, div)
- xincs[:, 1:] = step
- xincs[:, 0] -= (div - 1) * step
- # first increment for each reduction is actually the offset
- xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
- xincs = xincs.view(-1)
- # block-mode input increments
- if trans:
- widx = torch.arange(num_blocks)
- else:
- widx = _empty.clone()
- current_offset = 0
- for z in range(layout.size(0)):
- layoutw = layout[z, :, :].clone()
- msum = layoutw.sum()
- layoutw[layoutw > 0] = 1 + torch.arange(msum)
- widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
- current_offset += msum
- widx = widx
- wincs = widx * block * block
- wincs[1:] -= widx[:-1] * block * block
- wincs = wincs.view(-1, 1).repeat(1, div)
- if trans:
- wincs[:, 1:] = step
- wincs[:, 0] -= (div - 1) * step
- else:
- wincs[:, 1:] = step * block
- wincs[:, 0] -= (div - 1) * step * block
- wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
- wincs = wincs.view(-1)
- # adjust offset and segment size
- offsets *= 2 * div
- segments *= div
- # create header
- width = column.size(0)
- offsets += 6 * width
- header = torch.stack((offsets,
- segments,
- column,
- depth,
- lockid,
- maxid),
- dim=1).view(-1).contiguous()
- incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
- incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
- # create lut
- lut = torch.cat((header, incs))
- lut = lut.type(torch.int32).to(device)
- # create locks
- num_locks = max(1, lockid.max())
- return lut, num_locks, width, None
- @staticmethod
- def _dds_matmul(a,
- b,
- trans_a,
- trans_b,
- trans_c,
- spdims,
- block,
- lut,
- num_locks,
- width,
- packs,
- bench,
- time):
- global triton
- if triton is None:
- triton = importlib.import_module('triton')
- # shapes / dtypes
- AS0 = a.size(0)
- AS1 = a.size(1)
- AS2 = a.size(3 if trans_a else 2)
- AS3 = a.size(2 if trans_a else 3)
- BS0 = spdims[0]
- BS1 = block * spdims[2 if trans_b else 1]
- BS2 = block * spdims[1 if trans_b else 2]
- dtype = a.dtype
- # kernel
- meta = {
- 'TN': block,
- 'TM': 128,
- 'TK': 16,
- 'BLOCK': block,
- 'TZ': 1,
- 'SDD': False,
- 'DSD': False,
- 'DDS': True
- }
- # output
- CS0 = AS0
- CS1 = AS1
- CS2 = BS2 if trans_c else AS2
- CS3 = AS2 if trans_c else BS2
- locks = _sparse_matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
- c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
- grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0]
- _kernel[grid](a,
- b,
- c,
- a.stride(0),
- a.stride(1),
- a.stride(3 if trans_a else 2),
- a.stride(2 if trans_a else 3),
- b.stride(0),
- b.stride(1),
- b.stride(3 if trans_b else 2),
- b.stride(2 if trans_b else 3),
- c.stride(0),
- c.stride(1),
- c.stride(3 if trans_c else 2),
- c.stride(2 if trans_c else 3),
- AS2,
- BS2,
- 0,
- 0,
- lut,
- locks,
- num_locks,
- num_warps=4,
- **meta)
- return c
- @staticmethod
- def _dsd_matmul(a,
- b,
- trans_a,
- trans_b,
- trans_c,
- spdims,
- block,
- lut,
- num_locks,
- width,
- packs,
- bench,
- time):
- global triton
- if triton is None:
- triton = importlib.import_module('triton')
- # shapes / dtypes
- AS0 = spdims[0]
- AS1 = block * spdims[2 if trans_a else 1]
- AS2 = block * spdims[1 if trans_a else 2]
- BS0 = b.size(0)
- BS1 = b.size(1)
- BS2 = b.size(3 if trans_b else 2)
- BS3 = b.size(2 if trans_b else 3)
- dtype = a.dtype
- # kernel
- meta = {
- 'TM': block,
- 'TN': 128,
- 'TK': 16,
- 'BLOCK': block,
- 'TZ': 1,
- 'SDD': False,
- 'DSD': True,
- 'DDS': False
- }
- # output
- CS0 = BS0
- CS1 = BS1
- CS2 = BS3 if trans_c else AS1
- CS3 = AS1 if trans_c else BS3
- locks = _sparse_matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
- c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
- grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]
- _kernel[grid](a,
- b,
- c,
- a.stride(0),
- a.stride(1),
- a.stride(3 if trans_a else 2),
- a.stride(2 if trans_a else 3),
- b.stride(0),
- b.stride(1),
- b.stride(3 if trans_b else 2),
- b.stride(2 if trans_b else 3),
- c.stride(0),
- c.stride(1),
- c.stride(2),
- c.stride(3),
- BS3,
- AS1,
- 0,
- 0,
- lut,
- locks,
- num_locks,
- num_warps=4,
- **meta)
- return c
- fn = {
- 'sdd': _sdd_matmul.__get__(object),
- 'dsd': _dsd_matmul.__get__(object),
- 'dds': _dds_matmul.__get__(object)
- }
- @staticmethod
- def forward(ctx,
- a,
- b,
- trans_a,
- trans_b,
- trans_c,
- mode,
- spdims,
- block,
- c_lut,
- c_num_locks,
- c_width,
- c_packs,
- c_bench,
- c_time,
- da_lut,
- da_num_locks,
- da_width,
- da_packs,
- da_bench,
- da_time,
- db_lut,
- db_num_locks,
- db_width,
- db_packs,
- db_bench,
- db_time):
- c = _sparse_matmul.fn[mode](a,
- b,
- trans_a,
- trans_b,
- trans_c,
- spdims,
- block,
- c_lut,
- c_num_locks,
- c_width,
- c_packs,
- c_bench,
- c_time)
- # save for backward
- ctx.save_for_backward(a, b)
- ctx.da_num_locks = da_num_locks
- ctx.da_lut = da_lut
- ctx.da_width = da_width
- ctx.da_packs = da_packs
- ctx.da_bench = da_bench
- ctx.da_time = da_time
- ctx.db_lut = db_lut
- ctx.db_num_locks = db_num_locks
- ctx.db_width = db_width
- ctx.db_bench = db_bench
- ctx.db_packs = db_packs
- ctx.db_time = db_time
- ctx.mode = mode
- ctx.spdims = spdims
- ctx.block = block
- ctx.trans_a = trans_a
- ctx.trans_b = trans_b
- return c
- @staticmethod
- def backward(ctx, dc):
- # saved for backward
- a, b = ctx.saved_tensors
- mode = ctx.mode
- # gradients w.r.t. a
- if ctx.needs_input_grad[0]:
- mode_da = mode[1] + mode[0] + mode[2]
- da = _sparse_matmul.fn[mode_da](dc,
- b,
- False,
- not ctx.trans_b,
- ctx.trans_a,
- ctx.spdims,
- ctx.block,
- ctx.da_lut,
- ctx.da_num_locks,
- ctx.da_width,
- ctx.da_packs,
- ctx.da_bench,
- ctx.da_time)
- # gradients w.r.t. b
- if ctx.needs_input_grad[1]:
- mode_db = mode[2] + mode[1] + mode[0]
- db = _sparse_matmul.fn[mode_db](a,
- dc,
- not ctx.trans_a,
- False,
- ctx.trans_b,
- ctx.spdims,
- ctx.block,
- ctx.db_lut,
- ctx.db_num_locks,
- ctx.db_width,
- ctx.db_packs,
- ctx.db_bench,
- ctx.db_time)
- return da, db, None, None, None,\
- None, None, None, None,\
- None, None, None, None, None, None,\
- None, None, None, None, None, None,\
- None, None, None, None, None, None
- class MatMul:
- """Block-Sparse MatMul class; this class handles three types of matrix-multiplication:
- - sparse = dense X dense
- - dense = sparse X dense
- - dense = dense X sparse
- For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
- """
- def make_lut(self, dtype, device):
- """Generates the sparsity layout/s used in block-sparse matmul
- """
- key = (dtype, device)
- if key in self.lut_cache:
- return self.lut_cache[key]
- # C look-up table
- layout, block = self.layout, self.block
- step = 16
- if self.mode == 'sdd':
- c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
- elif self.mode == 'dsd':
- c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
- elif self.mode == 'dds':
- c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
- # DA look-up table
- if self.mode == 'sdd':
- da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, True, device)
- elif self.mode == 'dsd':
- da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
- elif self.mode == 'dds':
- da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
- # DB look-up table
- if self.mode == 'sdd':
- db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, False, device)
- elif self.mode == 'dsd':
- db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
- elif self.mode == 'dds':
- db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
- self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
- da_lut, da_num_locks, da_width, da_packs,\
- db_lut, db_num_locks, db_width, db_packs)
- return self.lut_cache[key]
- def __init__(self, layout, block, mode, trans_a=False, trans_b=False, bench=False):
- """Initialize the Block-Sparse MatMul class.
- Arguments:
- layout: required: sparsity layout tensor
- block: required: an integer determining the block size.
- mode: required: a string determining type of matmul; ('sdd') sparse = dense X dense, ('dsd') dense = sparse X dense, ('dds') dense = dense X sparse
- trans_a: optional: a boolean determining if multiplication needs to be applied on transpose of input a; default is false
- trans_b: optional: a boolean determining if multiplication needs to be applied on transpose of input b; default is false
- bench: optional: set if you want to do benchmarking
- """
- if mode not in ['sdd', 'dsd', 'dds']:
- raise NotImplementedError('Supported modes are: sdd, dsd, dds')
- # look-up table cache
- self.lut_cache = dict()
- # attributes
- self.trans_a = trans_a
- self.trans_b = trans_b
- self.mode = mode
- self.block = block
- self.layout = layout
- layout_dim = layout.ndim
- assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
- if not mode == 'sdd':
- # Dims to be reduced on the 'inside' of the matmul, either -1 or -2
- trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2)
- self.dense_inner_dim = -(
- (sparse_inner % 2) + 1) if not trans_dense else sparse_inner
- sparse_inner = sparse_inner if not trans_sparse else -(
- (sparse_inner % 2) + 1)
- # Inner dim of the dense input should be equal to the inner dim of the sparse input
- self.dense_inner_size = layout.shape[sparse_inner] * block
- # Expected shape for sparse inputs
- self.sparse_shape = (layout.sum().item(), block, block)
- # Support using the same layout across attention heads etc.
- if layout_dim == 2:
- layout = layout.unsqueeze(0)
- layout = layout.long(
- ) # Above code assumes the layout tensor is an integral type
- self.spdims = layout.shape
- # timings
- self.bench = bench
- self.time_c = None
- self.time_da = None
- self.time_db = None
- # pad shapes of a tensor to make it
- # compatible with kernel calls
- @staticmethod
- def _pad_shape(x, is_sparse):
- max_dim = 3 if is_sparse else 4
- for i in range(max_dim - x.dim()):
- x = x.unsqueeze(0)
- return x
- def __call__(self, a, b):
- """Applies Block-Sparse MatMul.
- For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
- Arguments:
- a: required: a dense/block-sparse tensor; first input of mat-mul
- b: required: a dense/block-sparse tensor; second input of mat-mul
- Return:
- c: a dense/block-sparse tensor result of a X b
- """
- c_lut, c_num_locks, c_width, c_packs,\
- da_lut, da_num_locks, da_width, da_packs,\
- db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
- # timings
- time_c = [None]
- time_da = [None]
- time_db = [None]
- original_dims = max(a.ndim, b.ndim)
- a, b = self._validate_inputs(a, b)
- # pad shapes with ones
- a = MatMul._pad_shape(a, self.mode == 'dsd')
- b = MatMul._pad_shape(b, self.mode == 'dds')
- # execute
- c = _sparse_matmul.apply(a,
- b,
- self.trans_a,
- self.trans_b,
- False,
- self.mode,
- self.spdims,
- self.block,
- c_lut,
- c_num_locks,
- c_width,
- c_packs,
- self.bench,
- time_c,
- da_lut,
- da_num_locks,
- da_width,
- da_packs,
- self.bench,
- time_da,
- db_lut,
- db_num_locks,
- db_width,
- db_packs,
- self.bench,
- time_db)
- # This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
- dims_to_trim = c.ndim - original_dims
- for _ in range(dims_to_trim):
- c = c.squeeze(0)
- self.time_c = time_c[0]
- self.time_da = time_da[0]
- self.time_db = time_db[0]
- return c
- def _validate_inputs(self, a, b):
- if a.device != b.device:
- raise ValueError(
- f"Inputs must be on the same device; got {a.device} for tensor A "
- f"and {b.device} for tensor B")
- if not get_accelerator().on_accelerator(a):
- raise ValueError("Only GPU devices are supported for now")
- # When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
- if torch.is_autocast_enabled():
- a, b = a.half(), b.half()
- elif a.dtype != b.dtype:
- raise ValueError(
- f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B"
- )
- mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
- if mode != 'sdd':
- # One input is sparse
- dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
- dense_inner = dense.shape[self.dense_inner_dim]
- if dense_inner != self.dense_inner_size:
- raise ValueError(
- f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
- f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
- if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
- raise ValueError(
- f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
- f"{sparse_name}, got {sparse.shape}")
- def add_extra_dims(x):
- # Add extra leading singleton dimensions if needed
- dims_needed = 4 - x.ndim
- if dims_needed > 0:
- singletons = [1] * dims_needed
- x = x.view(*singletons, *x.shape)
- elif dims_needed < 0:
- raise ValueError(
- "Tensors with more than 4 dimensions are not currently supported")
- return x
- # Pad shapes with leading singleton dimensions
- a = add_extra_dims(a)
- b = add_extra_dims(b)
- return a, b
|