cpu_adam.h 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #define NOMINMAX // Windows idiosyncrasy
  6. // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
  7. #include <stdio.h>
  8. #include <cassert>
  9. #include "simd.h"
  10. #if defined(__ENABLE_CUDA__)
  11. #include <cuda_fp16.h>
  12. #include <cuda_runtime_api.h>
  13. #include "cuda.h"
  14. #include "custom_cuda_layers.h"
  15. typedef __half ds_half_precision_t;
  16. #else
  17. #include <cmath>
  18. typedef unsigned short ds_half_precision_t;
  19. #endif
  20. #define STEP(SPAN) \
  21. void Step_##SPAN(float* _params, \
  22. float* grads, \
  23. float* _exp_avg, \
  24. float* _exp_avg_sq, \
  25. size_t _param_size, \
  26. ds_half_precision_t* dev_param = nullptr, \
  27. bool half_precision = false);
  28. class Adam_Optimizer {
  29. public:
  30. Adam_Optimizer(float alpha = 1e-3,
  31. float betta1 = 0.9,
  32. float betta2 = 0.999,
  33. float eps = 1e-8,
  34. float weight_decay = 0,
  35. bool adamw_mode = true)
  36. : _alpha(alpha),
  37. _betta1(betta1),
  38. _betta2(betta2),
  39. _eps(eps),
  40. _weight_decay(weight_decay),
  41. _betta1_t(1.0),
  42. _betta2_t(1.0),
  43. _step(0),
  44. _adamw_mode(adamw_mode)
  45. {
  46. #if defined(__ENABLE_CUDA__)
  47. cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
  48. cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
  49. _streams[0] = TrainingContext::Instance().GetCurrentStream();
  50. _streams[1] = TrainingContext::Instance().GetNewStream();
  51. _buf_index = false;
  52. #endif
  53. }
  54. ~Adam_Optimizer()
  55. {
  56. #if defined(__ENABLE_CUDA__)
  57. cudaFreeHost(_doubled_buffer[0]);
  58. cudaFreeHost(_doubled_buffer[1]);
  59. #endif
  60. }
  61. #if defined(__AVX512__) or defined(__AVX256__)
  62. template <int span>
  63. void Step_AVX(size_t* rounded_size,
  64. float* _params,
  65. float* grads,
  66. float* _exp_avg,
  67. float* _exp_avg_sq,
  68. size_t param_size,
  69. ds_half_precision_t* dev_param = nullptr,
  70. bool half_precision = false);
  71. #endif
  72. STEP(1)
  73. STEP(4)
  74. STEP(8)
  75. #if defined(__ENABLE_CUDA__)
  76. inline void SynchronizeStreams()
  77. {
  78. for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
  79. }
  80. #endif
  81. inline void IncrementStep(size_t step, float beta1, float beta2)
  82. {
  83. if (beta1 != _betta1 || beta2 != _betta2) {
  84. _step = step;
  85. _betta1 = beta1;
  86. _betta2 = beta2;
  87. _betta1_t = std::pow(_betta1, step);
  88. _betta2_t = std::pow(_betta2, step);
  89. } else {
  90. _step++;
  91. if (_step != step) {
  92. _betta1_t = std::pow(_betta1, step);
  93. _betta2_t = std::pow(_betta2, step);
  94. _step = step;
  95. } else {
  96. _betta1_t *= _betta1;
  97. _betta2_t *= _betta2;
  98. }
  99. }
  100. }
  101. inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
  102. {
  103. _alpha = lr;
  104. _eps = epsilon;
  105. _weight_decay = weight_decay;
  106. _bias_correction1 = 1.0f;
  107. _bias_correction2 = 1.0f;
  108. if (bias_correction == 1) {
  109. _bias_correction1 = 1 - _betta1_t;
  110. _bias_correction2 = 1 / sqrt(1 - _betta2_t);
  111. }
  112. }
  113. private:
  114. float _alpha;
  115. float _betta1;
  116. float _betta2;
  117. float _eps;
  118. float _weight_decay;
  119. float _betta1_t;
  120. float _betta2_t;
  121. size_t _step;
  122. float _bias_correction1;
  123. float _bias_correction2;
  124. bool _adamw_mode;
  125. #if defined(__ENABLE_CUDA__)
  126. float* _doubled_buffer[2];
  127. cudaStream_t _streams[2];
  128. bool _buf_index;
  129. #endif
  130. };
  131. #if defined(__AVX512__) or defined(__AVX256__)
  132. template <int span>
  133. void Adam_Optimizer::Step_AVX(size_t* rounded_size,
  134. float* _params,
  135. float* grads,
  136. float* _exp_avg,
  137. float* _exp_avg_sq,
  138. size_t _param_size,
  139. ds_half_precision_t* dev_params,
  140. bool half_precision)
  141. {
  142. size_t new_rounded_size = 0;
  143. int rshft = half_precision ? 1 : 0;
  144. AVX_Data betta1_4;
  145. betta1_4.data = SIMD_SET(_betta1);
  146. AVX_Data betta2_4;
  147. betta2_4.data = SIMD_SET(_betta2);
  148. float betta1_minus1 = 1 - _betta1;
  149. float betta2_minus1 = 1 - _betta2;
  150. AVX_Data betta1_minus1_4;
  151. betta1_minus1_4.data = SIMD_SET(betta1_minus1);
  152. AVX_Data betta2_minus1_4;
  153. betta2_minus1_4.data = SIMD_SET(betta2_minus1);
  154. AVX_Data bias2_sqrt;
  155. bias2_sqrt.data = SIMD_SET(_bias_correction2);
  156. AVX_Data eps_4;
  157. eps_4.data = SIMD_SET(_eps);
  158. float step_size = -1 * _alpha / _bias_correction1;
  159. AVX_Data step_size_4;
  160. step_size_4.data = SIMD_SET(step_size);
  161. float w_decay = -1 * _alpha * _weight_decay;
  162. AVX_Data weight_decay4;
  163. if (_weight_decay > 0)
  164. weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
  165. new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
  166. for (size_t t = 0; t < new_rounded_size; t += TILE) {
  167. size_t copy_size = TILE;
  168. if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
  169. size_t offset = copy_size + t;
  170. #if defined(__ENABLE_CUDA__)
  171. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  172. #endif
  173. #pragma omp parallel for
  174. for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
  175. AVX_Data grad_4[span];
  176. simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
  177. AVX_Data momentum_4[span];
  178. simd_load<span>(momentum_4, _exp_avg + i, false);
  179. AVX_Data variance_4[span];
  180. simd_load<span>(variance_4, _exp_avg_sq + i, false);
  181. AVX_Data param_4[span];
  182. simd_load<span>(param_4, _params + (i >> rshft), half_precision);
  183. if (_weight_decay > 0 && !_adamw_mode) {
  184. simd_fma<span>(grad_4, param_4, weight_decay4, grad_4);
  185. }
  186. simd_mul<span>(momentum_4, momentum_4, betta1_4);
  187. simd_fma<span>(momentum_4, grad_4, betta1_minus1_4, momentum_4);
  188. simd_mul<span>(variance_4, variance_4, betta2_4);
  189. simd_mul<span>(grad_4, grad_4, grad_4);
  190. simd_fma<span>(variance_4, grad_4, betta2_minus1_4, variance_4);
  191. simd_sqrt<span>(grad_4, variance_4);
  192. simd_fma<span>(grad_4, grad_4, bias2_sqrt, eps_4);
  193. simd_div<span>(grad_4, momentum_4, grad_4);
  194. if (_weight_decay > 0 && _adamw_mode) {
  195. simd_fma<span>(param_4, param_4, weight_decay4, param_4);
  196. }
  197. simd_fma<span>(param_4, grad_4, step_size_4, param_4);
  198. simd_store<span>(_params + (i >> rshft), param_4, half_precision);
  199. #if defined(__ENABLE_CUDA__)
  200. if (dev_params) {
  201. simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
  202. }
  203. #endif
  204. simd_store<span>(_exp_avg + i, momentum_4, false);
  205. simd_store<span>(_exp_avg_sq + i, variance_4, false);
  206. }
  207. #if defined(__ENABLE_CUDA__)
  208. if (dev_params) {
  209. if (half_precision)
  210. launch_param_update_half(
  211. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  212. else
  213. launch_param_update(
  214. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  215. _buf_index = !_buf_index;
  216. }
  217. #endif
  218. }
  219. *rounded_size = new_rounded_size;
  220. }
  221. #endif