multi_tensor_adam.cu 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. /*
  5. Copyright NVIDIA/apex
  6. This file is adapted from fused adam in NVIDIA/apex, commit a109f85
  7. */
  8. #include <ATen/ATen.h>
  9. #include <ATen/AccumulateType.h>
  10. #include <ATen/cuda/CUDAContext.h>
  11. #include <ATen/cuda/Exceptions.h>
  12. // Another possibility:
  13. // #include <torch/all.h>
  14. #include <assert.h>
  15. #include "multi_tensor_apply.cuh"
  16. #include "type_shim.h"
  17. #define BLOCK_SIZE 512
  18. #define ILP 4
  19. typedef enum : int {
  20. ADAM_MODE_0 = 0, // L2 regularization mode
  21. ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
  22. } adamMode_t;
  23. using MATH_T = float;
  24. template <typename T, typename index_t>
  25. struct AdamFunctor {
  26. __device__ __forceinline__ void operator()(int chunk_size,
  27. volatile int* noop_gmem,
  28. TensorListMetadata<4>& tl,
  29. const float beta1,
  30. const float beta2,
  31. const float beta1_correction,
  32. const float beta2_correction,
  33. const float epsilon,
  34. const float lr,
  35. adamMode_t mode,
  36. const float decay)
  37. {
  38. // I'd like this kernel to propagate infs/nans.
  39. // if(*noop_gmem == 1)
  40. // return;
  41. index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
  42. // potentially use to pass in list of scalar
  43. // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  44. index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
  45. index_t n = tl.sizes[tensor_loc];
  46. T* g = (T*)tl.addresses[0][tensor_loc];
  47. g += chunk_idx * chunk_size;
  48. T* p = (T*)tl.addresses[1][tensor_loc];
  49. p += chunk_idx * chunk_size;
  50. T* m = (T*)tl.addresses[2][tensor_loc];
  51. m += chunk_idx * chunk_size;
  52. T* v = (T*)tl.addresses[3][tensor_loc];
  53. v += chunk_idx * chunk_size;
  54. n -= chunk_idx * chunk_size;
  55. // see note in multi_tensor_scale_kernel.cu
  56. for (index_t i_start = 0; i_start < n && i_start < chunk_size;
  57. i_start += blockDim.x * ILP) {
  58. MATH_T r_g[ILP];
  59. MATH_T r_p[ILP];
  60. MATH_T r_m[ILP];
  61. MATH_T r_v[ILP];
  62. #pragma unroll
  63. for (int ii = 0; ii < ILP; ii++) {
  64. int i = i_start + threadIdx.x + ii * blockDim.x;
  65. if (i < n && i < chunk_size) {
  66. r_g[ii] = g[i];
  67. r_p[ii] = p[i];
  68. r_m[ii] = m[i];
  69. r_v[ii] = v[i];
  70. } else {
  71. r_g[ii] = MATH_T(0);
  72. r_p[ii] = MATH_T(0);
  73. r_m[ii] = MATH_T(0);
  74. r_v[ii] = MATH_T(0);
  75. }
  76. }
  77. #pragma unroll
  78. for (int ii = 0; ii < ILP; ii++) {
  79. if (mode == ADAM_MODE_0) { // L2
  80. r_g[ii] = r_g[ii] + (decay * r_p[ii]);
  81. r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
  82. r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
  83. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  84. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  85. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  86. MATH_T update = next_m_unbiased / denom;
  87. r_p[ii] = r_p[ii] - (lr * update);
  88. } else { // weight decay
  89. r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
  90. r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
  91. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  92. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  93. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  94. MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
  95. r_p[ii] = r_p[ii] - (lr * update);
  96. }
  97. }
  98. #pragma unroll
  99. for (int ii = 0; ii < ILP; ii++) {
  100. int i = i_start + threadIdx.x + ii * blockDim.x;
  101. if (i < n && i < chunk_size) {
  102. p[i] = r_p[ii];
  103. m[i] = r_m[ii];
  104. v[i] = r_v[ii];
  105. }
  106. }
  107. }
  108. }
  109. };
  110. void multi_tensor_adam_cuda(int chunk_size,
  111. at::Tensor noop_flag,
  112. std::vector<std::vector<at::Tensor>> tensor_lists,
  113. const float lr,
  114. const float beta1,
  115. const float beta2,
  116. const float epsilon,
  117. const int step,
  118. const int mode,
  119. const int bias_correction,
  120. const float weight_decay)
  121. {
  122. using namespace at;
  123. // Handle bias correction mode
  124. float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
  125. if (bias_correction == 1) {
  126. bias_correction1 = 1 - std::pow(beta1, step);
  127. bias_correction2 = 1 - std::pow(beta2, step);
  128. }
  129. size_t max_size = 0;
  130. bool requires_64bit_indexing = false;
  131. for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
  132. for (auto it2 = it->begin(); it2 != it->end(); it2++) {
  133. if (it2->numel() > max_size) {
  134. max_size = it2->numel();
  135. if (max_size >= INT_MAX) {
  136. requires_64bit_indexing = true;
  137. break;
  138. }
  139. }
  140. }
  141. if (requires_64bit_indexing) { break; }
  142. }
  143. // Assume single type across p,g,m1,m2 now
  144. if (requires_64bit_indexing) {
  145. DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
  146. 0,
  147. "adam",
  148. multi_tensor_apply<4>((int64_t)BLOCK_SIZE,
  149. (int64_t)chunk_size,
  150. noop_flag,
  151. tensor_lists,
  152. AdamFunctor<scalar_t_0, int64_t>(),
  153. beta1,
  154. beta2,
  155. bias_correction1,
  156. bias_correction2,
  157. epsilon,
  158. lr,
  159. (adamMode_t)mode,
  160. weight_decay);)
  161. } else {
  162. DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
  163. 0,
  164. "adam",
  165. multi_tensor_apply<4>(BLOCK_SIZE,
  166. chunk_size,
  167. noop_flag,
  168. tensor_lists,
  169. AdamFunctor<scalar_t_0, int32_t>(),
  170. beta1,
  171. beta2,
  172. bias_correction1,
  173. bias_correction2,
  174. epsilon,
  175. lr,
  176. (adamMode_t)mode,
  177. weight_decay);)
  178. }
  179. AT_CUDA_CHECK(cudaGetLastError());
  180. }