1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- #include <ATen/cuda/CUDAContext.h>
- #include <torch/extension.h>
- #include <vector>
- #include "custom_cuda_layers.h"
- template <typename T>
- at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
- {
- auto t_size = vals.sizes();
- int size = 1;
- for (auto dim : t_size) size *= dim;
- if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
- launch_quantize_kernel(
- (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
- }
- return vals;
- }
- template <typename T>
- at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
- {
- auto t_size = vals.sizes();
- int size = 1;
- for (auto dim : t_size) size *= dim;
- if (((size / groups) / 4 / 1024) <= 256) {
- launch_sr_quantize_kernel(
- (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
- }
- return vals;
- }
- template <typename T>
- at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
- {
- auto t_size = vals.sizes();
- int size = 1;
- for (auto dim : t_size) size *= dim;
- if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
- launch_quantize_kernel_asym(
- (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
- }
- return vals;
- }
- template <typename T>
- at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
- {
- auto t_size = vals.sizes();
- int size = 1;
- for (auto dim : t_size) size *= dim;
- if (((size / groups) / 4 / 1024) <= 256) {
- launch_sr_quantize_kernel_asym(
- (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
- }
- return vals;
- }
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
- {
- m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
- m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
- m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
- m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
- m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
- m.def(
- "ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
- m.def("ds_sr_quantize_asym_fp32",
- &ds_sr_quantize_asym<float>,
- "DeepSpeed Quantize with fp32 (CUDA)");
- m.def("ds_sr_quantize_asym_fp16",
- &ds_sr_quantize_asym<__half>,
- "DeepSpeed Quantize with fp16 (CUDA)");
- }
|