softmax_fwd.tr 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. /*
  5. DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
  6. https:github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
  7. */
  8. __global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16),
  9. float scale,
  10. int *LUT __readonly __noalias __aligned(16),
  11. TYPE *RPE __readonly __noalias __aligned(16),
  12. TYPE *KP_M __readonly __noalias __aligned(16),
  13. TYPE *ATTN_M __readonly __noalias __aligned(16),
  14. int num_blocks,
  15. int sizemax,
  16. long stride_zx __multipleof(BLOCK),
  17. long stride_zrpe __multipleof(BLOCK),
  18. int stride_hrpe __multipleof(BLOCK),
  19. int stride_srpe __multipleof(BLOCK),
  20. int stride_zkpm __multipleof(BLOCK),
  21. int stride_zattnm __multipleof(BLOCK)){
  22. int pidhm = get_program_id(0);
  23. int pidz = get_program_id(1);
  24. // create index ranges
  25. int rxm = pidhm % BLOCK;
  26. int rbm = pidhm / BLOCK;
  27. int rxn[TN] = (0 ... TN) % BLOCK;
  28. int rbn[TN] = (0 ... TN) / BLOCK;
  29. // extract information from look-up table
  30. int* header = LUT + rbm * 2;
  31. int size = *(header + 0);
  32. int offset = *(header + 1);
  33. bool check[TN] = rbn < size;
  34. int rbmn[TN] = check ? rbn : size - 1;
  35. // block id and column id
  36. long blockid [TN] = *(LUT + offset + rbmn*4 + 0);
  37. long columnid[TN] = *(LUT + offset + rbmn*4 + 1);
  38. long rowid [TN] = *(LUT + offset + rbmn*4 + 2);
  39. long headid [TN] = *(LUT + offset + rbmn*4 + 3);
  40. // pointers to X
  41. TYPE* px[TN] = X + pidz * stride_zx
  42. + blockid * BLOCK * BLOCK
  43. + rxm * BLOCK
  44. + rxn;
  45. #ifdef APPLY_RPE
  46. // pointers to relative position embedding
  47. TYPE* prpe[TN] = RPE + pidz * stride_zrpe
  48. + headid * stride_hrpe
  49. + columnid * BLOCK
  50. + rowid * BLOCK * stride_srpe
  51. + rxm * stride_srpe
  52. + rxn;
  53. #endif
  54. #ifdef APPLY_KP_MASK
  55. // pointers to key padding mask
  56. TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm
  57. + columnid * BLOCK
  58. + rxn;
  59. #endif
  60. #ifdef APPLY_ATTN_MASK
  61. // pointers to attention mask
  62. TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
  63. + rowid * BLOCK * stride_zattnm
  64. + rxm * stride_zattnm
  65. + rxn;
  66. #endif
  67. // load input
  68. TYPE x[TN] = check ? *px : -INFINITY;
  69. #ifdef APPLY_RPE
  70. // load relative position embedding
  71. TYPE rpe[TN] = check ? *prpe : 0;
  72. #endif
  73. #ifdef APPLY_KP_MASK
  74. // load key-padding mask
  75. TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
  76. #endif
  77. #ifdef APPLY_ATTN_MASK
  78. // load attention mask
  79. TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
  80. #endif
  81. // compute softmax in float
  82. #ifdef APPLY_RPE
  83. float Frpe[TN] = rpe;
  84. #endif
  85. #ifdef APPLY_KP_MASK
  86. float Fkp_m[TN] = kp_m;
  87. #endif
  88. #ifdef APPLY_ATTN_MASK
  89. float Fattn_m[TN] = attn_m;
  90. #endif
  91. #ifdef KP_MASK_MUL
  92. Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
  93. #endif
  94. #ifdef ATTN_MASK_MUL
  95. Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
  96. #endif
  97. float Fx[TN] = x;
  98. #ifdef APPLY_SCALE
  99. Fx = Fx * scale; // apply scale
  100. #endif
  101. #ifdef APPLY_RPE
  102. Fx = Fx + Frpe; // apply relative position embedding
  103. #endif
  104. #ifdef APPLY_KP_MASK
  105. Fx = Fx + Fkp_m; // apply key padding mask
  106. #endif
  107. #ifdef APPLY_ATTN_MASK
  108. Fx = Fx + Fattn_m; // apply attention mask
  109. #endif
  110. float Fxmax = Fx[max];
  111. float Fy[TN] = exp(Fx - Fxmax);
  112. float Fysum = (check ? Fy : 0)[+];
  113. // write-back in half/float
  114. TYPE y[TN] = Fy;
  115. TYPE ysum = Fysum;
  116. *?(check)px = y / ysum;
  117. }