#include #include #include #include "custom_cuda_layers.h" template at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits) { auto t_size = vals.sizes(); int size = 1; for (auto dim : t_size) size *= dim; if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) { launch_quantize_kernel( (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream()); } return vals; } template at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits) { auto t_size = vals.sizes(); int size = 1; for (auto dim : t_size) size *= dim; if (((size / groups) / 4 / 1024) <= 256) { launch_sr_quantize_kernel( (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream()); } return vals; } template at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits) { auto t_size = vals.sizes(); int size = 1; for (auto dim : t_size) size *= dim; if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) { launch_quantize_kernel_asym( (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream()); } return vals; } template at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits) { auto t_size = vals.sizes(); int size = 1; for (auto dim : t_size) size *= dim; if (((size / groups) / 4 / 1024) <= 256) { launch_sr_quantize_kernel_asym( (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream()); } return vals; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ds_quantize_fp32", &ds_quantize, "DeepSpeed Quantize with fp32 (CUDA)"); m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)"); m.def("ds_sr_quantize_fp32", &ds_sr_quantize, "DeepSpeed Quantize with fp32 (CUDA)"); m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)"); m.def("ds_quantize_asym_fp32", &ds_quantize_asym, "DeepSpeed Quantize with fp32 (CUDA)"); m.def( "ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)"); m.def("ds_sr_quantize_asym_fp32", &ds_sr_quantize_asym, "DeepSpeed Quantize with fp32 (CUDA)"); m.def("ds_sr_quantize_asym_fp16", &ds_sr_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)"); }