utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import deepspeed
  6. from torch import Tensor
  7. from typing import Tuple
  8. import torch.nn as nn
  9. from typing import Dict, Callable, Union
  10. from deepspeed.accelerator import get_accelerator
  11. import functools
  12. device = get_accelerator().device_name() if get_accelerator().is_available() else 'cpu'
  13. quantizer_cuda_module = None
  14. def get_quantizer_cuda_module():
  15. global quantizer_cuda_module
  16. if quantizer_cuda_module is None:
  17. quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load()
  18. return quantizer_cuda_module
  19. def tensor_clamp(tensor: Tensor, min, max) -> Tensor:
  20. if tensor.device.type == 'cpu' and tensor.dtype == torch.float16:
  21. # CPU does not support FP16 clamp
  22. return tensor.to(dtype=torch.float32).clamp_(min, max).to(dtype=torch.float16)
  23. else:
  24. return tensor.clamp_(min, max)
  25. def tensor_round(tensor: Tensor) -> Tensor:
  26. if tensor.device.type == 'cpu' and tensor.dtype == torch.float16:
  27. # CPU does not support FP16 round
  28. return tensor.to(dtype=torch.float32).round_().to(dtype=torch.float16)
  29. else:
  30. return tensor.round_()
  31. class Quantizer:
  32. def __init__(self, config: Dict) -> None:
  33. self.config = config
  34. assert self.config['num_bits'] == 4 or self.config[
  35. 'num_bits'] == 8, 'Only INT4 and INT8 quantization is supported.'
  36. assert self.config['symmetric'] == False, 'Only asymmetric quantization is supported at this moment.'
  37. def quantize(self, tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
  38. assert tensor.shape[self.config['group_dim']] % self.config['group_size'] == 0 \
  39. , f'Tensor shape: {tensor.shape} quantization config {self.config}'
  40. tensor = torch.clone(tensor)
  41. shape = tensor.shape
  42. num_groups = shape[self.config['group_dim']] // self.config['group_size']
  43. new_shape = (shape[:self.config['group_dim']] + (num_groups, self.config['group_size']) +
  44. shape[self.config['group_dim'] + 1:])
  45. tensor = tensor.view(new_shape)
  46. quantized_tensor, scale, min_value = self._quantize_int8(tensor)
  47. quantized_tensor = quantized_tensor.view(shape)
  48. if self.config['num_bits'] == 4:
  49. return self._compress_uint8_to_uint4(quantized_tensor), scale, min_value
  50. if self.config['num_bits'] == 8:
  51. return quantized_tensor, scale, min_value
  52. assert False, 'Unsupported quantization bits {}'.format(self.config['num_bits'])
  53. def _quantize_int8(self, tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
  54. q_range = 2**self.config['num_bits'] - 1
  55. min_value = tensor.amin(dim=self.config['group_dim'] + 1, keepdim=True)
  56. max_value = tensor.amax(dim=self.config['group_dim'] + 1, keepdim=True)
  57. scale = q_range / (max_value - min_value)
  58. tensor = tensor.sub_(min_value).mul_(scale)
  59. tensor = tensor_round(tensor_clamp(tensor, 0, q_range)).to(torch.uint8)
  60. return tensor, scale, min_value
  61. def _compress_uint8_to_uint4(self, tensor: Tensor) -> Tensor:
  62. assert tensor.shape[-1] % 2 == 0
  63. new_data_shape = list(tensor.shape)
  64. new_data_shape[-1] = new_data_shape[-1] // 2
  65. data = torch.empty(new_data_shape, dtype=torch.uint8, device=tensor.device)
  66. data = torch.bitwise_or(tensor[..., 0::2].bitwise_left_shift(4), tensor[..., 1::2])
  67. return data
  68. class DeQuantizer:
  69. def __init__(self, config: Dict, dtype: torch.dtype) -> None:
  70. self.config = config
  71. self.dtype = dtype
  72. assert self.config['num_bits'] == 4 or self.config[
  73. 'num_bits'] == 8, 'Only INT4 and INT8 quantization is supported.'
  74. assert self.config['symmetric'] == False, 'Only asymmetric quantization is supported at this moment.'
  75. def dequantize(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> Tensor:
  76. # Use customized CUDA quantization kernel if possible.
  77. if self.config['group_size'] % 8 == 0 and \
  78. self.config['num_bits'] == 4 and \
  79. self.config['group_dim'] == len(tensor.shape) - 1 and \
  80. self.dtype == torch.float16 and device == 'cuda':
  81. last_dimension_size = self.config['group_size']
  82. if self.config['num_bits'] == 4:
  83. last_dimension_size = last_dimension_size // 2
  84. quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
  85. tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
  86. tensor.numel() // last_dimension_size, self.config['group_size'])
  87. shape = list(tensor.shape)
  88. if self.config['num_bits'] == 4:
  89. shape[-1] = shape[-1] * 2
  90. return quantized_tensor.reshape(shape)
  91. if self.config['num_bits'] == 4:
  92. tensor = self._decompress_uint4_to_uint8(tensor)
  93. elif self.config['num_bits'] != 8:
  94. assert False, 'Unsupported quantization bits {}'.format(self.config['num_bits'])
  95. shape = tensor.shape
  96. num_groups = shape[self.config['group_dim']] // self.config['group_size']
  97. new_shape = (shape[:self.config['group_dim']] + (num_groups, self.config['group_size']) +
  98. shape[self.config['group_dim'] + 1:])
  99. tensor = tensor.view(new_shape)
  100. dequantized_tensor = self._dequantize_int8(tensor, quant_scale, quant_min).view(shape)
  101. return dequantized_tensor
  102. def _dequantize_int8(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> Tensor:
  103. assert tensor.dtype == torch.uint8
  104. data = torch.zeros_like(tensor, dtype=self.dtype, device=tensor.device)
  105. data = data.copy_(tensor)
  106. data = data.div_(quant_scale).add_(quant_min)
  107. return data
  108. def _decompress_uint4_to_uint8(self, tensor: Tensor) -> Tensor:
  109. new_data_shape = list(tensor.shape)
  110. new_data_shape[-1] = new_data_shape[-1] * 2
  111. data = torch.empty(new_data_shape, dtype=torch.uint8, device=tensor.device)
  112. data[..., 0::2] = tensor.bitwise_right_shift(4)
  113. data[..., 1::2] = tensor.bitwise_and(0xF)
  114. return data
  115. def get_AsyncPartitionedParameterSwapper(model: nn.Module):
  116. for param_name, param in model.named_parameters():
  117. if hasattr(param, 'nvme_swapper') and param.nvme_swapper is not None:
  118. return param.nvme_swapper
  119. return None
  120. def recursive_setattr(model, module_name, module):
  121. """
  122. Recursively set the attribute of a module.
  123. Args:
  124. model (`torch.nn.Module`)
  125. The model to set the attribute in.
  126. module_name (`str`)
  127. The name of the module to set the attribute in.
  128. module (`torch.nn.Module`)
  129. The module to set the attribute to.
  130. """
  131. split_list = module_name.split('.')
  132. output = model
  133. for name in split_list[:-1]:
  134. output = getattr(output, name)
  135. output.__setattr__(split_list[-1], module)
  136. def concat_to_compat_param(quantized_weight: Tensor,
  137. quant_scale: Tensor,
  138. quant_min: Tensor,
  139. return_param: bool = True) -> Union[nn.Parameter, Tensor]:
  140. shape_wieght = quantized_weight.shape
  141. shape_scale = quant_scale.shape
  142. shape_min = quant_min.shape
  143. quantized_weight = torch.flatten(quantized_weight)
  144. quant_scale = torch.flatten(quant_scale)
  145. quant_min = torch.flatten(quant_min)
  146. def deconcat_individual_tensors(shape_wieght: torch.Size, shape_scale: torch.Size,
  147. shape_min: torch.Size) -> Callable:
  148. def fn(compat_tensor: nn.Parameter) -> Tuple[Tensor, Tensor, Tensor]:
  149. weight = torch.narrow(compat_tensor, 0, 0, shape_wieght.numel()).view(shape_wieght)
  150. scale = torch.narrow(compat_tensor, 0, shape_wieght.numel(), shape_scale.numel()).view(shape_scale)
  151. min_val = torch.narrow(compat_tensor, 0,
  152. shape_wieght.numel() + shape_scale.numel(), shape_min.numel()).view(shape_min)
  153. return weight, scale, min_val
  154. return fn
  155. compat_tensor = torch.concat([quantized_weight, quant_scale, quant_min])
  156. if return_param:
  157. compat_tensor = nn.Parameter(compat_tensor, requires_grad=False)
  158. compat_tensor.deconcat = deconcat_individual_tensors(shape_wieght, shape_scale, shape_min)
  159. return compat_tensor
  160. def _quantize_param(param: nn.Parameter, quant_config: Dict):
  161. assert not hasattr(param, 'weight_quantized'), 'Parameter has already been quantized.'
  162. quantizer = Quantizer(quant_config)
  163. dequantizer = DeQuantizer(quant_config, param.dtype)
  164. quantized_weight, quant_scale, quant_min = quantizer.quantize(param.data)
  165. quantized_weight = quantized_weight.view(param.dtype)
  166. quant_scale = quant_scale.view(param.dtype)
  167. quant_min = quant_min.view(param.dtype)
  168. quantized_compat_tensor = concat_to_compat_param(quantized_weight, quant_scale, quant_min)
  169. param.data = quantized_compat_tensor
  170. param.deconcat = quantized_compat_tensor.deconcat
  171. param.quantizer = quantizer
  172. param.dequantizer = dequantizer
  173. setattr(param, 'weight_quantized', True)
  174. def wrap_quantized_functional(f):
  175. @functools.wraps(f)
  176. def wrapper(input: Tensor, weight: nn.Parameter, *args, **kwargs) -> Tensor:
  177. if hasattr(weight, 'weight_quantized') and getattr(weight, 'weight_quantized'):
  178. quantized_weight, quant_scale, quant_min = weight.deconcat(weight)
  179. temp_dequantized_weight = weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale,
  180. quant_min)
  181. return f(input, temp_dequantized_weight, *args, **kwargs)
  182. else:
  183. return f(input, weight, *args, **kwargs)
  184. return wrapper
  185. def wrap_load_from_state_dict(f):
  186. @functools.wraps(f)
  187. def wrapper(model, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  188. replaced_old_value = None
  189. key = None
  190. # We may have nested wrappers if we launch multiple initialization context.
  191. # Use state_dict_quantized flag to quantize state_dict only once
  192. if hasattr(model.weight, 'weight_quantized') and getattr(
  193. model.weight, 'weight_quantized') and not hasattr(model.weight, 'state_dict_quantized'):
  194. setattr(model.weight, 'state_dict_quantized', True)
  195. key = prefix + 'weight'
  196. if key in state_dict:
  197. quantized_weight, quant_scale, quant_min = model.weight.quantizer.quantize(state_dict[key])
  198. quantized_weight = quantized_weight.view(model.weight.dtype)
  199. quant_scale = quant_scale.view(model.weight.dtype)
  200. quant_min = quant_min.view(model.weight.dtype)
  201. replaced_old_value = state_dict[key]
  202. state_dict[key] = concat_to_compat_param(quantized_weight, quant_scale, quant_min)
  203. f(model, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  204. if replaced_old_value is not None:
  205. state_dict[key] = replaced_old_value
  206. delattr(model.weight, 'state_dict_quantized')
  207. return wrapper
  208. WEIGHT_QUANTIZATION_LAYERS = (
  209. nn.Linear,
  210. nn.Embedding,
  211. )