123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- // DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
- // https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
- __global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
- TYPE* B __readonly __noalias __aligned(16),
- TYPE* C __noalias __aligned(16),
- int lda __multipleof(8),
- int ldb __multipleof(8),
- int ldc __multipleof(8),
- long stride_za __multipleof(8),
- long stride_zb __multipleof(8),
- long stride_zc __multipleof(8),
- long stride_ha __multipleof(8),
- long stride_hb __multipleof(8),
- long stride_hc __multipleof(8),
- int DS0, int DS1,
- int SDD_K __multipleof(16),
- int SDD_off_width,
- int* lut, int* locks, int nlocks) {
- /* ---------------- */
- /* Prologue */
- /* ---------------- */
- // program ids
- int pid0 = get_program_id(0);
- int pid1 = get_program_id(1);
- int pidz = get_program_id(2);
- #ifdef SDD
- // load LUT header
- pid1 = pid1 + SDD_off_width;
- int blockidm[TM] = (0 ... TM) / BLOCK;
- int blockidn[TN] = (0 ... TN) / BLOCK;
- int offlutm[TM] = blockidm*(TN/BLOCK)*4;
- int offlutn[TN] = blockidn*4;
- int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
- int z = *(header + 0);
- int i[TM] = *(header + 1 + offlutm);
- int j[TN] = *(header + 2 + offlutn);
- int AS1 = SDD_K / TZ;
- int lockid = select(TZ > 1, 1, 0);
- int offka = pid0 * AS1;
- int offkb = pid0 * AS1;
- int offmc = 0;
- int offnc = 0;
- int offpa = 0;
- int offpb = 0;
- int maxid = TZ;
- int offhc = 0;
- int offha = z;
- int offhb = z;
- int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
- int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
- #else
- // load LUT header
- int *header = lut + pid0 * 6;
- int offset = *(header + 0);
- int AS1 = *(header + 1);
- int column = *(header + 2);
- int depth = *(header + 3);
- int lockid = *(header + 4);
- int maxid = *(header + 5);
- int *pinc = lut + offset;
- int offhc = depth;
- #ifdef DSD
- // output offset
- int offnc = pid1 * TN;
- int offmc = column * TM;
- int offpc = 0;
- // dense input offset
- int offnb = pid1 * TN;
- int offkb __multipleof(8) = *pinc;
- int offpb = 0;
- // sparse input offset
- int offma = 0;
- int offka = 0;
- long offpa __multipleof(8) = *(pinc + 1);
- offpa = offpa * BLOCK * BLOCK;
- int offha = 0;
- int offhb = depth;
- #endif
- #ifdef DDS
- // output offset
- int offmc = pid1 * TM;
- int offnc = column * TN;
- int offpc = 0;
- // dense input offset
- int offma = pid1 * TM;
- int offka __multipleof(8) = *pinc;
- int offpa = 0;
- // sparse input offset
- int offnb = 0;
- int offkb = 0;
- long offpb __multipleof(8) = *(pinc + 1);
- offpb = offpb * BLOCK * BLOCK;
- int offha = depth;
- int offhb = 0;
- #endif
- int ram[TM] = offma + 0 ... TM;
- int rbn[TN] = offnb + 0 ... TN;
- #endif
- // initialize a, b pointers
- int rka[TK] = offka + 0 ... TK;
- int rkb[TK] = offkb + 0 ... TK;
- TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
- TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
- // pre-fetch
- #ifdef DDS
- bool checkam[TM, TK] = ram[:, newaxis] < DS0;
- #else
- bool checkam[TM, TK] = AS1 > 0;
- #endif
- #ifdef DSD
- bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
- #else
- bool checkbn[TK, TN] = AS1 > 0;
- #endif
- TYPE a[TM, TK] = checkam ? *pa : 0;
- TYPE b[TK, TN] = checkbn ? *pb : 0;
- /* ---------------- */
- /* Inner Loop */
- /* ---------------- */
- // create result tile
- float acc[TM, TN] = 0;
- int step = TK;
- for(int k = AS1; k > 0; k -= step) {
- acc += a @ b;
- // update pointers
- #ifdef SDD
- int inc_a = TK * STRIDE_AK;
- int inc_b = TK * STRIDE_BK;
- #else
- pinc += 2;
- #ifdef DSD
- int inc_b __multipleof(8) = *pinc;
- int inc_a __multipleof(8) = *(pinc + 1);
- inc_b = inc_b * STRIDE_BK;
- #endif
- #ifdef DDS
- int inc_a __multipleof(8) = *pinc;
- int inc_b __multipleof(8) = *(pinc + 1);
- inc_a = inc_a * STRIDE_AK;
- #endif
- #endif
- pa += inc_a;
- pb += inc_b;
- // pre-fetch
- bool checkak[TM, TK] = k > TK;
- bool checkbk[TK, TN] = k > TK;
- bool checka[TM, TK] = checkam && checkak;
- bool checkb[TK, TN] = checkbk && checkbn;
- a = *?(checka)pa;
- b = *?(checkb)pb;
- }
- TYPE c[TM, TN] = acc;
- /* ---------------- */
- /* Epilogue */
- /* ---------------- */
- // initialize c pointers
- #ifdef SDD
- bool checkc[TM, TN] = 1;
- // rematerialize
- int rr_blockidm[TM] = (0 ... TM) / BLOCK;
- int rr_blockidn[TN] = (0 ... TN) / BLOCK;
- int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
- int rr_offlutn[TN] = rr_blockidn*4;
- int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
- int bkid[TM, TN] = *(header + off_bkid);
- long offpc[TM, TN] = bkid * BLOCK * BLOCK;
- // range within blocks
- int rcm[TM] = (0 ... TM) % BLOCK;
- int rcn[TN] = (0 ... TN) % BLOCK;
- #else
- int rcm[TM] = offmc + 0 ... TM;
- int rcn[TN] = offnc + 0 ... TN;
- #ifdef DSD
- bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
- #endif
- #ifdef DDS
- bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
- #endif
- #endif
- TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
- // write-back directly
- if(lockid == 0) {
- *?(checkc) pc = c;
- }
- // accumulate partial result using spin-locks
- else {
- int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
- int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
- for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
- int count = *pcount;
- if(count == 0)
- *?(checkc) pc = c;
- else
- *?(checkc) pc = c + *?(checkc)pc;
- atomic_xchg(pcount, (count + 1) % maxid);
- atomic_xchg(plock, 0);
- }
- }
|