cpu_lion.cpp 535 B

12345678910111213141516
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "cpu_lion.h"
  5. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  6. {
  7. m.def("lion_update", &ds_lion_step, "DeepSpeed CPU Lion update (C++)");
  8. m.def("lion_update_copy",
  9. &ds_lion_step_plus_copy,
  10. "DeepSpeed CPU Lion update and param copy (C++)");
  11. m.def("create_lion", &create_lion_optimizer, "DeepSpeed CPU Lion (C++)");
  12. m.def("destroy_lion", &destroy_lion_optimizer, "DeepSpeed CPU Lion destroy (C++)");
  13. }