pt_binding.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <torch/extension.h>
  6. #include <cassert>
  7. #include <vector>
  8. #include "quantization.h"
  9. template <typename T>
  10. at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
  11. {
  12. auto t_size = vals.sizes();
  13. int size = 1;
  14. for (auto dim : t_size) size *= dim;
  15. if ((((size / groups) - 1) / 4096 + 1) <= 256) {
  16. launch_fake_quantize_kernel(
  17. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  18. }
  19. return vals;
  20. }
  21. template <typename T>
  22. at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
  23. {
  24. auto t_size = vals.sizes();
  25. int size = 1;
  26. for (auto dim : t_size) size *= dim;
  27. if (((size / groups) / 4 / 1024) <= 256) {
  28. launch_sr_fake_quantize_kernel(
  29. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  30. }
  31. return vals;
  32. }
  33. template <typename T>
  34. at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
  35. {
  36. auto t_size = vals.sizes();
  37. int size = 1;
  38. for (auto dim : t_size) size *= dim;
  39. if ((((size / groups) - 1) / 4096 + 1) <= 256) {
  40. launch_fake_quantize_kernel_asym(
  41. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  42. }
  43. return vals;
  44. }
  45. template <typename T>
  46. at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
  47. {
  48. auto t_size = vals.sizes();
  49. int size = 1;
  50. for (auto dim : t_size) size *= dim;
  51. if (((size / groups) / 4 / 1024) <= 256) {
  52. launch_sr_fake_quantize_kernel_asym(
  53. (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
  54. }
  55. return vals;
  56. }
  57. std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
  58. int groups,
  59. int numBits,
  60. quantize::Type quantType)
  61. {
  62. auto dtype = at::kFloat;
  63. auto params_options = at::TensorOptions()
  64. .dtype(dtype)
  65. .layout(at::kStrided)
  66. .device(at::kCUDA)
  67. .requires_grad(false);
  68. const int param_elems = (quantize::requires_offset(quantType)) ? 2 : 1;
  69. auto params = torch::empty({groups, param_elems}, params_options);
  70. auto output_options = at::TensorOptions()
  71. .dtype(at::kChar)
  72. .layout(at::kStrided)
  73. .device(at::kCUDA)
  74. .requires_grad(false);
  75. auto output_sizes = input_vals.sizes().vec();
  76. output_sizes[output_sizes.size() - 1] /= numBits == 8 ? 1 : 2;
  77. auto output = torch::empty(output_sizes, output_options);
  78. const int elems_per_group = at::numel(input_vals) / groups;
  79. launch_quant((int8_t*)output.data_ptr(),
  80. (float*)params.data_ptr(),
  81. (__half*)input_vals.data_ptr(),
  82. groups,
  83. elems_per_group,
  84. numBits,
  85. quantType,
  86. at::cuda::getCurrentCUDAStream());
  87. return {output, params};
  88. }
  89. template <typename T>
  90. at::Tensor dequantize(at::Tensor& quantized_data,
  91. at::Tensor& params,
  92. int groups,
  93. int num_bits,
  94. quantize::Type quant_type)
  95. {
  96. auto dtype = (std::is_same<T, float>::value) ? torch::kFloat32 : torch::kFloat16;
  97. auto output_options = at::TensorOptions()
  98. .dtype(dtype)
  99. .layout(at::kStrided)
  100. .device(at::kCUDA)
  101. .requires_grad(false);
  102. auto output_sizes = quantized_data.sizes().vec();
  103. output_sizes[output_sizes.size() - 1] *= num_bits == 8 ? 1 : 2;
  104. auto output = torch::empty(output_sizes, output_options);
  105. const int total_elems = at::numel(output);
  106. const int elems_per_group = total_elems / groups;
  107. launch_dequantize_kernel((T*)output.data_ptr(),
  108. (const int8_t*)quantized_data.data_ptr(),
  109. (const float*)params.data_ptr(),
  110. quant_type,
  111. num_bits,
  112. elems_per_group,
  113. total_elems,
  114. at::cuda::getCurrentCUDAStream());
  115. return output;
  116. }
  117. at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in,
  118. at::Tensor& scale_buffer,
  119. at::Tensor& min_val_buffer,
  120. int num_group,
  121. int group_size)
  122. {
  123. auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
  124. auto output = torch::empty({num_group, group_size}, output_options);
  125. launch_dequantize_int4_to_half_experimental((uint8_t*)data_in.data_ptr(),
  126. (half*)output.data_ptr(),
  127. (half*)scale_buffer.data_ptr(),
  128. (half*)min_val_buffer.data_ptr(),
  129. num_group,
  130. group_size,
  131. at::cuda::getCurrentCUDAStream());
  132. return output;
  133. }
  134. at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
  135. at::Tensor& scale_buffer,
  136. at::Tensor& min_val_buffer,
  137. int num_group,
  138. int group_size)
  139. {
  140. auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
  141. auto output = torch::empty({num_group, group_size}, output_options);
  142. launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(),
  143. (half*)output.data_ptr(),
  144. (half*)scale_buffer.data_ptr(),
  145. (half*)min_val_buffer.data_ptr(),
  146. num_group,
  147. group_size,
  148. at::cuda::getCurrentCUDAStream());
  149. return output;
  150. }
  151. std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
  152. int groups,
  153. int num_bits,
  154. quantize::Type quant_type,
  155. int pipeline_size,
  156. int nodes,
  157. int devices_per_node)
  158. {
  159. auto scales_options = at::TensorOptions()
  160. .dtype(at::kFloat)
  161. .layout(at::kStrided)
  162. .device(at::kCUDA)
  163. .requires_grad(false);
  164. const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
  165. auto scales = torch::empty({groups, scales_elems}, scales_options);
  166. auto output_options = at::TensorOptions()
  167. .dtype(at::kChar)
  168. .layout(at::kStrided)
  169. .device(at::kCUDA)
  170. .requires_grad(false);
  171. const int quantization_scalar = 8 / num_bits;
  172. const int compressed_vals = at::numel(input_vals) / quantization_scalar;
  173. auto output = torch::empty({compressed_vals}, output_options);
  174. const int elems_per_group = at::numel(input_vals) / groups;
  175. launch_swizzled_quant((int8_t*)output.data_ptr(),
  176. (float*)scales.data_ptr(),
  177. (__half*)input_vals.data_ptr(),
  178. num_bits,
  179. quant_type,
  180. groups,
  181. elems_per_group,
  182. pipeline_size,
  183. nodes,
  184. devices_per_node,
  185. at::cuda::getCurrentCUDAStream());
  186. return {output, scales};
  187. }
  188. std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
  189. at::Tensor& input_scales,
  190. int in_groups,
  191. int out_groups,
  192. int num_bits,
  193. quantize::Type quant_type,
  194. int devices_per_node)
  195. {
  196. auto scales_options = at::TensorOptions()
  197. .dtype(at::kFloat)
  198. .layout(at::kStrided)
  199. .device(at::kCUDA)
  200. .requires_grad(false);
  201. const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
  202. auto scales = torch::empty({out_groups, scales_elems}, scales_options);
  203. auto output_options = at::TensorOptions()
  204. .dtype(at::kChar)
  205. .layout(at::kStrided)
  206. .device(at::kCUDA)
  207. .requires_grad(false);
  208. std::vector<long int> sz(input_vals.sizes().begin(), input_vals.sizes().end());
  209. sz[sz.size() - 1] = sz.back() / devices_per_node; // num of GPU per nodes
  210. const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;
  211. auto output = torch::empty(sz, output_options);
  212. const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
  213. const int elems_per_out_group = elems_per_in_tensor / out_groups;
  214. launch_dequant_reduce((int8_t*)output.data_ptr(),
  215. (float*)scales.data_ptr(),
  216. (const int8_t*)input_vals.data_ptr(),
  217. (const float*)input_scales.data_ptr(),
  218. devices_per_node,
  219. num_bits,
  220. quant_type,
  221. out_groups,
  222. elems_per_out_group,
  223. elems_per_in_tensor,
  224. in_groups / devices_per_node,
  225. elems_per_in_group,
  226. at::cuda::getCurrentCUDAStream());
  227. return {output, scales};
  228. }
  229. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  230. {
  231. m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  232. m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  233. m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  234. m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  235. m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
  236. m.def(
  237. "ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
  238. m.def("ds_sr_quantize_asym_fp32",
  239. &ds_sr_quantize_asym<float>,
  240. "DeepSpeed Quantize with fp32 (CUDA)");
  241. m.def("ds_sr_quantize_asym_fp16",
  242. &ds_sr_quantize_asym<__half>,
  243. "DeepSpeed Quantize with fp16 (CUDA)");
  244. pybind11::enum_<quantize::Type>(m, "QuantizationType")
  245. .value("Symmetric", quantize::Type::Symmetric)
  246. .value("Asymmetric", quantize::Type::Asymmetric)
  247. .export_values();
  248. m.def("quantize", &quantize_kernel);
  249. m.def("dequantize", &dequantize<__half>);
  250. m.def("dequantize_fp32", &dequantize<float>);
  251. m.def("dequantize_int4_to_half_experimental",
  252. &dequantize_int4_to_half_experimental,
  253. "Dequantize int4 to half (experimental)");
  254. m.def("dequantize_int8_to_half_experimental",
  255. &dequantize_int8_to_half_experimental,
  256. "Dequantize int8 to half (experimental)");
  257. m.def("swizzle_quant", &ds_swizzle_quant);
  258. m.def("quantized_reduction", &quantized_reduction);
  259. }