test_init_on_device.py 904 B

12345678910111213141516171819202122232425262728
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import pytest
  6. from unit.simple_model import SimpleModel
  7. from deepspeed import OnDevice
  8. from packaging import version as pkg_version
  9. from deepspeed.accelerator import get_accelerator
  10. from unit.common import DistributedTest
  11. @pytest.mark.parametrize('device', ['meta', get_accelerator().device_name(0)])
  12. class TestOnDevice(DistributedTest):
  13. world_size = 1
  14. def test_on_device(self, device):
  15. if device == "meta" and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"):
  16. pytest.skip("meta tensors only became stable after torch 1.10")
  17. with OnDevice(dtype=torch.half, device=device):
  18. model = SimpleModel(4)
  19. for p in model.parameters():
  20. assert p.device == torch.device(device)
  21. assert p.dtype == torch.half