fused_adam_frontend.cpp 781 B

1234567891011121314151617181920
  1. #include <torch/extension.h>
  2. void multi_tensor_adam_cuda(int chunk_size,
  3. at::Tensor noop_flag,
  4. std::vector<std::vector<at::Tensor>> tensor_lists,
  5. const float lr,
  6. const float beta1,
  7. const float beta2,
  8. const float epsilon,
  9. const int step,
  10. const int mode,
  11. const int bias_correction,
  12. const float weight_decay);
  13. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  14. {
  15. m.def("multi_tensor_adam",
  16. &multi_tensor_adam_cuda,
  17. "Compute and apply gradient update to parameters for Adam optimizer");
  18. }