weight_quantizer.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import torch
  3. from ..module_inject.replace_policy import HFBertLayerPolicy, replace_policies
  4. from deepspeed.accelerator import get_accelerator
  5. class WeightQuantization(object):
  6. def __init__(self, mlp_extra_grouping=True, mp_size=1):
  7. self.dense_scales = []
  8. self.qkv_scales = []
  9. self.mlp4hh_scales = []
  10. self.mlph4h_scales = []
  11. self.mlp_extra_grouping = mlp_extra_grouping
  12. self.mp_size = mp_size
  13. def quantize_data(self, data, quantize_bits, groups, key=None):
  14. data_groups = torch.split(data.float().view(-1), data.numel() // groups)
  15. max_d = [max(g.max(), g.min().abs()) for g in data_groups]
  16. data_scale = [float(1 << quantize_bits) / (2 * mx + 1e-5) for mx in max_d]
  17. data_int = [(g * s) for g, s in zip(data_groups, data_scale)]
  18. data_int = [
  19. di.round().clamp(-(1 << (quantize_bits - 1)),
  20. (((1 << (quantize_bits - 1)) - 1))) for di in data_int
  21. ]
  22. data_int = torch.cat(data_int).reshape(data.shape)
  23. data_int = data_int.to(torch.int8)
  24. data_scale = torch.cat([s.unsqueeze(0).unsqueeze(0) for s in data_scale])
  25. return data_int, data_scale
  26. def is_mlp(self, data, merge_count=1):
  27. return ((self.mp_size *data.shape[0] * merge_count) / data.shape[1] == 4 or \
  28. (self.mp_size *data.shape[1] * merge_count) / data.shape[0] == 4)
  29. def is_qkv(self, data):
  30. return ((self.mp_size * data.shape[0]) / data.shape[1] == 3 or \
  31. (self.mp_size * data.shape[1]) / data.shape[0] == 3)
  32. def Quantize(self, value_list, quantize_bits, groups, key, merge_dim=0):
  33. if self.mlp_extra_grouping and self.is_mlp(value_list[0],
  34. merge_count=len(value_list)):
  35. groups *= 2
  36. q_scale = []
  37. index = 0
  38. for data in value_list:
  39. data_int, data_scale = self.quantize_data(data, quantize_bits, groups, key)
  40. q_scale.append(data_scale)
  41. value_list[index] = data_int
  42. index += 1
  43. q_scale = (
  44. 1 /
  45. torch.cat(q_scale,
  46. dim=merge_dim).to(
  47. get_accelerator().current_device_name()).view(-1).unsqueeze(0))
  48. if "mlp.dense_4h_to_h.weight" in key:
  49. self.mlp4hh_scales.append(q_scale)
  50. elif "mlp.dense_h_to_4h.weight" in key:
  51. self.mlph4h_scales.append(q_scale)
  52. elif "attention.query_key_value.weight" in key:
  53. self.qkv_scales.append(q_scale)
  54. else:
  55. self.dense_scales.append(q_scale)
  56. return value_list
  57. def merge_layer_scales(self, layer_scales):
  58. max_dim = max([s.shape[-1] for s in layer_scales])
  59. layer_scales = [
  60. torch.cat((s,
  61. torch.zeros((1,
  62. max_dim - s.shape[-1]),
  63. device=get_accelerator().current_device_name())),
  64. dim=-1) if s.shape[-1] < max_dim else s for s in layer_scales
  65. ]
  66. return torch.cat(layer_scales).unsqueeze(0)
  67. def merge_scales(self):
  68. all_scales = []
  69. for dense_scale, qkv_scale, m4hh_scale, mh4h_scale in \
  70. zip(self.dense_scales, self.qkv_scales, self.mlp4hh_scales, self.mlph4h_scales):
  71. all_scales.append(
  72. self.merge_layer_scales([qkv_scale,
  73. dense_scale,
  74. mh4h_scale,
  75. m4hh_scale]))
  76. return torch.cat(all_scales)
  77. def merge_scales_split(self, split_count):
  78. all_scales = [[] for _ in range(split_count)]
  79. for dense_scale, qkv_scale, m4hh_scale, mh4h_scale in \
  80. zip(self.dense_scales, self.qkv_scales, self.mlp4hh_scales, self.mlph4h_scales):
  81. dense_scale = torch.split(dense_scale, dense_scale.numel() // split_count)
  82. qkv_scale = torch.split(qkv_scale, qkv_scale.numel() // split_count)
  83. m4hh_scale = torch.split(m4hh_scale, m4hh_scale.numel() // split_count)
  84. mh4h_scale = torch.split(mh4h_scale, mh4h_scale.numel() // split_count)
  85. for s in range(split_count):
  86. all_scales[s].append(
  87. torch.cat([
  88. torch.cat((qkv_scale[s],
  89. torch.zeros_like(qkv_scale[s])),
  90. dim=1),
  91. torch.cat((dense_scale[s],
  92. torch.zeros_like(dense_scale[s])),
  93. dim=1),
  94. mh4h_scale[s],
  95. m4hh_scale[s]
  96. ]).unsqueeze(0))
  97. for scales_a in all_scales:
  98. torch.cat(scales_a)
  99. return all_scales
  100. def sd_quantize_megatron(self, sd, quantize_bits, groups):
  101. keys = sd.keys()
  102. for key in keys:
  103. value_list = [sd[key]]
  104. if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key or \
  105. "mlp.dense_h_to_4h.weight" in key or "attention.query_key_value.weight" in key:
  106. value_list = self.Quantize(value_list, quantize_bits, groups, key=key)
  107. sd[key] = value_list[0]
  108. all_scales = self.merge_scales()
  109. return sd, all_scales
  110. def model_quantize(self, model, quantize_policy, quantize_bits, groups):
  111. all_scales = []
  112. def quantize_fn(layer, policy_cls):
  113. policy = policy_cls(layer)
  114. _, qkvw, _, dense_w, _, _ = policy.attention()
  115. _, _h4h_w, _, _4hh_w, _ = policy.mlp()
  116. keys = [qkvw, dense_w, _h4h_w, _4hh_w]
  117. layer_scales = []
  118. for key in range(len(keys)):
  119. if self.mlp_extra_grouping and self.is_mlp(keys[key]):
  120. data_quantized, data_scale = self.quantize_data(keys[key], quantize_bits, groups * 2)
  121. elif policy_cls is HFBertLayerPolicy and self.is_qkv(keys[key]):
  122. data_quantized, data_scale = self.quantize_data(keys[key], quantize_bits, groups * 3)
  123. else:
  124. data_quantized, data_scale = self.quantize_data(keys[key], quantize_bits, groups)
  125. keys[key].copy_(data_quantized)
  126. layer_scales.append((1 / data_scale.to(
  127. get_accelerator().current_device_name()).view(-1).unsqueeze(0)))
  128. all_scales.append(self.merge_layer_scales(layer_scales))
  129. return layer
  130. def _quantize_module(model, policies):
  131. for name, child in model.named_children():
  132. if child.__class__ in policies:
  133. quantize_fn, replace_policy = policies[child.__class__]
  134. setattr(model, name, quantize_fn(child, replace_policy))
  135. else:
  136. _quantize_module(child, policies)
  137. return model
  138. policy = {}
  139. if quantize_policy is not None:
  140. for layer_name, replace_policy in quantize_policy.items():
  141. policy.update({layer_name: (quantize_fn, replace_policy)})
  142. else:
  143. for plcy in replace_policies:
  144. policy.update({plcy._orig_layer_class: (quantize_fn, plcy)})
  145. quantized_module = _quantize_module(model, policy)
  146. return quantized_module, torch.cat(all_scales)