quantize.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import math
  6. from deepspeed.utils import logger
  7. from deepspeed.ops.quantizer import ds_quantizer
  8. TWO_D_PARAMS = 6
  9. class Quantizer(object):
  10. def __init__(self,
  11. q_groups=1,
  12. q_mixed_fp16=False,
  13. q_change_ratio=0.01,
  14. q_type=0,
  15. q_rounding=0,
  16. q_verbose=False,
  17. q_eigenvalue=False,
  18. use_quantizer_kernel=False,
  19. layer_num=0):
  20. self.q_groups = q_groups
  21. self.q_mixed_fp16 = q_mixed_fp16
  22. self.q_change_ratio = q_change_ratio
  23. self.q_type = q_type
  24. self.qsteps = 0
  25. self.quantize_real_ratio = 1.000
  26. self.q_verbose = q_verbose
  27. self.q_eigenvalue = q_eigenvalue
  28. self.use_quantizer_kernel = use_quantizer_kernel
  29. self.q_rounding = q_rounding
  30. self.layer_num = layer_num
  31. def any_precision_switch(self):
  32. # Temporary disabled functionality
  33. if self.layer_num == 0:
  34. return True
  35. result = False
  36. for index in range(self.layer_num):
  37. if self.q_start_bits[index] != self.q_target_bits:
  38. next_step = self.qsteps + (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1))
  39. if next_step >= self.q_period[index]:
  40. result = True
  41. return result
  42. def quantize(self, parameter_group, overflow, eigenvalue_enabled, block_eigenvalue={}):
  43. if overflow and not eigenvalue_enabled:
  44. return
  45. self.step()
  46. self.update_fp16_ratio()
  47. for i in range(len(parameter_group)):
  48. for p in parameter_group[i]:
  49. if len(p.size()) > 1 and hasattr(p, "start_bits") and p.start_bits:
  50. param_id = id(p)
  51. if block_eigenvalue is None:
  52. eigenvalue, layer_id = None, 0
  53. else:
  54. eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None,
  55. 0)
  56. if eigenvalue is not None:
  57. factor = 1 + math.floor(eigenvalue * 4)
  58. p.data = self.compute_quantization(p.data, layer_id, factor)
  59. else:
  60. p.data = self.compute_quantization(p, layer_id)
  61. def step(self):
  62. self.qsteps += 1
  63. def quantize_highbit(self, inputs, num_bits):
  64. q_range = 2**num_bits
  65. input_flat = inputs.reshape(self.q_groups, -1)
  66. g_min = input_flat.amin(dim=-1, keepdim=True)
  67. g_max = input_flat.amax(dim=-1, keepdim=True)
  68. # Random number generator (Uniform)
  69. if self.q_rounding == 'nearest':
  70. p = 0.
  71. else:
  72. p = input_flat.new(input_flat.shape).uniform_(-0.5, 0.5)
  73. if self.q_type == 'symmetric':
  74. scale = 2 * torch.max(torch.abs(g_min), torch.abs(g_max)) / q_range
  75. zero_point = 0.
  76. input_flat = (input_flat / scale + p).round().clamp(-(q_range >> 1), (q_range >> 1) - 1) * scale
  77. elif self.q_type == 'asymmetric':
  78. scale = (g_max - g_min) / q_range
  79. zero_point = (g_min / scale).round() * scale
  80. input_flat = ((input_flat - zero_point) / scale + p).round().clamp(0, (q_range - 1)) * scale + zero_point
  81. output = input_flat.reshape(inputs.shape).contiguous()
  82. return output
  83. def quantize_tenary(self, inputs):
  84. input_flat = inputs.reshape(self.q_groups, -1)
  85. n = input_flat.shape[1]
  86. m = input_flat.norm(p=1, dim=1).div(n)
  87. thres = (0.7 * m).view(-1, 1) #.expand_as(input_flat)
  88. pos = (input_flat > thres).type(inputs.type())
  89. neg = (input_flat < -thres).type(inputs.type())
  90. mask = (input_flat.abs() > thres).type(inputs.type())
  91. alpha = ((mask * input_flat).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
  92. output = alpha * pos - alpha * neg
  93. output = output.reshape(inputs.shape).contiguous()
  94. return output
  95. def quantize_binary(self, inputs):
  96. input_flat = inputs.reshape(self.q_groups, -1)
  97. n = input_flat.shape[1]
  98. m = input_flat.norm(p=1, dim=1, keepdim=True).div(n)
  99. output = input_flat.sign().mul(m)
  100. output = output.reshape(inputs.shape).contiguous()
  101. return output
  102. def mixed_fp16_quantize(self, input, input_q, index):
  103. if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1):
  104. input_q = input * self.quantize_real_ratio + (1 - self.quantize_real_ratio) * input_q
  105. return input_q
  106. return input_q
  107. def compute_quantization(self, input, index=0, factor=1):
  108. # fixing the quantization bits based on the training steps
  109. # when reducing 1 bit at each period, we increase the period
  110. # to go slowly toward the target quantization bits
  111. # the period and starting bit can be configured
  112. if input.start_bits != input.target_bits:
  113. if self.qsteps >= input.q_period:
  114. self.quantize_real_ratio = 1.0
  115. input.q_period <<= 1
  116. input.q_period *= factor
  117. input.start_bits -= 1
  118. if self.q_verbose:
  119. logger.info(
  120. f'Quantization settings: current bit-precision = {input.start_bits}, step = {self.qsteps}, quantization period = {input.q_period}, index = {index}'
  121. )
  122. assert (input.start_bits >= input.target_bits), \
  123. 'Quantization bit is lower than target precision bits!'
  124. if self.use_quantizer_kernel:
  125. if input.start_bits <= 2:
  126. raise ValueError('Quantization bit is too low, please do it without quantization kernel!')
  127. input_q = ds_quantizer(input.data.clone(),
  128. self.q_groups,
  129. input.start_bits,
  130. asym=False if self.q_type == 'symmetric' else True,
  131. sr=False if self.q_rounding == 'nearest_neighbor' else True)
  132. else:
  133. if input.start_bits >= 3:
  134. input_flat = self.quantize_highbit(input.data, input.start_bits)
  135. elif input.start_bits == 2:
  136. assert self.q_type == 'symmetric', 'Quantization type is not symmetric!'
  137. assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!'
  138. input_flat = self.quantize_tenary(input.data)
  139. elif input.start_bits == 1:
  140. assert self.q_type == 'symmetric', 'Quantization type is not symmetric!'
  141. assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!'
  142. input_flat = self.quantize_binary(input.data)
  143. if self.use_quantizer_kernel:
  144. return self.mixed_fp16_quantize(input.data, input_q, index)
  145. else:
  146. if self.q_mixed_fp16 and input.start_bits >= input.target_bits - 1:
  147. input_flat = self.quantize_real_ratio * input.data + \
  148. (1 - self.quantize_real_ratio) * input_flat
  149. return input_flat
  150. def update_fp16_ratio(self):
  151. if self.q_mixed_fp16:
  152. if self.quantize_real_ratio > 0:
  153. self.quantize_real_ratio -= self.q_change_ratio
  154. else:
  155. self.quantize_real_ratio = 0.000