12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- import pytest
- import deepspeed
- from deepspeed.accelerator import get_accelerator
- from deepspeed.ops.op_builder import QuantizerBuilder
- if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
- pytest.skip("Inference ops are not available on this system", allow_module_level=True)
- quantizer_cuda_module = None
- def allclose(x, y):
- assert x.dtype == y.dtype
- rtol, atol = {torch.float32: (2e-2, 5e-3), torch.float16: (2e-2, 5e-3)}[x.dtype]
- return torch.allclose(x, y, rtol=rtol, atol=atol)
- def quantize_dequantize_ref(inputs, bit, num_groups=1):
- # quantize
- q_range = 2**bit
- input_flat = inputs.float().reshape(num_groups, -1).contiguous()
- input_flat = torch.nan_to_num(input_flat, nan=0.0)
- input_min = input_flat.amin(-1, keepdim=True)
- input_max = input_flat.amax(-1, keepdim=True)
- scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs() + 1e-5))
- input_flat = (input_flat * scale).round().clamp(-q_range // 2, q_range // 2 - 1)
- # dequantize
- dequant_flat = torch.t(input_flat.to(torch.int8)) / scale.view(-1).to(torch.float16)
- return torch.t(dequant_flat).reshape(inputs.shape)
- def run_quant_dequant(inputs, groups, bits):
- global quantizer_cuda_module
- if quantizer_cuda_module is None:
- quantizer_cuda_module = QuantizerBuilder().load()
- return quantizer_cuda_module.ds_quantize_fp16(inputs, groups, bits)
- @pytest.mark.inference_ops
- @pytest.mark.parametrize("tensor_shape", [(16, 4096), (128, 256)])
- # Test with two tensor shapes as (16, 4096) and (128, 256).
- @pytest.mark.parametrize("groups", [1, 16])
- # Test with number of quant groups as 1 and 16.
- # Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG.
- def test_fake_quant_dequant(tensor_shape, groups):
- input_tensor = torch.rand((tensor_shape), dtype=torch.float16).to(get_accelerator().device_name())
- # 8-bit quantization.
- ref_input_8bit = input_tensor.clone().detach()
- ds_input_8bit = input_tensor.clone().detach()
- ref_out_8bit = quantize_dequantize_ref(ref_input_8bit, 8, groups)
- # run_quant_dequant will do quantize then dequantize, and return the dequantized value.
- ds_out_8bit = run_quant_dequant(ds_input_8bit, groups, 8)
- assert (allclose(ds_out_8bit, ref_out_8bit))
- # 4-bit quantization.
- ref_input_4bit = input_tensor.clone().detach()
- ds_input_4bit = input_tensor.clone().detach()
- ref_out_4bit = quantize_dequantize_ref(ref_input_4bit, 4, groups)
- ds_out_4bit = run_quant_dequant(ds_input_4bit, groups, 4)
- assert (allclose(ds_out_4bit, ref_out_4bit))
|