multi_tensor_adam.cu 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /* Copyright 2020 The Microsoft DeepSpeed Team
  2. Copyright NVIDIA/apex
  3. This file is adapted from fused adam in NVIDIA/apex, commit a109f85
  4. */
  5. #include <ATen/ATen.h>
  6. #include <ATen/AccumulateType.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <ATen/cuda/Exceptions.h>
  9. // Another possibility:
  10. // #include <torch/all.h>
  11. #include <assert.h>
  12. #include "multi_tensor_apply.cuh"
  13. #include "type_shim.h"
  14. #define BLOCK_SIZE 512
  15. #define ILP 4
  16. typedef enum {
  17. ADAM_MODE_0 = 0, // L2 regularization mode
  18. ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
  19. } adamMode_t;
  20. using MATH_T = float;
  21. template <typename T>
  22. struct AdamFunctor {
  23. __device__ __forceinline__ void operator()(int chunk_size,
  24. volatile int* noop_gmem,
  25. TensorListMetadata<4>& tl,
  26. const float beta1,
  27. const float beta2,
  28. const float beta1_correction,
  29. const float beta2_correction,
  30. const float epsilon,
  31. const float lr,
  32. adamMode_t mode,
  33. const float decay)
  34. {
  35. // I'd like this kernel to propagate infs/nans.
  36. // if(*noop_gmem == 1)
  37. // return;
  38. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  39. // potentially use to pass in list of scalar
  40. // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  41. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  42. int n = tl.sizes[tensor_loc];
  43. T* g = (T*)tl.addresses[0][tensor_loc];
  44. g += chunk_idx * chunk_size;
  45. T* p = (T*)tl.addresses[1][tensor_loc];
  46. p += chunk_idx * chunk_size;
  47. T* m = (T*)tl.addresses[2][tensor_loc];
  48. m += chunk_idx * chunk_size;
  49. T* v = (T*)tl.addresses[3][tensor_loc];
  50. v += chunk_idx * chunk_size;
  51. n -= chunk_idx * chunk_size;
  52. // see note in multi_tensor_scale_kernel.cu
  53. for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
  54. MATH_T r_g[ILP];
  55. MATH_T r_p[ILP];
  56. MATH_T r_m[ILP];
  57. MATH_T r_v[ILP];
  58. #pragma unroll
  59. for (int ii = 0; ii < ILP; ii++) {
  60. int i = i_start + threadIdx.x + ii * blockDim.x;
  61. if (i < n && i < chunk_size) {
  62. r_g[ii] = g[i];
  63. r_p[ii] = p[i];
  64. r_m[ii] = m[i];
  65. r_v[ii] = v[i];
  66. } else {
  67. r_g[ii] = MATH_T(0);
  68. r_p[ii] = MATH_T(0);
  69. r_m[ii] = MATH_T(0);
  70. r_v[ii] = MATH_T(0);
  71. }
  72. }
  73. #pragma unroll
  74. for (int ii = 0; ii < ILP; ii++) {
  75. if (mode == ADAM_MODE_0) { // L2
  76. r_g[ii] = r_g[ii] + (decay * r_p[ii]);
  77. r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
  78. r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
  79. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  80. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  81. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  82. MATH_T update = next_m_unbiased / denom;
  83. r_p[ii] = r_p[ii] - (lr * update);
  84. } else { // weight decay
  85. r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
  86. r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
  87. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  88. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  89. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  90. MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
  91. r_p[ii] = r_p[ii] - (lr * update);
  92. }
  93. }
  94. #pragma unroll
  95. for (int ii = 0; ii < ILP; ii++) {
  96. int i = i_start + threadIdx.x + ii * blockDim.x;
  97. if (i < n && i < chunk_size) {
  98. p[i] = r_p[ii];
  99. m[i] = r_m[ii];
  100. v[i] = r_v[ii];
  101. }
  102. }
  103. }
  104. }
  105. };
  106. void multi_tensor_adam_cuda(int chunk_size,
  107. at::Tensor noop_flag,
  108. std::vector<std::vector<at::Tensor>> tensor_lists,
  109. const float lr,
  110. const float beta1,
  111. const float beta2,
  112. const float epsilon,
  113. const int step,
  114. const int mode,
  115. const int bias_correction,
  116. const float weight_decay)
  117. {
  118. using namespace at;
  119. // Handle bias correction mode
  120. float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
  121. if (bias_correction == 1) {
  122. bias_correction1 = 1 - std::pow(beta1, step);
  123. bias_correction2 = 1 - std::pow(beta2, step);
  124. }
  125. // Assume single type across p,g,m1,m2 now
  126. DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
  127. 0,
  128. "adam",
  129. multi_tensor_apply<4>(BLOCK_SIZE,
  130. chunk_size,
  131. noop_flag,
  132. tensor_lists,
  133. AdamFunctor<scalar_t_0>(),
  134. beta1,
  135. beta2,
  136. bias_correction1,
  137. bias_correction2,
  138. epsilon,
  139. lr,
  140. (adamMode_t)mode,
  141. weight_decay);)
  142. AT_CUDA_CHECK(cudaGetLastError());
  143. }