quantize.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import torch
  2. import math
  3. from deepspeed.utils import log_dist
  4. from deepspeed.utils import logger
  5. from deepspeed.ops.quantizer import ds_quantizer
  6. # number of 2-dimensional parameters in a layer
  7. # this is set for transformer-based models
  8. TWO_D_PARAMS = 6
  9. class Quantizer(object):
  10. def __init__(self,
  11. q_target_bits=8,
  12. q_start_bits=16,
  13. q_period=100,
  14. q_offset=100,
  15. q_groups=1,
  16. q_mixed_fp16=False,
  17. q_change_ratio=0.01,
  18. q_type=0,
  19. q_rounding=0,
  20. q_verbose=False,
  21. q_eigenvalue=False,
  22. use_quantizer_kernel=False,
  23. layer_num=0):
  24. self.q_target_bits = q_target_bits
  25. self.q_start_bits = [q_start_bits] * (layer_num if layer_num != 0 else 1)
  26. self.q_period = [q_period] * (layer_num if layer_num != 0 else 1)
  27. self.q_offset = q_offset
  28. self.q_groups = q_groups
  29. self.q_mixed_fp16 = q_mixed_fp16
  30. self.q_change_ratio = q_change_ratio
  31. self.q_type = q_type
  32. self.qsteps = 0
  33. self.q_init_period = q_period
  34. self.quantize_real_ratio = 1.000
  35. self.q_verbose = q_verbose
  36. self.q_eigenvalue = q_eigenvalue
  37. self.use_quantizer_kernel = use_quantizer_kernel
  38. self.q_rounding = q_rounding
  39. self.layer_num = layer_num
  40. def any_precision_switch(self):
  41. if self.layer_num == 0:
  42. return True
  43. result = False
  44. for index in range(self.layer_num):
  45. if self.q_start_bits[index] != self.q_target_bits:
  46. next_step = self.qsteps + (
  47. TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1))
  48. if next_step >= self.q_period[index]:
  49. result = True
  50. return result
  51. def quantize(self,
  52. parameter_group,
  53. overflow,
  54. eigenvalue_enabled,
  55. block_eigenvalue={}):
  56. if overflow and not eigenvalue_enabled:
  57. return
  58. self.step()
  59. self.update_fp16_ratio()
  60. for i in range(len(parameter_group)):
  61. for p in parameter_group[i]:
  62. if len(p.size()) > 1:
  63. param_id = id(p)
  64. eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None, 0)
  65. if eigenvalue is not None:
  66. factor = 1 + math.floor(eigenvalue * 4)
  67. p.data = self.compute_quantization(p.data, layer_id, factor)
  68. else:
  69. p.data = self.compute_quantization(p.data, layer_id)
  70. def step(self):
  71. self.qsteps += (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1))
  72. def sr_quantize(self, input_flat, input_g, scale):
  73. # Random number generator (Uniform)
  74. p = torch.cuda.FloatTensor(input_flat.size(),
  75. device=input_flat.device).uniform_()
  76. p = torch.split(p, p.size(0) // self.q_groups)
  77. add_s = torch.zeros_like(input_flat)
  78. add_s = torch.split(add_s, add_s.size(0) // self.q_groups)
  79. scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g]
  80. # Quantize with INT rounding
  81. input_flat = [(g * s).int().float() / s for (g, s) in zip(input_g, scale)]
  82. # Compute the error
  83. error = [((g - q).abs() / s) for (g, s, q) in zip(input_g, scale, input_flat)]
  84. # Stochastic Rounding
  85. add_s = [
  86. a_s.masked_fill_(pg < err_g,
  87. 1 / s) for (a_s,
  88. pg,
  89. err_g,
  90. s) in zip(add_s,
  91. p,
  92. error,
  93. scale)
  94. ]
  95. add_s = [
  96. a_s * (g > 0).float() - a_s * (g < 0).float() for a_s,
  97. g in zip(add_s,
  98. input_flat)
  99. ]
  100. input_flat = [((q + a_s) * s).clamp(-(q_range >> 1),
  101. (q_range >> 1) - 1) / s for q,
  102. a_s,
  103. s in zip(input_flat,
  104. add_s,
  105. scale)]
  106. return input_flat
  107. def mixed_fp16_quantize(self, input, input_q, index):
  108. if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1):
  109. input_q = input * self.quantize_real_ratio + (
  110. 1 - self.quantize_real_ratio) * input_q
  111. return input_q
  112. return input_q
  113. def compute_quantization(self, input, index=0, factor=1):
  114. # fixing the quantization bits based on the training steps
  115. # when reducing 1 bit at each period, we increase the period
  116. # to go slowly toward the target quantization bits
  117. # the period and starting bit can be configured
  118. if self.q_offset > 0:
  119. if self.qsteps >= self.q_offset:
  120. self.q_offset = 0
  121. self.qsteps = 0
  122. else:
  123. return input
  124. if self.q_start_bits[index] != self.q_target_bits:
  125. if self.qsteps >= self.q_period[index]:
  126. self.quantize_real_ratio = 1.0
  127. if self.q_eigenvalue:
  128. self.q_period[index] <<= 1
  129. self.q_period[index] *= factor
  130. self.q_start_bits[index] -= 1
  131. else:
  132. for i in range(len(self.q_start_bits)):
  133. self.q_start_bits[i] -= 1
  134. self.q_period[i] <<= 1
  135. if self.q_verbose:
  136. logger.info(
  137. f'Quantization settings: current bit-precision = {self.q_start_bits[index]}, step = {self.qsteps}, quantization period = {self.q_period[index]}, index = {index}'
  138. )
  139. assert (self.q_start_bits[index] >= self.q_target_bits), \
  140. 'Quantization bit is lower than target precision bits!'
  141. # quantize the weights base on the selected bits and the value-range
  142. if not self.use_quantizer_kernel:
  143. q_range = 2**self.q_start_bits[index]
  144. input_flat = input.view(-1)
  145. input_g = torch.split(input_flat, input_flat.size(0) // self.q_groups)
  146. if self.q_type == 0: #symmetric
  147. if self.use_quantizer_kernel:
  148. input_q = ds_quantizer(input.clone(),
  149. self.q_groups,
  150. self.q_start_bits[index])
  151. else:
  152. scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g]
  153. if self.q_rounding == 0: # Nearest value rounding
  154. input_flat = [(g * s).round().clamp(-(q_range >> 1),
  155. (q_range >> 1) - 1) / s for g,
  156. s in zip(input_g,
  157. scale)]
  158. else: # Stochastic Rounding
  159. if self.use_quantizer_kernel:
  160. input_q = ds_quantizer(input.clone(),
  161. self.q_groups,
  162. self.q_start_bits[index],
  163. sr=True)
  164. else:
  165. input_flat = self.sr_quantize(input_flat, input_g)
  166. else: #asymmetric
  167. if self.q_rounding == 0:
  168. if self.use_quantizer_kernel:
  169. input_q = ds_quantizer(input.clone(),
  170. self.q_groups,
  171. self.q_start_bits[index],
  172. asym=True)
  173. else:
  174. scale = [(g.max() - g.min()) / q_range for g in input_g]
  175. input_flat = [
  176. ((g - g.min()) / s).round().clamp(0,
  177. (q_range - 1)) * s + g.min()
  178. for g,
  179. s in zip(input_g,
  180. scale)
  181. ]
  182. else:
  183. input_q = ds_quantizer(input.clone(),
  184. self.q_groups,
  185. self.q_start_bits[index],
  186. asym=True)
  187. if self.use_quantizer_kernel or (self.q_type and self.q_rounding):
  188. return self.mixed_fp16_quantize(input, input_q, index)
  189. else:
  190. if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits -
  191. 1):
  192. input_flat = [(self.quantize_real_ratio * g) +
  193. ((1 - self.quantize_real_ratio) * g_q) for g,
  194. g_q in zip(input_g,
  195. input_flat)]
  196. input_q = torch.cat(input_flat)
  197. input_q = input_q.reshape(input.size())
  198. return input_q
  199. def update_fp16_ratio(self):
  200. if self.q_mixed_fp16:
  201. if self.quantize_real_ratio > 0:
  202. self.quantize_real_ratio -= self.q_change_ratio
  203. else:
  204. self.quantize_real_ratio = 0.000