1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
- # SPDX-License-Identifier: Apache-2.0
- import deepspeed
- from deepspeed.runtime.utils import required_torch_version
- from unit.common import DistributedTest
- from unit.simple_model import *
- from unit.checkpoint.common import *
- import pytest
- if not required_torch_version(max_version=2.0):
- pytest.skip("Skipping until we resolve problems with torch 2.1", allow_module_level=True)
- class TestMiCSCheckpoint(DistributedTest):
- world_size = 4
- def _toy_model_config(self, shard_size):
- config_dict = {
- "train_micro_batch_size_per_gpu": 2,
- "steps_per_print": 1,
- "optimizer": {
- "type": 'Adam',
- "params": {
- "lr": 0.00015,
- "betas": [0.8, 0.999],
- "eps": 1e-8,
- "weight_decay": 3e-7
- }
- },
- "fp16": {
- "enabled": True,
- "initial_scale_power": 8
- },
- "wall_clock_breakdown": True,
- "zero_optimization": {
- "stage": 3,
- "mics_shard_size": shard_size
- }
- }
- hidden_dim = 10
- with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
- models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
- return config_dict, hidden_dim, models
- @pytest.mark.parametrize('shard_size', [1, 2, 4])
- def test_load_optimizer_state(self, tmpdir, shard_size):
- config_dict, hidden_dim, models = self._toy_model_config(shard_size)
- checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=True)
- @pytest.mark.parametrize('shard_size', [1, 2, 4])
- def test_not_load_optimizer_state(self, tmpdir, shard_size):
- config_dict, hidden_dim, models = self._toy_model_config(shard_size)
- checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=False)
- @pytest.mark.parametrize('shard_size', [1, 2, 4])
- def test_load_module_only(self, tmpdir, shard_size):
- config_dict, hidden_dim, models = self._toy_model_config(shard_size)
- checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True)
- @pytest.mark.parametrize('shard_size', [1, 2, 4])
- def test_save_checkpoint_on_first_partition_group(self, tmpdir, shard_size):
- config_dict, _, models = self._toy_model_config(shard_size)
- ds_engine, _, _, _ = deepspeed.initialize(config=config_dict,
- model=models[0],
- model_parameters=models[0].parameters(),
- optimizer=None)
- ds_engine.save_checkpoint(tmpdir)
- if ds_engine.global_rank < shard_size:
- assert ds_engine.save_non_zero_checkpoint == True
- else:
- assert ds_engine.save_non_zero_checkpoint == False
|