1234567891011121314151617181920212223242526272829303132 |
- '''
- Copyright 2020 The Microsoft DeepSpeed Team
- '''
- import json
- import math
- import importlib
- import torch
- from torch import nn
- from torch.autograd import Function
- from ..op_builder import QuantizerBuilder
- # Cuda modules will be imported if needed
- quantizer_cuda_module = None
- def ds_quantizer(input, groups=1, bit_num=8, sr=False, asym=False):
- # Load cuda modules if needed
- global quantizer_cuda_module
- if quantizer_cuda_module is None:
- quantizer_cuda_module = QuantizerBuilder().load()
- if sr:
- if asym:
- quantize_func = quantizer_cuda_module.ds_sr_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_asym_fp32
- else:
- quantize_func = quantizer_cuda_module.ds_sr_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_fp32
- else:
- if asym:
- quantize_func = quantizer_cuda_module.ds_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_asym_fp32
- else:
- quantize_func = quantizer_cuda_module.ds_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_fp32
- return quantize_func(input, groups, bit_num)
|