1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import pytest
- import torch
- import deepspeed
- from unit.common import DistributedTest
- from unit.simple_model import create_config_from_dict
- @pytest.mark.inference
- class TestInferenceConfig(DistributedTest):
- world_size = 1
- def test_overlap_kwargs(self):
- config = {"replace_with_kernel_inject": True, "dtype": torch.float32}
- kwargs = {"replace_with_kernel_inject": True}
- engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
- assert engine._config.replace_with_kernel_inject
- def test_overlap_kwargs_conflict(self):
- config = {"replace_with_kernel_inject": True}
- kwargs = {"replace_with_kernel_inject": False}
- with pytest.raises(ValueError):
- engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
- def test_kwargs_and_config(self):
- config = {"replace_with_kernel_inject": True}
- kwargs = {"dtype": torch.float32}
- engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
- assert engine._config.replace_with_kernel_inject
- assert engine._config.dtype == kwargs["dtype"]
- def test_json_config(self, tmpdir):
- config = {"replace_with_kernel_inject": True, "dtype": "torch.float32"}
- config_json = create_config_from_dict(tmpdir, config)
- engine = deepspeed.init_inference(torch.nn.Module(), config=config_json)
- assert engine._config.replace_with_kernel_inject
|