cpu_adam.h 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. #pragma once
  2. #include <cuda_fp16.h>
  3. #include <cuda_runtime_api.h>
  4. #include <stdio.h>
  5. #include <cassert>
  6. #include "cuda.h"
  7. #include "custom_cuda_layers.h"
  8. #include "simd.h"
  9. #define STEP(SPAN) \
  10. void Step_##SPAN(float* _params, \
  11. float* grads, \
  12. float* _exp_avg, \
  13. float* _exp_avg_sq, \
  14. size_t _param_size, \
  15. __half* dev_param = nullptr, \
  16. bool half_precision = false);
  17. class Adam_Optimizer {
  18. public:
  19. Adam_Optimizer(float alpha = 1e-3,
  20. float betta1 = 0.9,
  21. float betta2 = 0.999,
  22. float eps = 1e-8,
  23. float weight_decay = 0,
  24. bool adamw_mode = true)
  25. : _alpha(alpha),
  26. _betta1(betta1),
  27. _betta2(betta2),
  28. _eps(eps),
  29. _weight_decay(weight_decay),
  30. _betta1_t(1.0),
  31. _betta2_t(1.0),
  32. _step(0),
  33. _buf_index(false),
  34. _adamw_mode(adamw_mode)
  35. {
  36. cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
  37. cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
  38. _streams[0] = Context::Instance().GetCurrentStream();
  39. _streams[1] = Context::Instance().GetNewStream();
  40. }
  41. ~Adam_Optimizer()
  42. {
  43. cudaFreeHost(_doubled_buffer[0]);
  44. cudaFreeHost(_doubled_buffer[1]);
  45. }
  46. #if defined(__AVX512__) or defined(__AVX256__)
  47. template <int span>
  48. void Step_AVX(size_t* rounded_size,
  49. float* _params,
  50. float* grads,
  51. float* _exp_avg,
  52. float* _exp_avg_sq,
  53. size_t param_size,
  54. __half* dev_param = nullptr,
  55. bool half_precision = false);
  56. #endif
  57. STEP(1)
  58. STEP(4)
  59. STEP(8)
  60. inline void SynchronizeStreams()
  61. {
  62. for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
  63. }
  64. inline void IncrementStep(size_t step, float beta1, float beta2)
  65. {
  66. if (beta1 != _betta1 || beta2 != _betta2) {
  67. _step = step;
  68. _betta1 = beta1;
  69. _betta2 = beta2;
  70. _betta1_t = std::pow(_betta1, step);
  71. _betta2_t = std::pow(_betta2, step);
  72. } else {
  73. _step++;
  74. if (_step != step) {
  75. _betta1_t = std::pow(_betta1, step);
  76. _betta2_t = std::pow(_betta2, step);
  77. _step = step;
  78. } else {
  79. _betta1_t *= _betta1;
  80. _betta2_t *= _betta2;
  81. }
  82. }
  83. }
  84. inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
  85. {
  86. _alpha = lr;
  87. _eps = epsilon;
  88. _weight_decay = weight_decay;
  89. _bias_correction1 = 1.0f;
  90. _bias_correction2 = 1.0f;
  91. if (bias_correction == 1) {
  92. _bias_correction1 = 1 - _betta1_t;
  93. _bias_correction2 = 1 / sqrt(1 - _betta2_t);
  94. }
  95. }
  96. private:
  97. float _alpha;
  98. float _betta1;
  99. float _betta2;
  100. float _eps;
  101. float _weight_decay;
  102. float _betta1_t;
  103. float _betta2_t;
  104. size_t _step;
  105. float _bias_correction1;
  106. float _bias_correction2;
  107. float* _doubled_buffer[2];
  108. bool _buf_index;
  109. bool _adamw_mode;
  110. cudaStream_t _streams[2];
  111. };
  112. #if defined(__AVX512__) or defined(__AVX256__)
  113. template <int span>
  114. void Adam_Optimizer::Step_AVX(size_t* rounded_size,
  115. float* _params,
  116. float* grads,
  117. float* _exp_avg,
  118. float* _exp_avg_sq,
  119. size_t _param_size,
  120. __half* dev_params,
  121. bool half_precision)
  122. {
  123. size_t new_rounded_size = 0;
  124. AVX_Data betta1_4;
  125. betta1_4.data = SIMD_SET(_betta1);
  126. AVX_Data betta2_4;
  127. betta2_4.data = SIMD_SET(_betta2);
  128. float betta1_minus1 = 1 - _betta1;
  129. float betta2_minus1 = 1 - _betta2;
  130. AVX_Data betta1_minus1_4;
  131. betta1_minus1_4.data = SIMD_SET(betta1_minus1);
  132. AVX_Data betta2_minus1_4;
  133. betta2_minus1_4.data = SIMD_SET(betta2_minus1);
  134. AVX_Data bias2_sqrt;
  135. bias2_sqrt.data = SIMD_SET(_bias_correction2);
  136. AVX_Data eps_4;
  137. eps_4.data = SIMD_SET(_eps);
  138. float step_size = -1 * _alpha / _bias_correction1;
  139. AVX_Data step_size_4;
  140. step_size_4.data = SIMD_SET(step_size);
  141. float w_decay = -1 * _alpha * _weight_decay;
  142. AVX_Data weight_decay4;
  143. if (_weight_decay > 0)
  144. weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
  145. new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
  146. for (size_t t = 0; t < new_rounded_size; t += TILE) {
  147. size_t copy_size = TILE;
  148. if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
  149. size_t offset = copy_size + t;
  150. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  151. #pragma omp parallel for
  152. for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
  153. AVX_Data grad_4[span];
  154. simd_load<span>(grad_4, grads + i, half_precision);
  155. AVX_Data momentum_4[span];
  156. simd_load<span>(momentum_4, _exp_avg + i, false);
  157. AVX_Data variance_4[span];
  158. simd_load<span>(variance_4, _exp_avg_sq + i, false);
  159. AVX_Data param_4[span];
  160. simd_load<span>(param_4, _params + i, half_precision);
  161. if (_weight_decay > 0 && !_adamw_mode) {
  162. simd_fma<span>(grad_4, param_4, weight_decay4, grad_4);
  163. }
  164. simd_mul<span>(momentum_4, momentum_4, betta1_4);
  165. simd_fma<span>(momentum_4, grad_4, betta1_minus1_4, momentum_4);
  166. simd_mul<span>(variance_4, variance_4, betta2_4);
  167. simd_mul<span>(grad_4, grad_4, grad_4);
  168. simd_fma<span>(variance_4, grad_4, betta2_minus1_4, variance_4);
  169. simd_sqrt<span>(grad_4, variance_4);
  170. simd_fma<span>(grad_4, grad_4, bias2_sqrt, eps_4);
  171. simd_div<span>(grad_4, momentum_4, grad_4);
  172. if (_weight_decay > 0 && _adamw_mode) {
  173. simd_fma<span>(param_4, param_4, weight_decay4, param_4);
  174. }
  175. simd_fma<span>(param_4, grad_4, step_size_4, param_4);
  176. simd_store<span>(_params + i, param_4, half_precision);
  177. if (dev_params) {
  178. simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
  179. }
  180. simd_store<span>(_exp_avg + i, momentum_4, false);
  181. simd_store<span>(_exp_avg_sq + i, variance_4, false);
  182. }
  183. if (dev_params) {
  184. if (half_precision)
  185. launch_param_update_half(
  186. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  187. else
  188. launch_param_update(
  189. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  190. _buf_index = !_buf_index;
  191. }
  192. }
  193. *rounded_size = new_rounded_size;
  194. }
  195. #endif