123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- import torch
- import math
- from deepspeed.utils import log_dist
- from deepspeed.utils import logger
- from deepspeed.ops.quantizer import ds_quantizer
- # number of 2-dimensional parameters in a layer
- # this is set for transformer-based models
- TWO_D_PARAMS = 6
- class Quantizer(object):
- def __init__(self,
- q_target_bits=8,
- q_start_bits=16,
- q_period=100,
- q_offset=100,
- q_groups=1,
- q_mixed_fp16=False,
- q_change_ratio=0.01,
- q_type=0,
- q_rounding=0,
- q_verbose=False,
- q_eigenvalue=False,
- use_quantizer_kernel=False,
- layer_num=0):
- self.q_target_bits = q_target_bits
- self.q_start_bits = [q_start_bits] * (layer_num if layer_num != 0 else 1)
- self.q_period = [q_period] * (layer_num if layer_num != 0 else 1)
- self.q_offset = q_offset
- self.q_groups = q_groups
- self.q_mixed_fp16 = q_mixed_fp16
- self.q_change_ratio = q_change_ratio
- self.q_type = q_type
- self.qsteps = 0
- self.q_init_period = q_period
- self.quantize_real_ratio = 1.000
- self.q_verbose = q_verbose
- self.q_eigenvalue = q_eigenvalue
- self.use_quantizer_kernel = use_quantizer_kernel
- self.q_rounding = q_rounding
- self.layer_num = layer_num
- def any_precision_switch(self):
- if self.layer_num == 0:
- return True
- result = False
- for index in range(self.layer_num):
- if self.q_start_bits[index] != self.q_target_bits:
- next_step = self.qsteps + (
- TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1))
- if next_step >= self.q_period[index]:
- result = True
- return result
- def quantize(self,
- parameter_group,
- overflow,
- eigenvalue_enabled,
- block_eigenvalue={}):
- if overflow and not eigenvalue_enabled:
- return
- self.step()
- self.update_fp16_ratio()
- for i in range(len(parameter_group)):
- for p in parameter_group[i]:
- if len(p.size()) > 1:
- param_id = id(p)
- eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None, 0)
- if eigenvalue is not None:
- factor = 1 + math.floor(eigenvalue * 4)
- p.data = self.compute_quantization(p.data, layer_id, factor)
- else:
- p.data = self.compute_quantization(p.data, layer_id)
- def step(self):
- self.qsteps += (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1))
- def sr_quantize(self, input_flat, input_g, scale):
- # Random number generator (Uniform)
- p = torch.cuda.FloatTensor(input_flat.size(),
- device=input_flat.device).uniform_()
- p = torch.split(p, p.size(0) // self.q_groups)
- add_s = torch.zeros_like(input_flat)
- add_s = torch.split(add_s, add_s.size(0) // self.q_groups)
- scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g]
- # Quantize with INT rounding
- input_flat = [(g * s).int().float() / s for (g, s) in zip(input_g, scale)]
- # Compute the error
- error = [((g - q).abs() / s) for (g, s, q) in zip(input_g, scale, input_flat)]
- # Stochastic Rounding
- add_s = [
- a_s.masked_fill_(pg < err_g,
- 1 / s) for (a_s,
- pg,
- err_g,
- s) in zip(add_s,
- p,
- error,
- scale)
- ]
- add_s = [
- a_s * (g > 0).float() - a_s * (g < 0).float() for a_s,
- g in zip(add_s,
- input_flat)
- ]
- input_flat = [((q + a_s) * s).clamp(-(q_range >> 1),
- (q_range >> 1) - 1) / s for q,
- a_s,
- s in zip(input_flat,
- add_s,
- scale)]
- return input_flat
- def mixed_fp16_quantize(self, input, input_q, index):
- if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1):
- input_q = input * self.quantize_real_ratio + (
- 1 - self.quantize_real_ratio) * input_q
- return input_q
- return input_q
- def compute_quantization(self, input, index=0, factor=1):
- # fixing the quantization bits based on the training steps
- # when reducing 1 bit at each period, we increase the period
- # to go slowly toward the target quantization bits
- # the period and starting bit can be configured
- if self.q_offset > 0:
- if self.qsteps >= self.q_offset:
- self.q_offset = 0
- self.qsteps = 0
- else:
- return input
- if self.q_start_bits[index] != self.q_target_bits:
- if self.qsteps >= self.q_period[index]:
- self.quantize_real_ratio = 1.0
- if self.q_eigenvalue:
- self.q_period[index] <<= 1
- self.q_period[index] *= factor
- self.q_start_bits[index] -= 1
- else:
- for i in range(len(self.q_start_bits)):
- self.q_start_bits[i] -= 1
- self.q_period[i] <<= 1
- if self.q_verbose:
- logger.info(
- f'Quantization settings: current bit-precision = {self.q_start_bits[index]}, step = {self.qsteps}, quantization period = {self.q_period[index]}, index = {index}'
- )
- assert (self.q_start_bits[index] >= self.q_target_bits), \
- 'Quantization bit is lower than target precision bits!'
- # quantize the weights base on the selected bits and the value-range
- if not self.use_quantizer_kernel:
- q_range = 2**self.q_start_bits[index]
- input_flat = input.view(-1)
- input_g = torch.split(input_flat, input_flat.size(0) // self.q_groups)
- if self.q_type == 0: #symmetric
- if self.use_quantizer_kernel:
- input_q = ds_quantizer(input.clone(),
- self.q_groups,
- self.q_start_bits[index])
- else:
- scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g]
- if self.q_rounding == 0: # Nearest value rounding
- input_flat = [(g * s).round().clamp(-(q_range >> 1),
- (q_range >> 1) - 1) / s for g,
- s in zip(input_g,
- scale)]
- else: # Stochastic Rounding
- if self.use_quantizer_kernel:
- input_q = ds_quantizer(input.clone(),
- self.q_groups,
- self.q_start_bits[index],
- sr=True)
- else:
- input_flat = self.sr_quantize(input_flat, input_g)
- else: #asymmetric
- if self.q_rounding == 0:
- if self.use_quantizer_kernel:
- input_q = ds_quantizer(input.clone(),
- self.q_groups,
- self.q_start_bits[index],
- asym=True)
- else:
- scale = [(g.max() - g.min()) / q_range for g in input_g]
- input_flat = [
- ((g - g.min()) / s).round().clamp(0,
- (q_range - 1)) * s + g.min()
- for g,
- s in zip(input_g,
- scale)
- ]
- else:
- input_q = ds_quantizer(input.clone(),
- self.q_groups,
- self.q_start_bits[index],
- asym=True)
- if self.use_quantizer_kernel or (self.q_type and self.q_rounding):
- return self.mixed_fp16_quantize(input, input_q, index)
- else:
- if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits -
- 1):
- input_flat = [(self.quantize_real_ratio * g) +
- ((1 - self.quantize_real_ratio) * g_q) for g,
- g_q in zip(input_g,
- input_flat)]
- input_q = torch.cat(input_flat)
- input_q = input_q.reshape(input.size())
- return input_q
- def update_fp16_ratio(self):
- if self.q_mixed_fp16:
- if self.quantize_real_ratio > 0:
- self.quantize_real_ratio -= self.q_change_ratio
- else:
- self.quantize_real_ratio = 0.000
|