cpu_adam.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "cpu_adam.h"
  5. #include <torch/extension.h>
  6. #include <cassert>
  7. #include <iostream>
  8. #include <memory>
  9. #include <type_traits>
  10. #include <unordered_map>
  11. #if defined(__ENABLE_CUDA__)
  12. #include <cuda_runtime_api.h>
  13. #include "cublas_v2.h"
  14. #include "cuda.h"
  15. #include "curand.h"
  16. #include "custom_cuda_layers.h"
  17. #endif
  18. static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
  19. // C++ interface
  20. void Adam_Optimizer::Step_1(float* _params,
  21. float* grads,
  22. float* _exp_avg,
  23. float* _exp_avg_sq,
  24. size_t _param_size,
  25. ds_half_precision_t* dev_params,
  26. bool half_precision)
  27. {
  28. size_t rounded_size = 0;
  29. #if defined(__AVX512__) or defined(__AVX256__)
  30. Step_AVX<1>(&rounded_size,
  31. _params,
  32. grads,
  33. _exp_avg,
  34. _exp_avg_sq,
  35. _param_size,
  36. dev_params,
  37. half_precision);
  38. #endif
  39. if (_param_size > rounded_size) {
  40. float betta1_minus1 = 1 - _betta1;
  41. float betta2_minus1 = 1 - _betta2;
  42. float step_size = -1 * _alpha / _bias_correction1;
  43. float w_decay = -1 * _alpha * _weight_decay;
  44. ds_half_precision_t* grads_cast_h;
  45. ds_half_precision_t* params_cast_h;
  46. if (half_precision) {
  47. grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
  48. params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
  49. }
  50. for (size_t t = rounded_size; t < _param_size; t += TILE) {
  51. size_t copy_size = TILE;
  52. if ((t + TILE) > _param_size) copy_size = _param_size - t;
  53. size_t offset = copy_size + t;
  54. #if defined(__ENABLE_CUDA__)
  55. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  56. #endif
  57. #pragma omp parallel for
  58. for (size_t k = t; k < offset; k++) {
  59. float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
  60. float param = half_precision ? (float)params_cast_h[k] : _params[k];
  61. float momentum = _exp_avg[k];
  62. float variance = _exp_avg_sq[k];
  63. if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
  64. momentum = momentum * _betta1;
  65. momentum = grad * betta1_minus1 + momentum;
  66. variance = variance * _betta2;
  67. grad = grad * grad;
  68. variance = grad * betta2_minus1 + variance;
  69. grad = sqrt(variance);
  70. grad = grad * _bias_correction2 + _eps;
  71. grad = momentum / grad;
  72. if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
  73. param = grad * step_size + param;
  74. #if defined(__ENABLE_CUDA__)
  75. if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
  76. #endif
  77. if (half_precision)
  78. params_cast_h[k] = (ds_half_precision_t)param;
  79. else
  80. _params[k] = param;
  81. _exp_avg[k] = momentum;
  82. _exp_avg_sq[k] = variance;
  83. }
  84. #if defined(__ENABLE_CUDA__)
  85. if (dev_params) {
  86. launch_param_update(
  87. _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
  88. _buf_index = !_buf_index;
  89. }
  90. #endif
  91. }
  92. }
  93. }
  94. void Adam_Optimizer::Step_4(float* _params,
  95. float* grads,
  96. float* _exp_avg,
  97. float* _exp_avg_sq,
  98. size_t _param_size,
  99. ds_half_precision_t* dev_params,
  100. bool half_precision)
  101. {
  102. size_t rounded_size = 0;
  103. #if defined(__AVX512__) or defined(__AVX256__)
  104. Step_AVX<4>(&rounded_size,
  105. _params,
  106. grads,
  107. _exp_avg,
  108. _exp_avg_sq,
  109. _param_size,
  110. dev_params,
  111. half_precision);
  112. #endif
  113. if (_param_size > rounded_size)
  114. Step_1((_params + rounded_size),
  115. (grads + rounded_size),
  116. (_exp_avg + rounded_size),
  117. (_exp_avg_sq + rounded_size),
  118. (_param_size - rounded_size),
  119. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  120. half_precision);
  121. }
  122. int create_adam_optimizer(int optimizer_id,
  123. float alpha = 1e-3,
  124. float betta1 = 0.9,
  125. float betta2 = 0.999,
  126. float eps = 1e-8,
  127. float weight_decay = 0,
  128. bool adamw_mode = true,
  129. bool should_log = false)
  130. {
  131. auto opt =
  132. std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay, adamw_mode);
  133. s_optimizers[optimizer_id] = opt;
  134. if (should_log) {
  135. std::string avx_type = "";
  136. #if defined(__AVX512__)
  137. avx_type = "AVX512";
  138. #else
  139. #if defined(__AVX256__)
  140. avx_type = "AVX2";
  141. #else
  142. avx_type = "scalar";
  143. #endif
  144. #endif
  145. printf("Adam Optimizer #%d is created with %s arithmetic capability.\n",
  146. optimizer_id,
  147. avx_type.c_str());
  148. printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
  149. alpha,
  150. betta1,
  151. betta2,
  152. weight_decay,
  153. (int)adamw_mode);
  154. }
  155. return 0;
  156. }
  157. void Adam_Optimizer::Step_8(float* _params,
  158. float* grads,
  159. float* _exp_avg,
  160. float* _exp_avg_sq,
  161. size_t _param_size,
  162. ds_half_precision_t* dev_params,
  163. bool half_precision)
  164. {
  165. size_t rounded_size = 0;
  166. #if defined(__AVX512__) or defined(__AVX256__)
  167. Step_AVX<8>(&rounded_size,
  168. _params,
  169. grads,
  170. _exp_avg,
  171. _exp_avg_sq,
  172. _param_size,
  173. dev_params,
  174. half_precision);
  175. #endif
  176. if (_param_size > rounded_size)
  177. Step_4((_params + rounded_size),
  178. (grads + rounded_size),
  179. (_exp_avg + rounded_size),
  180. (_exp_avg_sq + rounded_size),
  181. (_param_size - rounded_size),
  182. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  183. half_precision);
  184. }
  185. int ds_adam_step(int optimizer_id,
  186. size_t step,
  187. float lr,
  188. float beta1,
  189. float beta2,
  190. float epsilon,
  191. float weight_decay,
  192. bool bias_correction,
  193. torch::Tensor& params,
  194. torch::Tensor& grads,
  195. torch::Tensor& exp_avg,
  196. torch::Tensor& exp_avg_sq)
  197. {
  198. auto params_c = params.contiguous();
  199. auto grads_c = grads.contiguous();
  200. auto exp_avg_c = exp_avg.contiguous();
  201. auto exp_avg_sq_c = exp_avg_sq.contiguous();
  202. // assert(params.options().dtype() == grads.options().dtype());
  203. float* params_ptr = (float*)params_c.data_ptr();
  204. float* grads_ptr = (float*)grads_c.data_ptr();
  205. float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
  206. float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
  207. std::shared_ptr<Adam_Optimizer> opt =
  208. std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
  209. opt->IncrementStep(step, beta1, beta2);
  210. opt->update_state(lr, epsilon, weight_decay, bias_correction);
  211. opt->Step_8(params_ptr,
  212. grads_ptr,
  213. exp_avg_ptr,
  214. exp_avg_sq_ptr,
  215. params_c.numel(),
  216. nullptr,
  217. (params.options().dtype() == at::kHalf));
  218. #if defined(__ENABLE_CUDA__)
  219. opt->SynchronizeStreams();
  220. #endif
  221. return 0;
  222. }
  223. int ds_adam_step_plus_copy(int optimizer_id,
  224. size_t step,
  225. float lr,
  226. float beta1,
  227. float beta2,
  228. float epsilon,
  229. float weight_decay,
  230. bool bias_correction,
  231. torch::Tensor& params,
  232. torch::Tensor& grads,
  233. torch::Tensor& exp_avg,
  234. torch::Tensor& exp_avg_sq,
  235. torch::Tensor& gpu_params)
  236. {
  237. #if defined(__ENABLE_CUDA__)
  238. auto params_c = params.contiguous();
  239. auto gpu_params_c = gpu_params.contiguous();
  240. auto exp_avg_c = exp_avg.contiguous();
  241. auto exp_avg_sq_c = exp_avg_sq.contiguous();
  242. auto grads_c = grads.contiguous();
  243. float* params_ptr = (float*)params_c.data_ptr();
  244. float* grads_ptr = (float*)grads_c.data_ptr();
  245. ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
  246. float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
  247. float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
  248. std::shared_ptr<Adam_Optimizer> opt =
  249. std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
  250. opt->IncrementStep(step, beta1, beta2);
  251. opt->update_state(lr, epsilon, weight_decay, bias_correction);
  252. opt->Step_8(params_ptr,
  253. grads_ptr,
  254. exp_avg_ptr,
  255. exp_avg_sq_ptr,
  256. params_c.numel(),
  257. gpu_params_ptr,
  258. (params.options().dtype() == at::kHalf));
  259. opt->SynchronizeStreams();
  260. #else
  261. assert(false);
  262. #endif
  263. return 0;
  264. }
  265. int destroy_adam_optimizer(int optimizer_id)
  266. {
  267. s_optimizers.erase(optimizer_id);
  268. return 0;
  269. }
  270. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  271. {
  272. m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
  273. m.def("adam_update_copy",
  274. &ds_adam_step_plus_copy,
  275. "DeepSpeed CPU Adam update and param copy (C++)");
  276. m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
  277. m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
  278. }