pt_binding.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <torch/extension.h>
  6. #include <cassert>
  7. #include <vector>
  8. #include "quantization.h"
  9. template <typename T>
  10. at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
  11. {
  12. auto t_size = vals.sizes();
  13. int size = 1;
  14. for (auto dim : t_size) size *= dim;
  15. if ((((size / groups) - 1) / 4096 + 1) <= 256) {
  16. launch_fake_quantize_kernel(
  17. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  18. }
  19. return vals;
  20. }
  21. template <typename T>
  22. at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
  23. {
  24. auto t_size = vals.sizes();
  25. int size = 1;
  26. for (auto dim : t_size) size *= dim;
  27. if (((size / groups) / 4 / 1024) <= 256) {
  28. launch_sr_fake_quantize_kernel(
  29. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  30. }
  31. return vals;
  32. }
  33. template <typename T>
  34. at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
  35. {
  36. auto t_size = vals.sizes();
  37. int size = 1;
  38. for (auto dim : t_size) size *= dim;
  39. if ((((size / groups) - 1) / 4096 + 1) <= 256) {
  40. launch_fake_quantize_kernel_asym(
  41. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  42. }
  43. return vals;
  44. }
  45. template <typename T>
  46. at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
  47. {
  48. auto t_size = vals.sizes();
  49. int size = 1;
  50. for (auto dim : t_size) size *= dim;
  51. if (((size / groups) / 4 / 1024) <= 256) {
  52. launch_sr_fake_quantize_kernel_asym(
  53. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  54. }
  55. return vals;
  56. }
  57. std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
  58. int groups,
  59. int numBits,
  60. quantize::Type quantType)
  61. {
  62. auto dtype = at::kFloat;
  63. auto params_options = at::TensorOptions()
  64. .dtype(dtype)
  65. .layout(at::kStrided)
  66. .device(at::kCUDA)
  67. .requires_grad(false);
  68. const int param_elems = (quantize::requires_offset(quantType)) ? 2 : 1;
  69. auto params = torch::empty({groups, param_elems}, params_options);
  70. auto output_options = at::TensorOptions()
  71. .dtype(at::kChar)
  72. .layout(at::kStrided)
  73. .device(at::kCUDA)
  74. .requires_grad(false);
  75. auto output_sizes = input_vals.sizes().vec();
  76. output_sizes[output_sizes.size() - 1] /= numBits == 8 ? 1 : 2;
  77. auto output = torch::empty(output_sizes, output_options);
  78. const int elems_per_group = at::numel(input_vals) / groups;
  79. launch_quant((int8_t*)output.data_ptr(),
  80. (float*)params.data_ptr(),
  81. (__half*)input_vals.data_ptr(),
  82. groups,
  83. elems_per_group,
  84. numBits,
  85. quantType,
  86. at::cuda::getCurrentCUDAStream());
  87. return {output, params};
  88. }
  89. template <typename T>
  90. at::Tensor dequantize(at::Tensor& quantized_data,
  91. at::Tensor& params,
  92. int groups,
  93. int num_bits,
  94. quantize::Type quant_type)
  95. {
  96. auto dtype = (std::is_same<T, float>::value) ? torch::kFloat32 : torch::kFloat16;
  97. auto output_options = at::TensorOptions()
  98. .dtype(dtype)
  99. .layout(at::kStrided)
  100. .device(at::kCUDA)
  101. .requires_grad(false);
  102. auto output_sizes = quantized_data.sizes().vec();
  103. output_sizes[output_sizes.size() - 1] *= num_bits == 8 ? 1 : 2;
  104. auto output = torch::empty(output_sizes, output_options);
  105. const int total_elems = at::numel(output);
  106. const int elems_per_group = total_elems / groups;
  107. launch_dequantize_kernel((T*)output.data_ptr(),
  108. (const int8_t*)quantized_data.data_ptr(),
  109. (const float*)params.data_ptr(),
  110. quant_type,
  111. num_bits,
  112. elems_per_group,
  113. total_elems,
  114. at::cuda::getCurrentCUDAStream());
  115. return output;
  116. }
  117. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  118. {
  119. m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  120. m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  121. m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  122. m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  123. m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  124. m.def(
  125. "ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  126. m.def("ds_sr_quantize_asym_fp32",
  127. &ds_sr_quantize_asym<float>,
  128. "DeepSpeed Quantize with fp32 (CUDA)");
  129. m.def("ds_sr_quantize_asym_fp16",
  130. &ds_sr_quantize_asym<__half>,
  131. "DeepSpeed Quantize with fp16 (CUDA)");
  132. pybind11::enum_<quantize::Type>(m, "QuantizationType")
  133. .value("Symmetric", quantize::Type::Symmetric)
  134. .value("Asymmetric", quantize::Type::Asymmetric)
  135. .export_values();
  136. m.def("quantize", &quantize_kernel);
  137. m.def("dequantize", &dequantize<__half>);
  138. m.def("dequantize_fp32", &dequantize<float>);
  139. }