cpu_adagrad.cpp 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #include "cpu_adagrad.h"
  2. #include <cuda_runtime_api.h>
  3. #include <math.h>
  4. #include <omp.h>
  5. #include <torch/extension.h>
  6. #include <iostream>
  7. #include <memory>
  8. #include <type_traits>
  9. #include <unordered_map>
  10. #include "cublas_v2.h"
  11. #include "cuda.h"
  12. #include "curand.h"
  13. #include "custom_cuda_layers.h"
  14. static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
  15. // C++ interface
  16. void Adagrad_Optimizer::Step_1(float* _params,
  17. float* grads,
  18. float* _exp_avg_sq,
  19. size_t _param_size,
  20. __half* dev_params,
  21. bool half_precision)
  22. {
  23. size_t rounded_size = 0;
  24. #if defined(__AVX512__) or defined(__AVX256__)
  25. Step_AVX<1>(
  26. &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
  27. #endif
  28. if (_param_size > rounded_size) {
  29. float step_size = -1 * _alpha;
  30. __half* grads_cast_h;
  31. __half* params_cast_h;
  32. if (half_precision) {
  33. grads_cast_h = reinterpret_cast<__half*>(grads);
  34. params_cast_h = reinterpret_cast<__half*>(_params);
  35. }
  36. for (size_t t = rounded_size; t < _param_size; t += TILE) {
  37. size_t copy_size = TILE;
  38. if ((t + TILE) > _param_size) copy_size = _param_size - t;
  39. size_t offset = copy_size + t;
  40. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  41. #pragma omp parallel for
  42. for (size_t k = t; k < offset; k++) {
  43. float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
  44. float param = half_precision ? (float)params_cast_h[k] : _params[k];
  45. float momentum = grads[k];
  46. float variance = _exp_avg_sq[k];
  47. if (_weight_decay > 0) { grad = param * _weight_decay + grad; }
  48. variance += grad * grad;
  49. grad = sqrt(variance);
  50. grad += _eps;
  51. grad = momentum / grad;
  52. param = grad * step_size + param;
  53. if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
  54. if (half_precision)
  55. params_cast_h[k] = (__half)param;
  56. else
  57. _params[k] = param;
  58. // STORE UPDATE TERM TO GRAD'S MEMORY
  59. grads[k] = grad * step_size;
  60. _exp_avg_sq[k] = variance;
  61. }
  62. if (dev_params) {
  63. launch_param_update(
  64. _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
  65. _buf_index = !_buf_index;
  66. }
  67. }
  68. }
  69. }
  70. void Adagrad_Optimizer::Step_4(float* _params,
  71. float* grads,
  72. float* _exp_avg_sq,
  73. size_t _param_size,
  74. __half* dev_params,
  75. bool half_precision)
  76. {
  77. size_t rounded_size = 0;
  78. #if defined(__AVX512__) or defined(__AVX256__)
  79. Step_AVX<4>(
  80. &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
  81. #endif
  82. if (_param_size > rounded_size)
  83. Step_1((_params + rounded_size),
  84. (grads + rounded_size),
  85. (_exp_avg_sq + rounded_size),
  86. (_param_size - rounded_size),
  87. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  88. half_precision);
  89. }
  90. int create_adagrad_optimizer(int optimizer_id,
  91. float alpha = 1e-2,
  92. float eps = 1e-8,
  93. float weight_decay = 0,
  94. bool should_log = false)
  95. {
  96. auto opt = std::make_shared<Adagrad_Optimizer>(alpha, eps, weight_decay);
  97. s_optimizers[optimizer_id] = opt;
  98. if (should_log) {
  99. std::string avx_type = "";
  100. #if defined(__AVX512__)
  101. avx_type = "AVX512";
  102. #else
  103. #if defined(__AVX256__)
  104. avx_type = "AVX2";
  105. #else
  106. avx_type = "scalar";
  107. #endif
  108. #endif
  109. printf("Adagrad Optimizer #%d is created with %s arithmetic capability.\n",
  110. optimizer_id,
  111. avx_type.c_str());
  112. printf("Config: alpha=%f, weight_decay=%f\n", alpha, weight_decay);
  113. }
  114. return 0;
  115. }
  116. void Adagrad_Optimizer::Step_8(float* _params,
  117. float* grads,
  118. float* _exp_avg_sq,
  119. size_t _param_size,
  120. __half* dev_params,
  121. bool half_precision)
  122. {
  123. size_t rounded_size = 0;
  124. #if defined(__AVX512__) or defined(__AVX256__)
  125. Step_AVX<8>(
  126. &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
  127. #endif
  128. if (_param_size > rounded_size)
  129. Step_4((_params + rounded_size),
  130. (grads + rounded_size),
  131. (_exp_avg_sq + rounded_size),
  132. (_param_size - rounded_size),
  133. (dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
  134. half_precision);
  135. }
  136. int ds_adagrad_step(int optimizer_id,
  137. size_t step,
  138. float lr,
  139. float epsilon,
  140. float weight_decay,
  141. torch::Tensor& params,
  142. torch::Tensor& grads,
  143. torch::Tensor& exp_avg_sq)
  144. {
  145. auto params_c = params.contiguous();
  146. auto grads_c = grads.contiguous();
  147. auto exp_avg_sq_c = exp_avg_sq.contiguous();
  148. float* params_ptr = (float*)params_c.data_ptr();
  149. float* grads_ptr = (float*)grads_c.data_ptr();
  150. float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
  151. std::shared_ptr<Adagrad_Optimizer> opt =
  152. std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
  153. opt->IncrementStep(step);
  154. opt->update_state(lr, epsilon, weight_decay);
  155. opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.size(0));
  156. opt->SynchronizeStreams();
  157. return 0;
  158. }
  159. int ds_adagrad_step_plus_copy(int optimizer_id,
  160. size_t step,
  161. float lr,
  162. float epsilon,
  163. float weight_decay,
  164. torch::Tensor& params,
  165. torch::Tensor& grads,
  166. torch::Tensor& exp_avg_sq,
  167. torch::Tensor& gpu_params)
  168. {
  169. auto params_c = params.contiguous();
  170. auto gpu_params_c = gpu_params.contiguous();
  171. auto exp_avg_sq_c = exp_avg_sq.contiguous();
  172. auto grads_c = grads.contiguous();
  173. float* params_ptr = (float*)params_c.data_ptr();
  174. float* grads_ptr = (float*)grads_c.data_ptr();
  175. __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
  176. float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
  177. std::shared_ptr<Adagrad_Optimizer> opt =
  178. std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
  179. opt->IncrementStep(step);
  180. opt->update_state(lr, epsilon, weight_decay);
  181. opt->Step_8(params_ptr,
  182. grads_ptr,
  183. exp_avg_sq_ptr,
  184. params_c.size(0),
  185. gpu_params_ptr,
  186. (params.options().dtype() == at::kHalf));
  187. opt->SynchronizeStreams();
  188. return 0;
  189. }
  190. int destroy_adagrad_optimizer(int optimizer_id)
  191. {
  192. s_optimizers.erase(optimizer_id);
  193. return 0;
  194. }
  195. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  196. {
  197. m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)");
  198. m.def("adagrad_update_copy",
  199. &ds_adagrad_step_plus_copy,
  200. "DeepSpeed CPU Adagrad update and param copy (C++)");
  201. m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)");
  202. m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)");
  203. }