matmul.tr 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. // DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
  2. // https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
  3. __global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
  4. TYPE* B __readonly __noalias __aligned(16),
  5. TYPE* C __noalias __aligned(16),
  6. int lda __multipleof(8),
  7. int ldb __multipleof(8),
  8. int ldc __multipleof(8),
  9. long stride_za __multipleof(8),
  10. long stride_zb __multipleof(8),
  11. long stride_zc __multipleof(8),
  12. long stride_ha __multipleof(8),
  13. long stride_hb __multipleof(8),
  14. long stride_hc __multipleof(8),
  15. int DS0, int DS1,
  16. int SDD_K __multipleof(16),
  17. int SDD_off_width,
  18. int* lut, int* locks, int nlocks) {
  19. /* ---------------- */
  20. /* Prologue */
  21. /* ---------------- */
  22. // program ids
  23. int pid0 = get_program_id(0);
  24. int pid1 = get_program_id(1);
  25. int pidz = get_program_id(2);
  26. #ifdef SDD
  27. // load LUT header
  28. pid1 = pid1 + SDD_off_width;
  29. int blockidm[TM] = (0 ... TM) / BLOCK;
  30. int blockidn[TN] = (0 ... TN) / BLOCK;
  31. int offlutm[TM] = blockidm*(TN/BLOCK)*4;
  32. int offlutn[TN] = blockidn*4;
  33. int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
  34. int z = *(header + 0);
  35. int i[TM] = *(header + 1 + offlutm);
  36. int j[TN] = *(header + 2 + offlutn);
  37. int AS1 = SDD_K / TZ;
  38. int lockid = select(TZ > 1, 1, 0);
  39. int offka = pid0 * AS1;
  40. int offkb = pid0 * AS1;
  41. int offmc = 0;
  42. int offnc = 0;
  43. int offpa = 0;
  44. int offpb = 0;
  45. int maxid = TZ;
  46. int offhc = 0;
  47. int offha = z;
  48. int offhb = z;
  49. int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
  50. int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
  51. #else
  52. // load LUT header
  53. int *header = lut + pid0 * 6;
  54. int offset = *(header + 0);
  55. int AS1 = *(header + 1);
  56. int column = *(header + 2);
  57. int depth = *(header + 3);
  58. int lockid = *(header + 4);
  59. int maxid = *(header + 5);
  60. int *pinc = lut + offset;
  61. int offhc = depth;
  62. #ifdef DSD
  63. // output offset
  64. int offnc = pid1 * TN;
  65. int offmc = column * TM;
  66. int offpc = 0;
  67. // dense input offset
  68. int offnb = pid1 * TN;
  69. int offkb __multipleof(8) = *pinc;
  70. int offpb = 0;
  71. // sparse input offset
  72. int offma = 0;
  73. int offka = 0;
  74. long offpa __multipleof(8) = *(pinc + 1);
  75. offpa = offpa * BLOCK * BLOCK;
  76. int offha = 0;
  77. int offhb = depth;
  78. #endif
  79. #ifdef DDS
  80. // output offset
  81. int offmc = pid1 * TM;
  82. int offnc = column * TN;
  83. int offpc = 0;
  84. // dense input offset
  85. int offma = pid1 * TM;
  86. int offka __multipleof(8) = *pinc;
  87. int offpa = 0;
  88. // sparse input offset
  89. int offnb = 0;
  90. int offkb = 0;
  91. long offpb __multipleof(8) = *(pinc + 1);
  92. offpb = offpb * BLOCK * BLOCK;
  93. int offha = depth;
  94. int offhb = 0;
  95. #endif
  96. int ram[TM] = offma + 0 ... TM;
  97. int rbn[TN] = offnb + 0 ... TN;
  98. #endif
  99. // initialize a, b pointers
  100. int rka[TK] = offka + 0 ... TK;
  101. int rkb[TK] = offkb + 0 ... TK;
  102. TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
  103. TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
  104. // pre-fetch
  105. #ifdef DDS
  106. bool checkam[TM, TK] = ram[:, newaxis] < DS0;
  107. #else
  108. bool checkam[TM, TK] = AS1 > 0;
  109. #endif
  110. #ifdef DSD
  111. bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
  112. #else
  113. bool checkbn[TK, TN] = AS1 > 0;
  114. #endif
  115. TYPE a[TM, TK] = checkam ? *pa : 0;
  116. TYPE b[TK, TN] = checkbn ? *pb : 0;
  117. /* ---------------- */
  118. /* Inner Loop */
  119. /* ---------------- */
  120. // create result tile
  121. float acc[TM, TN] = 0;
  122. int step = TK;
  123. for(int k = AS1; k > 0; k -= step) {
  124. acc += a @ b;
  125. // update pointers
  126. #ifdef SDD
  127. int inc_a = TK * STRIDE_AK;
  128. int inc_b = TK * STRIDE_BK;
  129. #else
  130. pinc += 2;
  131. #ifdef DSD
  132. int inc_b __multipleof(8) = *pinc;
  133. int inc_a __multipleof(8) = *(pinc + 1);
  134. inc_b = inc_b * STRIDE_BK;
  135. #endif
  136. #ifdef DDS
  137. int inc_a __multipleof(8) = *pinc;
  138. int inc_b __multipleof(8) = *(pinc + 1);
  139. inc_a = inc_a * STRIDE_AK;
  140. #endif
  141. #endif
  142. pa += inc_a;
  143. pb += inc_b;
  144. // pre-fetch
  145. bool checkak[TM, TK] = k > TK;
  146. bool checkbk[TK, TN] = k > TK;
  147. bool checka[TM, TK] = checkam && checkak;
  148. bool checkb[TK, TN] = checkbk && checkbn;
  149. a = *?(checka)pa;
  150. b = *?(checkb)pb;
  151. }
  152. TYPE c[TM, TN] = acc;
  153. /* ---------------- */
  154. /* Epilogue */
  155. /* ---------------- */
  156. // initialize c pointers
  157. #ifdef SDD
  158. bool checkc[TM, TN] = 1;
  159. // rematerialize
  160. int rr_blockidm[TM] = (0 ... TM) / BLOCK;
  161. int rr_blockidn[TN] = (0 ... TN) / BLOCK;
  162. int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
  163. int rr_offlutn[TN] = rr_blockidn*4;
  164. int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
  165. int bkid[TM, TN] = *(header + off_bkid);
  166. long offpc[TM, TN] = bkid * BLOCK * BLOCK;
  167. // range within blocks
  168. int rcm[TM] = (0 ... TM) % BLOCK;
  169. int rcn[TN] = (0 ... TN) % BLOCK;
  170. #else
  171. int rcm[TM] = offmc + 0 ... TM;
  172. int rcn[TN] = offnc + 0 ... TN;
  173. #ifdef DSD
  174. bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
  175. #endif
  176. #ifdef DDS
  177. bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
  178. #endif
  179. #endif
  180. TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
  181. // write-back directly
  182. if(lockid == 0) {
  183. *?(checkc) pc = c;
  184. }
  185. // accumulate partial result using spin-locks
  186. else {
  187. int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
  188. int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
  189. for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
  190. int count = *pcount;
  191. if(count == 0)
  192. *?(checkc) pc = c;
  193. else
  194. *?(checkc) pc = c + *?(checkc)pc;
  195. atomic_xchg(pcount, (count + 1) % maxid);
  196. atomic_xchg(plock, 0);
  197. }
  198. }