cpu_adagrad.cpp 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "cpu_adagrad.h"
  5. #include <torch/extension.h>
  6. #include <iostream>
  7. #include <memory>
  8. #include <type_traits>
  9. #include <unordered_map>
  10. #if defined(__ENABLE_CUDA__)
  11. #include <cuda_runtime_api.h>
  12. #include "cublas_v2.h"
  13. #include "cuda.h"
  14. #include "curand.h"
  15. #include "custom_cuda_layers.h"
  16. #endif
  17. static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
  18. // C++ interface
  19. void Adagrad_Optimizer::Step_1(float* _params,
  20. float* grads,
  21. float* _exp_avg_sq,
  22. size_t _param_size,
  23. ds_half_precision_t* dev_params,
  24. bool half_precision)
  25. {
  26. size_t rounded_size = 0;
  27. #if defined(__AVX512__) or defined(__AVX256__)
  28. Step_AVX<1>(
  29. &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
  30. #endif
  31. if (_param_size > rounded_size) {
  32. float step_size = -1 * _alpha;
  33. ds_half_precision_t* grads_cast_h;
  34. ds_half_precision_t* params_cast_h;
  35. if (half_precision) {
  36. grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
  37. params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
  38. }
  39. for (size_t t = rounded_size; t < _param_size; t += TILE) {
  40. size_t copy_size = TILE;
  41. if ((t + TILE) > _param_size) copy_size = _param_size - t;
  42. size_t offset = copy_size + t;
  43. #if defined(__ENABLE_CUDA__)
  44. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  45. #endif
  46. #pragma omp parallel for
  47. for (size_t k = t; k < offset; k++) {
  48. float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
  49. float param = half_precision ? (float)params_cast_h[k] : _params[k];
  50. float momentum = grads[k];
  51. float variance = _exp_avg_sq[k];
  52. if (_weight_decay > 0) { grad = param * _weight_decay + grad; }
  53. variance += grad * grad;
  54. grad = sqrt(variance);
  55. grad += _eps;
  56. grad = momentum / grad;
  57. param = grad * step_size + param;
  58. #if defined(__ENABLE_CUDA__)
  59. if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
  60. #endif
  61. if (half_precision)
  62. params_cast_h[k] = (ds_half_precision_t)param;
  63. else
  64. _params[k] = param;
  65. // STORE UPDATE TERM TO GRAD'S MEMORY
  66. grads[k] = grad * step_size;
  67. _exp_avg_sq[k] = variance;
  68. }
  69. #if defined(__ENABLE_CUDA__)
  70. if (dev_params) {
  71. launch_param_update(
  72. _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
  73. _buf_index = !_buf_index;
  74. }
  75. #endif
  76. }
  77. }
  78. }
  79. void Adagrad_Optimizer::Step_4(float* _params,
  80. float* grads,
  81. float* _exp_avg_sq,
  82. size_t _param_size,
  83. ds_half_precision_t* dev_params,
  84. bool half_precision)
  85. {
  86. size_t rounded_size = 0;
  87. #if defined(__AVX512__) or defined(__AVX256__)
  88. Step_AVX<4>(
  89. &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
  90. #endif
  91. if (_param_size > rounded_size)
  92. Step_1((_params + rounded_size),
  93. (grads + rounded_size),
  94. (_exp_avg_sq + rounded_size),
  95. (_param_size - rounded_size),
  96. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  97. half_precision);
  98. }
  99. int create_adagrad_optimizer(int optimizer_id,
  100. float alpha = 1e-2,
  101. float eps = 1e-8,
  102. float weight_decay = 0,
  103. bool should_log = false)
  104. {
  105. auto opt = std::make_shared<Adagrad_Optimizer>(alpha, eps, weight_decay);
  106. s_optimizers[optimizer_id] = opt;
  107. if (should_log) {
  108. std::string avx_type = "";
  109. #if defined(__AVX512__)
  110. avx_type = "AVX512";
  111. #else
  112. #if defined(__AVX256__)
  113. avx_type = "AVX2";
  114. #else
  115. avx_type = "scalar";
  116. #endif
  117. #endif
  118. printf("Adagrad Optimizer #%d is created with %s arithmetic capability.\n",
  119. optimizer_id,
  120. avx_type.c_str());
  121. printf("Config: alpha=%f, weight_decay=%f\n", alpha, weight_decay);
  122. }
  123. return 0;
  124. }
  125. void Adagrad_Optimizer::Step_8(float* _params,
  126. float* grads,
  127. float* _exp_avg_sq,
  128. size_t _param_size,
  129. ds_half_precision_t* dev_params,
  130. bool half_precision)
  131. {
  132. size_t rounded_size = 0;
  133. #if defined(__AVX512__) or defined(__AVX256__)
  134. Step_AVX<8>(
  135. &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
  136. #endif
  137. if (_param_size > rounded_size)
  138. Step_4((_params + rounded_size),
  139. (grads + rounded_size),
  140. (_exp_avg_sq + rounded_size),
  141. (_param_size - rounded_size),
  142. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  143. half_precision);
  144. }
  145. int ds_adagrad_step(int optimizer_id,
  146. size_t step,
  147. float lr,
  148. float epsilon,
  149. float weight_decay,
  150. torch::Tensor& params,
  151. torch::Tensor& grads,
  152. torch::Tensor& exp_avg_sq)
  153. {
  154. auto params_c = params.contiguous();
  155. auto grads_c = grads.contiguous();
  156. auto exp_avg_sq_c = exp_avg_sq.contiguous();
  157. float* params_ptr = (float*)params_c.data_ptr();
  158. float* grads_ptr = (float*)grads_c.data_ptr();
  159. float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
  160. std::shared_ptr<Adagrad_Optimizer> opt =
  161. std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
  162. opt->IncrementStep(step);
  163. opt->update_state(lr, epsilon, weight_decay);
  164. opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());
  165. #if defined(__ENABLE_CUDA__)
  166. opt->SynchronizeStreams();
  167. #endif
  168. return 0;
  169. }
  170. int ds_adagrad_step_plus_copy(int optimizer_id,
  171. size_t step,
  172. float lr,
  173. float epsilon,
  174. float weight_decay,
  175. torch::Tensor& params,
  176. torch::Tensor& grads,
  177. torch::Tensor& exp_avg_sq,
  178. torch::Tensor& gpu_params)
  179. {
  180. #if defined(__ENABLE_CUDA__)
  181. auto params_c = params.contiguous();
  182. auto gpu_params_c = gpu_params.contiguous();
  183. auto exp_avg_sq_c = exp_avg_sq.contiguous();
  184. auto grads_c = grads.contiguous();
  185. float* params_ptr = (float*)params_c.data_ptr();
  186. float* grads_ptr = (float*)grads_c.data_ptr();
  187. ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
  188. float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
  189. std::shared_ptr<Adagrad_Optimizer> opt =
  190. std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
  191. opt->IncrementStep(step);
  192. opt->update_state(lr, epsilon, weight_decay);
  193. opt->Step_8(params_ptr,
  194. grads_ptr,
  195. exp_avg_sq_ptr,
  196. params_c.numel(),
  197. gpu_params_ptr,
  198. (params.options().dtype() == at::kHalf));
  199. opt->SynchronizeStreams();
  200. #else
  201. assert(false);
  202. #endif
  203. return 0;
  204. }
  205. int destroy_adagrad_optimizer(int optimizer_id)
  206. {
  207. s_optimizers.erase(optimizer_id);
  208. return 0;
  209. }
  210. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  211. {
  212. m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)");
  213. m.def("adagrad_update_copy",
  214. &ds_adagrad_step_plus_copy,
  215. "DeepSpeed CPU Adagrad update and param copy (C++)");
  216. m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)");
  217. m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)");
  218. }