cpu_lion.h 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #define NOMINMAX // Windows idiosyncrasy
  6. // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
  7. #include <stdio.h>
  8. #include <torch/extension.h>
  9. #include <cassert>
  10. #include "simd.h"
  11. #if defined(__ENABLE_CUDA__)
  12. #include <cuda_fp16.h>
  13. #include <cuda_runtime_api.h>
  14. #include "cuda.h"
  15. #include "custom_cuda_layers.h"
  16. typedef __half ds_half_precision_t;
  17. #elif defined(__ENABLE_CANN__)
  18. #include "acl/acl.h"
  19. #include "torch_npu/csrc/core/npu/NPUStream.h"
  20. typedef c10::Half ds_half_precision_t;
  21. #else
  22. #include <cmath>
  23. typedef unsigned short ds_half_precision_t;
  24. #endif
  25. #define STEP(SPAN) \
  26. void Step_##SPAN(float* _params, \
  27. float* grads, \
  28. float* _exp_avg, \
  29. size_t _param_size, \
  30. ds_half_precision_t* dev_param = nullptr, \
  31. bool half_precision = false);
  32. class Lion_Optimizer {
  33. public:
  34. Lion_Optimizer(float alpha = 1e-3,
  35. float betta1 = 0.9,
  36. float betta2 = 0.999,
  37. float weight_decay = 0)
  38. : _alpha(alpha), _betta1(betta1), _betta2(betta2), _weight_decay(weight_decay), _step(0)
  39. {
  40. #if defined(__ENABLE_CUDA__)
  41. cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
  42. cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
  43. _streams[0] = TrainingContext::Instance().GetCurrentStream();
  44. _streams[1] = TrainingContext::Instance().GetNewStream();
  45. _buf_index = false;
  46. #elif defined(__ENABLE_CANN__)
  47. aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
  48. aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
  49. _buf_index = false;
  50. #endif
  51. }
  52. ~Lion_Optimizer()
  53. {
  54. #if defined(__ENABLE_CUDA__)
  55. cudaFreeHost(_doubled_buffer[0]);
  56. cudaFreeHost(_doubled_buffer[1]);
  57. #elif defined(__ENABLE_CANN__)
  58. aclrtFreeHost(_doubled_buffer[0]);
  59. aclrtFreeHost(_doubled_buffer[1]);
  60. #endif
  61. }
  62. #if defined(__AVX512__) or defined(__AVX256__)
  63. template <int span>
  64. void Step_AVX(size_t* rounded_size,
  65. float* _params,
  66. float* grads,
  67. float* _exp_avg,
  68. size_t param_size,
  69. ds_half_precision_t* dev_param = nullptr,
  70. bool half_precision = false);
  71. #endif
  72. STEP(1)
  73. STEP(4)
  74. STEP(8)
  75. #if defined(__ENABLE_CUDA__)
  76. inline void SynchronizeStreams()
  77. {
  78. for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
  79. }
  80. #elif defined(__ENABLE_CANN__)
  81. inline void SynchronizeStreams()
  82. {
  83. for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
  84. }
  85. #endif
  86. inline void IncrementStep(size_t step, float beta1, float beta2)
  87. {
  88. _step++;
  89. if (_step != step || beta1 != _betta1 || beta2 != _betta2) {
  90. _step = step;
  91. _betta1 = beta1;
  92. _betta2 = beta2;
  93. }
  94. }
  95. inline void update_state(float lr, float weight_decay)
  96. {
  97. _alpha = lr;
  98. _weight_decay = weight_decay;
  99. }
  100. private:
  101. float _alpha;
  102. float _betta1;
  103. float _betta2;
  104. float _weight_decay;
  105. size_t _step;
  106. #if defined(__ENABLE_CUDA__)
  107. float* _doubled_buffer[2];
  108. cudaStream_t _streams[2];
  109. bool _buf_index;
  110. #elif defined(__ENABLE_CANN__)
  111. float* _doubled_buffer[2];
  112. c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
  113. c10_npu::getNPUStreamFromPool()};
  114. bool _buf_index;
  115. #endif
  116. };
  117. #if defined(__AVX512__) or defined(__AVX256__)
  118. template <int span>
  119. void Lion_Optimizer::Step_AVX(size_t* rounded_size,
  120. float* _params,
  121. float* grads,
  122. float* _exp_avg,
  123. size_t _param_size,
  124. ds_half_precision_t* dev_params,
  125. bool half_precision)
  126. {
  127. size_t new_rounded_size = 0;
  128. int rshft = half_precision ? 1 : 0;
  129. constexpr float neg1 = -1.0f;
  130. AVX_Data neg1_4;
  131. neg1_4.data = SIMD_SET(neg1);
  132. AVX_Data betta1_4;
  133. betta1_4.data = SIMD_SET(_betta1);
  134. AVX_Data betta2_4;
  135. betta2_4.data = SIMD_SET(_betta2);
  136. float betta1_minus1 = 1 - _betta1;
  137. float betta2_minus1 = 1 - _betta2;
  138. AVX_Data betta1_minus1_4;
  139. betta1_minus1_4.data = SIMD_SET(betta1_minus1);
  140. AVX_Data betta2_minus1_4;
  141. betta2_minus1_4.data = SIMD_SET(betta2_minus1);
  142. float step_size = -_alpha;
  143. AVX_Data step_size_4;
  144. step_size_4.data = SIMD_SET(step_size);
  145. float after_decay = 1.0f - _alpha * _weight_decay;
  146. AVX_Data after_decay_4;
  147. if (_weight_decay > 0) after_decay_4.data = SIMD_SET(after_decay);
  148. new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
  149. for (size_t t = 0; t < new_rounded_size; t += TILE) {
  150. size_t copy_size = TILE;
  151. if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
  152. size_t offset = copy_size + t;
  153. #if defined(__ENABLE_CUDA__)
  154. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  155. #elif defined(__ENABLE_CANN__)
  156. if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
  157. #endif
  158. #pragma omp parallel for
  159. for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
  160. AVX_Data grad_4[span];
  161. simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
  162. AVX_Data momentum_4[span];
  163. simd_load<span>(momentum_4, _exp_avg + i, false);
  164. AVX_Data param_4[span];
  165. simd_load<span>(param_4, _params + (i >> rshft), half_precision);
  166. AVX_Data tmp_4[span];
  167. simd_mul<span>(tmp_4, momentum_4, betta1_4);
  168. simd_fma<span>(tmp_4, grad_4, betta1_minus1_4, tmp_4);
  169. // We already used intrinsics, so consider the machine representation fixed.
  170. simd_and<span>(tmp_4, tmp_4, neg1_4);
  171. simd_xor<span>(tmp_4, tmp_4, step_size_4);
  172. if (_weight_decay > 0) {
  173. simd_fma<span>(param_4, param_4, after_decay_4, tmp_4);
  174. } else {
  175. simd_add<span>(param_4, param_4, tmp_4);
  176. }
  177. simd_mul<span>(momentum_4, momentum_4, betta2_4);
  178. simd_fma<span>(momentum_4, grad_4, betta2_minus1_4, momentum_4);
  179. simd_store<span>(_params + (i >> rshft), param_4, half_precision);
  180. #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
  181. if (dev_params) {
  182. simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
  183. }
  184. #endif
  185. simd_store<span>(_exp_avg + i, momentum_4, false);
  186. }
  187. #if defined(__ENABLE_CUDA__)
  188. if (dev_params) {
  189. if (half_precision)
  190. launch_param_update_half(
  191. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  192. else
  193. launch_param_update(
  194. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  195. _buf_index = !_buf_index;
  196. }
  197. #elif defined(__ENABLE_CANN__)
  198. if (dev_params) {
  199. size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
  200. if (half_precision) memcpy_size /= 2;
  201. aclrtMemcpy(dev_params + t,
  202. memcpy_size,
  203. _doubled_buffer[_buf_index],
  204. memcpy_size,
  205. aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
  206. _buf_index = !_buf_index;
  207. }
  208. #endif
  209. }
  210. *rounded_size = new_rounded_size;
  211. }
  212. #endif
  213. int create_lion_optimizer(int optimizer_id,
  214. float alpha = 1e-3,
  215. float betta1 = 0.9,
  216. float betta2 = 0.999,
  217. float weight_decay = 0,
  218. bool should_log = false);
  219. int ds_lion_step(int optimizer_id,
  220. size_t step,
  221. float lr,
  222. float beta1,
  223. float beta2,
  224. float weight_decay,
  225. torch::Tensor& params,
  226. torch::Tensor& grads,
  227. torch::Tensor& exp_avg);
  228. int ds_lion_step_plus_copy(int optimizer_id,
  229. size_t step,
  230. float lr,
  231. float beta1,
  232. float beta2,
  233. float weight_decay,
  234. torch::Tensor& params,
  235. torch::Tensor& grads,
  236. torch::Tensor& exp_avg,
  237. torch::Tensor& gpu_params);
  238. int destroy_lion_optimizer(int optimizer_id);