cpu_lion_impl.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <torch/extension.h>
  5. #include <cassert>
  6. #include <cmath>
  7. #include <iostream>
  8. #include <memory>
  9. #include <type_traits>
  10. #include <unordered_map>
  11. #include "cpu_lion.h"
  12. #if defined(__ENABLE_CUDA__)
  13. #include <cuda_runtime_api.h>
  14. #include "cublas_v2.h"
  15. #include "cuda.h"
  16. #include "curand.h"
  17. #include "custom_cuda_layers.h"
  18. #endif
  19. static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
  20. // C++ interface
  21. void Lion_Optimizer::Step_1(float* _params,
  22. float* grads,
  23. float* _exp_avg,
  24. size_t _param_size,
  25. ds_half_precision_t* dev_params,
  26. bool half_precision)
  27. {
  28. size_t rounded_size = 0;
  29. #if defined(__AVX512__) or defined(__AVX256__)
  30. Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
  31. #endif
  32. if (_param_size > rounded_size) {
  33. float betta1_minus1 = 1 - _betta1;
  34. float betta2_minus1 = 1 - _betta2;
  35. float alpha = _alpha;
  36. float after_decay = 1 - alpha * _weight_decay;
  37. ds_half_precision_t* grads_cast_h;
  38. ds_half_precision_t* params_cast_h;
  39. if (half_precision) {
  40. grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
  41. params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
  42. }
  43. for (size_t t = rounded_size; t < _param_size; t += TILE) {
  44. size_t copy_size = TILE;
  45. if ((t + TILE) > _param_size) copy_size = _param_size - t;
  46. size_t offset = copy_size + t;
  47. #if defined(__ENABLE_CUDA__)
  48. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  49. #elif defined(__ENABLE_CANN__)
  50. if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
  51. #endif
  52. #pragma omp parallel for
  53. for (size_t k = t; k < offset; k++) {
  54. float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
  55. float param = half_precision ? (float)params_cast_h[k] : _params[k];
  56. float momentum = _exp_avg[k];
  57. float tmp = momentum * _betta1;
  58. tmp = grad * betta1_minus1 + tmp;
  59. // Rely on portable C++ methods to manipulate the sign bit of a floating-point
  60. // number.
  61. tmp = -std::copysignf(alpha, tmp);
  62. if (_weight_decay > 0) {
  63. param = param * after_decay + tmp;
  64. } else {
  65. param = param + tmp;
  66. }
  67. momentum = momentum * _betta2;
  68. momentum = grad * betta2_minus1 + momentum;
  69. #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
  70. if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
  71. #endif
  72. if (half_precision)
  73. params_cast_h[k] = (ds_half_precision_t)param;
  74. else
  75. _params[k] = param;
  76. _exp_avg[k] = momentum;
  77. }
  78. #if defined(__ENABLE_CUDA__)
  79. if (dev_params) {
  80. launch_param_update(
  81. _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
  82. _buf_index = !_buf_index;
  83. }
  84. #elif defined(__ENABLE_CANN__)
  85. if (dev_params) {
  86. size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
  87. aclrtMemcpy(dev_params + t,
  88. memcpy_size,
  89. _doubled_buffer[_buf_index],
  90. memcpy_size,
  91. aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
  92. _buf_index = !_buf_index;
  93. }
  94. #endif
  95. }
  96. }
  97. }
  98. void Lion_Optimizer::Step_4(float* _params,
  99. float* grads,
  100. float* _exp_avg,
  101. size_t _param_size,
  102. ds_half_precision_t* dev_params,
  103. bool half_precision)
  104. {
  105. size_t rounded_size = 0;
  106. #if defined(__AVX512__) or defined(__AVX256__)
  107. Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
  108. #endif
  109. if (_param_size > rounded_size)
  110. Step_1((_params + rounded_size),
  111. (grads + rounded_size),
  112. (_exp_avg + rounded_size),
  113. (_param_size - rounded_size),
  114. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  115. half_precision);
  116. }
  117. int create_lion_optimizer(int optimizer_id,
  118. float alpha,
  119. float betta1,
  120. float betta2,
  121. float weight_decay,
  122. bool should_log)
  123. {
  124. auto opt = std::make_shared<Lion_Optimizer>(alpha, betta1, betta2, weight_decay);
  125. s_optimizers[optimizer_id] = opt;
  126. if (should_log) {
  127. std::string avx_type = "";
  128. #if defined(__AVX512__)
  129. avx_type = "AVX512";
  130. #else
  131. #if defined(__AVX256__)
  132. avx_type = "AVX2";
  133. #else
  134. avx_type = "scalar";
  135. #endif
  136. #endif
  137. printf("Lion Optimizer #%d is created with %s arithmetic capability.\n",
  138. optimizer_id,
  139. avx_type.c_str());
  140. printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f\n",
  141. alpha,
  142. betta1,
  143. betta2,
  144. weight_decay);
  145. }
  146. return 0;
  147. }
  148. void Lion_Optimizer::Step_8(float* _params,
  149. float* grads,
  150. float* _exp_avg,
  151. size_t _param_size,
  152. ds_half_precision_t* dev_params,
  153. bool half_precision)
  154. {
  155. size_t rounded_size = 0;
  156. #if defined(__AVX512__) or defined(__AVX256__)
  157. Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
  158. #endif
  159. if (_param_size > rounded_size)
  160. Step_4((_params + rounded_size),
  161. (grads + rounded_size),
  162. (_exp_avg + rounded_size),
  163. (_param_size - rounded_size),
  164. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  165. half_precision);
  166. }
  167. int ds_lion_step(int optimizer_id,
  168. size_t step,
  169. float lr,
  170. float beta1,
  171. float beta2,
  172. float weight_decay,
  173. torch::Tensor& params,
  174. torch::Tensor& grads,
  175. torch::Tensor& exp_avg)
  176. {
  177. auto params_c = params.contiguous();
  178. auto grads_c = grads.contiguous();
  179. auto exp_avg_c = exp_avg.contiguous();
  180. // assert(params.options().dtype() == grads.options().dtype());
  181. float* params_ptr = (float*)params_c.data_ptr();
  182. float* grads_ptr = (float*)grads_c.data_ptr();
  183. float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
  184. std::shared_ptr<Lion_Optimizer> opt =
  185. std::static_pointer_cast<Lion_Optimizer>(s_optimizers[optimizer_id]);
  186. opt->IncrementStep(step, beta1, beta2);
  187. opt->update_state(lr, weight_decay);
  188. opt->Step_8(params_ptr,
  189. grads_ptr,
  190. exp_avg_ptr,
  191. params_c.numel(),
  192. nullptr,
  193. (params.options().dtype() == at::kHalf));
  194. #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
  195. opt->SynchronizeStreams();
  196. #endif
  197. return 0;
  198. }
  199. int ds_lion_step_plus_copy(int optimizer_id,
  200. size_t step,
  201. float lr,
  202. float beta1,
  203. float beta2,
  204. float weight_decay,
  205. torch::Tensor& params,
  206. torch::Tensor& grads,
  207. torch::Tensor& exp_avg,
  208. torch::Tensor& gpu_params)
  209. {
  210. #if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
  211. auto params_c = params.contiguous();
  212. auto gpu_params_c = gpu_params.contiguous();
  213. auto exp_avg_c = exp_avg.contiguous();
  214. auto grads_c = grads.contiguous();
  215. float* params_ptr = (float*)params_c.data_ptr();
  216. float* grads_ptr = (float*)grads_c.data_ptr();
  217. ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
  218. float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
  219. std::shared_ptr<Lion_Optimizer> opt =
  220. std::static_pointer_cast<Lion_Optimizer>(s_optimizers[optimizer_id]);
  221. opt->IncrementStep(step, beta1, beta2);
  222. opt->update_state(lr, weight_decay);
  223. opt->Step_8(params_ptr,
  224. grads_ptr,
  225. exp_avg_ptr,
  226. params_c.numel(),
  227. gpu_params_ptr,
  228. (params.options().dtype() == at::kHalf));
  229. opt->SynchronizeStreams();
  230. #else
  231. assert(false);
  232. #endif
  233. return 0;
  234. }
  235. int destroy_lion_optimizer(int optimizer_id)
  236. {
  237. s_optimizers.erase(optimizer_id);
  238. return 0;
  239. }