#include #include #include #include #include #include #include #include "Timer.h" #include "context.h" #include "cublas_wrappers.h" #include "custom_cuda_layers.h" #include "ds_transformer_cuda.h" static std::unordered_map> s_transformer_layers; const int init_seq_length = 128; // C++ interface template unsigned get_workspace_size(unsigned maxBatchSize, unsigned seq_len, unsigned hidden_size, unsigned intermediate_size, unsigned heads, bool training, bool gelu_checkpoint) { unsigned workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size); if (training) { workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size); workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size), 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len))); if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size); } return workSpacesize; // * sizeof(T); } // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) template BertTransformerLayer::BertTransformerLayer(unsigned layer_id, unsigned batch_size, unsigned hidden_size, unsigned num_heads, unsigned intermediate_size, unsigned seq_length, float attn_prob_dropout_ratio, float hidden_output_dropout_ratio, float layer_norm_eps, bool pre_or_postLayerNorm, const std::vector>& gemm_algos, bool attn_dropout_checkpoint, bool normalize_invertible, bool gelu_checkpoint, bool stochastic_mode) : _layer_id(layer_id), _batch_size(batch_size), _hidden_size(hidden_size), _heads(num_heads), _intermediate_size(intermediate_size), _seq_length(seq_length), _training(true), _pre_or_postLayerNorm(pre_or_postLayerNorm), _attn_dropout_checkpoint(attn_dropout_checkpoint), _normalize_invertible(normalize_invertible), _gelu_checkpoint(gelu_checkpoint), _stochastic_mode(stochastic_mode), _stream(Context::Instance().GetCurrentStream()), _cublasHandle(Context::Instance().GetCublasHandle()), _qkv_linear(typename FeedForward::Config(batch_size * seq_length, 3 * hidden_size, hidden_size, gemm_algos[0])), _attn_out_linear(typename FeedForward::Config(batch_size * seq_length, hidden_size, hidden_size, gemm_algos[0])), _attn_layer_norm(typename Normalize_Layer::Config(batch_size, seq_length, hidden_size, layer_norm_eps, true, !normalize_invertible)), _layer_norm(typename Normalize_Layer::Config(batch_size, seq_length, hidden_size, layer_norm_eps, true, !normalize_invertible)), _ff1(typename FeedForward::Config(batch_size * seq_length, _intermediate_size, hidden_size, gemm_algos[1])), _ff2(typename FeedForward::Config(batch_size * seq_length, hidden_size, _intermediate_size, gemm_algos[2])), _softmax(typename Softmax::Config(batch_size, num_heads, seq_length)), _gelu(typename Gelu::Config(_intermediate_size)), _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio, _seq_length)), _attn_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)), _layer_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)), _attn_scores(typename StridedBatchGemm::Config(_batch_size * _heads, _seq_length, _seq_length, _hidden_size / _heads, (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T, CUBLAS_OP_N, gemm_algos[3])), _attn_context(typename StridedBatchGemm::Config(_batch_size * _heads, _hidden_size / _heads, _seq_length, _seq_length, T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N, gemm_algos[4])) { assert(_hidden_size % _heads == 0); Initialize(); } template BertTransformerLayer::~BertTransformerLayer() { } template void BertTransformerLayer::Initialize() { #ifndef __HIP_PLATFORM_HCC__ if (std::is_same::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); #endif } template void BertTransformerLayer::Forward(unsigned bsz, const T* input_ptr, const T* input_mask_ptr, const T* attn_qkvw_ptr, const T* attn_qkvb_ptr, const T* attn_ow_ptr, const T* attn_ob_ptr, const T* attn_nw_ptr, const T* attn_nb_ptr, const T* inter_w_ptr, const T* inter_b_ptr, const T* output_w_ptr, const T* output_b_ptr, const T* norm_w_ptr, const T* norm_b_ptr, T* out_ptr, T* inp_norm_ptr, T* q_tf_ptr, T* k_tf_ptr, T* v_tf_ptr, T* soft_out_ptr, T* ctx_bufB_ptr, T* attn_o_inp_ptr, T* add_res_ptr, T* ff1_inp_ptr, T* gelu_inp_ptr, T* ff2_inp_ptr) { cublasSetStream(_cublasHandle, _stream); if (!_stochastic_mode) cudaStreamSynchronize(_stream); T* workspace = static_cast(Context::Instance().GetWorkSpace()); size_t small_buf_size = bsz * _seq_length * _hidden_size; T* buf_0 = workspace; T* buf_1 = buf_0 + small_buf_size; T* buf_2 = buf_1; if (_normalize_invertible) { add_res_ptr = buf_1 + 3 * small_buf_size; buf_2 = add_res_ptr; } if (_gelu_checkpoint) buf_2 += small_buf_size; if (_attn_dropout_checkpoint) ctx_bufB_ptr = (_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size) : (buf_1 + 4 * small_buf_size)); int bsz_seq = bsz * _seq_length; if (_pre_or_postLayerNorm) { if (_layer_norm.UseMean()) _layer_norm.ForwardCheckpoint( bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); else _layer_norm.Forward( bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); } if (_pre_or_postLayerNorm) _qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle); else _qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle); launch_bias_add_transform_0213( q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3); int bsz_heads = bsz * _heads; // attention scores _attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle); // Softmax + Mask _softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream); // attn prob dropout. _attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream); // attention context _attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle); launch_transform4d_0213( attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1); if (_pre_or_postLayerNorm) _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle); else _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle); // attn output dropout. if (_pre_or_postLayerNorm) _attn_output_dropout.ForwardWithBias( bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream); else _attn_output_dropout.ForwardWithBias( bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream); if (_pre_or_postLayerNorm) { if (_attn_layer_norm.UseMean()) _attn_layer_norm.ForwardCheckpoint( bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); else _attn_layer_norm.Forward( bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); } else { if (_attn_layer_norm.UseMean()) _attn_layer_norm.ForwardCheckpoint( bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); else _attn_layer_norm.Forward( bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); } _ff1.Forward(bsz_seq, ff1_inp_ptr, inter_w_ptr, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), _cublasHandle); _gelu.ForwardWithBiasAdd(bsz_seq, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), _stream); _ff2.Forward( bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle); // layer output dropout. if (_pre_or_postLayerNorm) _layer_output_dropout.ForwardWithBias( bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream); else _layer_output_dropout.ForwardWithBias( bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream); if (!_pre_or_postLayerNorm) { if (_layer_norm.UseMean()) _layer_norm.ForwardCheckpoint( bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); else _layer_norm.Forward( bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); } } template void BertTransformerLayer::Backward(unsigned bsz, const T* grad_output_ptr, const T* input_ptr, const T* output_ptr, const T* inp_norm_ptr, const T* q_tf_ptr, const T* k_tf_ptr, const T* v_tf_ptr, const T* soft_out_ptr, const T* ctx_bufB_ptr, const T* attn_o_inp_ptr, const T* add_res_ptr, const T* ff1_inp_ptr, const T* gelu_inp_ptr, const T* ff2_inp_ptr, const T* input_mask_ptr, const T* attn_qkvw_ptr, const T* attn_ow_ptr, const T* attn_nw_ptr, const T* attn_nb_ptr, const T* inter_w_ptr, const T* inter_b_ptr, const T* output_w_ptr, const T* norm_w_ptr, const T* norm_b_ptr, T* grad_input_ptr, T* grad_attn_qkvw_ptr, T* grad_attn_qkvb_ptr, T* grad_attn_ow_ptr, T* grad_attn_ob_ptr, T* grad_attn_nw_ptr, T* grad_attn_nb_ptr, T* grad_inter_w_ptr, T* grad_inter_b_ptr, T* grad_output_w_ptr, T* grad_output_b_ptr, T* grad_norm_w_ptr, T* grad_norm_b_ptr) { cublasSetStream(_cublasHandle, _stream); if (!_stochastic_mode) cudaStreamSynchronize(_stream); T* workspace = static_cast(Context::Instance().GetWorkSpace()); size_t small_buf_size = bsz * _seq_length * _hidden_size; T* buf_0 = workspace; T* buf_1 = buf_0 + small_buf_size; T* buf_2 = buf_1 + small_buf_size; T* buf_3 = buf_2 + small_buf_size; T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size) : buf_3 + small_buf_size); T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads); cudaStream_t streams[2] = {_stream, _stream}; int bsz_seq = bsz * _seq_length; int bsz_heads = bsz * _heads; if (!_pre_or_postLayerNorm) { if (_layer_norm.UseMean()) _layer_norm.Backward(bsz_seq, grad_output_ptr, norm_w_ptr, grad_norm_w_ptr, grad_norm_b_ptr, streams, buf_1, inp_norm_ptr); else _layer_norm.Backward(bsz_seq, grad_output_ptr, norm_w_ptr, norm_b_ptr, grad_norm_w_ptr, grad_norm_b_ptr, streams, buf_1, output_ptr); } if (_pre_or_postLayerNorm) _layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream); else _layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream); const T* layer_dropout_buf = _layer_output_dropout.HasDropout() ? buf_0 : (_pre_or_postLayerNorm ? grad_output_ptr : buf_1); if (_gelu_checkpoint) _gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream); _ff2.Backward(bsz_seq, layer_dropout_buf, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, grad_output_w_ptr, grad_output_b_ptr, _cublasHandle, _stream, ff2_buf); _gelu.Backward( bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream); _ff1.Backward(bsz_seq, ff2_buf, ff1_inp_ptr, inter_w_ptr, grad_inter_w_ptr, grad_inter_b_ptr, _cublasHandle, _stream, buf_3); if (!_pre_or_postLayerNorm) launch_fused_add2(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream); if (_pre_or_postLayerNorm) { if (_attn_layer_norm.UseMean()) _attn_layer_norm.BackwardFusedAdd(bsz_seq, buf_3, grad_output_ptr, attn_nw_ptr, grad_attn_nw_ptr, grad_attn_nb_ptr, streams, buf_0, add_res_ptr); else _attn_layer_norm.BackwardFusedAdd(bsz_seq, buf_3, grad_output_ptr, attn_nw_ptr, attn_nb_ptr, grad_attn_nw_ptr, grad_attn_nb_ptr, streams, buf_0, ff1_inp_ptr); } else { if (_attn_layer_norm.UseMean()) _attn_layer_norm.Backward(bsz_seq, buf_2, attn_nw_ptr, grad_attn_nw_ptr, grad_attn_nb_ptr, streams, buf_0, add_res_ptr); else _attn_layer_norm.Backward(bsz_seq, buf_2, attn_nw_ptr, attn_nb_ptr, grad_attn_nw_ptr, grad_attn_nb_ptr, streams, buf_0, ff1_inp_ptr); } _attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream); T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0; _attn_out_linear.Backward(bsz_seq, attn_output_dropout_buf, attn_o_inp_ptr, attn_ow_ptr, grad_attn_ow_ptr, grad_attn_ob_ptr, _cublasHandle, _stream, buf_1); launch_transform_0213(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream); if (_attn_prob_dropout.HasDropout()) { if (_attn_dropout_checkpoint) _attn_prob_dropout.Forward( bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true); _attn_context.Backward(bsz_heads, buf_2, v_tf_ptr, (_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr), _cublasHandle, buf_3, ff2_buf); } else _attn_context.Backward( bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf); _attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream); _softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream); _attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1); launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3); if (_pre_or_postLayerNorm) _qkv_linear.Backward(bsz_seq, ff2_buf, inp_norm_ptr, attn_qkvw_ptr, grad_attn_qkvw_ptr, grad_attn_qkvb_ptr, _cublasHandle, _stream, buf_2); else _qkv_linear.Backward(bsz_seq, ff2_buf, input_ptr, attn_qkvw_ptr, grad_attn_qkvw_ptr, grad_attn_qkvb_ptr, _cublasHandle, _stream, buf_2); if (_pre_or_postLayerNorm) { if (_layer_norm.UseMean()) _layer_norm.BackwardFusedAdd(bsz_seq, buf_2, buf_0, norm_w_ptr, grad_norm_w_ptr, grad_norm_b_ptr, streams, grad_input_ptr, input_ptr); else _layer_norm.BackwardFusedAdd(bsz_seq, buf_2, buf_0, norm_w_ptr, norm_b_ptr, grad_norm_w_ptr, grad_norm_b_ptr, streams, grad_input_ptr, inp_norm_ptr); } else launch_fused_add2(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream); } template void BertTransformerLayer::SetTrainingMode(bool training) { // Dropout will be skipped when not in training model. _attn_prob_dropout.SetTrainingMode(training); _attn_output_dropout.SetTrainingMode(training); _layer_output_dropout.SetTrainingMode(training); } template void BertTransformerLayer::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr, uint8_t* attn_output_dropout_mask_ptr, uint8_t* layer_output_dropout_mask_ptr, T* attn_layer_norm_var, T* attn_layer_norm_mean, T* layer_norm_var, T* layer_norm_mean) { _attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr); _attn_output_dropout.SetMask(attn_output_dropout_mask_ptr); _layer_output_dropout.SetMask(layer_output_dropout_mask_ptr); _attn_layer_norm.SetVar(attn_layer_norm_var); _attn_layer_norm.SetMean(attn_layer_norm_mean); _layer_norm.SetVar(layer_norm_var); _layer_norm.SetMean(layer_norm_mean); } template void BertTransformerLayer::SetSeqLength(unsigned seq_len) { _seq_length = seq_len; _softmax.SetSeqLength(_seq_length); _attn_prob_dropout.SetDimension(_seq_length); _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads); _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length); } template int create_transformer_layer(unsigned layer_id, unsigned batch_size, unsigned hidden_dim, unsigned num_heads, unsigned intermediate_size, float attn_dropout_ratio, float hidden_dropout_ratio, float layer_norm_eps, int seed, bool pre_or_postLayerNorm, bool test_gemm, bool attn_dropout_checkpoint, bool normalize_invertible, bool gelu_checkpoint, bool stochastic_mode) { Context::Instance().SetSeed(seed); Context::Instance().TestGemmFP16( test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads); auto layer = std::make_shared>(layer_id, batch_size, hidden_dim, num_heads, intermediate_size, init_seq_length, attn_dropout_ratio, hidden_dropout_ratio, layer_norm_eps, pre_or_postLayerNorm, Context::Instance().GetGemmAlgos(), attn_dropout_checkpoint, normalize_invertible, gelu_checkpoint, stochastic_mode); s_transformer_layers[layer_id] = layer; std::string dtype = (std::is_same::value) ? "half" : "float"; std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]." << std::endl; return 0; } template std::vector ds_transformer_forward(unsigned layer_id, const torch::Tensor& input, const torch::Tensor& input_mask, const torch::Tensor& attn_qkvw, const torch::Tensor& attn_qkvb, const torch::Tensor& attn_ow, const torch::Tensor& attn_ob, const torch::Tensor& attn_nw, const torch::Tensor& attn_nb, const torch::Tensor& inter_w, const torch::Tensor& inter_b, const torch::Tensor& output_w, const torch::Tensor& output_b, const torch::Tensor& norm_w, const torch::Tensor& norm_b, bool training_mode, bool prelayernorm, bool attn_dropout_checkpoint, bool normalize_invertible, bool gelu_checkpoint) { CHECK_INPUT(input); CHECK_INPUT(input_mask); CHECK_INPUT(attn_qkvw); CHECK_INPUT(attn_qkvb); CHECK_INPUT(attn_ow); CHECK_INPUT(attn_ob); CHECK_INPUT(attn_nw); CHECK_INPUT(attn_nb); CHECK_INPUT(inter_w); CHECK_INPUT(inter_b); CHECK_INPUT(output_w); CHECK_INPUT(output_b); CHECK_INPUT(norm_w); CHECK_INPUT(norm_b); unsigned bsz = input.size(0); const T* input_ptr = (const T*)input.data_ptr(); const T* input_mask_ptr = (const T*)input_mask.data_ptr(); const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr(); const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr(); const T* attn_ow_ptr = (const T*)attn_ow.data_ptr(); const T* attn_ob_ptr = (const T*)attn_ob.data_ptr(); const T* attn_nw_ptr = (const T*)attn_nw.data_ptr(); const T* attn_nb_ptr = (const T*)attn_nb.data_ptr(); const T* inter_w_ptr = (const T*)inter_w.data_ptr(); const T* inter_b_ptr = (const T*)inter_b.data_ptr(); const T* output_w_ptr = (const T*)output_w.data_ptr(); const T* output_b_ptr = (const T*)output_b.data_ptr(); const T* norm_w_ptr = (const T*)norm_w.data_ptr(); const T* norm_b_ptr = (const T*)norm_b.data_ptr(); auto output = torch::empty_like(input); T* out_ptr = (T*)output.data_ptr(); auto options = torch::TensorOptions() .dtype(input.options().dtype()) .layout(torch::kStrided) .device(torch::kCUDA) .requires_grad(true); auto uint8_options = torch::TensorOptions() .dtype(torch::kInt8) .layout(torch::kStrided) .device(torch::kCUDA) .requires_grad(false); std::shared_ptr> layer = std::static_pointer_cast>(s_transformer_layers[layer_id]); unsigned seq_len = layer->GetSeqLength(); if (input.size(1) != seq_len) { seq_len = input.size(1); layer->SetSeqLength(seq_len); } auto workspace = torch::empty({get_workspace_size(bsz, seq_len, layer->GetHiddenSize(), layer->GetIntermediateSize(), layer->GetNumHeads(), layer->IsTrainingMode(), layer->GeluCheckpoint())}, options); Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output); auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input)); auto attn_o_inp = torch::empty_like(input); auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options); auto attn_prob_dropout_mask = torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options); auto attn_output_dropout_mask = torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options); auto layer_output_dropout_mask = torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options); auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options); auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options); auto layer_norm_var = torch::empty({(bsz * seq_len)}, options); auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options); T* inp_norm_ptr = (T*)inp_norm.data_ptr(); T* add_res_ptr = (T*)add_res.data_ptr(); T* q_tf_ptr = (T*)qkv_tf.data_ptr(); T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr(); T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr(); T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr(); torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options); torch::Tensor gelu_inp = (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options)); auto ff1_inp = torch::empty_like(input); T* ff2_inp_ptr = (T*)ff2_inp.data_ptr(); T* gelu_inp_ptr = (T*)gelu_inp.data_ptr(); T* ff1_inp_ptr = (T*)ff1_inp.data_ptr(); torch::Tensor soft_out = torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options); torch::Tensor ctx_bufB = (attn_dropout_checkpoint ? soft_out : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options)); T* soft_out_ptr = (T*)soft_out.data_ptr(); T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr(); layer->SetTrainingMode(training_mode); layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), (uint8_t*)attn_output_dropout_mask.data_ptr(), (uint8_t*)layer_output_dropout_mask.data_ptr(), (T*)attn_layer_norm_var.data_ptr(), (T*)attn_layer_norm_mean.data_ptr(), (T*)layer_norm_var.data_ptr(), (T*)layer_norm_mean.data_ptr()); layer->Forward(bsz, input_ptr, input_mask_ptr, attn_qkvw_ptr, attn_qkvb_ptr, attn_ow_ptr, attn_ob_ptr, attn_nw_ptr, attn_nb_ptr, inter_w_ptr, inter_b_ptr, output_w_ptr, output_b_ptr, norm_w_ptr, norm_b_ptr, out_ptr, inp_norm_ptr, q_tf_ptr, k_tf_ptr, v_tf_ptr, soft_out_ptr, ctx_bufB_ptr, attn_o_inp_ptr, add_res_ptr, ff1_inp_ptr, gelu_inp_ptr, ff2_inp_ptr); return {output, inp_norm, qkv_tf, soft_out, ctx_bufB, attn_o_inp, add_res, ff1_inp, gelu_inp, ff2_inp, attn_prob_dropout_mask, attn_output_dropout_mask, layer_output_dropout_mask, attn_layer_norm_var, attn_layer_norm_mean, layer_norm_var, layer_norm_mean}; } template std::vector ds_transformer_backward(unsigned layer_id, const torch::Tensor& grad_output, const torch::Tensor& output, const torch::Tensor& inp_norm, const torch::Tensor& qkv_tf, const torch::Tensor& soft_out, const torch::Tensor& ctx_bufB, const torch::Tensor& attn_o_inp, const torch::Tensor& add_res, const torch::Tensor& ff1_inp, const torch::Tensor& gelu_inp, const torch::Tensor& ff2_inp, const torch::Tensor& attn_prob_dropout_mask, const torch::Tensor& attn_output_dropout_mask, const torch::Tensor& layer_output_dropout_mask, const torch::Tensor& attn_layer_norm_var, const torch::Tensor& attn_layer_norm_mean, const torch::Tensor& layer_norm_var, const torch::Tensor& layer_norm_mean, const torch::Tensor& input, const torch::Tensor& input_mask, const torch::Tensor& attn_qkvw, const torch::Tensor& attn_qkvb, const torch::Tensor& attn_ow, const torch::Tensor& attn_ob, const torch::Tensor& attn_nw, const torch::Tensor& attn_nb, const torch::Tensor& inter_w, const torch::Tensor& inter_b, const torch::Tensor& output_w, const torch::Tensor& output_b, const torch::Tensor& norm_w, const torch::Tensor& norm_b) { auto g_output = grad_output.contiguous(); CHECK_INPUT(g_output); CHECK_INPUT(output); CHECK_INPUT(inp_norm); CHECK_INPUT(qkv_tf); CHECK_INPUT(add_res); CHECK_INPUT(soft_out); CHECK_INPUT(ctx_bufB); CHECK_INPUT(attn_o_inp); CHECK_INPUT(ff1_inp); CHECK_INPUT(gelu_inp); CHECK_INPUT(ff2_inp); CHECK_INPUT(input); CHECK_INPUT(input_mask); CHECK_INPUT(attn_qkvw); CHECK_INPUT(attn_qkvb); CHECK_INPUT(attn_ow); CHECK_INPUT(attn_ob); CHECK_INPUT(attn_nw); CHECK_INPUT(attn_nb); CHECK_INPUT(inter_w); CHECK_INPUT(inter_b); CHECK_INPUT(output_w); CHECK_INPUT(output_b); CHECK_INPUT(norm_w); CHECK_INPUT(norm_b); unsigned bsz = g_output.size(0); std::shared_ptr> layer = std::static_pointer_cast>(s_transformer_layers[layer_id]); unsigned seq_len = layer->GetSeqLength(); if (g_output.size(1) != seq_len) { seq_len = g_output.size(1); layer->SetSeqLength(seq_len); } auto options = torch::TensorOptions() .dtype(g_output.options().dtype()) .layout(torch::kStrided) .device(torch::kCUDA) .requires_grad(true); auto workspace = torch::empty({get_workspace_size(bsz, seq_len, layer->GetHiddenSize(), layer->GetIntermediateSize(), layer->GetNumHeads(), layer->IsTrainingMode(), layer->GeluCheckpoint())}, options); Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); auto grad_input = torch::empty_like(input); auto grad_attn_qkvw = torch::empty_like(attn_qkvw); auto grad_attn_qkvb = torch::empty_like(attn_qkvb); auto grad_attn_ow = torch::empty_like(attn_ow); auto grad_attn_ob = torch::empty_like(attn_ob); auto grad_attn_nw = torch::empty_like(attn_nw); auto grad_attn_nb = torch::empty_like(attn_nb); auto grad_inter_w = torch::empty_like(inter_w); auto grad_inter_b = torch::empty_like(inter_b); auto grad_output_w = torch::empty_like(output_w); auto grad_output_b = torch::empty_like(output_b); auto grad_norm_w = torch::empty_like(norm_w); auto grad_norm_b = torch::empty_like(norm_b); // inputs. const T* grad_output_ptr = (const T*)g_output.data_ptr(); const T* input_ptr = (const T*)input.data_ptr(); const T* output_ptr = (const T*)output.data_ptr(); const T* inp_norm_ptr = (const T*)inp_norm.data_ptr(); const T* q_tf_ptr = (const T*)qkv_tf.data_ptr(); const T* add_res_ptr = (const T*)add_res.data_ptr(); const T* k_tf_ptr = q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr(); const T* v_tf_ptr = k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr(); const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr(); const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr(); const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr(); const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr(); const T* soft_out_ptr = (const T*)soft_out.data_ptr(); const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr(); const T* input_mask_ptr = (const T*)input_mask.data_ptr(); const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr(); const T* attn_ow_ptr = (const T*)attn_ow.data_ptr(); const T* attn_nw_ptr = (const T*)attn_nw.data_ptr(); const T* attn_nb_ptr = (const T*)attn_nb.data_ptr(); const T* inter_w_ptr = (const T*)inter_w.data_ptr(); const T* inter_b_ptr = (const T*)inter_b.data_ptr(); const T* output_w_ptr = (const T*)output_w.data_ptr(); const T* norm_w_ptr = (const T*)norm_w.data_ptr(); const T* norm_b_ptr = (const T*)norm_b.data_ptr(); // outputs. T* grad_input_ptr = (T*)grad_input.data_ptr(); T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr(); T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr(); T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr(); T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr(); T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr(); T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr(); T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr(); T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr(); T* grad_output_w_ptr = (T*)grad_output_w.data_ptr(); T* grad_output_b_ptr = (T*)grad_output_b.data_ptr(); T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr(); T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr(); layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), (uint8_t*)attn_output_dropout_mask.data_ptr(), (uint8_t*)layer_output_dropout_mask.data_ptr(), (T*)attn_layer_norm_var.data_ptr(), (T*)attn_layer_norm_mean.data_ptr(), (T*)layer_norm_var.data_ptr(), (T*)layer_norm_mean.data_ptr()); layer->Backward(bsz, grad_output_ptr, input_ptr, output_ptr, inp_norm_ptr, q_tf_ptr, k_tf_ptr, v_tf_ptr, soft_out_ptr, ctx_bufB_ptr, attn_o_inp_ptr, add_res_ptr, ff1_inp_ptr, gelu_inp_ptr, ff2_inp_ptr, input_mask_ptr, attn_qkvw_ptr, attn_ow_ptr, attn_nw_ptr, attn_nb_ptr, inter_w_ptr, inter_b_ptr, output_w_ptr, norm_w_ptr, norm_b_ptr, grad_input_ptr, grad_attn_qkvw_ptr, grad_attn_qkvb_ptr, grad_attn_ow_ptr, grad_attn_ob_ptr, grad_attn_nw_ptr, grad_attn_nb_ptr, grad_inter_w_ptr, grad_inter_b_ptr, grad_output_w_ptr, grad_output_b_ptr, grad_norm_w_ptr, grad_norm_b_ptr); return {grad_input, grad_attn_qkvw, grad_attn_qkvb, grad_attn_ow, grad_attn_ob, grad_attn_nw, grad_attn_nb, grad_inter_w, grad_inter_b, grad_output_w, grad_output_b, grad_norm_w, grad_norm_b}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_fp32", &ds_transformer_forward, "DeepSpeed Transformer forward with fp32 (CUDA)"); m.def("forward_fp16", &ds_transformer_forward<__half>, "DeepSpeed Transformer forward with fp16 (CUDA)"); m.def("backward_fp32", &ds_transformer_backward, "DeepSpeed Transformer backward with fp32 (CUDA)"); m.def("backward_fp16", &ds_transformer_backward<__half>, "DeepSpeed Transformer backward with fp16 (CUDA)"); m.def("create_transformer_layer_fp32", &create_transformer_layer, "Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"); m.def("create_transformer_layer_fp16", &create_transformer_layer<__half>, "Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"); }