123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- #pragma once
- #include <cuda.h>
- #include <cuda_fp16.h>
- #include <stdio.h>
- #include <stdlib.h>
- #include <cooperative_groups.h>
- #include <curand_kernel.h>
- #include "context.h"
- #include "cublas_wrappers.h"
- #define CUDA_CHECK(callstr) \
- { \
- cudaError_t error_code = callstr; \
- if (error_code != cudaSuccess) { \
- std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
- assert(0); \
- } \
- }
- #define MAX_THREADS 1024
- #define THREADS 256
- #define MAX_THREAD_STRIDE 32
- #define TILE_DIM 32
- // Maximum sequence-length support based on the number of threads (2048) allowed in each block and
- // this MAX is 8K For higher sequence length we need to use higher Max, like for 64K : 32
- #define MAX_THREAD_ITERATIONS 8 // Maximum 8K
- #define MAX_WARP_NUM 32
- #define MAX_REGISTERS 256
- #define MAX_REG 256
- template <typename T>
- void launch_quantize_kernel(T* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- template <typename T>
- void launch_sr_quantize_kernel(T* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- template <typename T>
- void launch_quantize_kernel_asym(T* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- template <typename T>
- void launch_sr_quantize_kernel_asym(T* vals,
- int total_count,
- int group_num,
- int num_bits,
- cudaStream_t stream);
- // Fused bias add with gelu activation
- template <typename T>
- void launch_bias_gelu(const T* input,
- const T* bias,
- T* output,
- int intermediate_size,
- int batch_size,
- cudaStream_t stream);
- template <typename T>
- void launch_gelu(const T* input,
- T* output,
- int intermediate_size,
- int batch_size,
- cudaStream_t stream);
- template <typename T>
- void launch_d_gelu(T* d_output,
- const T* input,
- const T* bias,
- int intermediate_size,
- int batch_size,
- cudaStream_t stream);
- // Custom fused bias add with layer normalization
- template <typename T>
- void launch_bias_residual_layer_norm(T* vals,
- const T* residual,
- const T* gamma,
- const T* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- T* vars,
- T* means);
- template <typename T>
- void launch_bias_residual_layer_norm(T* vals,
- const T* residual,
- const T* gamma,
- const T* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- T* vars);
- template <typename T>
- void launch_layerNorm_backward_fused_add(const T* out_grad1,
- const T* out_grad2,
- const T* X_data,
- const T* vars,
- const T* means,
- const T* gamma,
- T* gamma_grad,
- T* betta_grad,
- T* inp_grad,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream[2]);
- template <typename T>
- void launch_layerNorm_backward_fused_add(const T* out_grad1,
- const T* out_grad2,
- const T* vals_hat,
- const T* vars,
- const T* gamma,
- T* gamma_grad,
- T* betta_grad,
- T* inp_grad,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible = false,
- const T* betta = nullptr);
- template <typename T>
- void launch_layerNorm_backward(const T* out_grad,
- const T* X_data,
- const T* vars,
- const T* means,
- const T* gamma,
- T* gamma_grad,
- T* betta_grad,
- T* inp_grad,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream[2]);
- template <typename T>
- void launch_layerNorm_backward(const T* out_grad,
- const T* vals_hat,
- const T* vars,
- const T* gamma,
- T* gamma_grad,
- T* betta_grad,
- T* inp_grad,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible = false,
- const T* betta = nullptr);
- template <typename T>
- void launch_layerNorm_backward_nreversible(const T* out_grad,
- const T* vals,
- const T* out_grad_trans,
- const T* vals_trans,
- const T* means,
- const T* vars,
- const T* gamma,
- T* gamma_grad,
- T* betta_grad,
- T* inp_grad,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream[2]);
- template <typename T>
- void Transpose(const T* inp_mat, T* out_mat, int rows, int cols, cudaStream_t stream);
- template <typename T>
- void launch_attn_softmax_backward(T* out_grad,
- const T* soft_inp,
- int batch_size,
- int heads,
- int seq_length,
- cudaStream_t stream);
- template <typename T>
- void launch_attn_softmax_backward_v2(T* out_grad,
- const T* soft_inp,
- int batch_size,
- int heads,
- int seq_length,
- cudaStream_t stream);
- // Custom softmax with scaling and attention mask addition
- template <typename T>
- void launch_attn_softmax(T* vals,
- const T* attn_mask,
- int batch_size,
- int heads,
- int sequence_length,
- cudaStream_t stream);
- template <typename T>
- void launch_transform_0213(T* output,
- const T* vals,
- int batch_size,
- int seq_length,
- int hidden_dim,
- int heads,
- cudaStream_t stream);
- // Custom bias add
- template <typename T>
- void launch_bias_add_transform_0213(T* outputs,
- const T* vals,
- const T* bias,
- int batch_size,
- int seq_length,
- int hidden_dim,
- int heads,
- cudaStream_t stream,
- int trans_count);
- // 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3]
- template <typename T>
- void launch_transform4d_0213(T* out,
- const T* in,
- int batch_size,
- int heads,
- int seq_length,
- int hidden_dim,
- cudaStream_t stream,
- int trans_count);
- template <typename T>
- void launch_dropout(T* vals,
- const T* bias,
- uint8_t* mask,
- int batch,
- int dim,
- float ratio,
- cudaStream_t stream);
- template <typename T>
- void launch_dropout(T* vals_out,
- const T* vals,
- uint8_t* mask,
- int total_count,
- int dim,
- float ratio,
- cudaStream_t stream,
- bool bwd = false);
- template <typename T>
- void launch_dropout(T* out,
- const T* vals,
- const T* residual,
- const T* bias,
- uint8_t* mask,
- int batch,
- int dim,
- float ratio,
- cudaStream_t stream);
- template <typename T>
- void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream);
- template <typename T>
- void launch_dropout_grad(T* vals_out,
- const T* vals,
- uint8_t* mask,
- int total_count,
- float ratio,
- cudaStream_t stream);
- template <typename T>
- void launch_fuse_transpose_bias_kernel(const T* inp,
- T* out,
- int rows,
- int cols,
- cudaStream_t stream);
- void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
- void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream);
|