quantizer.py 1.2 KB

1234567891011121314151617181920212223242526272829303132
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import json
  5. import math
  6. import importlib
  7. import torch
  8. from torch import nn
  9. from torch.autograd import Function
  10. from ..op_builder import QuantizerBuilder
  11. # Cuda modules will be imported if needed
  12. quantizer_cuda_module = None
  13. def ds_quantizer(input, groups=1, bit_num=8, sr=False, asym=False):
  14. # Load cuda modules if needed
  15. global quantizer_cuda_module
  16. if quantizer_cuda_module is None:
  17. quantizer_cuda_module = QuantizerBuilder().load()
  18. if sr:
  19. if asym:
  20. quantize_func = quantizer_cuda_module.ds_sr_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_asym_fp32
  21. else:
  22. quantize_func = quantizer_cuda_module.ds_sr_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_fp32
  23. else:
  24. if asym:
  25. quantize_func = quantizer_cuda_module.ds_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_asym_fp32
  26. else:
  27. quantize_func = quantizer_cuda_module.ds_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_fp32
  28. return quantize_func(input, groups, bit_num)