softmax_bwd.tr 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. // DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
  2. // https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
  3. __global__ void softmax_bwd(TYPE * X __readonly __noalias __aligned(16),
  4. float scale,
  5. TYPE* DX __readonly __noalias __aligned(16),
  6. int* LUT,
  7. int sizemax,
  8. long stride_zx __multipleof(BLOCK),
  9. long stride_zdx __multipleof(BLOCK)) {
  10. int pidhm = get_program_id(0);
  11. int pidz = get_program_id(1);
  12. // create index ranges
  13. int rxm = pidhm % BLOCK;
  14. int rbm = pidhm / BLOCK;
  15. int rxn[TN] = (0 ... TN) % BLOCK;
  16. int rbn[TN] = (0 ... TN) / BLOCK;
  17. // extract information from look-up table
  18. int* header = LUT + rbm * 2;
  19. int size = *(header + 0);
  20. int offset = *(header + 1);
  21. // bounds checking on lut
  22. bool check[TN] = rbn < size;
  23. int rbmn[TN] = check ? rbn : size - 1;
  24. // initialize pointers to block-sparse input
  25. long blockid[TN] = *(LUT + offset + rbmn*4);
  26. TYPE* px[TN] = X + pidz * stride_zx
  27. + blockid * BLOCK * BLOCK
  28. + rxm * BLOCK
  29. + rxn;
  30. TYPE* pdx[TN] = DX + pidz * stride_zdx
  31. + blockid * BLOCK * BLOCK
  32. + rxm * BLOCK
  33. + rxn;
  34. // compute fused softmax backward
  35. TYPE x[TN] = check ? *px : 0;
  36. TYPE dx[TN] = check ? *pdx : 0;
  37. float Fdx[TN] = dx;
  38. float Fx[TN] = x;
  39. float Fxdx[TN] = Fdx*Fx;
  40. float Fxdxsum = Fxdx[+];
  41. float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
  42. TYPE y[TN] = Fy;
  43. // write-back
  44. *? (check)pdx = y;
  45. }