fused_lamb_cuda.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <torch/extension.h>
  5. // CUDA forward declaration
  6. void fused_lamb_cuda(at::Tensor& p,
  7. at::Tensor& p_copy,
  8. at::Tensor& m,
  9. at::Tensor& v,
  10. at::Tensor& g,
  11. float lr,
  12. float beta1,
  13. float beta2,
  14. float max_coeff,
  15. float min_coeff,
  16. float eps,
  17. float grad_scale,
  18. int step,
  19. int mode,
  20. int bias_correction,
  21. float decay,
  22. at::Tensor& w_l2_i,
  23. at::Tensor& u_l2_i,
  24. at::Tensor& lamb_coeff_val);
  25. #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
  26. #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
  27. #define CHECK_INPUT(x) \
  28. CHECK_CUDA(x); \
  29. CHECK_CONTIGUOUS(x)
  30. // C++ interface
  31. at::Tensor lamb(at::Tensor& p,
  32. at::Tensor& p_copy,
  33. at::Tensor& m,
  34. at::Tensor& v,
  35. at::Tensor& g,
  36. float lr,
  37. float beta1,
  38. float beta2,
  39. float max_coeff,
  40. float min_coeff,
  41. float eps,
  42. float grad_scale,
  43. int step,
  44. int mode,
  45. int bias_correction,
  46. float decay)
  47. {
  48. CHECK_INPUT(p);
  49. if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
  50. CHECK_INPUT(m);
  51. CHECK_INPUT(v);
  52. CHECK_INPUT(g);
  53. int64_t num_elem = p.numel();
  54. AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
  55. AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
  56. AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
  57. AT_ASSERTM(
  58. p_copy.numel() == num_elem || p_copy.numel() == 0,
  59. "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
  60. // intermediate for weight L2 reduction
  61. // make sure that the threads per block is at least 512 during the kernel launch otherwise the
  62. // behaviour is unexpected
  63. at::Tensor w_l2_i = at::empty(
  64. {512},
  65. p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
  66. : p.type().scalarType()));
  67. // intermediate for update L2 reduction
  68. // make sure that the threads per block is at least 512 during the kernel launch otherwise the
  69. // behaviour is unexpected
  70. at::Tensor u_l2_i = at::empty(
  71. {512},
  72. p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
  73. : p.type().scalarType()));
  74. at::Tensor lamb_coeff_val = at::empty(
  75. {1},
  76. p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
  77. : p.type().scalarType()));
  78. fused_lamb_cuda(p,
  79. p_copy,
  80. m,
  81. v,
  82. g,
  83. lr,
  84. beta1,
  85. beta2,
  86. max_coeff,
  87. min_coeff,
  88. eps,
  89. grad_scale,
  90. step,
  91. mode,
  92. bias_correction,
  93. decay,
  94. w_l2_i,
  95. u_l2_i,
  96. lamb_coeff_val);
  97. return lamb_coeff_val;
  98. }
  99. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  100. {
  101. m.def("lamb", &lamb, "Adam optimized CUDA implementation with LAMB.");
  102. }