quantization.h 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #pragma once
  2. #include <cuda_fp16.h>
  3. #include "ds_kernel_utils.h"
  4. namespace quantize {
  5. enum class Type { Symmetric, Asymmetric, IntegerSymmetric };
  6. struct PackedInt4 {
  7. int8_t high : 4;
  8. int8_t low : 4;
  9. };
  10. DS_HD_INLINE bool requires_offset(Type qType) { return qType == Type::Asymmetric; }
  11. } // namespace quantize
  12. template <int numBits, quantize::Type qType>
  13. void launch_quant(int8_t* output_data,
  14. float* params,
  15. const __half* input_data,
  16. int groups,
  17. int elems_per_group,
  18. cudaStream_t stream);
  19. void launch_dequantize_kernel(__half* dequant_data,
  20. const int8_t* q_data,
  21. const float* q_params,
  22. quantize::Type q_type,
  23. int num_bits,
  24. int elems_per_group,
  25. int total_elems,
  26. cudaStream_t stream);
  27. template <typename T>
  28. void launch_fake_quantize_kernel(T* vals,
  29. int total_count,
  30. int group_num,
  31. int num_bits,
  32. cudaStream_t stream);
  33. template <typename T>
  34. void launch_sr_fake_quantize_kernel(T* vals,
  35. int total_count,
  36. int group_num,
  37. int num_bits,
  38. cudaStream_t stream);
  39. template <typename T>
  40. void launch_fake_quantize_kernel_asym(T* vals,
  41. int total_count,
  42. int group_num,
  43. int num_bits,
  44. cudaStream_t stream);
  45. template <typename T>
  46. void launch_sr_fake_quantize_kernel_asym(T* vals,
  47. int total_count,
  48. int group_num,
  49. int num_bits,
  50. cudaStream_t stream);