cpu_adam.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #define NOMINMAX // Windows idiosyncrasy
  6. // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
  7. #include <stdio.h>
  8. #include <torch/extension.h>
  9. #include <cassert>
  10. #include "simd.h"
  11. #if defined(__ENABLE_CUDA__)
  12. #include <cuda_fp16.h>
  13. #include <cuda_runtime_api.h>
  14. #include "cuda.h"
  15. #include "custom_cuda_layers.h"
  16. typedef __half ds_half_precision_t;
  17. #elif defined(__ENABLE_CANN__)
  18. #include "acl/acl.h"
  19. #include "torch_npu/csrc/core/npu/NPUStream.h"
  20. typedef c10::Half ds_half_precision_t;
  21. #else
  22. #include <cmath>
  23. typedef unsigned short ds_half_precision_t;
  24. #endif
  25. #define STEP(SPAN) \
  26. void Step_##SPAN(float* _params, \
  27. float* grads, \
  28. float* _exp_avg, \
  29. float* _exp_avg_sq, \
  30. size_t _param_size, \
  31. ds_half_precision_t* dev_param = nullptr, \
  32. bool half_precision = false);
  33. class Adam_Optimizer {
  34. public:
  35. Adam_Optimizer(float alpha = 1e-3,
  36. float betta1 = 0.9,
  37. float betta2 = 0.999,
  38. float eps = 1e-8,
  39. float weight_decay = 0,
  40. bool adamw_mode = true)
  41. : _alpha(alpha),
  42. _betta1(betta1),
  43. _betta2(betta2),
  44. _eps(eps),
  45. _weight_decay(weight_decay),
  46. _betta1_t(1.0),
  47. _betta2_t(1.0),
  48. _step(0),
  49. _adamw_mode(adamw_mode)
  50. {
  51. #if defined(__ENABLE_CUDA__)
  52. cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
  53. cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
  54. _streams[0] = TrainingContext::Instance().GetCurrentStream();
  55. _streams[1] = TrainingContext::Instance().GetNewStream();
  56. _buf_index = false;
  57. #elif defined(__ENABLE_CANN__)
  58. aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
  59. aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
  60. _buf_index = false;
  61. #endif
  62. }
  63. ~Adam_Optimizer()
  64. {
  65. #if defined(__ENABLE_CUDA__)
  66. cudaFreeHost(_doubled_buffer[0]);
  67. cudaFreeHost(_doubled_buffer[1]);
  68. #elif defined(__ENABLE_CANN__)
  69. aclrtFreeHost(_doubled_buffer[0]);
  70. aclrtFreeHost(_doubled_buffer[1]);
  71. #endif
  72. }
  73. #if defined(__AVX512__) or defined(__AVX256__)
  74. template <int span>
  75. void Step_AVX(size_t* rounded_size,
  76. float* _params,
  77. float* grads,
  78. float* _exp_avg,
  79. float* _exp_avg_sq,
  80. size_t param_size,
  81. ds_half_precision_t* dev_param = nullptr,
  82. bool half_precision = false);
  83. #endif
  84. STEP(1)
  85. STEP(4)
  86. STEP(8)
  87. #if defined(__ENABLE_CUDA__)
  88. inline void SynchronizeStreams()
  89. {
  90. for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
  91. }
  92. #elif defined(__ENABLE_CANN__)
  93. inline void SynchronizeStreams()
  94. {
  95. for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
  96. }
  97. #endif
  98. inline void IncrementStep(size_t step, float beta1, float beta2)
  99. {
  100. if (beta1 != _betta1 || beta2 != _betta2) {
  101. _step = step;
  102. _betta1 = beta1;
  103. _betta2 = beta2;
  104. _betta1_t = std::pow(_betta1, step);
  105. _betta2_t = std::pow(_betta2, step);
  106. } else {
  107. _step++;
  108. if (_step != step) {
  109. _betta1_t = std::pow(_betta1, step);
  110. _betta2_t = std::pow(_betta2, step);
  111. _step = step;
  112. } else {
  113. _betta1_t *= _betta1;
  114. _betta2_t *= _betta2;
  115. }
  116. }
  117. }
  118. inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
  119. {
  120. _alpha = lr;
  121. _eps = epsilon;
  122. _weight_decay = weight_decay;
  123. _bias_correction1 = 1.0f;
  124. _bias_correction2 = 1.0f;
  125. if (bias_correction == 1) {
  126. _bias_correction1 = 1 - _betta1_t;
  127. _bias_correction2 = 1 / sqrt(1 - _betta2_t);
  128. }
  129. }
  130. private:
  131. float _alpha;
  132. float _betta1;
  133. float _betta2;
  134. float _eps;
  135. float _weight_decay;
  136. float _betta1_t;
  137. float _betta2_t;
  138. size_t _step;
  139. float _bias_correction1;
  140. float _bias_correction2;
  141. bool _adamw_mode;
  142. #if defined(__ENABLE_CUDA__)
  143. float* _doubled_buffer[2];
  144. cudaStream_t _streams[2];
  145. bool _buf_index;
  146. #elif defined(__ENABLE_CANN__)
  147. float* _doubled_buffer[2];
  148. c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
  149. c10_npu::getNPUStreamFromPool()};
  150. bool _buf_index;
  151. #endif
  152. };
  153. #if defined(__AVX512__) or defined(__AVX256__)
  154. template <int span>
  155. void Adam_Optimizer::Step_AVX(size_t* rounded_size,
  156. float* _params,
  157. float* grads,
  158. float* _exp_avg,
  159. float* _exp_avg_sq,
  160. size_t _param_size,
  161. ds_half_precision_t* dev_params,
  162. bool half_precision)
  163. {
  164. size_t new_rounded_size = 0;
  165. int rshft = half_precision ? 1 : 0;
  166. AVX_Data betta1_4;
  167. betta1_4.data = SIMD_SET(_betta1);
  168. AVX_Data betta2_4;
  169. betta2_4.data = SIMD_SET(_betta2);
  170. float betta1_minus1 = 1 - _betta1;
  171. float betta2_minus1 = 1 - _betta2;
  172. AVX_Data betta1_minus1_4;
  173. betta1_minus1_4.data = SIMD_SET(betta1_minus1);
  174. AVX_Data betta2_minus1_4;
  175. betta2_minus1_4.data = SIMD_SET(betta2_minus1);
  176. AVX_Data bias2_sqrt;
  177. bias2_sqrt.data = SIMD_SET(_bias_correction2);
  178. AVX_Data eps_4;
  179. eps_4.data = SIMD_SET(_eps);
  180. float step_size = -1 * _alpha / _bias_correction1;
  181. AVX_Data step_size_4;
  182. step_size_4.data = SIMD_SET(step_size);
  183. float w_decay = -1 * _alpha * _weight_decay;
  184. AVX_Data weight_decay4;
  185. if (_weight_decay > 0)
  186. weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
  187. new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
  188. for (size_t t = 0; t < new_rounded_size; t += TILE) {
  189. size_t copy_size = TILE;
  190. if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
  191. size_t offset = copy_size + t;
  192. #if defined(__ENABLE_CUDA__)
  193. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  194. #elif defined(__ENABLE_CANN__)
  195. if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
  196. #endif
  197. #pragma omp parallel for
  198. for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
  199. AVX_Data grad_4[span];
  200. simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
  201. AVX_Data momentum_4[span];
  202. simd_load<span>(momentum_4, _exp_avg + i, false);
  203. AVX_Data variance_4[span];
  204. simd_load<span>(variance_4, _exp_avg_sq + i, false);
  205. AVX_Data param_4[span];
  206. simd_load<span>(param_4, _params + (i >> rshft), half_precision);
  207. if (_weight_decay > 0 && !_adamw_mode) {
  208. simd_fma<span>(grad_4, param_4, weight_decay4, grad_4);
  209. }
  210. simd_mul<span>(momentum_4, momentum_4, betta1_4);
  211. simd_fma<span>(momentum_4, grad_4, betta1_minus1_4, momentum_4);
  212. simd_mul<span>(variance_4, variance_4, betta2_4);
  213. simd_mul<span>(grad_4, grad_4, grad_4);
  214. simd_fma<span>(variance_4, grad_4, betta2_minus1_4, variance_4);
  215. simd_sqrt<span>(grad_4, variance_4);
  216. simd_fma<span>(grad_4, grad_4, bias2_sqrt, eps_4);
  217. simd_div<span>(grad_4, momentum_4, grad_4);
  218. if (_weight_decay > 0 && _adamw_mode) {
  219. simd_fma<span>(param_4, param_4, weight_decay4, param_4);
  220. }
  221. simd_fma<span>(param_4, grad_4, step_size_4, param_4);
  222. simd_store<span>(_params + (i >> rshft), param_4, half_precision);
  223. #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
  224. if (dev_params) {
  225. simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
  226. }
  227. #endif
  228. simd_store<span>(_exp_avg + i, momentum_4, false);
  229. simd_store<span>(_exp_avg_sq + i, variance_4, false);
  230. }
  231. #if defined(__ENABLE_CUDA__)
  232. if (dev_params) {
  233. if (half_precision)
  234. launch_param_update_half(
  235. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  236. else
  237. launch_param_update(
  238. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  239. _buf_index = !_buf_index;
  240. }
  241. #elif defined(__ENABLE_CANN__)
  242. if (dev_params) {
  243. size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
  244. if (half_precision) memcpy_size /= 2;
  245. aclrtMemcpy(dev_params + t,
  246. memcpy_size,
  247. _doubled_buffer[_buf_index],
  248. memcpy_size,
  249. aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
  250. _buf_index = !_buf_index;
  251. }
  252. #endif
  253. }
  254. *rounded_size = new_rounded_size;
  255. }
  256. #endif
  257. int create_adam_optimizer(int optimizer_id,
  258. float alpha = 1e-3,
  259. float betta1 = 0.9,
  260. float betta2 = 0.999,
  261. float eps = 1e-8,
  262. float weight_decay = 0,
  263. bool adamw_mode = true,
  264. bool should_log = false);
  265. int ds_adam_step(int optimizer_id,
  266. size_t step,
  267. float lr,
  268. float beta1,
  269. float beta2,
  270. float epsilon,
  271. float weight_decay,
  272. bool bias_correction,
  273. torch::Tensor& params,
  274. torch::Tensor& grads,
  275. torch::Tensor& exp_avg,
  276. torch::Tensor& exp_avg_sq);
  277. int ds_adam_step_plus_copy(int optimizer_id,
  278. size_t step,
  279. float lr,
  280. float beta1,
  281. float beta2,
  282. float epsilon,
  283. float weight_decay,
  284. bool bias_correction,
  285. torch::Tensor& params,
  286. torch::Tensor& grads,
  287. torch::Tensor& exp_avg,
  288. torch::Tensor& exp_avg_sq,
  289. torch::Tensor& gpu_params);
  290. int destroy_adam_optimizer(int optimizer_id);