inference_cuda_layers.h 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include "ds_kernel_utils.h"
  6. #include <cuda.h>
  7. #ifdef BF16_AVAILABLE
  8. #include <cuda_bf16.h>
  9. #endif
  10. #include <cuda_fp16.h>
  11. #include <stdio.h>
  12. #include <stdlib.h>
  13. #include <cassert>
  14. #include <iostream>
  15. #define MAX_WARP_NUM 32
  16. #define WARP_SIZE 32
  17. #define MAX_THREADS 1024
  18. #define SMs 80
  19. #define MAX_REGISTERS 256
  20. template <typename T>
  21. void launch_attn_softmax_v2(T* vals,
  22. T* mask,
  23. T* alibi,
  24. float layer_scale,
  25. bool triangular,
  26. bool recompute,
  27. bool local_attention,
  28. int window_size,
  29. int batch_size,
  30. int heads,
  31. int num_seq,
  32. int sequence_length,
  33. int offset,
  34. int mask_stride,
  35. int mp_size,
  36. cudaStream_t stream);
  37. // Fused bias add with gelu activation
  38. template <typename T>
  39. void launch_bias_gelu(T* input,
  40. const T* bias,
  41. int intermediate_size,
  42. int batch_size,
  43. cudaStream_t stream);
  44. template <typename T>
  45. void launch_gated_activation(T* output,
  46. const T* activation,
  47. const T* bias,
  48. int rows,
  49. int output_stride,
  50. int elems_per_row,
  51. bool use_gelu,
  52. cudaStream_t stream);
  53. // Fused bias add with relu activation
  54. template <typename T>
  55. void launch_bias_relu(T* input,
  56. const T* bias,
  57. int intermediate_size,
  58. int batch_size,
  59. cudaStream_t stream);
  60. template <typename T>
  61. void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream);
  62. template <typename T>
  63. void launch_bias_residual(T* input,
  64. T* output,
  65. T* attn,
  66. T* bias,
  67. T* attn_bias,
  68. int batch,
  69. int hidden_dim,
  70. int mp_size,
  71. bool preln,
  72. cudaStream_t stream);
  73. template <typename T>
  74. void launch_fused_ln(T* output,
  75. const T* vals,
  76. const T* gamma,
  77. const T* beta,
  78. float epsilon,
  79. int rows,
  80. int elems_per_row,
  81. cudaStream_t stream);
  82. template <typename T>
  83. void launch_fused_residual_ln(T* output,
  84. const T* vals,
  85. const T* residual,
  86. const T* bias,
  87. const T* gamma,
  88. const T* beta,
  89. float epsilon,
  90. int rows,
  91. int elems_per_row,
  92. cudaStream_t stream);
  93. template <typename T>
  94. void launch_fused_residual_ln_store_pre_ln_res(T* norm_output,
  95. T* res_output,
  96. const T* vals,
  97. const T* residual,
  98. const T* bias,
  99. const T* gamma,
  100. const T* beta,
  101. float epsilon,
  102. int rows,
  103. int elems_per_row,
  104. cudaStream_t stream);
  105. template <typename T>
  106. void launch_rms_norm(T* norm_output,
  107. T* res_output,
  108. const T* vals,
  109. const T* residual,
  110. const T* gamma,
  111. float epsilon,
  112. int rows,
  113. int elems_per_row,
  114. cudaStream_t stream);
  115. template <typename T>
  116. void launch_dequantize(T* output,
  117. const int8_t* input,
  118. const float* qscale,
  119. unsigned output_size,
  120. unsigned hidden_dim,
  121. unsigned groups,
  122. unsigned merge_count,
  123. cudaStream_t stream);
  124. template <typename T>
  125. void launch_dequantize(T* output,
  126. const int8_t* input,
  127. const float* qscale,
  128. unsigned output_size,
  129. unsigned hidden_dim,
  130. unsigned groups,
  131. cudaStream_t stream);
  132. template <typename T>
  133. void launch_gptj_residual_add(T* input,
  134. T* output,
  135. T* attn,
  136. T* bias,
  137. T* attn_bias,
  138. int batch,
  139. int head_size,
  140. int mp_size,
  141. cudaStream_t stream);
  142. template <typename T>
  143. void launch_apply_rotary_pos_emb(T* mixed_query,
  144. T* key_layer,
  145. unsigned head_size,
  146. unsigned seq_len,
  147. unsigned rotary_dim,
  148. unsigned offset,
  149. unsigned num_heads,
  150. unsigned batch,
  151. float rope_theta,
  152. cudaStream_t stream,
  153. int max_out_tokens);
  154. template <typename T>
  155. void launch_moe_res_matmul(T* residual,
  156. T* coef,
  157. T* mlp_out,
  158. int seq_len,
  159. int hidden_dim,
  160. cudaStream_t stream);
  161. // 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3]
  162. template <typename T>
  163. void launch_transform4d_0213(T* out,
  164. const T* in,
  165. int batch_size,
  166. int heads,
  167. int seq_length,
  168. int hidden_dim,
  169. cudaStream_t stream,
  170. int trans_count);
  171. template <typename T>
  172. void launch_bias_add_transform_0213(T* outputs,
  173. T* vals,
  174. T* vals1,
  175. const T* vals2,
  176. const T* bias,
  177. int batch_size,
  178. int seq_length,
  179. unsigned seq_offset,
  180. int seq_length1,
  181. int hidden_dim,
  182. int heads,
  183. int num_kv,
  184. int rotary_dim,
  185. bool rotate_half,
  186. bool rotate_every_two,
  187. cudaStream_t stream,
  188. int trans_count,
  189. int max_out_tokens,
  190. float rope_theta);
  191. template <typename T>
  192. void pad_data(T* padded_output,
  193. T* output,
  194. int bsz,
  195. int head_size,
  196. int padded_head_size,
  197. cudaStream_t stream);
  198. template <typename T>
  199. void pad_head_seq(T* padded_output,
  200. T* output,
  201. int bsz,
  202. int seq_len,
  203. int padded_seq_len,
  204. int head_size,
  205. int padded_head_size,
  206. cudaStream_t stream);
  207. template <typename T>
  208. void launch_pad_add_transform_0213(T* output,
  209. const T* vals,
  210. int batch_size,
  211. int hidden_dim,
  212. int seq_length,
  213. int padded_seq_len,
  214. int heads,
  215. int padded_head_size,
  216. cudaStream_t stream);
  217. template <typename T>
  218. void launch_vector_add(T* out,
  219. const T* a,
  220. const T* b,
  221. float gamma,
  222. int num_elems,
  223. cudaStream_t stream);