123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- import abc
- from abc import ABC
- from deepspeed.ops.op_builder import FPQuantizerBuilder
- fp_quant_module = None
- class Quantizer(ABC):
- """
- Abstract Quantizer class that implmenents quantize/dequantize methods.
- Arguments:
- group_size (int, optional): number of values or elements that are grouped
- together for the quantization process.
- """
- def __init__(self, group_size=512) -> None:
- self.group_size = group_size
- @abc.abstractmethod
- def quantize(self,
- input,
- q_bits=8,
- q_mantisa_bits=3,
- stochastic_mode=False,
- return_meta_tensor=False) -> torch.Tensor:
- ...
- @abc.abstractmethod
- def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
- ...
- class FP_Quantize(Quantizer):
- def __init__(self, group_size=512) -> None:
- global fp_quant_module
- super().__init__(group_size=group_size)
- if fp_quant_module is None:
- fp_quant_module = FPQuantizerBuilder().load()
- self.orig_dtype = None
- def quantize(self,
- input,
- q_bits=8,
- q_mantisa_bits=3,
- stochastic_mode=False,
- return_meta_tensor=False) -> torch.Tensor:
- assert input.dtype == torch.bfloat16, "only support bf16 for now"
- if return_meta_tensor:
- assert q_bits == 8, "meta tensor is only supported with q_bit=8"
- self.orig_dtype = input.dtype
- self.orig_shape = input.shape
- if q_bits == 8:
- pass
- elif q_bits == 12:
- q_mantisa_bits = 4
- elif q_bits == 6:
- q_mantisa_bits = 2
- elif q_bits == 4:
- q_mantisa_bits = 1
- else:
- assert (0), \
- f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
- out = fp_quant_module.quantize(input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
- if return_meta_tensor:
- data, scale = out.split(self.group_size, dim=-1)
- return data.contiguous().reshape(input.shape), scale.contiguous()
- return out
- def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
- assert (self.orig_dtype is not None), \
- "[De-quantization Error]: you need to call quantize before dequantizing!"
- fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype,
- device=input_q.device) if fp_out is None else fp_out
- if q_bits == 8:
- pass
- elif q_bits == 12:
- q_mantisa_bits = 4
- elif q_bits == 6:
- q_mantisa_bits = 2
- elif q_bits == 4:
- q_mantisa_bits = 1
- else:
- assert (0), \
- f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
- if scale is not None:
- assert input_q.numel() == fp_out.numel(), \
- f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
- input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
- fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
- return fp_out
- def selective_dequantize(self,
- input_q,
- indexes,
- fp_out=None,
- q_bits=8,
- q_mantisa_bits=3,
- scale=None) -> torch.Tensor:
- assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \
- "Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function."
- assert (self.orig_dtype is not None), \
- "[De-quantization Error]: you need to call quantize before dequantizing!"
- fp_out = torch.empty(
- (indexes.shape[0],
- *self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out
- if q_bits == 8:
- pass
- elif q_bits == 12:
- q_mantisa_bits = 4
- elif q_bits == 6:
- q_mantisa_bits = 2
- elif q_bits == 4:
- q_mantisa_bits = 1
- else:
- assert (0), \
- f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
- if scale is not None:
- assert input_q.numel() == fp_out.numel(), \
- f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
- input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
- fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
- q_bits - q_mantisa_bits - 1)
- return fp_out
|