quantizer.py 1.1 KB

123456789101112131415161718192021222324252627
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. from deepspeed.ops.op_builder import QuantizerBuilder
  6. # Cuda modules will be imported if needed
  7. quantizer_cuda_module = None
  8. def ds_quantizer(input, groups=1, bit_num=8, sr=False, asym=False):
  9. # Load cuda modules if needed
  10. global quantizer_cuda_module
  11. if quantizer_cuda_module is None:
  12. quantizer_cuda_module = QuantizerBuilder().load()
  13. if sr:
  14. if asym:
  15. quantize_func = quantizer_cuda_module.ds_sr_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_asym_fp32
  16. else:
  17. quantize_func = quantizer_cuda_module.ds_sr_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_fp32
  18. else:
  19. if asym:
  20. quantize_func = quantizer_cuda_module.ds_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_asym_fp32
  21. else:
  22. quantize_func = quantizer_cuda_module.ds_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_fp32
  23. return quantize_func(input, groups, bit_num)