custom_cuda_layers.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_fp16.h>
  4. #include <stdio.h>
  5. #include <stdlib.h>
  6. #include <cooperative_groups.h>
  7. #include <curand_kernel.h>
  8. #include "context.h"
  9. #include "cublas_wrappers.h"
  10. #define MAX_THREADS 1024
  11. #define THREADS 256
  12. #define MAX_THREAD_STRIDE 32
  13. #define TILE_DIM 32
  14. // Maximum sequence-length support based on the number of threads (2048) allowed in each block and
  15. // this MAX is 8K For higher sequence length we need to use higher Max, like for 64K : 32
  16. #define MAX_THREAD_ITERATIONS 8 // Maximum 8K
  17. #define MAX_WARP_NUM 32
  18. #define MAX_REGISTERS 256
  19. #define MAX_REG 256
  20. template <typename T>
  21. void launch_quantize_kernel(T* vals,
  22. int total_count,
  23. int group_num,
  24. int num_bits,
  25. cudaStream_t stream);
  26. template <typename T>
  27. void launch_sr_quantize_kernel(T* vals,
  28. int total_count,
  29. int group_num,
  30. int num_bits,
  31. cudaStream_t stream);
  32. template <typename T>
  33. void launch_quantize_kernel_asym(T* vals,
  34. int total_count,
  35. int group_num,
  36. int num_bits,
  37. cudaStream_t stream);
  38. template <typename T>
  39. void launch_sr_quantize_kernel_asym(T* vals,
  40. int total_count,
  41. int group_num,
  42. int num_bits,
  43. cudaStream_t stream);
  44. // Fused bias add with gelu activation
  45. template <typename T>
  46. void launch_bias_gelu(const T* input,
  47. const T* bias,
  48. T* output,
  49. int intermediate_size,
  50. int batch_size,
  51. cudaStream_t stream);
  52. template <typename T>
  53. void launch_gelu(const T* input,
  54. T* output,
  55. int intermediate_size,
  56. int batch_size,
  57. cudaStream_t stream);
  58. template <typename T>
  59. void launch_d_gelu(T* d_output,
  60. const T* input,
  61. const T* bias,
  62. int intermediate_size,
  63. int batch_size,
  64. cudaStream_t stream);
  65. // Custom fused bias add with layer normalization
  66. template <typename T>
  67. void launch_bias_residual_layer_norm(T* vals,
  68. const T* residual,
  69. const T* gamma,
  70. const T* beta,
  71. float epsilon,
  72. int batch_size,
  73. int hidden_dim,
  74. cudaStream_t stream,
  75. bool preLayerNorm,
  76. bool training,
  77. T* vars,
  78. T* means);
  79. template <typename T>
  80. void launch_bias_residual_layer_norm(T* vals,
  81. const T* residual,
  82. const T* gamma,
  83. const T* beta,
  84. float epsilon,
  85. int batch_size,
  86. int hidden_dim,
  87. cudaStream_t stream,
  88. bool preLayerNorm,
  89. bool training,
  90. T* vars);
  91. template <typename T>
  92. void launch_layerNorm_backward_fused_add(const T* out_grad1,
  93. const T* out_grad2,
  94. const T* X_data,
  95. const T* vars,
  96. const T* means,
  97. const T* gamma,
  98. T* gamma_grad,
  99. T* betta_grad,
  100. T* inp_grad,
  101. int batch_size,
  102. int hidden_dim,
  103. cudaStream_t stream[2]);
  104. template <typename T>
  105. void launch_layerNorm_backward_fused_add(const T* out_grad1,
  106. const T* out_grad2,
  107. const T* vals_hat,
  108. const T* vars,
  109. const T* gamma,
  110. T* gamma_grad,
  111. T* betta_grad,
  112. T* inp_grad,
  113. int batch_size,
  114. int hidden_dim,
  115. cudaStream_t stream[2],
  116. bool invertible = false,
  117. const T* betta = nullptr);
  118. template <typename T>
  119. void launch_layerNorm_backward(const T* out_grad,
  120. const T* X_data,
  121. const T* vars,
  122. const T* means,
  123. const T* gamma,
  124. T* gamma_grad,
  125. T* betta_grad,
  126. T* inp_grad,
  127. int batch_size,
  128. int hidden_dim,
  129. cudaStream_t stream[2]);
  130. template <typename T>
  131. void launch_layerNorm_backward(const T* out_grad,
  132. const T* vals_hat,
  133. const T* vars,
  134. const T* gamma,
  135. T* gamma_grad,
  136. T* betta_grad,
  137. T* inp_grad,
  138. int batch_size,
  139. int hidden_dim,
  140. cudaStream_t stream[2],
  141. bool invertible = false,
  142. const T* betta = nullptr);
  143. template <typename T>
  144. void launch_layerNorm_backward_nreversible(const T* out_grad,
  145. const T* vals,
  146. const T* out_grad_trans,
  147. const T* vals_trans,
  148. const T* means,
  149. const T* vars,
  150. const T* gamma,
  151. T* gamma_grad,
  152. T* betta_grad,
  153. T* inp_grad,
  154. int batch_size,
  155. int hidden_dim,
  156. cudaStream_t stream[2]);
  157. template <typename T>
  158. void Transpose(const T* inp_mat, T* out_mat, int rows, int cols, cudaStream_t stream);
  159. template <typename T>
  160. void launch_attn_softmax_backward(T* out_grad,
  161. const T* soft_inp,
  162. int batch_size,
  163. int heads,
  164. int seq_length,
  165. cudaStream_t stream);
  166. template <typename T>
  167. void launch_attn_softmax_backward_v2(T* out_grad,
  168. const T* soft_inp,
  169. int batch_size,
  170. int heads,
  171. int seq_length,
  172. cudaStream_t stream);
  173. // Custom softmax with scaling and attention mask addition
  174. template <typename T>
  175. void launch_attn_softmax(T* vals,
  176. const T* attn_mask,
  177. int batch_size,
  178. int heads,
  179. int sequence_length,
  180. cudaStream_t stream);
  181. template <typename T>
  182. void launch_transform_0213(T* output,
  183. const T* vals,
  184. int batch_size,
  185. int seq_length,
  186. int hidden_dim,
  187. int heads,
  188. cudaStream_t stream);
  189. // Custom bias add
  190. template <typename T>
  191. void launch_bias_add_transform_0213(T* outputs,
  192. const T* vals,
  193. const T* bias,
  194. int batch_size,
  195. int seq_length,
  196. int hidden_dim,
  197. int heads,
  198. cudaStream_t stream,
  199. int trans_count);
  200. // 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3]
  201. template <typename T>
  202. void launch_transform4d_0213(T* out,
  203. const T* in,
  204. int batch_size,
  205. int heads,
  206. int seq_length,
  207. int hidden_dim,
  208. cudaStream_t stream,
  209. int trans_count);
  210. template <typename T>
  211. void launch_dropout(T* vals,
  212. const T* bias,
  213. uint8_t* mask,
  214. int batch,
  215. int dim,
  216. float ratio,
  217. cudaStream_t stream);
  218. template <typename T>
  219. void launch_dropout(T* vals_out,
  220. const T* vals,
  221. uint8_t* mask,
  222. int total_count,
  223. int dim,
  224. float ratio,
  225. cudaStream_t stream,
  226. bool bwd = false);
  227. template <typename T>
  228. void launch_dropout(T* out,
  229. const T* vals,
  230. const T* residual,
  231. const T* bias,
  232. uint8_t* mask,
  233. int batch,
  234. int dim,
  235. float ratio,
  236. cudaStream_t stream);
  237. template <typename T>
  238. void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream);
  239. template <typename T>
  240. void launch_dropout_grad(T* vals_out,
  241. const T* vals,
  242. uint8_t* mask,
  243. int total_count,
  244. float ratio,
  245. cudaStream_t stream);
  246. template <typename T>
  247. void launch_fuse_transpose_bias_kernel(const T* inp,
  248. T* out,
  249. int rows,
  250. int cols,
  251. cudaStream_t stream);
  252. void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
  253. void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream);