multi_tensor_lion.cu 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. /*
  5. Copyright NVIDIA/apex
  6. This file is adapted from fused adam in NVIDIA/apex, commit a109f85
  7. */
  8. #include <ATen/ATen.h>
  9. #include <ATen/AccumulateType.h>
  10. #include <ATen/cuda/CUDAContext.h>
  11. #include <ATen/cuda/Exceptions.h>
  12. // Another possibility:
  13. // #include <torch/all.h>
  14. #include <assert.h>
  15. #include "multi_tensor_apply.cuh"
  16. #include "type_shim.h"
  17. #define BLOCK_SIZE 512
  18. #define ILP 4
  19. using MATH_T = float;
  20. template <typename T>
  21. struct LionFunctor {
  22. __device__ __forceinline__ void operator()(int chunk_size,
  23. volatile int* noop_gmem,
  24. TensorListMetadata<3>& tl,
  25. const float beta1,
  26. const float beta2,
  27. const float lr,
  28. const float decay)
  29. {
  30. // I'd like this kernel to propagate infs/nans.
  31. // if(*noop_gmem == 1)
  32. // return;
  33. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  34. // potentially use to pass in list of scalar
  35. // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  36. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  37. int n = tl.sizes[tensor_loc];
  38. T* g = (T*)tl.addresses[0][tensor_loc];
  39. g += chunk_idx * chunk_size;
  40. T* p = (T*)tl.addresses[1][tensor_loc];
  41. p += chunk_idx * chunk_size;
  42. T* m = (T*)tl.addresses[2][tensor_loc];
  43. m += chunk_idx * chunk_size;
  44. n -= chunk_idx * chunk_size;
  45. MATH_T after_decay = 1.0f - lr * decay;
  46. // see note in multi_tensor_scale_kernel.cu
  47. for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
  48. MATH_T r_g[ILP];
  49. MATH_T r_p[ILP];
  50. MATH_T r_m[ILP];
  51. #pragma unroll
  52. for (int ii = 0; ii < ILP; ii++) {
  53. int i = i_start + threadIdx.x + ii * blockDim.x;
  54. if (i < n && i < chunk_size) {
  55. r_g[ii] = g[i];
  56. r_p[ii] = p[i];
  57. r_m[ii] = m[i];
  58. } else {
  59. r_g[ii] = MATH_T(0);
  60. r_p[ii] = MATH_T(0);
  61. r_m[ii] = MATH_T(0);
  62. }
  63. }
  64. #pragma unroll
  65. for (int ii = 0; ii < ILP; ii++) {
  66. MATH_T c = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
  67. MATH_T update = c > 0 ? (-lr) : lr;
  68. r_p[ii] = r_p[ii] * after_decay + update;
  69. r_m[ii] = beta2 * r_m[ii] + (1 - beta2) * r_g[ii];
  70. }
  71. #pragma unroll
  72. for (int ii = 0; ii < ILP; ii++) {
  73. int i = i_start + threadIdx.x + ii * blockDim.x;
  74. if (i < n && i < chunk_size) {
  75. p[i] = r_p[ii];
  76. m[i] = r_m[ii];
  77. }
  78. }
  79. }
  80. }
  81. };
  82. void multi_tensor_lion_cuda(int chunk_size,
  83. at::Tensor noop_flag,
  84. std::vector<std::vector<at::Tensor>> tensor_lists,
  85. const float lr,
  86. const float beta1,
  87. const float beta2,
  88. const int step,
  89. const float weight_decay)
  90. {
  91. using namespace at;
  92. // Assume single type across p,g,m1,m2 now
  93. DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
  94. 0,
  95. "lion",
  96. multi_tensor_apply<3>(BLOCK_SIZE,
  97. chunk_size,
  98. noop_flag,
  99. tensor_lists,
  100. LionFunctor<scalar_t_0>(),
  101. beta1,
  102. beta2,
  103. lr,
  104. weight_decay);)
  105. AT_CUDA_CHECK(cudaGetLastError());
  106. }