pt_binding.cpp 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/extension.h>
  3. #include <vector>
  4. #include "custom_cuda_layers.h"
  5. template <typename T>
  6. at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
  7. {
  8. auto t_size = vals.sizes();
  9. int size = 1;
  10. for (auto dim : t_size) size *= dim;
  11. if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
  12. launch_quantize_kernel(
  13. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  14. }
  15. return vals;
  16. }
  17. template <typename T>
  18. at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
  19. {
  20. auto t_size = vals.sizes();
  21. int size = 1;
  22. for (auto dim : t_size) size *= dim;
  23. if (((size / groups) / 4 / 1024) <= 256) {
  24. launch_sr_quantize_kernel(
  25. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  26. }
  27. return vals;
  28. }
  29. template <typename T>
  30. at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
  31. {
  32. auto t_size = vals.sizes();
  33. int size = 1;
  34. for (auto dim : t_size) size *= dim;
  35. if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
  36. launch_quantize_kernel_asym(
  37. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  38. }
  39. return vals;
  40. }
  41. template <typename T>
  42. at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
  43. {
  44. auto t_size = vals.sizes();
  45. int size = 1;
  46. for (auto dim : t_size) size *= dim;
  47. if (((size / groups) / 4 / 1024) <= 256) {
  48. launch_sr_quantize_kernel_asym(
  49. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  50. }
  51. return vals;
  52. }
  53. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  54. {
  55. m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  56. m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  57. m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  58. m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  59. m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  60. m.def(
  61. "ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  62. m.def("ds_sr_quantize_asym_fp32",
  63. &ds_sr_quantize_asym<float>,
  64. "DeepSpeed Quantize with fp32 (CUDA)");
  65. m.def("ds_sr_quantize_asym_fp16",
  66. &ds_sr_quantize_asym<__half>,
  67. "DeepSpeed Quantize with fp16 (CUDA)");
  68. }