test_dequantization.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Copyright (c) 2023, 2023, Oracle and/or its affiliates.
  5. import os
  6. import torch
  7. import pytest
  8. from unit.common import DistributedTest
  9. import deepspeed
  10. from deepspeed.accelerator import get_accelerator
  11. class TestDequantization(DistributedTest):
  12. def init(self):
  13. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  14. self.device = torch.device(get_accelerator().device_name(local_rank))
  15. from deepspeed.ops.op_builder import InferenceBuilder
  16. if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
  17. pytest.skip("InferenceBuilder is not implemented")
  18. else:
  19. self.dequantize_func = InferenceBuilder().load().dequantize_fp16
  20. def run_dequantize_test(self, M, N, num_groups):
  21. weight = torch.randint(-255, 255, (M, N)).to(dtype=torch.int8, device=self.device)
  22. scale = torch.rand(num_groups, 1).to(device=self.device)
  23. weight_deq = (weight.reshape(num_groups, -1) * scale).reshape(M, N).to(torch.float16).contiguous()
  24. weight_deq_backend = self.dequantize_func(weight, scale, num_groups)
  25. assert torch.allclose(weight_deq, weight_deq_backend)
  26. def test_dequantize(self):
  27. self.init()
  28. self.run_dequantize_test(14336, 7168, 32)
  29. self.run_dequantize_test(14336, 1792, 32)
  30. self.run_dequantize_test(768, 768, 32)
  31. self.run_dequantize_test(768, 768, 48)