test_quantize.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import pytest
  5. import torch
  6. import deepspeed
  7. from deepspeed.ops.op_builder import QuantizerBuilder
  8. from deepspeed.accelerator import get_accelerator
  9. if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
  10. pytest.skip("Inference ops are not available on this system", allow_module_level=True)
  11. inference_module = None
  12. def run_quantize_ds(activations, num_groups, q_bits, is_symmetric_quant):
  13. global inference_module
  14. if inference_module is None:
  15. inference_module = QuantizerBuilder().load()
  16. return inference_module.quantize(activations, num_groups, q_bits,
  17. inference_module.Symmetric if is_symmetric_quant else inference_module.Asymmetric)
  18. def run_dequantize_ds(activations, params, num_groups, q_bits, is_symmetric_quant):
  19. global inference_module
  20. if inference_module is None:
  21. inference_module = QuantizerBuilder().load()
  22. return inference_module.dequantize(
  23. activations,
  24. params,
  25. num_groups,
  26. q_bits,
  27. inference_module.Symmetric if is_symmetric_quant else inference_module.Asymmetric,
  28. )
  29. def get_q_props(q_bits):
  30. q_range = 2**q_bits
  31. q_min = -(2**(q_bits - 1))
  32. q_max = (2**(q_bits - 1) - 1)
  33. q_min = torch.IntTensor([q_min]).to(device=get_accelerator().device_name())
  34. q_max = torch.IntTensor([q_max]).to(device=get_accelerator().device_name())
  35. return q_range, q_max, q_min
  36. def get_scale_zero_point(q_bits, is_symmetric_quant, max, min, absmax, scales=None, zero_points=None):
  37. q_range, q_max, q_min = get_q_props(q_bits)
  38. if is_symmetric_quant:
  39. scale = torch.empty_like(absmax)
  40. for i, x in enumerate(absmax):
  41. scale[i] = torch.ones_like(x) if x == 0 else q_range / (2 * x)
  42. zero_point = torch.zeros(scale.shape, dtype=torch.float32, device=get_accelerator().device_name())
  43. else:
  44. scale = torch.empty_like(max)
  45. for i, x in enumerate(max):
  46. scale[i] = torch.ones_like(x) if max[i] == min[i] else q_range / (max[i] - min[i])
  47. zero_point = q_min - (min * scale)
  48. return scale, zero_point
  49. def int4x2to2xint4(int4X2tensor):
  50. high = int4X2tensor >> 4
  51. low = (int4X2tensor << 4) >> 4
  52. return torch.stack((high, low), dim=-1).flatten()
  53. def run_float_quantize(q_bits, is_symmetric_quant, activations_ref, num_groups):
  54. # Reference implementation
  55. # https://pytorch.org/docs/stable/quantization-support.html
  56. activations_ref = activations_ref.reshape(num_groups, -1).to(dtype=torch.float32)
  57. max_abs_activations_ref = torch.amax(torch.abs(activations_ref), dim=-1).view(num_groups, -1)
  58. max_activations_ref = torch.amax(activations_ref, dim=-1).view(num_groups, -1)
  59. min_activations_ref = torch.amin(activations_ref, dim=-1).view(num_groups, -1)
  60. _, q_max, q_min = get_q_props(q_bits)
  61. scale, zero_point = get_scale_zero_point(q_bits, is_symmetric_quant, max_activations_ref, min_activations_ref,
  62. max_abs_activations_ref)
  63. data_f = activations_ref * scale
  64. if not is_symmetric_quant:
  65. data_f = data_f + zero_point
  66. data_i32 = torch.round(data_f).to(dtype=torch.int32)
  67. data_i32 = torch.minimum(torch.maximum(data_i32, q_min.expand_as(data_i32)), q_max.expand_as(data_i32))
  68. data_i8 = data_i32.to(dtype=torch.int8)
  69. scales = (1.0 / scale).reshape(-1, 1)
  70. offsets = zero_point.reshape(-1, 1)
  71. params = torch.cat((scales, offsets), dim=-1)
  72. return data_i8, params
  73. def run_float_dequantize(q_bits, is_symmetric_quant, data_i8, params, num_groups):
  74. data_f = data_i8.reshape(num_groups, -1).to(dtype=torch.float32)
  75. scales = params[:, 0].reshape(-1, 1)
  76. offsets = params[:, 1].reshape(-1, 1)
  77. if not is_symmetric_quant:
  78. data_f = data_f - offsets
  79. else:
  80. assert offsets.allclose(torch.zeros_like(offsets))
  81. data_f = data_f * scales
  82. return data_f
  83. @pytest.mark.inference_ops
  84. @pytest.mark.parametrize("num_groups", [1, 13, 512])
  85. @pytest.mark.parametrize("num_elems", [8, 16, 32, 64, 128, 256, 4096, 8192, 12288, 16384])
  86. @pytest.mark.parametrize("is_symmetric_quant", [True, False])
  87. @pytest.mark.parametrize("q_bits", [4, 8])
  88. @pytest.mark.parametrize("directed_case", ["all_zeros", None])
  89. def test_float_quantize(num_elems, num_groups, is_symmetric_quant, q_bits, directed_case):
  90. # fix seed
  91. torch.manual_seed(num_elems)
  92. if directed_case == "all_zeros":
  93. activations_ds = torch.zeros((num_groups, num_elems),
  94. dtype=torch.float16,
  95. device=get_accelerator().device_name())
  96. else:
  97. activations_ds = torch.randn((num_groups, num_elems),
  98. dtype=torch.float16,
  99. device=get_accelerator().device_name())
  100. activations_ref = activations_ds.clone().detach()
  101. ref_out_tensor, ref_params = run_float_quantize(q_bits, is_symmetric_quant, activations_ref, num_groups)
  102. ref_dequantized_tensor = run_float_dequantize(q_bits, is_symmetric_quant, ref_out_tensor, ref_params, num_groups)
  103. # we need to convert the tensor to float64 to avoid overflow
  104. ref_quantization_error = torch.sum(torch.abs((activations_ref - ref_dequantized_tensor).to(torch.float64)))
  105. ds_out_tensor, ds_out_params = run_quantize_ds(activations_ds, num_groups, q_bits, is_symmetric_quant)
  106. ds_dequantized_tensor = run_dequantize_ds(ds_out_tensor, ds_out_params, num_groups, q_bits, is_symmetric_quant)
  107. assert torch.all(torch.isfinite(ds_dequantized_tensor))
  108. ds_quantization_error = torch.sum(torch.abs((activations_ds - ds_dequantized_tensor).to(torch.float64)))
  109. assert (ds_quantization_error <= ref_quantization_error * 1.05)