test_dequantization.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Copyright (c) 2023, 2023, Oracle and/or its affiliates.
  5. import os
  6. import torch
  7. from unit.common import DistributedTest
  8. from deepspeed.ops.op_builder import InferenceBuilder
  9. from deepspeed.accelerator import get_accelerator
  10. class TestDequantization(DistributedTest):
  11. def init(self):
  12. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  13. self.device = torch.device(get_accelerator().device_name(local_rank))
  14. self.dequantize_func = InferenceBuilder().load().dequantize_fp16
  15. def run_dequantize_test(self, M, N, num_groups):
  16. weight = torch.randint(-255, 255, (M, N)).to(dtype=torch.int8, device=self.device)
  17. scale = torch.rand(num_groups, 1).to(device=self.device)
  18. weight_deq = (weight.reshape(num_groups, -1) * scale).reshape(M, N).to(torch.float16).contiguous()
  19. weight_deq_backend = self.dequantize_func(weight, scale, num_groups)
  20. assert torch.allclose(weight_deq, weight_deq_backend)
  21. def test_dequantize(self):
  22. self.init()
  23. self.run_dequantize_test(14336, 7168, 32)
  24. self.run_dequantize_test(14336, 1792, 32)
  25. self.run_dequantize_test(768, 768, 32)
  26. self.run_dequantize_test(768, 768, 48)