pt_binding.cpp 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/extension.h>
  3. #include <cassert>
  4. #include <vector>
  5. #include "quantization.h"
  6. template <typename T>
  7. at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
  8. {
  9. auto t_size = vals.sizes();
  10. int size = 1;
  11. for (auto dim : t_size) size *= dim;
  12. if ((((size / groups) - 1) / 4096 + 1) <= 256) {
  13. launch_fake_quantize_kernel(
  14. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  15. }
  16. return vals;
  17. }
  18. template <typename T>
  19. at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
  20. {
  21. auto t_size = vals.sizes();
  22. int size = 1;
  23. for (auto dim : t_size) size *= dim;
  24. if (((size / groups) / 4 / 1024) <= 256) {
  25. launch_sr_fake_quantize_kernel(
  26. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  27. }
  28. return vals;
  29. }
  30. template <typename T>
  31. at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
  32. {
  33. auto t_size = vals.sizes();
  34. int size = 1;
  35. for (auto dim : t_size) size *= dim;
  36. if ((((size / groups) - 1) / 4096 + 1) <= 256) {
  37. launch_fake_quantize_kernel_asym(
  38. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  39. }
  40. return vals;
  41. }
  42. template <typename T>
  43. at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
  44. {
  45. auto t_size = vals.sizes();
  46. int size = 1;
  47. for (auto dim : t_size) size *= dim;
  48. if (((size / groups) / 4 / 1024) <= 256) {
  49. launch_sr_fake_quantize_kernel_asym(
  50. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  51. }
  52. return vals;
  53. }
  54. #define QUANTIZATION_CASE(TYPE, BITS) \
  55. case TYPE: \
  56. launch_quant<BITS, TYPE>((int8_t*)output.data_ptr(), \
  57. (float*)params.data_ptr(), \
  58. (__half*)input_vals.data_ptr(), \
  59. groups, \
  60. elems_per_group, \
  61. at::cuda::getCurrentCUDAStream()); \
  62. break;
  63. std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
  64. int groups,
  65. int numBits,
  66. quantize::Type quantType)
  67. {
  68. auto dtype = (quantType == quantize::Type::IntegerSymmetric) ? torch::kInt32 : at::kFloat;
  69. auto params_options = at::TensorOptions()
  70. .dtype(dtype)
  71. .layout(at::kStrided)
  72. .device(at::kCUDA)
  73. .requires_grad(false);
  74. const int param_elems = (quantize::requires_offset(quantType)) ? 2 : 1;
  75. auto params = torch::empty({groups, param_elems}, params_options);
  76. auto output_options = at::TensorOptions()
  77. .dtype(at::kChar)
  78. .layout(at::kStrided)
  79. .device(at::kCUDA)
  80. .requires_grad(false);
  81. auto output_sizes = input_vals.sizes().vec();
  82. output_sizes[output_sizes.size() - 1] /= numBits == 8 ? 1 : 2;
  83. auto output = torch::empty(output_sizes, output_options);
  84. const int elems_per_group = at::numel(input_vals) / groups;
  85. if (numBits == 4) {
  86. switch (quantType) {
  87. QUANTIZATION_CASE(quantize::Type::Symmetric, 4)
  88. QUANTIZATION_CASE(quantize::Type::Asymmetric, 4)
  89. QUANTIZATION_CASE(quantize::Type::IntegerSymmetric, 4)
  90. }
  91. } else {
  92. switch (quantType) {
  93. QUANTIZATION_CASE(quantize::Type::Symmetric, 8)
  94. QUANTIZATION_CASE(quantize::Type::Asymmetric, 8)
  95. QUANTIZATION_CASE(quantize::Type::IntegerSymmetric, 8)
  96. }
  97. }
  98. return {output, params};
  99. }
  100. at::Tensor dequantize(at::Tensor& quantized_data,
  101. at::Tensor& params,
  102. int groups,
  103. int num_bits,
  104. quantize::Type quant_type)
  105. {
  106. auto output_options = at::TensorOptions()
  107. .dtype(torch::kFloat16)
  108. .layout(at::kStrided)
  109. .device(at::kCUDA)
  110. .requires_grad(false);
  111. auto output_sizes = quantized_data.sizes().vec();
  112. output_sizes[output_sizes.size() - 1] *= num_bits == 8 ? 1 : 2;
  113. auto output = torch::empty(output_sizes, output_options);
  114. const int total_elems = at::numel(output);
  115. const int elems_per_group = total_elems / groups;
  116. launch_dequantize_kernel((__half*)output.data_ptr(),
  117. (const int8_t*)quantized_data.data_ptr(),
  118. (const float*)params.data_ptr(),
  119. quant_type,
  120. num_bits,
  121. elems_per_group,
  122. total_elems,
  123. at::cuda::getCurrentCUDAStream());
  124. return output;
  125. }
  126. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  127. {
  128. m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  129. m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  130. m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  131. m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  132. m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  133. m.def(
  134. "ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  135. m.def("ds_sr_quantize_asym_fp32",
  136. &ds_sr_quantize_asym<float>,
  137. "DeepSpeed Quantize with fp32 (CUDA)");
  138. m.def("ds_sr_quantize_asym_fp16",
  139. &ds_sr_quantize_asym<__half>,
  140. "DeepSpeed Quantize with fp16 (CUDA)");
  141. pybind11::enum_<quantize::Type>(m, "QuantizationType")
  142. .value("Symmetric", quantize::Type::Symmetric)
  143. .value("Asymmetric", quantize::Type::Asymmetric)
  144. .value("IntegerSymmetric", quantize::Type::IntegerSymmetric)
  145. .export_values();
  146. m.def("quantize", &quantize_kernel);
  147. m.def("dequantize", &dequantize);
  148. }