cpu_adam.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. #include "cpu_adam.h"
  2. #include <cuda_runtime_api.h>
  3. #include <math.h>
  4. #include <omp.h>
  5. #include <torch/extension.h>
  6. #include <iostream>
  7. #include <memory>
  8. #include <type_traits>
  9. #include <unordered_map>
  10. #include "cublas_v2.h"
  11. #include "cuda.h"
  12. #include "curand.h"
  13. #include "custom_cuda_layers.h"
  14. static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
  15. // C++ interface
  16. void Adam_Optimizer::Step_1(float* _params,
  17. float* grads,
  18. float* _exp_avg,
  19. float* _exp_avg_sq,
  20. size_t _param_size,
  21. __half* dev_params,
  22. bool half_precision)
  23. {
  24. size_t rounded_size = 0;
  25. #if defined(__AVX512__) or defined(__AVX256__)
  26. Step_AVX<1>(&rounded_size,
  27. _params,
  28. grads,
  29. _exp_avg,
  30. _exp_avg_sq,
  31. _param_size,
  32. dev_params,
  33. half_precision);
  34. #endif
  35. if (_param_size > rounded_size) {
  36. float betta1_minus1 = 1 - _betta1;
  37. float betta2_minus1 = 1 - _betta2;
  38. float step_size = -1 * _alpha / _bias_correction1;
  39. float w_decay = -1 * _alpha * _weight_decay;
  40. __half* grads_cast_h;
  41. __half* params_cast_h;
  42. if (half_precision) {
  43. grads_cast_h = reinterpret_cast<__half*>(grads);
  44. params_cast_h = reinterpret_cast<__half*>(_params);
  45. }
  46. for (size_t t = rounded_size; t < _param_size; t += TILE) {
  47. size_t copy_size = TILE;
  48. if ((t + TILE) > _param_size) copy_size = _param_size - t;
  49. size_t offset = copy_size + t;
  50. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  51. #pragma omp parallel for
  52. for (size_t k = t; k < offset; k++) {
  53. float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
  54. float param = half_precision ? (float)params_cast_h[k] : _params[k];
  55. float momentum = _exp_avg[k];
  56. float variance = _exp_avg_sq[k];
  57. if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
  58. momentum = momentum * _betta1;
  59. momentum = grad * betta1_minus1 + momentum;
  60. variance = variance * _betta2;
  61. grad = grad * grad;
  62. variance = grad * betta2_minus1 + variance;
  63. grad = sqrt(variance);
  64. grad = grad * _bias_correction2 + _eps;
  65. grad = momentum / grad;
  66. if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
  67. param = grad * step_size + param;
  68. if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
  69. if (half_precision)
  70. params_cast_h[k] = (__half)param;
  71. else
  72. _params[k] = param;
  73. _exp_avg[k] = momentum;
  74. _exp_avg_sq[k] = variance;
  75. }
  76. if (dev_params) {
  77. launch_param_update(
  78. _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
  79. _buf_index = !_buf_index;
  80. }
  81. }
  82. }
  83. }
  84. void Adam_Optimizer::Step_4(float* _params,
  85. float* grads,
  86. float* _exp_avg,
  87. float* _exp_avg_sq,
  88. size_t _param_size,
  89. __half* dev_params,
  90. bool half_precision)
  91. {
  92. size_t rounded_size = 0;
  93. #if defined(__AVX512__) or defined(__AVX256__)
  94. Step_AVX<4>(&rounded_size,
  95. _params,
  96. grads,
  97. _exp_avg,
  98. _exp_avg_sq,
  99. _param_size,
  100. dev_params,
  101. half_precision);
  102. #endif
  103. if (_param_size > rounded_size)
  104. Step_1((_params + rounded_size),
  105. (grads + rounded_size),
  106. (_exp_avg + rounded_size),
  107. (_exp_avg_sq + rounded_size),
  108. (_param_size - rounded_size),
  109. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  110. half_precision);
  111. }
  112. int create_adam_optimizer(int optimizer_id,
  113. float alpha = 1e-3,
  114. float betta1 = 0.9,
  115. float betta2 = 0.999,
  116. float eps = 1e-8,
  117. float weight_decay = 0,
  118. bool adamw_mode = true,
  119. bool should_log = false)
  120. {
  121. auto opt =
  122. std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay, adamw_mode);
  123. s_optimizers[optimizer_id] = opt;
  124. if (should_log) {
  125. std::string avx_type = "";
  126. #if defined(__AVX512__)
  127. avx_type = "AVX512";
  128. #else
  129. #if defined(__AVX256__)
  130. avx_type = "AVX2";
  131. #else
  132. avx_type = "scalar";
  133. #endif
  134. #endif
  135. printf("Adam Optimizer #%d is created with %s arithmetic capability.\n",
  136. optimizer_id,
  137. avx_type.c_str());
  138. printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
  139. alpha,
  140. betta1,
  141. betta2,
  142. weight_decay,
  143. (int)adamw_mode);
  144. }
  145. return 0;
  146. }
  147. void Adam_Optimizer::Step_8(float* _params,
  148. float* grads,
  149. float* _exp_avg,
  150. float* _exp_avg_sq,
  151. size_t _param_size,
  152. __half* dev_params,
  153. bool half_precision)
  154. {
  155. size_t rounded_size = 0;
  156. #if defined(__AVX512__) or defined(__AVX256__)
  157. Step_AVX<8>(&rounded_size,
  158. _params,
  159. grads,
  160. _exp_avg,
  161. _exp_avg_sq,
  162. _param_size,
  163. dev_params,
  164. half_precision);
  165. #endif
  166. if (_param_size > rounded_size)
  167. Step_4((_params + rounded_size),
  168. (grads + rounded_size),
  169. (_exp_avg + rounded_size),
  170. (_exp_avg_sq + rounded_size),
  171. (_param_size - rounded_size),
  172. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  173. half_precision);
  174. }
  175. int ds_adam_step(int optimizer_id,
  176. size_t step,
  177. float lr,
  178. float beta1,
  179. float beta2,
  180. float epsilon,
  181. float weight_decay,
  182. bool bias_correction,
  183. torch::Tensor& params,
  184. torch::Tensor& grads,
  185. torch::Tensor& exp_avg,
  186. torch::Tensor& exp_avg_sq)
  187. {
  188. auto params_c = params.contiguous();
  189. auto grads_c = grads.contiguous();
  190. auto exp_avg_c = exp_avg.contiguous();
  191. auto exp_avg_sq_c = exp_avg_sq.contiguous();
  192. // assert(params.options().dtype() == grads.options().dtype());
  193. float* params_ptr = (float*)params_c.data_ptr();
  194. float* grads_ptr = (float*)grads_c.data_ptr();
  195. float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
  196. float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
  197. std::shared_ptr<Adam_Optimizer> opt =
  198. std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
  199. opt->IncrementStep(step, beta1, beta2);
  200. opt->update_state(lr, epsilon, weight_decay, bias_correction);
  201. opt->Step_8(params_ptr,
  202. grads_ptr,
  203. exp_avg_ptr,
  204. exp_avg_sq_ptr,
  205. params_c.size(0),
  206. nullptr,
  207. (params.options().dtype() == at::kHalf));
  208. opt->SynchronizeStreams();
  209. return 0;
  210. }
  211. int ds_adam_step_plus_copy(int optimizer_id,
  212. size_t step,
  213. float lr,
  214. float beta1,
  215. float beta2,
  216. float epsilon,
  217. float weight_decay,
  218. bool bias_correction,
  219. torch::Tensor& params,
  220. torch::Tensor& grads,
  221. torch::Tensor& exp_avg,
  222. torch::Tensor& exp_avg_sq,
  223. torch::Tensor& gpu_params)
  224. {
  225. auto params_c = params.contiguous();
  226. auto gpu_params_c = gpu_params.contiguous();
  227. auto exp_avg_c = exp_avg.contiguous();
  228. auto exp_avg_sq_c = exp_avg_sq.contiguous();
  229. auto grads_c = grads.contiguous();
  230. float* params_ptr = (float*)params_c.data_ptr();
  231. float* grads_ptr = (float*)grads_c.data_ptr();
  232. __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
  233. float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
  234. float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
  235. std::shared_ptr<Adam_Optimizer> opt =
  236. std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
  237. opt->IncrementStep(step, beta1, beta2);
  238. opt->update_state(lr, epsilon, weight_decay, bias_correction);
  239. opt->Step_8(params_ptr,
  240. grads_ptr,
  241. exp_avg_ptr,
  242. exp_avg_sq_ptr,
  243. params_c.size(0),
  244. gpu_params_ptr,
  245. (params.options().dtype() == at::kHalf));
  246. opt->SynchronizeStreams();
  247. return 0;
  248. }
  249. int destroy_adam_optimizer(int optimizer_id)
  250. {
  251. s_optimizers.erase(optimizer_id);
  252. return 0;
  253. }
  254. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  255. {
  256. m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
  257. m.def("adam_update_copy",
  258. &ds_adam_step_plus_copy,
  259. "DeepSpeed CPU Adam update and param copy (C++)");
  260. m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
  261. m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
  262. }