12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- # Copyright (c) 2023, 2023, Oracle and/or its affiliates.
- import os
- import torch
- import pytest
- from unit.common import DistributedTest
- import deepspeed
- from deepspeed.accelerator import get_accelerator
- class TestDequantization(DistributedTest):
- def init(self):
- local_rank = int(os.getenv("LOCAL_RANK", "0"))
- self.device = torch.device(get_accelerator().device_name(local_rank))
- from deepspeed.ops.op_builder import InferenceBuilder
- if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
- pytest.skip("InferenceBuilder is not implemented")
- else:
- self.dequantize_func = InferenceBuilder().load().dequantize_fp16
- def run_dequantize_test(self, M, N, num_groups):
- weight = torch.randint(-255, 255, (M, N)).to(dtype=torch.int8, device=self.device)
- scale = torch.rand(num_groups, 1).to(device=self.device)
- weight_deq = (weight.reshape(num_groups, -1) * scale).reshape(M, N).to(torch.float16).contiguous()
- weight_deq_backend = self.dequantize_func(weight, scale, num_groups)
- assert torch.allclose(weight_deq, weight_deq_backend)
- def test_dequantize(self):
- self.init()
- self.run_dequantize_test(14336, 7168, 32)
- self.run_dequantize_test(14336, 1792, 32)
- self.run_dequantize_test(768, 768, 32)
- self.run_dequantize_test(768, 768, 48)
|