cpu_adagrad.h 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #define NOMINMAX // Windows idiosyncrasy
  6. // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
  7. #include <stdio.h>
  8. #include <cassert>
  9. #include "simd.h"
  10. #if defined(__ENABLE_CUDA__)
  11. #include <cuda_fp16.h>
  12. #include <cuda_runtime_api.h>
  13. #include "cuda.h"
  14. #include "custom_cuda_layers.h"
  15. typedef __half ds_half_precision_t;
  16. #else
  17. typedef unsigned short ds_half_precision_t;
  18. #endif
  19. #define STEP(SPAN) \
  20. void Step_##SPAN(float* _params, \
  21. float* grads, \
  22. float* _exp_avg_sq, \
  23. size_t _param_size, \
  24. ds_half_precision_t* dev_param = nullptr, \
  25. bool half_precision = false);
  26. class Adagrad_Optimizer {
  27. public:
  28. Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0)
  29. : _alpha(alpha), _eps(eps), _weight_decay(weight_decay)
  30. {
  31. #if defined(__ENABLE_CUDA__)
  32. cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
  33. cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
  34. _streams[0] = TrainingContext::Instance().GetCurrentStream();
  35. _streams[1] = TrainingContext::Instance().GetNewStream();
  36. _buf_index = false;
  37. #endif
  38. }
  39. ~Adagrad_Optimizer()
  40. {
  41. #if defined(__ENABLE_CUDA__)
  42. cudaFreeHost(_doubled_buffer[0]);
  43. cudaFreeHost(_doubled_buffer[1]);
  44. #endif
  45. }
  46. #if defined(__AVX512__) or defined(__AVX256__)
  47. template <int span>
  48. void Step_AVX(size_t* rounded_size,
  49. float* _params,
  50. float* grads,
  51. float* _exp_avg_sq,
  52. size_t param_size,
  53. ds_half_precision_t* dev_param = nullptr,
  54. bool half_precision = false);
  55. #endif
  56. STEP(1)
  57. STEP(4)
  58. STEP(8)
  59. #if defined(__ENABLE_CUDA__)
  60. inline void SynchronizeStreams()
  61. {
  62. for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
  63. }
  64. #endif
  65. inline void IncrementStep(size_t step)
  66. {
  67. _step++;
  68. if (_step != step) { _step = step; }
  69. }
  70. inline void update_state(float lr, float epsilon, float weight_decay)
  71. {
  72. _alpha = lr;
  73. _eps = epsilon;
  74. _weight_decay = weight_decay;
  75. }
  76. private:
  77. float _alpha;
  78. float _eps;
  79. float _weight_decay;
  80. float _betta1_t;
  81. float _betta2_t;
  82. size_t _step;
  83. #if defined(__ENABLE_CUDA__)
  84. bool _buf_index;
  85. float* _doubled_buffer[2];
  86. cudaStream_t _streams[2];
  87. #endif
  88. };
  89. #if defined(__AVX512__) or defined(__AVX256__)
  90. template <int span>
  91. void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
  92. float* _params,
  93. float* grads,
  94. float* _exp_avg_sq,
  95. size_t _param_size,
  96. ds_half_precision_t* dev_params,
  97. bool half_precision)
  98. {
  99. size_t new_rounded_size = 0;
  100. AVX_Data eps_4;
  101. eps_4.data = SIMD_SET(_eps);
  102. float step_size = -1 * _alpha;
  103. AVX_Data step_size_4;
  104. step_size_4.data = SIMD_SET(step_size);
  105. AVX_Data weight_decay4;
  106. if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay);
  107. new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
  108. for (size_t t = 0; t < new_rounded_size; t += TILE) {
  109. size_t copy_size = TILE;
  110. if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
  111. size_t offset = copy_size + t;
  112. #if defined(__ENABLE_CUDA__)
  113. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  114. #endif
  115. #pragma omp parallel for
  116. for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
  117. AVX_Data grad_4[span];
  118. simd_load<span>(grad_4, grads + i, half_precision);
  119. AVX_Data momentum_4[span];
  120. simd_load<span>(momentum_4, grads + i, false);
  121. AVX_Data variance_4[span];
  122. simd_load<span>(variance_4, _exp_avg_sq + i, false);
  123. AVX_Data param_4[span];
  124. simd_load<span>(param_4, _params + i, half_precision);
  125. if (_weight_decay > 0) { simd_fma<span>(grad_4, param_4, weight_decay4, grad_4); }
  126. simd_fma<span>(variance_4, grad_4, grad_4, variance_4);
  127. simd_sqrt<span>(grad_4, variance_4);
  128. simd_add<span>(grad_4, grad_4, eps_4);
  129. simd_div<span>(grad_4, momentum_4, grad_4);
  130. simd_fma<span>(param_4, grad_4, step_size_4, param_4);
  131. simd_store<span>(_params + i, param_4, half_precision);
  132. #if defined(__ENABLE_CUDA__)
  133. if (dev_params) {
  134. simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
  135. }
  136. #endif
  137. simd_store<span>(_exp_avg_sq + i, variance_4, false);
  138. }
  139. #if defined(__ENABLE_CUDA__)
  140. if (dev_params) {
  141. if (half_precision)
  142. launch_param_update_half(
  143. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  144. else
  145. launch_param_update(
  146. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  147. _buf_index = !_buf_index;
  148. }
  149. #endif
  150. }
  151. *rounded_size = new_rounded_size;
  152. }
  153. #endif