dequantize.cu 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. /*
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. */
  4. #include "dequantization_utils.h"
  5. #include "memory_access_utils.h"
  6. namespace cg = cooperative_groups;
  7. template <int numBits, dequantize::Type qType, int unroll, int threads>
  8. __global__ void dequantize_kernel(__half* __restrict__ dequant_data,
  9. const int8_t* __restrict__ q_data,
  10. const float* __restrict__ q_params,
  11. int elems_per_group,
  12. int total_elems)
  13. {
  14. dequantize::to_global<numBits, qType, unroll, threads>(
  15. dequant_data, q_data, q_params, elems_per_group, total_elems);
  16. }
  17. #define LAUNCH_DEQUANT_KERNEL(num_bits, q_type) \
  18. dequantize_kernel<num_bits, q_type, unroll, threads><<<grid, block, 0, stream>>>( \
  19. dequant_data, q_data, q_params, elems_per_group, total_elems);
  20. void launch_dequantize_kernel(__half* dequant_data,
  21. const int8_t* q_data,
  22. const float* q_params,
  23. quantize::Type q_type,
  24. int num_bits,
  25. int elems_per_group,
  26. int total_elems,
  27. cudaStream_t stream)
  28. {
  29. constexpr int unroll = 8;
  30. constexpr int threads = 512;
  31. constexpr int elems_per_block = unroll * threads * dequantize::h_per_chunk;
  32. const dim3 block(threads);
  33. const dim3 grid((total_elems + elems_per_block - 1) / elems_per_block);
  34. // TODO(cmikeh2): It may make sense to tune unroll, there is perf benefit for large
  35. // problem sizes with this large unroll value.
  36. if (num_bits == 8 && q_type == quantize::Type::Symmetric) {
  37. LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Symmetric);
  38. } else if (num_bits == 8 && q_type == quantize::Type::Asymmetric) {
  39. LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Asymmetric);
  40. } else if (num_bits == 4 && q_type == quantize::Type::Symmetric) {
  41. LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Symmetric);
  42. } else if (num_bits == 4 && q_type == quantize::Type::Asymmetric) {
  43. LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Asymmetric);
  44. }
  45. }