# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from unit.common import DistributedTest from unit.simple_model import * from unit.checkpoint.common import checkpoint_correctness_verification from unit.util import skip_on_arch import pytest class TestPipelineCheckpoint(DistributedTest): world_size = 4 @pytest.mark.parametrize("zero_stage", [0, 1]) def test_checkpoint_pipe_engine(self, zero_stage, tmpdir): skip_on_arch(min_arch=7) config_dict = { "train_batch_size": 2, "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, "optimizer": { "type": "Adam", "params": { "lr": 1e-5 } }, "zero_optimization": { "stage": zero_stage }, "fp16": { "enabled": zero_stage > 0 }, "scheduler": { "type": "OneCycle", "params": { "cycle_first_step_size": 1000, "cycle_first_stair_count": 500, "cycle_second_step_size": 1000, "cycle_second_stair_count": 500, "decay_step_size": 1000, "cycle_min_lr": 0.0001, "cycle_max_lr": 0.0010, "decay_lr_rate": 0.001, "cycle_min_mom": 0.85, "cycle_max_mom": 0.99, "decay_mom_rate": 0.0 } } } models = [LinearStackPipe(num_stages=2) for _ in range(2)] checkpoint_correctness_verification(config_dict=config_dict, models=models, hidden_dim=models[0].hidden_dim, tmpdir=tmpdir, fp16=config_dict['fp16']['enabled'], load_optimizer_states=True, load_lr_scheduler_states=True, train_batch=True) @pytest.mark.parametrize( "base_topo,test_topo", [ #(PipeTopo(num_pp=1, # num_dp=4), # PipeTopo(num_pp=4, # num_dp=1)), #(PipeTopo(num_pp=2, # num_dp=2), # PipeTopo(num_pp=2, # num_dp=2)), #(PipeTopo(num_pp=4, # num_dp=1), # PipeTopo(num_pp=2, # num_dp=2)), ]) def test_checkpoint_pipe_module(self, base_topo, test_topo, tmpdir): checkpoint_engine = TorchCheckpointEngine() base_model = LinearStackPipe(topology=base_topo) base_model.save_state_dict(tmpdir, checkpoint_engine=checkpoint_engine) dist.barrier() test_model = LinearStackPipe(topology=test_topo) test_model.load_state_dir(tmpdir, checkpoint_engine=checkpoint_engine) # Base and test can have different lengths, so make sure we map from the # smaller to larger model if len(base_model.forward_funcs) < len(test_model.forward_funcs): A = base_model B = test_model else: A = test_model B = base_model # Compare layers individually since partitions are different for idx, A_layer in enumerate(A.forward_funcs): if not hasattr(A_layer, 'parameters'): # Skip functionals, etc. continue # Find the corresponding layer in B global_idx = idx + A._local_start B_local_idx = global_idx - B._local_start B_layer = B.forward_funcs[B_local_idx] # Compare layer parameters for p0, p1 in zip(A_layer.parameters(), B_layer.parameters()): assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}"