ds_transformer_cuda.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #pragma once
  2. #include <cuda_runtime_api.h>
  3. #include <curand.h>
  4. #include <memory>
  5. #include <vector>
  6. #include "cublas_v2.h"
  7. #include "cuda.h"
  8. #include "dropout.h"
  9. #include "feed_forward.h"
  10. #include "gelu.h"
  11. #include "general_kernels.h"
  12. #include "normalize_layer.h"
  13. #include "softmax.h"
  14. #include "strided_batch_gemm.h"
  15. struct BertGemmAlgos {
  16. int m_gemm_qkv_algo;
  17. int m_gemm_inter_algo;
  18. int m_gemm_output_algo;
  19. int m_gemm_batch1_algo;
  20. int m_gemm_batch2_algo;
  21. BertGemmAlgos()
  22. : m_gemm_qkv_algo(-1),
  23. m_gemm_inter_algo(-1),
  24. m_gemm_output_algo(-1),
  25. m_gemm_batch1_algo(-1),
  26. m_gemm_batch2_algo(-1)
  27. {
  28. }
  29. };
  30. template <typename T>
  31. class BertTransformerLayer {
  32. public:
  33. BertTransformerLayer(unsigned layer_id,
  34. unsigned batch_size,
  35. unsigned hidden_size,
  36. unsigned num_heads,
  37. unsigned intermediate_size,
  38. unsigned seq_length,
  39. float attn_dropout_ratio,
  40. float hidden_output_dropout_ratio,
  41. float layer_norm_eps,
  42. bool pre_or_postLayerNorm,
  43. const std::vector<std::array<int, 3>>& gemm_algos,
  44. bool attn_dropout_checkpoint,
  45. bool normalize_invertible,
  46. bool gelu_checkpoint,
  47. bool stochastic_mode);
  48. virtual ~BertTransformerLayer();
  49. void Forward(unsigned bsz,
  50. const T* input_ptr,
  51. const T* input_mask_ptr,
  52. const T* attn_qkvw_ptr,
  53. const T* attn_qkvb_ptr,
  54. const T* attn_ow_ptr,
  55. const T* attn_ob_ptr,
  56. const T* attn_nw_ptr,
  57. const T* attn_nb_ptr,
  58. const T* inter_w_ptr,
  59. const T* inter_b_ptr,
  60. const T* output_w_ptr,
  61. const T* output_b_ptr,
  62. const T* norm_w_ptr,
  63. const T* norm_b_ptr,
  64. T* out_ptr,
  65. T* inp_norm_ptr,
  66. T* q_tf_ptr,
  67. T* k_tf_ptr,
  68. T* v_tf_ptr,
  69. T* softmax_output_ptr,
  70. T* ctx_bufB_ptr,
  71. T* attn_o_inp_ptr,
  72. T* add_res_ptr,
  73. T* ff1_inp_ptr,
  74. T* gelu_inp_ptr,
  75. T* ff2_inp_ptr);
  76. void Backward(unsigned bsz,
  77. const T* grad_output_ptr,
  78. const T* input_ptr,
  79. const T* output_ptr,
  80. const T* inp_norm_ptr,
  81. const T* q_tf_ptr,
  82. const T* k_tf_ptr,
  83. const T* v_tf_ptr,
  84. const T* softmax_output_ptr,
  85. const T* ctx_bufB_ptr,
  86. const T* attn_o_inp_ptr,
  87. const T* add_res_ptr,
  88. const T* ff1_inp_ptr,
  89. const T* gelu_inp_ptr,
  90. const T* ff2_inp_ptr,
  91. const T* input_mask_ptr,
  92. const T* attn_qkvw_ptr,
  93. const T* attn_ow_ptr,
  94. const T* attn_nw_ptr,
  95. const T* attn_nb_ptr,
  96. const T* inter_w_ptr,
  97. const T* inter_b_ptr,
  98. const T* output_w_ptr,
  99. const T* norm_w_ptr,
  100. const T* norm_b_ptr,
  101. T* grad_input_ptr,
  102. T* grad_attn_qkvw_ptr,
  103. T* grad_attn_qkvb_ptr,
  104. T* grad_attn_ow_ptr,
  105. T* grad_attn_ob_ptr,
  106. T* grad_attn_nw_ptr,
  107. T* grad_attn_nb_ptr,
  108. T* grad_inter_w_ptr,
  109. T* grad_inter_b_ptr,
  110. T* grad_output_w_ptr,
  111. T* grad_output_b_ptr,
  112. T* grad_norm_w_ptr,
  113. T* grad_norm_b_ptr);
  114. void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
  115. uint8_t* attn_output_dropout_mask_ptr,
  116. uint8_t* layer_output_dropout_mask_ptr,
  117. T* layer_norm_var,
  118. T* layer_norm_mean,
  119. T* attn_layer_norm_var,
  120. T* attn_layer_norm_mean);
  121. inline unsigned GetBatchSize() const { return _batch_size; }
  122. inline unsigned GetNumHeads() const { return _heads; }
  123. inline unsigned GetSeqLength() const { return _seq_length; }
  124. inline unsigned GetIntermediateSize() const { return _intermediate_size; }
  125. void SetSeqLength(unsigned seq_len);
  126. inline unsigned GetHiddenSize() const { return _hidden_size; }
  127. void SetTrainingMode(bool training);
  128. inline bool IsTrainingMode() const { return _training; }
  129. inline bool GeluCheckpoint() const { return _gelu_checkpoint; }
  130. private:
  131. void Initialize();
  132. size_t getWorkspaceSize(int maxBatchSize) const;
  133. // Params
  134. unsigned _layer_id;
  135. unsigned _batch_size;
  136. unsigned _hidden_size;
  137. unsigned _heads;
  138. unsigned _size_per_head;
  139. unsigned _intermediate_size;
  140. unsigned _seq_length;
  141. bool _pre_or_postLayerNorm;
  142. cublasHandle_t _cublasHandle;
  143. cudaStream_t _stream;
  144. // layers
  145. FeedForward<T> _qkv_linear;
  146. FeedForward<T> _attn_out_linear;
  147. Normalize_Layer<T> _attn_layer_norm;
  148. Normalize_Layer<T> _layer_norm;
  149. Normalize_Layer<T>* _last_normalize;
  150. FeedForward<T> _ff1, _ff2;
  151. Softmax<T> _softmax;
  152. Gelu<T> _gelu;
  153. Dropout<T> _attn_prob_dropout;
  154. Dropout<T> _attn_output_dropout;
  155. Dropout<T> _layer_output_dropout;
  156. StridedBatchGemm<T> _attn_scores;
  157. StridedBatchGemm<T> _attn_context;
  158. bool _training;
  159. // Memory saving flags
  160. bool _attn_dropout_checkpoint;
  161. bool _normalize_invertible;
  162. bool _gelu_checkpoint;
  163. // High Performance flags
  164. bool _stochastic_mode;
  165. };