ds_transformer_cuda.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <cuda_runtime_api.h>
  6. #include <curand.h>
  7. #include <memory>
  8. #include <vector>
  9. #include "cublas_v2.h"
  10. #include "cuda.h"
  11. #include "dropout.h"
  12. #include "feed_forward.h"
  13. #include "gelu.h"
  14. #include "general_kernels.h"
  15. #include "normalize_layer.h"
  16. #include "softmax.h"
  17. #include "strided_batch_gemm.h"
  18. struct BertGemmAlgos {
  19. int m_gemm_qkv_algo;
  20. int m_gemm_inter_algo;
  21. int m_gemm_output_algo;
  22. int m_gemm_batch1_algo;
  23. int m_gemm_batch2_algo;
  24. BertGemmAlgos()
  25. : m_gemm_qkv_algo(-1),
  26. m_gemm_inter_algo(-1),
  27. m_gemm_output_algo(-1),
  28. m_gemm_batch1_algo(-1),
  29. m_gemm_batch2_algo(-1)
  30. {
  31. }
  32. };
  33. template <typename T>
  34. class BertTransformerLayer {
  35. public:
  36. BertTransformerLayer(unsigned layer_id,
  37. unsigned batch_size,
  38. unsigned hidden_size,
  39. unsigned num_heads,
  40. unsigned intermediate_size,
  41. unsigned seq_length,
  42. float attn_dropout_ratio,
  43. float hidden_output_dropout_ratio,
  44. float layer_norm_eps,
  45. bool pre_or_postLayerNorm,
  46. const std::vector<std::array<int, 3>>& gemm_algos,
  47. bool attn_dropout_checkpoint,
  48. bool normalize_invertible,
  49. bool gelu_checkpoint,
  50. bool stochastic_mode);
  51. virtual ~BertTransformerLayer();
  52. void Forward(unsigned bsz,
  53. const T* input_ptr,
  54. const T* input_mask_ptr,
  55. const T* attn_qkvw_ptr,
  56. const T* attn_qkvb_ptr,
  57. const T* attn_ow_ptr,
  58. const T* attn_ob_ptr,
  59. const T* attn_nw_ptr,
  60. const T* attn_nb_ptr,
  61. const T* inter_w_ptr,
  62. const T* inter_b_ptr,
  63. const T* output_w_ptr,
  64. const T* output_b_ptr,
  65. const T* norm_w_ptr,
  66. const T* norm_b_ptr,
  67. T* out_ptr,
  68. T* inp_norm_ptr,
  69. T* q_tf_ptr,
  70. T* k_tf_ptr,
  71. T* v_tf_ptr,
  72. T* softmax_output_ptr,
  73. T* ctx_bufB_ptr,
  74. T* attn_o_inp_ptr,
  75. T* add_res_ptr,
  76. T* ff1_inp_ptr,
  77. T* gelu_inp_ptr,
  78. T* ff2_inp_ptr);
  79. void Backward(unsigned bsz,
  80. const T* grad_output_ptr,
  81. const T* input_ptr,
  82. const T* output_ptr,
  83. const T* inp_norm_ptr,
  84. const T* q_tf_ptr,
  85. const T* k_tf_ptr,
  86. const T* v_tf_ptr,
  87. const T* softmax_output_ptr,
  88. const T* ctx_bufB_ptr,
  89. const T* attn_o_inp_ptr,
  90. const T* add_res_ptr,
  91. const T* ff1_inp_ptr,
  92. const T* gelu_inp_ptr,
  93. const T* ff2_inp_ptr,
  94. const T* input_mask_ptr,
  95. const T* attn_qkvw_ptr,
  96. const T* attn_ow_ptr,
  97. const T* attn_nw_ptr,
  98. const T* attn_nb_ptr,
  99. const T* inter_w_ptr,
  100. const T* inter_b_ptr,
  101. const T* output_w_ptr,
  102. const T* norm_w_ptr,
  103. const T* norm_b_ptr,
  104. T* grad_input_ptr,
  105. T* grad_attn_qkvw_ptr,
  106. T* grad_attn_qkvb_ptr,
  107. T* grad_attn_ow_ptr,
  108. T* grad_attn_ob_ptr,
  109. T* grad_attn_nw_ptr,
  110. T* grad_attn_nb_ptr,
  111. T* grad_inter_w_ptr,
  112. T* grad_inter_b_ptr,
  113. T* grad_output_w_ptr,
  114. T* grad_output_b_ptr,
  115. T* grad_norm_w_ptr,
  116. T* grad_norm_b_ptr);
  117. void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
  118. uint8_t* attn_output_dropout_mask_ptr,
  119. uint8_t* layer_output_dropout_mask_ptr,
  120. T* layer_norm_var,
  121. T* layer_norm_mean,
  122. T* attn_layer_norm_var,
  123. T* attn_layer_norm_mean);
  124. inline unsigned GetBatchSize() const { return _batch_size; }
  125. inline unsigned GetNumHeads() const { return _heads; }
  126. inline unsigned GetSeqLength() const { return _seq_length; }
  127. inline unsigned GetIntermediateSize() const { return _intermediate_size; }
  128. void SetSeqLength(unsigned seq_len);
  129. inline unsigned GetHiddenSize() const { return _hidden_size; }
  130. void SetTrainingMode(bool training);
  131. inline bool IsTrainingMode() const { return _training; }
  132. inline bool GeluCheckpoint() const { return _gelu_checkpoint; }
  133. private:
  134. void Initialize();
  135. size_t getWorkspaceSize(int maxBatchSize) const;
  136. // Params
  137. unsigned _layer_id;
  138. unsigned _batch_size;
  139. unsigned _hidden_size;
  140. unsigned _heads;
  141. unsigned _size_per_head;
  142. unsigned _intermediate_size;
  143. unsigned _seq_length;
  144. bool _pre_or_postLayerNorm;
  145. cublasHandle_t _cublasHandle;
  146. cudaStream_t _stream;
  147. // layers
  148. FeedForward<T> _qkv_linear;
  149. FeedForward<T> _attn_out_linear;
  150. Normalize_Layer<T> _attn_layer_norm;
  151. Normalize_Layer<T> _layer_norm;
  152. Normalize_Layer<T>* _last_normalize;
  153. FeedForward<T> _ff1, _ff2;
  154. Softmax<T> _softmax;
  155. Gelu<T> _gelu;
  156. Dropout<T> _attn_prob_dropout;
  157. Dropout<T> _attn_output_dropout;
  158. Dropout<T> _layer_output_dropout;
  159. StridedBatchGemm<T> _attn_scores;
  160. StridedBatchGemm<T> _attn_context;
  161. bool _training;
  162. // Memory saving flags
  163. bool _attn_dropout_checkpoint;
  164. bool _normalize_invertible;
  165. bool _gelu_checkpoint;
  166. // High Performance flags
  167. bool _stochastic_mode;
  168. };