custom_cuda_layers.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_fp16.h>
  4. #include <stdio.h>
  5. #include <stdlib.h>
  6. #ifdef __HIP_PLATFORM_HCC__
  7. #define HALF_PRECISION_AVAILABLE = 1
  8. #include <hip/hip_cooperative_groups.h>
  9. #else
  10. #if __CUDA_ARCH__ >= 700
  11. #define HALF_PRECISION_AVAILABLE = 1
  12. #endif
  13. #include <cooperative_groups.h>
  14. #endif
  15. #include <curand_kernel.h>
  16. #include "context.h"
  17. #include "cublas_wrappers.h"
  18. #define CUDA_CHECK(callstr) \
  19. { \
  20. cudaError_t error_code = callstr; \
  21. if (error_code != cudaSuccess) { \
  22. std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
  23. assert(0); \
  24. } \
  25. }
  26. #define MAX_THREADS 1024
  27. #define THREADS 256
  28. #define MAX_THREAD_STRIDE 32
  29. #define TILE_DIM 32
  30. // Maximum sequence-length support based on the number of threads (2048) allowed in each block and
  31. // this MAX is 8K For higher sequence length we need to use higher Max, like for 64K : 32
  32. #define MAX_THREAD_ITERATIONS 8 // Maximum 8K
  33. #define MAX_WARP_NUM 32
  34. #define MAX_REGISTERS 256
  35. #define MAX_REG 256
  36. #define WARP_SIZE_BITS 5
  37. template <typename T>
  38. void launch_quantize_kernel(T* vals,
  39. int total_count,
  40. int group_num,
  41. int num_bits,
  42. cudaStream_t stream);
  43. template <typename T>
  44. void launch_sr_quantize_kernel(T* vals,
  45. int total_count,
  46. int group_num,
  47. int num_bits,
  48. cudaStream_t stream);
  49. template <typename T>
  50. void launch_quantize_kernel_asym(T* vals,
  51. int total_count,
  52. int group_num,
  53. int num_bits,
  54. cudaStream_t stream);
  55. template <typename T>
  56. void launch_sr_quantize_kernel_asym(T* vals,
  57. int total_count,
  58. int group_num,
  59. int num_bits,
  60. cudaStream_t stream);
  61. // Fused bias add with gelu activation
  62. template <typename T>
  63. void launch_bias_gelu(const T* input,
  64. const T* bias,
  65. T* output,
  66. int intermediate_size,
  67. int batch_size,
  68. cudaStream_t stream);
  69. template <typename T>
  70. void launch_gelu(const T* input,
  71. T* output,
  72. int intermediate_size,
  73. int batch_size,
  74. cudaStream_t stream);
  75. template <typename T>
  76. void launch_d_gelu(T* d_output,
  77. const T* input,
  78. const T* bias,
  79. int intermediate_size,
  80. int batch_size,
  81. cudaStream_t stream);
  82. // Custom fused bias add with layer normalization
  83. template <typename T>
  84. void launch_bias_residual_layer_norm(T* vals,
  85. const T* residual,
  86. const T* gamma,
  87. const T* beta,
  88. float epsilon,
  89. int batch_size,
  90. int hidden_dim,
  91. cudaStream_t stream,
  92. bool preLayerNorm,
  93. bool training,
  94. T* vars,
  95. T* means);
  96. template <typename T>
  97. void launch_bias_residual_layer_norm(T* vals,
  98. const T* residual,
  99. const T* gamma,
  100. const T* beta,
  101. float epsilon,
  102. int batch_size,
  103. int hidden_dim,
  104. cudaStream_t stream,
  105. bool preLayerNorm,
  106. bool training,
  107. T* vars);
  108. template <typename T>
  109. void launch_layerNorm_backward_fused_add(const T* out_grad1,
  110. const T* out_grad2,
  111. const T* X_data,
  112. const T* vars,
  113. const T* means,
  114. const T* gamma,
  115. T* gamma_grad,
  116. T* betta_grad,
  117. T* inp_grad,
  118. int batch_size,
  119. int hidden_dim,
  120. cudaStream_t stream[2]);
  121. template <typename T>
  122. void launch_layerNorm_backward_fused_add(const T* out_grad1,
  123. const T* out_grad2,
  124. const T* vals_hat,
  125. const T* vars,
  126. const T* gamma,
  127. T* gamma_grad,
  128. T* betta_grad,
  129. T* inp_grad,
  130. int batch_size,
  131. int hidden_dim,
  132. cudaStream_t stream[2],
  133. bool invertible = false,
  134. const T* betta = nullptr);
  135. template <typename T>
  136. void launch_layerNorm_backward(const T* out_grad,
  137. const T* X_data,
  138. const T* vars,
  139. const T* means,
  140. const T* gamma,
  141. T* gamma_grad,
  142. T* betta_grad,
  143. T* inp_grad,
  144. int batch_size,
  145. int hidden_dim,
  146. cudaStream_t stream[2]);
  147. template <typename T>
  148. void launch_layerNorm_backward(const T* out_grad,
  149. const T* vals_hat,
  150. const T* vars,
  151. const T* gamma,
  152. T* gamma_grad,
  153. T* betta_grad,
  154. T* inp_grad,
  155. int batch_size,
  156. int hidden_dim,
  157. cudaStream_t stream[2],
  158. bool invertible = false,
  159. const T* betta = nullptr);
  160. template <typename T>
  161. void launch_layerNorm_backward_nreversible(const T* out_grad,
  162. const T* vals,
  163. const T* out_grad_trans,
  164. const T* vals_trans,
  165. const T* means,
  166. const T* vars,
  167. const T* gamma,
  168. T* gamma_grad,
  169. T* betta_grad,
  170. T* inp_grad,
  171. int batch_size,
  172. int hidden_dim,
  173. cudaStream_t stream[2]);
  174. template <typename T>
  175. void Transpose(const T* inp_mat, T* out_mat, int rows, int cols, cudaStream_t stream);
  176. template <typename T>
  177. void launch_attn_softmax_backward(T* out_grad,
  178. const T* soft_inp,
  179. int batch_size,
  180. int heads,
  181. int seq_length,
  182. cudaStream_t stream);
  183. template <typename T>
  184. void launch_attn_softmax_backward_v2(T* out_grad,
  185. const T* soft_inp,
  186. int batch_size,
  187. int heads,
  188. int seq_length,
  189. cudaStream_t stream);
  190. // Custom softmax with scaling and attention mask addition
  191. template <typename T>
  192. void launch_attn_softmax(T* vals,
  193. const T* attn_mask,
  194. int batch_size,
  195. int heads,
  196. int sequence_length,
  197. cudaStream_t stream);
  198. template <typename T>
  199. void launch_transform_0213(T* output,
  200. const T* vals,
  201. int batch_size,
  202. int seq_length,
  203. int hidden_dim,
  204. int heads,
  205. cudaStream_t stream);
  206. // Custom bias add
  207. template <typename T>
  208. void launch_bias_add_transform_0213(T* outputs,
  209. const T* vals,
  210. const T* bias,
  211. int batch_size,
  212. int seq_length,
  213. int hidden_dim,
  214. int heads,
  215. cudaStream_t stream,
  216. int trans_count);
  217. // 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3]
  218. template <typename T>
  219. void launch_transform4d_0213(T* out,
  220. const T* in,
  221. int batch_size,
  222. int heads,
  223. int seq_length,
  224. int hidden_dim,
  225. cudaStream_t stream,
  226. int trans_count);
  227. template <typename T>
  228. void launch_dropout(T* vals,
  229. const T* bias,
  230. uint8_t* mask,
  231. int batch,
  232. int dim,
  233. float ratio,
  234. cudaStream_t stream);
  235. template <typename T>
  236. void launch_dropout(T* vals_out,
  237. const T* vals,
  238. uint8_t* mask,
  239. int total_count,
  240. int dim,
  241. float ratio,
  242. cudaStream_t stream,
  243. bool bwd = false);
  244. template <typename T>
  245. void launch_dropout(T* out,
  246. const T* vals,
  247. const T* residual,
  248. const T* bias,
  249. uint8_t* mask,
  250. int batch,
  251. int dim,
  252. float ratio,
  253. cudaStream_t stream);
  254. template <typename T>
  255. void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream);
  256. template <typename T>
  257. void launch_dropout_grad(T* vals_out,
  258. const T* vals,
  259. uint8_t* mask,
  260. int total_count,
  261. float ratio,
  262. cudaStream_t stream);
  263. template <typename T>
  264. void launch_fuse_transpose_bias_kernel(const T* inp,
  265. T* out,
  266. int rows,
  267. int cols,
  268. cudaStream_t stream);
  269. void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
  270. void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream);