multi_tensor_adam.cu 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. /*
  5. Copyright NVIDIA/apex
  6. This file is adapted from fused adam in NVIDIA/apex, commit a109f85
  7. */
  8. #include <ATen/ATen.h>
  9. #include <ATen/AccumulateType.h>
  10. #include <ATen/cuda/CUDAContext.h>
  11. #include <ATen/cuda/Exceptions.h>
  12. // Another possibility:
  13. // #include <torch/all.h>
  14. #include <assert.h>
  15. #include "multi_tensor_apply.cuh"
  16. #include "type_shim.h"
  17. #define BLOCK_SIZE 512
  18. #define ILP 4
  19. typedef enum {
  20. ADAM_MODE_0 = 0, // L2 regularization mode
  21. ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
  22. } adamMode_t;
  23. using MATH_T = float;
  24. template <typename T>
  25. struct AdamFunctor {
  26. __device__ __forceinline__ void operator()(int chunk_size,
  27. volatile int* noop_gmem,
  28. TensorListMetadata<4>& tl,
  29. const float beta1,
  30. const float beta2,
  31. const float beta1_correction,
  32. const float beta2_correction,
  33. const float epsilon,
  34. const float lr,
  35. adamMode_t mode,
  36. const float decay)
  37. {
  38. // I'd like this kernel to propagate infs/nans.
  39. // if(*noop_gmem == 1)
  40. // return;
  41. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  42. // potentially use to pass in list of scalar
  43. // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  44. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  45. int n = tl.sizes[tensor_loc];
  46. T* g = (T*)tl.addresses[0][tensor_loc];
  47. g += chunk_idx * chunk_size;
  48. T* p = (T*)tl.addresses[1][tensor_loc];
  49. p += chunk_idx * chunk_size;
  50. T* m = (T*)tl.addresses[2][tensor_loc];
  51. m += chunk_idx * chunk_size;
  52. T* v = (T*)tl.addresses[3][tensor_loc];
  53. v += chunk_idx * chunk_size;
  54. n -= chunk_idx * chunk_size;
  55. // see note in multi_tensor_scale_kernel.cu
  56. for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
  57. MATH_T r_g[ILP];
  58. MATH_T r_p[ILP];
  59. MATH_T r_m[ILP];
  60. MATH_T r_v[ILP];
  61. #pragma unroll
  62. for (int ii = 0; ii < ILP; ii++) {
  63. int i = i_start + threadIdx.x + ii * blockDim.x;
  64. if (i < n && i < chunk_size) {
  65. r_g[ii] = g[i];
  66. r_p[ii] = p[i];
  67. r_m[ii] = m[i];
  68. r_v[ii] = v[i];
  69. } else {
  70. r_g[ii] = MATH_T(0);
  71. r_p[ii] = MATH_T(0);
  72. r_m[ii] = MATH_T(0);
  73. r_v[ii] = MATH_T(0);
  74. }
  75. }
  76. #pragma unroll
  77. for (int ii = 0; ii < ILP; ii++) {
  78. if (mode == ADAM_MODE_0) { // L2
  79. r_g[ii] = r_g[ii] + (decay * r_p[ii]);
  80. r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
  81. r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
  82. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  83. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  84. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  85. MATH_T update = next_m_unbiased / denom;
  86. r_p[ii] = r_p[ii] - (lr * update);
  87. } else { // weight decay
  88. r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
  89. r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
  90. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  91. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  92. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  93. MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
  94. r_p[ii] = r_p[ii] - (lr * update);
  95. }
  96. }
  97. #pragma unroll
  98. for (int ii = 0; ii < ILP; ii++) {
  99. int i = i_start + threadIdx.x + ii * blockDim.x;
  100. if (i < n && i < chunk_size) {
  101. p[i] = r_p[ii];
  102. m[i] = r_m[ii];
  103. v[i] = r_v[ii];
  104. }
  105. }
  106. }
  107. }
  108. };
  109. void multi_tensor_adam_cuda(int chunk_size,
  110. at::Tensor noop_flag,
  111. std::vector<std::vector<at::Tensor>> tensor_lists,
  112. const float lr,
  113. const float beta1,
  114. const float beta2,
  115. const float epsilon,
  116. const int step,
  117. const int mode,
  118. const int bias_correction,
  119. const float weight_decay)
  120. {
  121. using namespace at;
  122. // Handle bias correction mode
  123. float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
  124. if (bias_correction == 1) {
  125. bias_correction1 = 1 - std::pow(beta1, step);
  126. bias_correction2 = 1 - std::pow(beta2, step);
  127. }
  128. // Assume single type across p,g,m1,m2 now
  129. DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
  130. 0,
  131. "adam",
  132. multi_tensor_apply<4>(BLOCK_SIZE,
  133. chunk_size,
  134. noop_flag,
  135. tensor_lists,
  136. AdamFunctor<scalar_t_0>(),
  137. beta1,
  138. beta2,
  139. bias_correction1,
  140. bias_correction2,
  141. epsilon,
  142. lr,
  143. (adamMode_t)mode,
  144. weight_decay);)
  145. AT_CUDA_CHECK(cudaGetLastError());
  146. }