common.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import numbers
  7. import deepspeed
  8. from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
  9. from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
  10. from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
  11. from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
  12. from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
  13. from unit.simple_model import *
  14. from unittest.mock import MagicMock, patch
  15. def compare_deepspeed_states(saved_model, loaded_model):
  16. # These are compared in more depth in other places
  17. assert hasattr(loaded_model, 'module')
  18. assert saved_model.sparse_tensor_module_names == loaded_model.sparse_tensor_module_names
  19. assert saved_model.skipped_steps == loaded_model.skipped_steps
  20. assert saved_model.global_steps == loaded_model.global_steps
  21. def zero3_params_to_fetch(param_list):
  22. return [p for p in param_list if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE]
  23. def compare_model_states(saved_model, loaded_model, compare_optimizer=True, load_module_only=False):
  24. if not load_module_only:
  25. compare_deepspeed_states(saved_model, loaded_model)
  26. params_to_fetch = zero3_params_to_fetch(
  27. list(saved_model.module.named_parameters()) + list(loaded_model.module.named_parameters()))
  28. enable_gather = len(params_to_fetch) > 0
  29. with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=enable_gather):
  30. for p0, p1 in zip(saved_model.module.named_parameters(), loaded_model.module.named_parameters()):
  31. np0, p0 = p0
  32. np1, p1 = p1
  33. if 'deepspeed_moe.gate.wg' in np0:
  34. # these params are converted to float at runtime, cast to half for comparison
  35. p1 = p1.half()
  36. p0 = p0.half()
  37. assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}'
  38. try:
  39. assert torch.allclose(p0, p1,
  40. atol=1e-07), f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}"
  41. except RuntimeError as err:
  42. print(f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}")
  43. raise err
  44. if not compare_optimizer:
  45. return
  46. if DeepSpeedZeroOptimizer_Stage3 is not None and isinstance(saved_model.optimizer, DeepSpeedZeroOptimizer_Stage3):
  47. for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat,
  48. loaded_model.optimizer.fp32_partitioned_groups_flat):
  49. assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
  50. elif isinstance(saved_model.optimizer, DeepSpeedZeroOptimizer):
  51. for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups,
  52. loaded_model.optimizer.single_partition_of_fp32_groups):
  53. assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
  54. assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
  55. elif isinstance(saved_model.optimizer, FP16_Optimizer):
  56. for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
  57. assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
  58. assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
  59. elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
  60. for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
  61. for p0, p1 in zip(params0, params1):
  62. assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
  63. assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
  64. elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
  65. pass
  66. else:
  67. assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}'
  68. def compare_state_dicts(state0, state1, expected_mismatch_keys=[]):
  69. for (k0, s0), (k1, s1) in zip(state0.items(), state1.items()):
  70. assert k0 == k1, f'failure due to key mismatch {k0} != {k1}'
  71. if k0 in expected_mismatch_keys:
  72. continue
  73. if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
  74. assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}'
  75. assert torch.equal(s0.to('cpu'), s1.to('cpu'))
  76. else:
  77. assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}'
  78. def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
  79. saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer
  80. loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer
  81. for state0, state1 in zip(saved_optimizer.state.values(), loaded_optimizer.state.values()):
  82. compare_state_dicts(state0, state1)
  83. def compare_lr_scheduler_states(saved_model, loaded_model):
  84. assert hasattr(saved_model, 'lr_scheduler')
  85. assert hasattr(loaded_model, 'lr_scheduler')
  86. saved_scheduler = saved_model.lr_scheduler
  87. loaded_scheduler = loaded_model.lr_scheduler
  88. assert hasattr(saved_scheduler, 'state_dict')
  89. assert hasattr(loaded_scheduler, 'state_dict')
  90. saved_sd = saved_scheduler.state_dict()
  91. loaded_sd = loaded_scheduler.state_dict()
  92. print(f"saved_sd = {saved_sd}")
  93. print(f"loaded_sd = {loaded_sd}")
  94. assert saved_sd.keys() == loaded_sd.keys()
  95. for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
  96. if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
  97. assert state0 == state1
  98. # following mixture-of-experts.md
  99. def create_moe_param_groups(model):
  100. from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
  101. parameters = {'params': [p for p in model.parameters()], 'name': 'parameters'}
  102. return split_params_into_different_moe_groups_for_optimizer(parameters)
  103. def create_deepspeed_model(config_dict, model, base_optimizer):
  104. ds_model, _, _, _ = deepspeed.initialize(config=config_dict,
  105. model=model,
  106. model_parameters=create_moe_param_groups(model),
  107. optimizer=base_optimizer)
  108. ds_model.empty_partition_cache()
  109. return ds_model
  110. def checkpoint_correctness_verification(config_dict,
  111. models,
  112. hidden_dim,
  113. tmpdir,
  114. load_optimizer_states=False,
  115. load_lr_scheduler_states=False,
  116. fp16=True,
  117. train_batch=False,
  118. base_optimizers=[None, None],
  119. empty_tag=False,
  120. seq_dataloader=False,
  121. load_module_only=False):
  122. dtype = torch.half if fp16 else torch.float32
  123. ds_model = create_deepspeed_model(config_dict=config_dict, model=models[0], base_optimizer=base_optimizers[0])
  124. if seq_dataloader:
  125. data_loader = sequence_dataloader(model=ds_model,
  126. total_samples=50,
  127. hidden_dim=hidden_dim,
  128. device=ds_model.device,
  129. dtype=dtype)
  130. else:
  131. data_loader = random_dataloader(model=ds_model,
  132. total_samples=50,
  133. hidden_dim=hidden_dim,
  134. device=ds_model.device,
  135. dtype=dtype)
  136. if train_batch:
  137. ds_model.set_dataloader(data_loader)
  138. for _, batch in enumerate(data_loader):
  139. loss = ds_model.train_batch()
  140. else:
  141. for _, batch in enumerate(data_loader):
  142. loss = ds_model(batch[0], batch[1])
  143. ds_model.backward(loss)
  144. ds_model.step()
  145. # Flush zero stage 3 cache
  146. ds_model.empty_partition_cache()
  147. trained_model = ds_model
  148. save_folder = os.path.join(tmpdir, 'saved_checkpoint')
  149. save_tag = None if empty_tag else '1'
  150. trained_model.save_checkpoint(save_folder, tag=save_tag)
  151. dist.barrier()
  152. for root, _, files in os.walk(save_folder):
  153. for f in files:
  154. if "_expert_" in f and "_model_states" in f:
  155. expert = torch.load(os.path.join(root, f))
  156. needed, storages = 0, {}
  157. for name, tensor in expert.items():
  158. needed += tensor.size().numel()
  159. storage = tensor.storage()
  160. # some storage can be shared within an expert's checkpoint
  161. storages[storage.data_ptr()] = storage.size()
  162. stored = sum(v for _, v in storages.items())
  163. assert needed == stored, f"MoE expert checkpoint uses more storage than required: {f}"
  164. loaded_model = create_deepspeed_model(config_dict=config_dict, model=models[1], base_optimizer=base_optimizers[1])
  165. assert list(trained_model.parameters())[0].dtype == list(loaded_model.parameters())[0].dtype
  166. context = patch.object(loaded_model, "_get_optimizer_ckpt_name",
  167. wraps=loaded_model._get_optimizer_ckpt_name) if not load_optimizer_states else MagicMock()
  168. with context as optim_load_state_dict_mock:
  169. loaded_model.load_checkpoint(save_folder,
  170. tag=save_tag,
  171. load_optimizer_states=load_optimizer_states,
  172. load_lr_scheduler_states=load_lr_scheduler_states,
  173. load_module_only=load_module_only)
  174. if not load_optimizer_states:
  175. # should not attempt to get the file name to load it
  176. optim_load_state_dict_mock.assert_not_called()
  177. compare_model_states(trained_model,
  178. loaded_model,
  179. compare_optimizer=load_optimizer_states,
  180. load_module_only=load_module_only)
  181. if load_optimizer_states:
  182. compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
  183. if load_lr_scheduler_states:
  184. compare_lr_scheduler_states(trained_model, loaded_model)