fused_adam_frontend.cpp 880 B

12345678910111213141516171819202122232425
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <torch/extension.h>
  5. void multi_tensor_adam_cuda(int chunk_size,
  6. at::Tensor noop_flag,
  7. std::vector<std::vector<at::Tensor>> tensor_lists,
  8. const float lr,
  9. const float beta1,
  10. const float beta2,
  11. const float epsilon,
  12. const int step,
  13. const int mode,
  14. const int bias_correction,
  15. const float weight_decay);
  16. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  17. {
  18. m.def("multi_tensor_adam",
  19. &multi_tensor_adam_cuda,
  20. "Compute and apply gradient update to parameters for Adam optimizer");
  21. }