normalize_layer.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_fp16.h>
  4. #include <stdio.h>
  5. #include <fstream>
  6. #include "custom_cuda_layers.h"
  7. using namespace std;
  8. template <typename T>
  9. class Normalize_Layer {
  10. public:
  11. struct Config {
  12. uint32_t batchSize;
  13. uint32_t seqLength;
  14. uint32_t hiddenDim;
  15. float epsilon;
  16. bool training;
  17. bool useMean;
  18. Config(uint32_t batch,
  19. uint32_t seq,
  20. uint32_t h,
  21. float epsilon = 1e-12,
  22. bool training = true,
  23. bool useMean = true)
  24. : batchSize(batch),
  25. seqLength(seq),
  26. hiddenDim(h),
  27. epsilon(epsilon),
  28. training(training),
  29. useMean(useMean)
  30. {
  31. }
  32. };
  33. Normalize_Layer(Config config)
  34. : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
  35. {
  36. }
  37. ~Normalize_Layer() {}
  38. void ForwardCheckpoint(int bsz, // batch * seq
  39. T* vals,
  40. const T* residual,
  41. const T* gamma,
  42. const T* betta,
  43. cudaStream_t& stream,
  44. bool preLayerNorm = false)
  45. {
  46. launch_bias_residual_layer_norm(vals,
  47. residual,
  48. gamma,
  49. betta,
  50. config_.epsilon,
  51. bsz,
  52. config_.hiddenDim,
  53. stream,
  54. preLayerNorm,
  55. config_.training,
  56. vars,
  57. means);
  58. }
  59. void Forward(int bsz,
  60. T* vals,
  61. const T* residual,
  62. const T* gamma,
  63. const T* betta,
  64. cudaStream_t& stream,
  65. bool preLayerNorm = false)
  66. {
  67. launch_bias_residual_layer_norm(vals,
  68. residual,
  69. gamma,
  70. betta,
  71. config_.epsilon,
  72. bsz,
  73. config_.hiddenDim,
  74. stream,
  75. preLayerNorm,
  76. config_.training,
  77. vars);
  78. }
  79. void Backward(int bsz,
  80. const T* out_grad,
  81. const T* gamma,
  82. T* gamma_grad,
  83. T* betta_grad,
  84. cudaStream_t stream[2],
  85. T* inp_grad_out,
  86. const T* norm_in = nullptr)
  87. {
  88. launch_layerNorm_backward(out_grad,
  89. norm_in,
  90. vars,
  91. means,
  92. gamma,
  93. gamma_grad,
  94. betta_grad,
  95. inp_grad_out,
  96. bsz,
  97. config_.hiddenDim,
  98. stream);
  99. }
  100. void Backward(int bsz,
  101. const T* out_grad,
  102. const T* gamma,
  103. const T* betta,
  104. T* gamma_grad,
  105. T* betta_grad,
  106. cudaStream_t stream[2],
  107. T* inp_grad_out,
  108. const T* norm_out)
  109. {
  110. launch_layerNorm_backward(out_grad,
  111. norm_out,
  112. vars,
  113. gamma,
  114. gamma_grad,
  115. betta_grad,
  116. inp_grad_out,
  117. bsz,
  118. config_.hiddenDim,
  119. stream,
  120. !config_.useMean,
  121. betta);
  122. }
  123. void BackwardFusedAdd(int bsz,
  124. const T* out_grad1,
  125. const T* out_grad2,
  126. const T* gamma,
  127. T* gamma_grad,
  128. T* betta_grad,
  129. cudaStream_t stream[2],
  130. T* inp_grad_out,
  131. const T* norm_in = nullptr)
  132. {
  133. launch_layerNorm_backward_fused_add(out_grad1,
  134. out_grad2,
  135. norm_in,
  136. vars,
  137. means,
  138. gamma,
  139. gamma_grad,
  140. betta_grad,
  141. inp_grad_out,
  142. bsz,
  143. config_.hiddenDim,
  144. stream);
  145. }
  146. void BackwardFusedAdd(int bsz,
  147. const T* out_grad1,
  148. const T* out_grad2,
  149. const T* gamma,
  150. const T* betta,
  151. T* gamma_grad,
  152. T* betta_grad,
  153. cudaStream_t stream[2],
  154. T* inp_grad_out,
  155. const T* norm_out)
  156. {
  157. launch_layerNorm_backward_fused_add(out_grad1,
  158. out_grad2,
  159. norm_out,
  160. vars,
  161. gamma,
  162. gamma_grad,
  163. betta_grad,
  164. inp_grad_out,
  165. bsz,
  166. config_.hiddenDim,
  167. stream,
  168. !config_.useMean,
  169. betta);
  170. }
  171. inline bool UseMean() const { return config_.useMean; }
  172. inline void SetVar(T* variance)
  173. {
  174. if (!variance) { throw std::runtime_error("Normalize variance is null."); }
  175. vars = variance;
  176. }
  177. inline void SetMean(T* mean)
  178. {
  179. if (!mean) { throw std::runtime_error("Normalize mean is null."); }
  180. means = mean;
  181. }
  182. private:
  183. Config config_;
  184. T* vars;
  185. T* means;
  186. T* vals_hat;
  187. };