quantizer.py 1.2 KB

1234567891011121314151617181920212223242526272829
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.ops.op_builder import QuantizerBuilder
  6. # Cuda modules will be imported if needed
  7. quantizer_cuda_module = None
  8. def ds_quantizer(input, groups=1, bit_num=8, sr=False, asym=False):
  9. # Load cuda modules if needed
  10. global quantizer_cuda_module
  11. if quantizer_cuda_module is None:
  12. quantizer_cuda_module = QuantizerBuilder().load()
  13. if sr:
  14. if asym:
  15. quantize_func = quantizer_cuda_module.ds_sr_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_asym_fp32
  16. else:
  17. quantize_func = quantizer_cuda_module.ds_sr_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_fp32
  18. else:
  19. if asym:
  20. quantize_func = quantizer_cuda_module.ds_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_asym_fp32
  21. else:
  22. quantize_func = quantizer_cuda_module.ds_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_fp32
  23. return quantize_func(input, groups, bit_num)