quantization.h 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <cuda_fp16.h>
  6. #include "ds_kernel_utils.h"
  7. namespace quantize {
  8. enum class Type { Symmetric, Asymmetric };
  9. struct PackedInt4 {
  10. int8_t high : 4;
  11. int8_t low : 4;
  12. };
  13. DS_HD_INLINE bool requires_offset(Type qType) { return qType == Type::Asymmetric; }
  14. } // namespace quantize
  15. void launch_quant(int8_t* output_data,
  16. float* params,
  17. const __half* input_data,
  18. const int groups,
  19. const int elems_per_group,
  20. const int num_bits,
  21. const quantize::Type quant_type,
  22. cudaStream_t stream);
  23. template <typename T>
  24. void launch_dequantize_kernel(T* dequant_data,
  25. const int8_t* q_data,
  26. const float* q_params,
  27. quantize::Type q_type,
  28. int num_bits,
  29. int elems_per_group,
  30. int total_elems,
  31. cudaStream_t stream);
  32. template <typename T>
  33. void launch_fake_quantize_kernel(T* vals,
  34. int total_count,
  35. int group_num,
  36. int num_bits,
  37. cudaStream_t stream);
  38. template <typename T>
  39. void launch_sr_fake_quantize_kernel(T* vals,
  40. int total_count,
  41. int group_num,
  42. int num_bits,
  43. cudaStream_t stream);
  44. template <typename T>
  45. void launch_fake_quantize_kernel_asym(T* vals,
  46. int total_count,
  47. int group_num,
  48. int num_bits,
  49. cudaStream_t stream);
  50. template <typename T>
  51. void launch_sr_fake_quantize_kernel_asym(T* vals,
  52. int total_count,
  53. int group_num,
  54. int num_bits,
  55. cudaStream_t stream);