test_inference_config.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import pytest
  5. import torch
  6. import deepspeed
  7. from unit.common import DistributedTest
  8. from unit.simple_model import create_config_from_dict
  9. @pytest.mark.inference
  10. class TestInferenceConfig(DistributedTest):
  11. world_size = 1
  12. def test_overlap_kwargs(self):
  13. config = {"replace_with_kernel_inject": True, "dtype": torch.float32}
  14. kwargs = {"replace_with_kernel_inject": True}
  15. engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
  16. assert engine._config.replace_with_kernel_inject
  17. def test_overlap_kwargs_conflict(self):
  18. config = {"replace_with_kernel_inject": True}
  19. kwargs = {"replace_with_kernel_inject": False}
  20. with pytest.raises(ValueError):
  21. engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
  22. def test_kwargs_and_config(self):
  23. config = {"replace_with_kernel_inject": True}
  24. kwargs = {"dtype": torch.float32}
  25. engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
  26. assert engine._config.replace_with_kernel_inject
  27. assert engine._config.dtype == kwargs["dtype"]
  28. def test_json_config(self, tmpdir):
  29. config = {"replace_with_kernel_inject": True, "dtype": "torch.float32"}
  30. config_json = create_config_from_dict(tmpdir, config)
  31. engine = deepspeed.init_inference(torch.nn.Module(), config=config_json)
  32. assert engine._config.replace_with_kernel_inject