123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- import math
- from deepspeed.utils import logger
- from deepspeed.ops.quantizer import ds_quantizer
- TWO_D_PARAMS = 6
- class Quantizer(object):
- def __init__(self,
- q_groups=1,
- q_mixed_fp16=False,
- q_change_ratio=0.01,
- q_type=0,
- q_rounding=0,
- q_verbose=False,
- q_eigenvalue=False,
- use_quantizer_kernel=False,
- layer_num=0):
- self.q_groups = q_groups
- self.q_mixed_fp16 = q_mixed_fp16
- self.q_change_ratio = q_change_ratio
- self.q_type = q_type
- self.qsteps = 0
- self.quantize_real_ratio = 1.000
- self.q_verbose = q_verbose
- self.q_eigenvalue = q_eigenvalue
- self.use_quantizer_kernel = use_quantizer_kernel
- self.q_rounding = q_rounding
- self.layer_num = layer_num
- def any_precision_switch(self):
- # Temporary disabled functionality
- if self.layer_num == 0:
- return True
- result = False
- for index in range(self.layer_num):
- if self.q_start_bits[index] != self.q_target_bits:
- next_step = self.qsteps + (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1))
- if next_step >= self.q_period[index]:
- result = True
- return result
- def quantize(self, parameter_group, overflow, eigenvalue_enabled, block_eigenvalue={}):
- if overflow and not eigenvalue_enabled:
- return
- self.step()
- self.update_fp16_ratio()
- for i in range(len(parameter_group)):
- for p in parameter_group[i]:
- if len(p.size()) > 1 and hasattr(p, "start_bits") and p.start_bits:
- param_id = id(p)
- if block_eigenvalue is None:
- eigenvalue, layer_id = None, 0
- else:
- eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None,
- 0)
- if eigenvalue is not None:
- factor = 1 + math.floor(eigenvalue * 4)
- p.data = self.compute_quantization(p.data, layer_id, factor)
- else:
- p.data = self.compute_quantization(p, layer_id)
- def step(self):
- self.qsteps += 1
- def quantize_highbit(self, inputs, num_bits):
- q_range = 2**num_bits
- input_flat = inputs.reshape(self.q_groups, -1)
- g_min = input_flat.amin(dim=-1, keepdim=True)
- g_max = input_flat.amax(dim=-1, keepdim=True)
- # Random number generator (Uniform)
- if self.q_rounding == 'nearest':
- p = 0.
- else:
- p = input_flat.new(input_flat.shape).uniform_(-0.5, 0.5)
- if self.q_type == 'symmetric':
- scale = 2 * torch.max(torch.abs(g_min), torch.abs(g_max)) / q_range
- zero_point = 0.
- input_flat = (input_flat / scale + p).round().clamp(-(q_range >> 1), (q_range >> 1) - 1) * scale
- elif self.q_type == 'asymmetric':
- scale = (g_max - g_min) / q_range
- zero_point = (g_min / scale).round() * scale
- input_flat = ((input_flat - zero_point) / scale + p).round().clamp(0, (q_range - 1)) * scale + zero_point
- output = input_flat.reshape(inputs.shape).contiguous()
- return output
- def quantize_tenary(self, inputs):
- input_flat = inputs.reshape(self.q_groups, -1)
- n = input_flat.shape[1]
- m = input_flat.norm(p=1, dim=1).div(n)
- thres = (0.7 * m).view(-1, 1) #.expand_as(input_flat)
- pos = (input_flat > thres).type(inputs.type())
- neg = (input_flat < -thres).type(inputs.type())
- mask = (input_flat.abs() > thres).type(inputs.type())
- alpha = ((mask * input_flat).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
- output = alpha * pos - alpha * neg
- output = output.reshape(inputs.shape).contiguous()
- return output
- def quantize_binary(self, inputs):
- input_flat = inputs.reshape(self.q_groups, -1)
- n = input_flat.shape[1]
- m = input_flat.norm(p=1, dim=1, keepdim=True).div(n)
- output = input_flat.sign().mul(m)
- output = output.reshape(inputs.shape).contiguous()
- return output
- def mixed_fp16_quantize(self, input, input_q, index):
- if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1):
- input_q = input * self.quantize_real_ratio + (1 - self.quantize_real_ratio) * input_q
- return input_q
- return input_q
- def compute_quantization(self, input, index=0, factor=1):
- # fixing the quantization bits based on the training steps
- # when reducing 1 bit at each period, we increase the period
- # to go slowly toward the target quantization bits
- # the period and starting bit can be configured
- if input.start_bits != input.target_bits:
- if self.qsteps >= input.q_period:
- self.quantize_real_ratio = 1.0
- input.q_period <<= 1
- input.q_period *= factor
- input.start_bits -= 1
- if self.q_verbose:
- logger.info(
- f'Quantization settings: current bit-precision = {input.start_bits}, step = {self.qsteps}, quantization period = {input.q_period}, index = {index}'
- )
- assert (input.start_bits >= input.target_bits), \
- 'Quantization bit is lower than target precision bits!'
- if self.use_quantizer_kernel:
- if input.start_bits <= 2:
- raise ValueError('Quantization bit is too low, please do it without quantization kernel!')
- input_q = ds_quantizer(input.data.clone(),
- self.q_groups,
- input.start_bits,
- asym=False if self.q_type == 'symmetric' else True,
- sr=False if self.q_rounding == 'nearest_neighbor' else True)
- else:
- if input.start_bits >= 3:
- input_flat = self.quantize_highbit(input.data, input.start_bits)
- elif input.start_bits == 2:
- assert self.q_type == 'symmetric', 'Quantization type is not symmetric!'
- assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!'
- input_flat = self.quantize_tenary(input.data)
- elif input.start_bits == 1:
- assert self.q_type == 'symmetric', 'Quantization type is not symmetric!'
- assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!'
- input_flat = self.quantize_binary(input.data)
- if self.use_quantizer_kernel:
- return self.mixed_fp16_quantize(input.data, input_q, index)
- else:
- if self.q_mixed_fp16 and input.start_bits >= input.target_bits - 1:
- input_flat = self.quantize_real_ratio * input.data + \
- (1 - self.quantize_real_ratio) * input_flat
- return input_flat
- def update_fp16_ratio(self):
- if self.q_mixed_fp16:
- if self.quantize_real_ratio > 0:
- self.quantize_real_ratio -= self.q_change_ratio
- else:
- self.quantize_real_ratio = 0.000
|