dequantize.cu 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "dequantization_utils.h"
  5. #include "memory_access_utils.h"
  6. namespace cg = cooperative_groups;
  7. template <typename T, int numBits, dequantize::Type qType, int unroll, int threads>
  8. __global__ void dequantize_kernel(T* __restrict__ dequant_data,
  9. const int8_t* __restrict__ q_data,
  10. const float* __restrict__ q_params,
  11. int elems_per_group,
  12. int total_elems)
  13. {
  14. dequantize::to_global<T, numBits, qType, unroll, threads>(
  15. dequant_data, q_data, q_params, elems_per_group, total_elems);
  16. }
  17. #define LAUNCH_DEQUANT_KERNEL(num_bits, q_type) \
  18. dequantize_kernel<T, num_bits, q_type, unroll, threads><<<grid, block, 0, stream>>>( \
  19. dequant_data, q_data, q_params, elems_per_group, total_elems);
  20. template <typename T>
  21. void launch_dequantize_kernel(T* dequant_data,
  22. const int8_t* q_data,
  23. const float* q_params,
  24. quantize::Type q_type,
  25. int num_bits,
  26. int elems_per_group,
  27. int total_elems,
  28. cudaStream_t stream)
  29. {
  30. constexpr int unroll = 8;
  31. constexpr int threads = 512;
  32. constexpr int elems_per_block = unroll * threads * dequantize::granularity / (sizeof(T));
  33. const dim3 block(threads);
  34. const dim3 grid((total_elems + elems_per_block - 1) / elems_per_block);
  35. // TODO(cmikeh2): It may make sense to tune unroll, there is perf benefit for large
  36. // problem sizes with this large unroll value.
  37. if (num_bits == 8 && q_type == quantize::Type::Symmetric) {
  38. LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Symmetric);
  39. } else if (num_bits == 8 && q_type == quantize::Type::Asymmetric) {
  40. LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Asymmetric);
  41. } else if (num_bits == 4 && q_type == quantize::Type::Symmetric) {
  42. LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Symmetric);
  43. } else if (num_bits == 4 && q_type == quantize::Type::Asymmetric) {
  44. LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Asymmetric);
  45. }
  46. }
  47. template void launch_dequantize_kernel(__half* dequant_data,
  48. const int8_t* q_data,
  49. const float* q_params,
  50. quantize::Type q_type,
  51. int num_bits,
  52. int elems_per_group,
  53. int total_elems,
  54. cudaStream_t stream);
  55. template void launch_dequantize_kernel(float* dequant_data,
  56. const int8_t* q_data,
  57. const float* q_params,
  58. quantize::Type q_type,
  59. int num_bits,
  60. int elems_per_group,
  61. int total_elems,
  62. cudaStream_t stream);