test_fake_quantization.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import pytest
  6. import deepspeed
  7. from deepspeed.accelerator import get_accelerator
  8. from deepspeed.ops.op_builder import QuantizerBuilder
  9. if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
  10. pytest.skip("Inference ops are not available on this system", allow_module_level=True)
  11. quantizer_cuda_module = None
  12. def allclose(x, y):
  13. assert x.dtype == y.dtype
  14. rtol, atol = {torch.float32: (2e-2, 5e-3), torch.float16: (2e-2, 5e-3)}[x.dtype]
  15. return torch.allclose(x, y, rtol=rtol, atol=atol)
  16. def quantize_dequantize_ref(inputs, bit, num_groups=1):
  17. # quantize
  18. q_range = 2**bit
  19. input_flat = inputs.float().reshape(num_groups, -1).contiguous()
  20. input_flat = torch.nan_to_num(input_flat, nan=0.0)
  21. input_min = input_flat.amin(-1, keepdim=True)
  22. input_max = input_flat.amax(-1, keepdim=True)
  23. scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs() + 1e-5))
  24. input_flat = (input_flat * scale).round().clamp(-q_range // 2, q_range // 2 - 1)
  25. # dequantize
  26. dequant_flat = torch.t(input_flat.to(torch.int8)) / scale.view(-1).to(torch.float16)
  27. return torch.t(dequant_flat).reshape(inputs.shape)
  28. def run_quant_dequant(inputs, groups, bits):
  29. global quantizer_cuda_module
  30. if quantizer_cuda_module is None:
  31. quantizer_cuda_module = QuantizerBuilder().load()
  32. return quantizer_cuda_module.ds_quantize_fp16(inputs, groups, bits)
  33. @pytest.mark.inference_ops
  34. @pytest.mark.parametrize("tensor_shape", [(16, 4096), (128, 256)])
  35. # Test with two tensor shapes as (16, 4096) and (128, 256).
  36. @pytest.mark.parametrize("groups", [1, 16])
  37. # Test with number of quant groups as 1 and 16.
  38. # Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG.
  39. def test_fake_quant_dequant(tensor_shape, groups):
  40. input_tensor = torch.rand((tensor_shape), dtype=torch.float16).to(get_accelerator().device_name())
  41. # 8-bit quantization.
  42. ref_input_8bit = input_tensor.clone().detach()
  43. ds_input_8bit = input_tensor.clone().detach()
  44. ref_out_8bit = quantize_dequantize_ref(ref_input_8bit, 8, groups)
  45. # run_quant_dequant will do quantize then dequantize, and return the dequantized value.
  46. ds_out_8bit = run_quant_dequant(ds_input_8bit, groups, 8)
  47. assert (allclose(ds_out_8bit, ref_out_8bit))
  48. # 4-bit quantization.
  49. ref_input_4bit = input_tensor.clone().detach()
  50. ds_input_4bit = input_tensor.clone().detach()
  51. ref_out_4bit = quantize_dequantize_ref(ref_input_4bit, 4, groups)
  52. ds_out_4bit = run_quant_dequant(ds_input_4bit, groups, 4)
  53. assert (allclose(ds_out_4bit, ref_out_4bit))