test_universal_checkpoint.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import deepspeed
  5. from types import SimpleNamespace
  6. from torch.utils._pytree import tree_map
  7. from deepspeed.utils.torch import required_torch_version
  8. from deepspeed.checkpoint import UNIVERSAL_CHECKPOINT_INFO
  9. from deepspeed.checkpoint.ds_to_universal import main as convert_to_universal
  10. from unit.common import DistributedTest, DistributedFixture
  11. from unit.simple_model import *
  12. from unit.util import bf16_required_version_check
  13. from unit.checkpoint.common import compare_opt_state_dicts, compare_state_dicts
  14. import pytest
  15. import deepspeed.comm as dist
  16. def get_expected_mismatch_keys():
  17. # torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to
  18. # false positive mismatches in checkpoint state comparisons.
  19. # Newer torch versions store tensor ids as 0, 1, 2, ...
  20. return [] if required_torch_version(min_version=1.4) else ['params']
  21. def maybe_step(t):
  22. return not torch.is_tensor(t) or (t.device.type == 'cpu' and t.numel() == 1)
  23. def gather_opt_state(optimizer_state):
  24. def gather_tensor(t):
  25. if maybe_step(t):
  26. return t
  27. else:
  28. buffer = [torch.zeros_like(t.flatten()) for _ in range(dist.get_world_size())]
  29. dist.all_gather(buffer, t.flatten())
  30. return torch.cat(buffer)
  31. return tree_map(gather_tensor, optimizer_state)
  32. def remove_pad_in_opt_state(optimizer_state, num_params):
  33. def remove_pad(t):
  34. if maybe_step(t):
  35. return t
  36. else:
  37. return t[:num_params]
  38. return tree_map(remove_pad, optimizer_state)
  39. CP_TAG = "test_tag"
  40. def init_ds_engine(model, ds_config, use_torch_adam):
  41. if use_torch_adam:
  42. ds_optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
  43. del ds_config["optimizer"]
  44. model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, optimizer=ds_optimizer)
  45. else:
  46. model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters())
  47. return model
  48. def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir):
  49. if dtype == torch.bfloat16 and not bf16_required_version_check():
  50. return
  51. test_step = 8
  52. model = SimpleModel(hidden_dim)
  53. model = init_ds_engine(model, ds_config, use_torch_adam)
  54. data_loader = random_dataloader(model=model,
  55. total_samples=test_step,
  56. hidden_dim=hidden_dim,
  57. device=model.device,
  58. dtype=dtype)
  59. for batch in data_loader:
  60. loss = model(batch[0], batch[1])
  61. model.backward(loss)
  62. model.step()
  63. if ds_config["zero_optimization"]["stage"] == 3:
  64. model.optimizer._set_fp32_optimizer_param_groups()
  65. sd = model.optimizer.optimizer.state_dict() if load_optim else None
  66. model.optimizer._clear_fp32_optimizer_param_groups()
  67. else:
  68. sd = model.optimizer.optimizer.state_dict() if load_optim else None
  69. client_state = {}
  70. client_state[UNIVERSAL_CHECKPOINT_INFO] = {}
  71. client_state['iteration'] = test_step
  72. model.save_checkpoint(tmpdir, tag=CP_TAG, client_state=client_state)
  73. cp_dir = os.path.join(tmpdir, CP_TAG)
  74. univ_cp_dir = f"{cp_dir}_universal"
  75. args = SimpleNamespace(input_folder=cp_dir,
  76. output_folder=univ_cp_dir,
  77. num_extract_workers=1,
  78. num_merge_workers=1,
  79. keep_temp_folder=False,
  80. strict=True,
  81. inject_missing_state=False)
  82. dist.barrier()
  83. if dist.get_rank() == 0:
  84. convert_to_universal(args)
  85. model_state = model.state_dict()
  86. optimizer_state = None
  87. if load_optim:
  88. if ds_config["zero_optimization"]["stage"] == 3:
  89. model.optimizer._set_fp32_optimizer_param_groups()
  90. optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
  91. model.optimizer._clear_fp32_optimizer_param_groups()
  92. else:
  93. optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
  94. if dist.get_rank() == 0:
  95. torch.save((model_state, optimizer_state), os.path.join(tmpdir, "baseline_state.pt"))
  96. dist.barrier()
  97. return model, sd
  98. @pytest.fixture
  99. def ds_config(zero_stage, dtype):
  100. ds_config = {
  101. "train_batch_size": 8,
  102. "optimizer": {
  103. "type": 'Adam'
  104. },
  105. "zero_optimization": {
  106. "stage": zero_stage,
  107. }
  108. }
  109. if dtype == torch.float16:
  110. ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8}
  111. elif dtype == torch.bfloat16:
  112. ds_config["bf16"] = {"enabled": True}
  113. return ds_config
  114. class _baseline(DistributedFixture):
  115. world_size = None
  116. def run(self, tmpdir, ds_config, zero_stage, dtype, load_optim, use_torch_adam):
  117. hidden_dim = 10
  118. train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir)
  119. class baseline_ws2(_baseline):
  120. world_size = 2
  121. class baseline_ws4(_baseline):
  122. world_size = 4
  123. @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
  124. @pytest.mark.parametrize("zero_stage", [1, 3])
  125. @pytest.mark.parametrize("use_torch_adam", [False, True])
  126. @pytest.mark.parametrize("load_optim", [False, True])
  127. class TestZeROUniversalCheckpointDP(DistributedTest):
  128. def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
  129. if dtype == torch.bfloat16 and not bf16_required_version_check():
  130. pytest.skip(
  131. " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
  132. )
  133. hidden_dim = 10
  134. loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt")
  135. ds_config["checkpoint"] = {"load_universal": True}
  136. univ_model = SimpleModel(hidden_dim)
  137. univ_model = init_ds_engine(univ_model, ds_config, use_torch_adam)
  138. univ_model.load_checkpoint(tmpdir, tag=f"{CP_TAG}_universal", load_optimizer_states=load_optim)
  139. model_state = univ_model.state_dict()
  140. compare_state_dicts(model_state, loaded_model_state)
  141. if load_optim and ds_config["zero_optimization"]["stage"] != 3:
  142. optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
  143. # padding sizes may differ when dp sizes are different
  144. param_count = sum(p.numel() for p in univ_model.parameters())
  145. optimizer_state = remove_pad_in_opt_state(optimizer_state, param_count)
  146. loaded_optimizer_state = remove_pad_in_opt_state(loaded_optimizer_state, param_count)
  147. compare_opt_state_dicts(optimizer_state, loaded_optimizer_state, get_expected_mismatch_keys())
  148. # Run training again to verify that the optimizer has necessary states
  149. test_step = 8
  150. data_loader = random_dataloader(model=univ_model,
  151. total_samples=test_step,
  152. hidden_dim=hidden_dim,
  153. device=univ_model.device,
  154. dtype=dtype)
  155. for batch in data_loader:
  156. loss = univ_model(batch[0], batch[1])
  157. univ_model.backward(loss)
  158. univ_model.step()
  159. @pytest.mark.world_size(2)
  160. def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
  161. self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
  162. @pytest.mark.world_size(2)
  163. def test_dp_world_size_4to2(self, baseline_ws4, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
  164. self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
  165. @pytest.mark.world_size(4)
  166. def test_dp_world_size_2to4(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
  167. self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)