1234567891011121314151617181920212223242526272829303132333435363738 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- # Copyright (c) 2023, 2023, Oracle and/or its affiliates.
- import os
- import torch
- from unit.common import DistributedTest
- from deepspeed.ops.op_builder import InferenceBuilder
- from deepspeed.accelerator import get_accelerator
- class TestDequantization(DistributedTest):
- def init(self):
- local_rank = int(os.getenv("LOCAL_RANK", "0"))
- self.device = torch.device(get_accelerator().device_name(local_rank))
- self.dequantize_func = InferenceBuilder().load().dequantize_fp16
- def run_dequantize_test(self, M, N, num_groups):
- weight = torch.randint(-255, 255, (M, N)).to(dtype=torch.int8, device=self.device)
- scale = torch.rand(num_groups, 1).to(device=self.device)
- weight_deq = (weight.reshape(num_groups, -1) * scale).reshape(M, N).to(torch.float16).contiguous()
- weight_deq_backend = self.dequantize_func(weight, scale, num_groups)
- assert torch.allclose(weight_deq, weight_deq_backend)
- def test_dequantize(self):
- self.init()
- self.run_dequantize_test(14336, 7168, 32)
- self.run_dequantize_test(14336, 1792, 32)
- self.run_dequantize_test(768, 768, 32)
- self.run_dequantize_test(768, 768, 48)
|