custom_cuda_layers.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include "ds_kernel_utils.h"
  6. #include <cuda.h>
  7. #include <cuda_fp16.h>
  8. #include <curand_kernel.h>
  9. #include <stdio.h>
  10. #include <stdlib.h>
  11. #include "context.h"
  12. #include "cublas_wrappers.h"
  13. #define CUDA_CHECK(callstr) \
  14. { \
  15. cudaError_t error_code = callstr; \
  16. if (error_code != cudaSuccess) { \
  17. std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
  18. assert(0); \
  19. } \
  20. }
  21. #define MAX_THREADS 1024
  22. #define THREADS 256
  23. #define MAX_THREAD_STRIDE 32
  24. #define TILE_DIM 32
  25. // Maximum sequence-length support based on the number of threads (2048) allowed in each block and
  26. // this MAX is 8K For higher sequence length we need to use higher Max, like for 64K : 32
  27. #define MAX_THREAD_ITERATIONS 8 // Maximum 8K
  28. #define MAX_WARP_NUM 32
  29. #define MAX_REGISTERS 256
  30. #define MAX_REG 256
  31. #define WARP_SIZE_BITS 5
  32. // Fused bias add with gelu activation
  33. template <typename T>
  34. void launch_bias_gelu(const T* input,
  35. const T* bias,
  36. T* output,
  37. int intermediate_size,
  38. int batch_size,
  39. cudaStream_t stream);
  40. template <typename T>
  41. void launch_gelu(const T* input,
  42. T* output,
  43. int intermediate_size,
  44. int batch_size,
  45. cudaStream_t stream);
  46. template <typename T>
  47. void launch_d_gelu(T* d_output,
  48. const T* input,
  49. const T* bias,
  50. int intermediate_size,
  51. int batch_size,
  52. cudaStream_t stream);
  53. // Custom fused bias add with layer normalization
  54. template <typename T>
  55. void launch_bias_residual_layer_norm(T* vals,
  56. const T* residual,
  57. const T* gamma,
  58. const T* beta,
  59. float epsilon,
  60. int batch_size,
  61. int hidden_dim,
  62. cudaStream_t stream,
  63. bool preLayerNorm,
  64. bool training,
  65. T* vars,
  66. T* means);
  67. template <typename T>
  68. void launch_bias_residual_layer_norm(T* vals,
  69. const T* residual,
  70. const T* gamma,
  71. const T* beta,
  72. float epsilon,
  73. int batch_size,
  74. int hidden_dim,
  75. cudaStream_t stream,
  76. bool preLayerNorm,
  77. bool training,
  78. T* vars);
  79. template <typename T>
  80. void launch_layerNorm_backward_fused_add(const T* out_grad1,
  81. const T* out_grad2,
  82. const T* X_data,
  83. const T* vars,
  84. const T* means,
  85. const T* gamma,
  86. T* gamma_grad,
  87. T* betta_grad,
  88. T* inp_grad,
  89. int batch_size,
  90. int hidden_dim,
  91. cudaStream_t stream[2]);
  92. template <typename T>
  93. void launch_layerNorm_backward_fused_add(const T* out_grad1,
  94. const T* out_grad2,
  95. const T* vals_hat,
  96. const T* vars,
  97. const T* gamma,
  98. T* gamma_grad,
  99. T* betta_grad,
  100. T* inp_grad,
  101. int batch_size,
  102. int hidden_dim,
  103. cudaStream_t stream[2],
  104. bool invertible = false,
  105. const T* betta = nullptr);
  106. template <typename T>
  107. void launch_layerNorm_backward(const T* out_grad,
  108. const T* X_data,
  109. const T* vars,
  110. const T* means,
  111. const T* gamma,
  112. T* gamma_grad,
  113. T* betta_grad,
  114. T* inp_grad,
  115. int batch_size,
  116. int hidden_dim,
  117. cudaStream_t stream[2]);
  118. template <typename T>
  119. void launch_layerNorm_backward(const T* out_grad,
  120. const T* vals_hat,
  121. const T* vars,
  122. const T* gamma,
  123. T* gamma_grad,
  124. T* betta_grad,
  125. T* inp_grad,
  126. int batch_size,
  127. int hidden_dim,
  128. cudaStream_t stream[2],
  129. bool invertible = false,
  130. const T* betta = nullptr);
  131. template <typename T>
  132. void launch_layerNorm_backward_nreversible(const T* out_grad,
  133. const T* vals,
  134. const T* out_grad_trans,
  135. const T* vals_trans,
  136. const T* means,
  137. const T* vars,
  138. const T* gamma,
  139. T* gamma_grad,
  140. T* betta_grad,
  141. T* inp_grad,
  142. int batch_size,
  143. int hidden_dim,
  144. cudaStream_t stream[2]);
  145. template <typename T>
  146. void Transpose(const T* inp_mat, T* out_mat, int rows, int cols, cudaStream_t stream);
  147. template <typename T>
  148. void launch_attn_softmax_backward(T* out_grad,
  149. const T* soft_inp,
  150. int batch_size,
  151. int heads,
  152. int seq_length,
  153. cudaStream_t stream);
  154. template <typename T>
  155. void launch_attn_softmax_backward_v2(T* out_grad,
  156. const T* soft_inp,
  157. int batch_size,
  158. int heads,
  159. int seq_length,
  160. cudaStream_t stream);
  161. // Custom softmax with scaling and attention mask addition
  162. template <typename T>
  163. void launch_attn_softmax(T* vals,
  164. const T* attn_mask,
  165. int batch_size,
  166. int heads,
  167. int sequence_length,
  168. cudaStream_t stream);
  169. template <typename T>
  170. void launch_transform_0213(T* output,
  171. const T* vals,
  172. int batch_size,
  173. int seq_length,
  174. int hidden_dim,
  175. int heads,
  176. cudaStream_t stream);
  177. // Custom bias add
  178. template <typename T>
  179. void launch_bias_add_transform_0213(T* outputs,
  180. const T* vals,
  181. const T* bias,
  182. int batch_size,
  183. int seq_length,
  184. int hidden_dim,
  185. int heads,
  186. cudaStream_t stream,
  187. int trans_count);
  188. // 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3]
  189. template <typename T>
  190. void launch_transform4d_0213(T* out,
  191. const T* in,
  192. int batch_size,
  193. int heads,
  194. int seq_length,
  195. int hidden_dim,
  196. cudaStream_t stream,
  197. int trans_count);
  198. template <typename T>
  199. void launch_dropout(T* vals,
  200. const T* bias,
  201. uint8_t* mask,
  202. int batch,
  203. int dim,
  204. float ratio,
  205. cudaStream_t stream);
  206. template <typename T>
  207. void launch_dropout(T* vals_out,
  208. const T* vals,
  209. uint8_t* mask,
  210. int total_count,
  211. int dim,
  212. float ratio,
  213. cudaStream_t stream,
  214. bool bwd = false);
  215. template <typename T>
  216. void launch_dropout(T* out,
  217. const T* vals,
  218. const T* residual,
  219. const T* bias,
  220. uint8_t* mask,
  221. int batch,
  222. int dim,
  223. float ratio,
  224. cudaStream_t stream);
  225. template <typename T>
  226. void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream);
  227. template <typename T>
  228. void launch_dropout_grad(T* vals_out,
  229. const T* vals,
  230. uint8_t* mask,
  231. int total_count,
  232. float ratio,
  233. cudaStream_t stream);
  234. template <typename T>
  235. void launch_fuse_transpose_bias_kernel(const T* inp,
  236. T* out,
  237. int rows,
  238. int cols,
  239. cudaStream_t stream);
  240. void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
  241. void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream);
  242. void launch_token_sort(int32_t* indices,
  243. int layers,
  244. int batch_size,
  245. int reserved_size,
  246. int original_tokens,
  247. cudaStream_t stream);
  248. template <typename T>
  249. void launch_gather_tokens(T* retained_tokens,
  250. T* activations,
  251. int32_t* gather_indices,
  252. int32_t batch_size,
  253. int32_t sampled_tokens,
  254. int32_t channels,
  255. int32_t read_batch_stride,
  256. int32_t read_seq_stride,
  257. int32_t write_batch_stride,
  258. int32_t write_seq_stride,
  259. cudaStream_t stream);
  260. template <typename T>
  261. void launch_scatter_tokens(T* all_activations,
  262. T* layer_activations,
  263. int32_t* gather_indices,
  264. int32_t batch_size,
  265. int32_t sampled_tokens,
  266. int32_t channels,
  267. int32_t read_batch_stride,
  268. int32_t read_seq_stride,
  269. int32_t write_batch_stride,
  270. int32_t write_seq_stride,
  271. cudaStream_t stream);
  272. template <typename T>
  273. void launch_slice_gpt_mask(T* output_mask,
  274. const T* input_mask,
  275. int batch_size,
  276. int truncated_seq_len,
  277. int orig_seq_len,
  278. cudaStream_t stream);
  279. template <typename T>
  280. void launch_slice_bert_mask(T* output_mask,
  281. const T* input_mask,
  282. const int32_t* retained_indices,
  283. int32_t layers,
  284. int32_t batch_size,
  285. int32_t truncated_seq_len,
  286. int32_t orig_seq_len,
  287. cudaStream_t stream);