quantize.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import abc
  6. from abc import ABC
  7. from deepspeed.ops.op_builder import FPQuantizerBuilder
  8. fp_quant_module = None
  9. class Quantizer(ABC):
  10. """
  11. Abstract Quantizer class that implmenents quantize/dequantize methods.
  12. Arguments:
  13. group_size (int, optional): number of values or elements that are grouped
  14. together for the quantization process.
  15. """
  16. def __init__(self, group_size=512) -> None:
  17. self.group_size = group_size
  18. @abc.abstractmethod
  19. def quantize(self,
  20. input,
  21. q_bits=8,
  22. q_mantisa_bits=3,
  23. stochastic_mode=False,
  24. return_meta_tensor=False) -> torch.Tensor:
  25. ...
  26. @abc.abstractmethod
  27. def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
  28. ...
  29. class FP_Quantize(Quantizer):
  30. def __init__(self, group_size=512) -> None:
  31. global fp_quant_module
  32. super().__init__(group_size=group_size)
  33. if fp_quant_module is None:
  34. fp_quant_module = FPQuantizerBuilder().load()
  35. self.orig_dtype = None
  36. def quantize(self,
  37. input,
  38. q_bits=8,
  39. q_mantisa_bits=3,
  40. stochastic_mode=False,
  41. return_meta_tensor=False) -> torch.Tensor:
  42. assert input.dtype == torch.bfloat16, "only support bf16 for now"
  43. if return_meta_tensor:
  44. assert q_bits == 8, "meta tensor is only supported with q_bit=8"
  45. self.orig_dtype = input.dtype
  46. self.orig_shape = input.shape
  47. if q_bits == 8:
  48. pass
  49. elif q_bits == 12:
  50. q_mantisa_bits = 4
  51. elif q_bits == 6:
  52. q_mantisa_bits = 2
  53. elif q_bits == 4:
  54. q_mantisa_bits = 1
  55. else:
  56. assert (0), \
  57. f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
  58. out = fp_quant_module.quantize(input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
  59. if return_meta_tensor:
  60. data, scale = out.split(self.group_size, dim=-1)
  61. return data.contiguous().reshape(input.shape), scale.contiguous()
  62. return out
  63. def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
  64. assert (self.orig_dtype is not None), \
  65. "[De-quantization Error]: you need to call quantize before dequantizing!"
  66. fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype,
  67. device=input_q.device) if fp_out is None else fp_out
  68. if q_bits == 8:
  69. pass
  70. elif q_bits == 12:
  71. q_mantisa_bits = 4
  72. elif q_bits == 6:
  73. q_mantisa_bits = 2
  74. elif q_bits == 4:
  75. q_mantisa_bits = 1
  76. else:
  77. assert (0), \
  78. f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
  79. if scale is not None:
  80. assert input_q.numel() == fp_out.numel(), \
  81. f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
  82. input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
  83. fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
  84. return fp_out
  85. def selective_dequantize(self,
  86. input_q,
  87. indexes,
  88. fp_out=None,
  89. q_bits=8,
  90. q_mantisa_bits=3,
  91. scale=None) -> torch.Tensor:
  92. assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \
  93. "Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function."
  94. assert (self.orig_dtype is not None), \
  95. "[De-quantization Error]: you need to call quantize before dequantizing!"
  96. fp_out = torch.empty(
  97. (indexes.shape[0],
  98. *self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out
  99. if q_bits == 8:
  100. pass
  101. elif q_bits == 12:
  102. q_mantisa_bits = 4
  103. elif q_bits == 6:
  104. q_mantisa_bits = 2
  105. elif q_bits == 4:
  106. q_mantisa_bits = 1
  107. else:
  108. assert (0), \
  109. f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
  110. if scale is not None:
  111. assert input_q.numel() == fp_out.numel(), \
  112. f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
  113. input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
  114. fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
  115. q_bits - q_mantisa_bits - 1)
  116. return fp_out