softmax_fwd.tr 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. // DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
  2. // https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
  3. __global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16),
  4. float scale,
  5. int *LUT __readonly __noalias __aligned(16),
  6. TYPE *RPE __readonly __noalias __aligned(16),
  7. TYPE *KP_M __readonly __noalias __aligned(16),
  8. TYPE *ATTN_M __readonly __noalias __aligned(16),
  9. int num_blocks,
  10. int sizemax,
  11. long stride_zx __multipleof(BLOCK),
  12. long stride_zrpe __multipleof(BLOCK),
  13. int stride_hrpe __multipleof(BLOCK),
  14. int stride_srpe __multipleof(BLOCK),
  15. int stride_zkpm __multipleof(BLOCK),
  16. int stride_zattnm __multipleof(BLOCK)){
  17. int pidhm = get_program_id(0);
  18. int pidz = get_program_id(1);
  19. // create index ranges
  20. int rxm = pidhm % BLOCK;
  21. int rbm = pidhm / BLOCK;
  22. int rxn[TN] = (0 ... TN) % BLOCK;
  23. int rbn[TN] = (0 ... TN) / BLOCK;
  24. // extract information from look-up table
  25. int* header = LUT + rbm * 2;
  26. int size = *(header + 0);
  27. int offset = *(header + 1);
  28. bool check[TN] = rbn < size;
  29. int rbmn[TN] = check ? rbn : size - 1;
  30. // block id and column id
  31. long blockid [TN] = *(LUT + offset + rbmn*4 + 0);
  32. long columnid[TN] = *(LUT + offset + rbmn*4 + 1);
  33. long rowid [TN] = *(LUT + offset + rbmn*4 + 2);
  34. long headid [TN] = *(LUT + offset + rbmn*4 + 3);
  35. // pointers to X
  36. TYPE* px[TN] = X + pidz * stride_zx
  37. + blockid * BLOCK * BLOCK
  38. + rxm * BLOCK
  39. + rxn;
  40. #ifdef APPLY_RPE
  41. // pointers to relative position embedding
  42. TYPE* prpe[TN] = RPE + pidz * stride_zrpe
  43. + headid * stride_hrpe
  44. + columnid * BLOCK
  45. + rowid * BLOCK * stride_srpe
  46. + rxm * stride_srpe
  47. + rxn;
  48. #endif
  49. #ifdef APPLY_KP_MASK
  50. // pointers to key padding mask
  51. TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm
  52. + columnid * BLOCK
  53. + rxn;
  54. #endif
  55. #ifdef APPLY_ATTN_MASK
  56. // pointers to attention mask
  57. TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
  58. + rowid * BLOCK * stride_zattnm
  59. + rxm * stride_zattnm
  60. + rxn;
  61. #endif
  62. // load input
  63. TYPE x[TN] = check ? *px : -INFINITY;
  64. #ifdef APPLY_RPE
  65. // load relative position embedding
  66. TYPE rpe[TN] = check ? *prpe : 0;
  67. #endif
  68. #ifdef APPLY_KP_MASK
  69. // load key-padding mask
  70. TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
  71. #endif
  72. #ifdef APPLY_ATTN_MASK
  73. // load attention mask
  74. TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
  75. #endif
  76. // compute softmax in float
  77. #ifdef APPLY_RPE
  78. float Frpe[TN] = rpe;
  79. #endif
  80. #ifdef APPLY_KP_MASK
  81. float Fkp_m[TN] = kp_m;
  82. #endif
  83. #ifdef APPLY_ATTN_MASK
  84. float Fattn_m[TN] = attn_m;
  85. #endif
  86. #ifdef KP_MASK_MUL
  87. Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
  88. #endif
  89. #ifdef ATTN_MASK_MUL
  90. Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
  91. #endif
  92. float Fx[TN] = x;
  93. #ifdef APPLY_SCALE
  94. Fx = Fx * scale; // apply scale
  95. #endif
  96. #ifdef APPLY_RPE
  97. Fx = Fx + Frpe; // apply relative position embedding
  98. #endif
  99. #ifdef APPLY_KP_MASK
  100. Fx = Fx + Fkp_m; // apply key padding mask
  101. #endif
  102. #ifdef APPLY_ATTN_MASK
  103. Fx = Fx + Fattn_m; // apply attention mask
  104. #endif
  105. float Fxmax = Fx[max];
  106. float Fy[TN] = exp(Fx - Fxmax);
  107. float Fysum = (check ? Fy : 0)[+];
  108. // write-back in half/float
  109. TYPE y[TN] = Fy;
  110. TYPE ysum = Fysum;
  111. *?(check)px = y / ysum;
  112. }