fused_lion.cpp 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "cpu_lion.h"
  5. // C++ interface
  6. void multi_tensor_lion(int chunk_size,
  7. at::Tensor noop_flag,
  8. std::vector<std::vector<at::Tensor>> tensor_lists, /*gpmv*/
  9. const float lr,
  10. const float beta1,
  11. const float beta2,
  12. const int step,
  13. const int mode,
  14. const float weight_decay)
  15. {
  16. static bool initialized = false;
  17. if (!initialized) {
  18. create_lion_optimizer(0);
  19. initialized = true;
  20. }
  21. for (int i = 0; i < tensor_lists[0].size(); i++) {
  22. ds_lion_step(0,
  23. step,
  24. lr,
  25. beta1,
  26. beta2,
  27. weight_decay,
  28. tensor_lists[1][i],
  29. tensor_lists[0][i],
  30. tensor_lists[2][i]);
  31. }
  32. }
  33. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  34. {
  35. m.def("multi_tensor_lion",
  36. &multi_tensor_lion,
  37. "Compute and apply gradient update to parameters for Lion optimizer");
  38. }