normalize_layer.h 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <cuda.h>
  6. #include <cuda_fp16.h>
  7. #include <stdio.h>
  8. #include <fstream>
  9. #include "custom_cuda_layers.h"
  10. using namespace std;
  11. template <typename T>
  12. class Normalize_Layer {
  13. public:
  14. struct Config {
  15. uint32_t batchSize;
  16. uint32_t seqLength;
  17. uint32_t hiddenDim;
  18. float epsilon;
  19. bool training;
  20. bool useMean;
  21. Config(uint32_t batch,
  22. uint32_t seq,
  23. uint32_t h,
  24. float epsilon = 1e-12,
  25. bool training = true,
  26. bool useMean = true)
  27. : batchSize(batch),
  28. seqLength(seq),
  29. hiddenDim(h),
  30. epsilon(epsilon),
  31. training(training),
  32. useMean(useMean)
  33. {
  34. }
  35. };
  36. Normalize_Layer(Config config)
  37. : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
  38. {
  39. }
  40. ~Normalize_Layer() {}
  41. void ForwardCheckpoint(int bsz, // batch * seq
  42. T* vals,
  43. const T* residual,
  44. const T* gamma,
  45. const T* betta,
  46. cudaStream_t& stream,
  47. bool preLayerNorm = false)
  48. {
  49. launch_bias_residual_layer_norm(vals,
  50. residual,
  51. gamma,
  52. betta,
  53. config_.epsilon,
  54. bsz,
  55. config_.hiddenDim,
  56. stream,
  57. preLayerNorm,
  58. config_.training,
  59. vars,
  60. means);
  61. }
  62. void Forward(int bsz,
  63. T* vals,
  64. const T* residual,
  65. const T* gamma,
  66. const T* betta,
  67. cudaStream_t& stream,
  68. bool preLayerNorm = false)
  69. {
  70. launch_bias_residual_layer_norm(vals,
  71. residual,
  72. gamma,
  73. betta,
  74. config_.epsilon,
  75. bsz,
  76. config_.hiddenDim,
  77. stream,
  78. preLayerNorm,
  79. config_.training,
  80. vars);
  81. }
  82. void Backward(int bsz,
  83. const T* out_grad,
  84. const T* gamma,
  85. T* gamma_grad,
  86. T* betta_grad,
  87. cudaStream_t stream[2],
  88. T* inp_grad_out,
  89. const T* norm_in = nullptr)
  90. {
  91. launch_layerNorm_backward(out_grad,
  92. norm_in,
  93. vars,
  94. means,
  95. gamma,
  96. gamma_grad,
  97. betta_grad,
  98. inp_grad_out,
  99. bsz,
  100. config_.hiddenDim,
  101. stream);
  102. }
  103. void Backward(int bsz,
  104. const T* out_grad,
  105. const T* gamma,
  106. const T* betta,
  107. T* gamma_grad,
  108. T* betta_grad,
  109. cudaStream_t stream[2],
  110. T* inp_grad_out,
  111. const T* norm_out)
  112. {
  113. launch_layerNorm_backward(out_grad,
  114. norm_out,
  115. vars,
  116. gamma,
  117. gamma_grad,
  118. betta_grad,
  119. inp_grad_out,
  120. bsz,
  121. config_.hiddenDim,
  122. stream,
  123. !config_.useMean,
  124. betta);
  125. }
  126. void BackwardFusedAdd(int bsz,
  127. const T* out_grad1,
  128. const T* out_grad2,
  129. const T* gamma,
  130. T* gamma_grad,
  131. T* betta_grad,
  132. cudaStream_t stream[2],
  133. T* inp_grad_out,
  134. const T* norm_in = nullptr)
  135. {
  136. launch_layerNorm_backward_fused_add(out_grad1,
  137. out_grad2,
  138. norm_in,
  139. vars,
  140. means,
  141. gamma,
  142. gamma_grad,
  143. betta_grad,
  144. inp_grad_out,
  145. bsz,
  146. config_.hiddenDim,
  147. stream);
  148. }
  149. void BackwardFusedAdd(int bsz,
  150. const T* out_grad1,
  151. const T* out_grad2,
  152. const T* gamma,
  153. const T* betta,
  154. T* gamma_grad,
  155. T* betta_grad,
  156. cudaStream_t stream[2],
  157. T* inp_grad_out,
  158. const T* norm_out)
  159. {
  160. launch_layerNorm_backward_fused_add(out_grad1,
  161. out_grad2,
  162. norm_out,
  163. vars,
  164. gamma,
  165. gamma_grad,
  166. betta_grad,
  167. inp_grad_out,
  168. bsz,
  169. config_.hiddenDim,
  170. stream,
  171. !config_.useMean,
  172. betta);
  173. }
  174. inline bool UseMean() const { return config_.useMean; }
  175. inline void SetVar(T* variance)
  176. {
  177. if (!variance) { throw std::runtime_error("Normalize variance is null."); }
  178. vars = variance;
  179. }
  180. inline void SetMean(T* mean)
  181. {
  182. if (!mean) { throw std::runtime_error("Normalize mean is null."); }
  183. means = mean;
  184. }
  185. private:
  186. Config config_;
  187. T* vars;
  188. T* means;
  189. T* vals_hat;
  190. };