custom_cuda_layers.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_fp16.h>
  4. #include <stdio.h>
  5. #include <stdlib.h>
  6. #include <cooperative_groups.h>
  7. #include <curand_kernel.h>
  8. #include "context.h"
  9. #include "cublas_wrappers.h"
  10. #define CUDA_CHECK(callstr) \
  11. { \
  12. cudaError_t error_code = callstr; \
  13. if (error_code != cudaSuccess) { \
  14. std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
  15. assert(0); \
  16. } \
  17. }
  18. #define MAX_THREADS 1024
  19. #define THREADS 256
  20. #define MAX_THREAD_STRIDE 32
  21. #define TILE_DIM 32
  22. // Maximum sequence-length support based on the number of threads (2048) allowed in each block and
  23. // this MAX is 8K For higher sequence length we need to use higher Max, like for 64K : 32
  24. #define MAX_THREAD_ITERATIONS 8 // Maximum 8K
  25. #define MAX_WARP_NUM 32
  26. #define MAX_REGISTERS 256
  27. #define MAX_REG 256
  28. template <typename T>
  29. void launch_quantize_kernel(T* vals,
  30. int total_count,
  31. int group_num,
  32. int num_bits,
  33. cudaStream_t stream);
  34. template <typename T>
  35. void launch_sr_quantize_kernel(T* vals,
  36. int total_count,
  37. int group_num,
  38. int num_bits,
  39. cudaStream_t stream);
  40. template <typename T>
  41. void launch_quantize_kernel_asym(T* vals,
  42. int total_count,
  43. int group_num,
  44. int num_bits,
  45. cudaStream_t stream);
  46. template <typename T>
  47. void launch_sr_quantize_kernel_asym(T* vals,
  48. int total_count,
  49. int group_num,
  50. int num_bits,
  51. cudaStream_t stream);
  52. // Fused bias add with gelu activation
  53. template <typename T>
  54. void launch_bias_gelu(const T* input,
  55. const T* bias,
  56. T* output,
  57. int intermediate_size,
  58. int batch_size,
  59. cudaStream_t stream);
  60. template <typename T>
  61. void launch_gelu(const T* input,
  62. T* output,
  63. int intermediate_size,
  64. int batch_size,
  65. cudaStream_t stream);
  66. template <typename T>
  67. void launch_d_gelu(T* d_output,
  68. const T* input,
  69. const T* bias,
  70. int intermediate_size,
  71. int batch_size,
  72. cudaStream_t stream);
  73. // Custom fused bias add with layer normalization
  74. template <typename T>
  75. void launch_bias_residual_layer_norm(T* vals,
  76. const T* residual,
  77. const T* gamma,
  78. const T* beta,
  79. float epsilon,
  80. int batch_size,
  81. int hidden_dim,
  82. cudaStream_t stream,
  83. bool preLayerNorm,
  84. bool training,
  85. T* vars,
  86. T* means);
  87. template <typename T>
  88. void launch_bias_residual_layer_norm(T* vals,
  89. const T* residual,
  90. const T* gamma,
  91. const T* beta,
  92. float epsilon,
  93. int batch_size,
  94. int hidden_dim,
  95. cudaStream_t stream,
  96. bool preLayerNorm,
  97. bool training,
  98. T* vars);
  99. template <typename T>
  100. void launch_layerNorm_backward_fused_add(const T* out_grad1,
  101. const T* out_grad2,
  102. const T* X_data,
  103. const T* vars,
  104. const T* means,
  105. const T* gamma,
  106. T* gamma_grad,
  107. T* betta_grad,
  108. T* inp_grad,
  109. int batch_size,
  110. int hidden_dim,
  111. cudaStream_t stream[2]);
  112. template <typename T>
  113. void launch_layerNorm_backward_fused_add(const T* out_grad1,
  114. const T* out_grad2,
  115. const T* vals_hat,
  116. const T* vars,
  117. const T* gamma,
  118. T* gamma_grad,
  119. T* betta_grad,
  120. T* inp_grad,
  121. int batch_size,
  122. int hidden_dim,
  123. cudaStream_t stream[2],
  124. bool invertible = false,
  125. const T* betta = nullptr);
  126. template <typename T>
  127. void launch_layerNorm_backward(const T* out_grad,
  128. const T* X_data,
  129. const T* vars,
  130. const T* means,
  131. const T* gamma,
  132. T* gamma_grad,
  133. T* betta_grad,
  134. T* inp_grad,
  135. int batch_size,
  136. int hidden_dim,
  137. cudaStream_t stream[2]);
  138. template <typename T>
  139. void launch_layerNorm_backward(const T* out_grad,
  140. const T* vals_hat,
  141. const T* vars,
  142. const T* gamma,
  143. T* gamma_grad,
  144. T* betta_grad,
  145. T* inp_grad,
  146. int batch_size,
  147. int hidden_dim,
  148. cudaStream_t stream[2],
  149. bool invertible = false,
  150. const T* betta = nullptr);
  151. template <typename T>
  152. void launch_layerNorm_backward_nreversible(const T* out_grad,
  153. const T* vals,
  154. const T* out_grad_trans,
  155. const T* vals_trans,
  156. const T* means,
  157. const T* vars,
  158. const T* gamma,
  159. T* gamma_grad,
  160. T* betta_grad,
  161. T* inp_grad,
  162. int batch_size,
  163. int hidden_dim,
  164. cudaStream_t stream[2]);
  165. template <typename T>
  166. void Transpose(const T* inp_mat, T* out_mat, int rows, int cols, cudaStream_t stream);
  167. template <typename T>
  168. void launch_attn_softmax_backward(T* out_grad,
  169. const T* soft_inp,
  170. int batch_size,
  171. int heads,
  172. int seq_length,
  173. cudaStream_t stream);
  174. template <typename T>
  175. void launch_attn_softmax_backward_v2(T* out_grad,
  176. const T* soft_inp,
  177. int batch_size,
  178. int heads,
  179. int seq_length,
  180. cudaStream_t stream);
  181. // Custom softmax with scaling and attention mask addition
  182. template <typename T>
  183. void launch_attn_softmax(T* vals,
  184. const T* attn_mask,
  185. int batch_size,
  186. int heads,
  187. int sequence_length,
  188. cudaStream_t stream);
  189. template <typename T>
  190. void launch_transform_0213(T* output,
  191. const T* vals,
  192. int batch_size,
  193. int seq_length,
  194. int hidden_dim,
  195. int heads,
  196. cudaStream_t stream);
  197. // Custom bias add
  198. template <typename T>
  199. void launch_bias_add_transform_0213(T* outputs,
  200. const T* vals,
  201. const T* bias,
  202. int batch_size,
  203. int seq_length,
  204. int hidden_dim,
  205. int heads,
  206. cudaStream_t stream,
  207. int trans_count);
  208. // 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3]
  209. template <typename T>
  210. void launch_transform4d_0213(T* out,
  211. const T* in,
  212. int batch_size,
  213. int heads,
  214. int seq_length,
  215. int hidden_dim,
  216. cudaStream_t stream,
  217. int trans_count);
  218. template <typename T>
  219. void launch_dropout(T* vals,
  220. const T* bias,
  221. uint8_t* mask,
  222. int batch,
  223. int dim,
  224. float ratio,
  225. cudaStream_t stream);
  226. template <typename T>
  227. void launch_dropout(T* vals_out,
  228. const T* vals,
  229. uint8_t* mask,
  230. int total_count,
  231. int dim,
  232. float ratio,
  233. cudaStream_t stream,
  234. bool bwd = false);
  235. template <typename T>
  236. void launch_dropout(T* out,
  237. const T* vals,
  238. const T* residual,
  239. const T* bias,
  240. uint8_t* mask,
  241. int batch,
  242. int dim,
  243. float ratio,
  244. cudaStream_t stream);
  245. template <typename T>
  246. void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream);
  247. template <typename T>
  248. void launch_dropout_grad(T* vals_out,
  249. const T* vals,
  250. uint8_t* mask,
  251. int total_count,
  252. float ratio,
  253. cudaStream_t stream);
  254. template <typename T>
  255. void launch_fuse_transpose_bias_kernel(const T* inp,
  256. T* out,
  257. int rows,
  258. int cols,
  259. cudaStream_t stream);
  260. void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
  261. void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream);