cpu_adagrad.h 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #pragma once
  2. #include <cuda_fp16.h>
  3. #include <cuda_runtime_api.h>
  4. #include <stdio.h>
  5. #include <cassert>
  6. #include "cuda.h"
  7. #include "custom_cuda_layers.h"
  8. #include "simd.h"
  9. #define STEP(SPAN) \
  10. void Step_##SPAN(float* _params, \
  11. float* grads, \
  12. float* _exp_avg_sq, \
  13. size_t _param_size, \
  14. __half* dev_param = nullptr, \
  15. bool half_precision = false);
  16. class Adagrad_Optimizer {
  17. public:
  18. Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0)
  19. : _alpha(alpha), _eps(eps), _weight_decay(weight_decay), _buf_index(false)
  20. {
  21. cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
  22. cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
  23. _streams[0] = Context::Instance().GetCurrentStream();
  24. _streams[1] = Context::Instance().GetNewStream();
  25. }
  26. ~Adagrad_Optimizer()
  27. {
  28. cudaFreeHost(_doubled_buffer[0]);
  29. cudaFreeHost(_doubled_buffer[1]);
  30. }
  31. #if defined(__AVX512__) or defined(__AVX256__)
  32. template <int span>
  33. void Step_AVX(size_t* rounded_size,
  34. float* _params,
  35. float* grads,
  36. float* _exp_avg_sq,
  37. size_t param_size,
  38. __half* dev_param = nullptr,
  39. bool half_precision = false);
  40. #endif
  41. STEP(1)
  42. STEP(4)
  43. STEP(8)
  44. inline void SynchronizeStreams()
  45. {
  46. for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
  47. }
  48. inline void IncrementStep(size_t step)
  49. {
  50. _step++;
  51. if (_step != step) { _step = step; }
  52. }
  53. inline void update_state(float lr, float epsilon, float weight_decay)
  54. {
  55. _alpha = lr;
  56. _eps = epsilon;
  57. _weight_decay = weight_decay;
  58. }
  59. private:
  60. float _alpha;
  61. float _eps;
  62. float _weight_decay;
  63. float _betta1_t;
  64. float _betta2_t;
  65. size_t _step;
  66. float* _doubled_buffer[2];
  67. bool _buf_index;
  68. cudaStream_t _streams[2];
  69. };
  70. #if defined(__AVX512__) or defined(__AVX256__)
  71. template <int span>
  72. void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
  73. float* _params,
  74. float* grads,
  75. float* _exp_avg_sq,
  76. size_t _param_size,
  77. __half* dev_params,
  78. bool half_precision)
  79. {
  80. size_t new_rounded_size = 0;
  81. AVX_Data eps_4;
  82. eps_4.data = SIMD_SET(_eps);
  83. float step_size = -1 * _alpha;
  84. AVX_Data step_size_4;
  85. step_size_4.data = SIMD_SET(step_size);
  86. AVX_Data weight_decay4;
  87. if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay);
  88. new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
  89. for (size_t t = 0; t < new_rounded_size; t += TILE) {
  90. size_t copy_size = TILE;
  91. if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
  92. size_t offset = copy_size + t;
  93. if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
  94. #pragma omp parallel for
  95. for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
  96. AVX_Data grad_4[span];
  97. simd_load<span>(grad_4, grads + i, half_precision);
  98. AVX_Data momentum_4[span];
  99. simd_load<span>(momentum_4, grads + i, false);
  100. AVX_Data variance_4[span];
  101. simd_load<span>(variance_4, _exp_avg_sq + i, false);
  102. AVX_Data param_4[span];
  103. simd_load<span>(param_4, _params + i, half_precision);
  104. if (_weight_decay > 0) { simd_fma<span>(grad_4, param_4, weight_decay4, grad_4); }
  105. simd_fma<span>(variance_4, grad_4, grad_4, variance_4);
  106. simd_sqrt<span>(grad_4, variance_4);
  107. simd_add<span>(grad_4, grad_4, eps_4);
  108. simd_div<span>(grad_4, momentum_4, grad_4);
  109. simd_fma<span>(param_4, grad_4, step_size_4, param_4);
  110. simd_store<span>(_params + i, param_4, half_precision);
  111. if (dev_params) {
  112. simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
  113. }
  114. simd_store<span>(_exp_avg_sq + i, variance_4, false);
  115. }
  116. if (dev_params) {
  117. if (half_precision)
  118. launch_param_update_half(
  119. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  120. else
  121. launch_param_update(
  122. _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
  123. _buf_index = !_buf_index;
  124. }
  125. }
  126. *rounded_size = new_rounded_size;
  127. }
  128. #endif