common.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import numbers
  7. import deepspeed
  8. from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
  9. from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
  10. from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
  11. from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
  12. from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
  13. from unit.simple_model import *
  14. def compare_deepspeed_states(saved_model, loaded_model):
  15. # These are compared in more depth in other places
  16. assert hasattr(loaded_model, 'module')
  17. assert saved_model.sparse_tensor_module_names == loaded_model.sparse_tensor_module_names
  18. assert saved_model.skipped_steps == loaded_model.skipped_steps
  19. assert saved_model.global_steps == loaded_model.global_steps
  20. def zero3_params_to_fetch(param_list):
  21. return [p for p in param_list if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE]
  22. def compare_model_states(saved_model, loaded_model, compare_optimizer=True, load_module_only=False):
  23. if not load_module_only:
  24. compare_deepspeed_states(saved_model, loaded_model)
  25. params_to_fetch = zero3_params_to_fetch(
  26. list(saved_model.module.named_parameters()) + list(loaded_model.module.named_parameters()))
  27. enable_gather = len(params_to_fetch) > 0
  28. with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=enable_gather):
  29. for p0, p1 in zip(saved_model.module.named_parameters(), loaded_model.module.named_parameters()):
  30. np0, p0 = p0
  31. np1, p1 = p1
  32. if 'deepspeed_moe.gate.wg' in np0:
  33. # these params are converted to float at runtime, cast to half for comparison
  34. p1 = p1.half()
  35. p0 = p0.half()
  36. assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}'
  37. try:
  38. assert torch.allclose(p0, p1,
  39. atol=1e-07), f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}"
  40. except RuntimeError as err:
  41. print(f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}")
  42. raise err
  43. if not compare_optimizer:
  44. return
  45. if DeepSpeedZeroOptimizer_Stage3 is not None and isinstance(saved_model.optimizer, DeepSpeedZeroOptimizer_Stage3):
  46. for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat,
  47. loaded_model.optimizer.fp32_partitioned_groups_flat):
  48. assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
  49. elif isinstance(saved_model.optimizer, DeepSpeedZeroOptimizer):
  50. for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups,
  51. loaded_model.optimizer.single_partition_of_fp32_groups):
  52. assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
  53. assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
  54. elif isinstance(saved_model.optimizer, FP16_Optimizer):
  55. for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
  56. assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
  57. assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
  58. elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
  59. for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
  60. for p0, p1 in zip(params0, params1):
  61. assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
  62. assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
  63. elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
  64. pass
  65. else:
  66. assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}'
  67. def compare_state_dicts(state0, state1, expected_mismatch_keys=[]):
  68. for (k0, s0), (k1, s1) in zip(state0.items(), state1.items()):
  69. assert k0 == k1, f'failure due to key mismatch {k0} != {k1}'
  70. if k0 in expected_mismatch_keys:
  71. continue
  72. if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
  73. assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}'
  74. assert torch.equal(s0.to('cpu'), s1.to('cpu'))
  75. else:
  76. assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}'
  77. def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
  78. saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer
  79. loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer
  80. for state0, state1 in zip(saved_optimizer.state.values(), loaded_optimizer.state.values()):
  81. compare_state_dicts(state0, state1)
  82. def compare_lr_scheduler_states(saved_model, loaded_model):
  83. assert hasattr(saved_model, 'lr_scheduler')
  84. assert hasattr(loaded_model, 'lr_scheduler')
  85. saved_scheduler = saved_model.lr_scheduler
  86. loaded_scheduler = loaded_model.lr_scheduler
  87. assert hasattr(saved_scheduler, 'state_dict')
  88. assert hasattr(loaded_scheduler, 'state_dict')
  89. saved_sd = saved_scheduler.state_dict()
  90. loaded_sd = loaded_scheduler.state_dict()
  91. print(f"saved_sd = {saved_sd}")
  92. print(f"loaded_sd = {loaded_sd}")
  93. assert saved_sd.keys() == loaded_sd.keys()
  94. for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
  95. if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
  96. assert state0 == state1
  97. # following mixture-of-experts.md
  98. def create_moe_param_groups(model):
  99. from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
  100. parameters = {'params': [p for p in model.parameters()], 'name': 'parameters'}
  101. return split_params_into_different_moe_groups_for_optimizer(parameters)
  102. def create_deepspeed_model(config_dict, model, base_optimizer):
  103. ds_model, _, _, _ = deepspeed.initialize(config=config_dict,
  104. model=model,
  105. model_parameters=create_moe_param_groups(model),
  106. optimizer=base_optimizer)
  107. ds_model.empty_partition_cache()
  108. return ds_model
  109. def checkpoint_correctness_verification(config_dict,
  110. models,
  111. hidden_dim,
  112. tmpdir,
  113. load_optimizer_states=False,
  114. load_lr_scheduler_states=False,
  115. fp16=True,
  116. train_batch=False,
  117. base_optimizers=[None, None],
  118. empty_tag=False,
  119. seq_dataloader=False,
  120. load_module_only=False):
  121. dtype = torch.half if fp16 else torch.float32
  122. ds_model = create_deepspeed_model(config_dict=config_dict, model=models[0], base_optimizer=base_optimizers[0])
  123. if seq_dataloader:
  124. data_loader = sequence_dataloader(model=ds_model,
  125. total_samples=50,
  126. hidden_dim=hidden_dim,
  127. device=ds_model.device,
  128. dtype=dtype)
  129. else:
  130. data_loader = random_dataloader(model=ds_model,
  131. total_samples=50,
  132. hidden_dim=hidden_dim,
  133. device=ds_model.device,
  134. dtype=dtype)
  135. if train_batch:
  136. ds_model.set_dataloader(data_loader)
  137. for _, batch in enumerate(data_loader):
  138. loss = ds_model.train_batch()
  139. else:
  140. for _, batch in enumerate(data_loader):
  141. loss = ds_model(batch[0], batch[1])
  142. ds_model.backward(loss)
  143. ds_model.step()
  144. # Flush zero stage 3 cache
  145. ds_model.empty_partition_cache()
  146. trained_model = ds_model
  147. save_folder = os.path.join(tmpdir, 'saved_checkpoint')
  148. save_tag = None if empty_tag else '1'
  149. trained_model.save_checkpoint(save_folder, tag=save_tag)
  150. dist.barrier()
  151. for root, _, files in os.walk(save_folder):
  152. for f in files:
  153. if "_expert_" in f and "_model_states" in f:
  154. expert = torch.load(os.path.join(root, f))
  155. needed, storages = 0, {}
  156. for name, tensor in expert.items():
  157. needed += tensor.size().numel()
  158. storage = tensor.storage()
  159. # some storage can be shared within an expert's checkpoint
  160. storages[storage.data_ptr()] = storage.size()
  161. stored = sum(v for _, v in storages.items())
  162. assert needed == stored, f"MoE expert checkpoint uses more storage than required: {f}"
  163. loaded_model = create_deepspeed_model(config_dict=config_dict, model=models[1], base_optimizer=base_optimizers[1])
  164. assert list(trained_model.parameters())[0].dtype == list(loaded_model.parameters())[0].dtype
  165. loaded_model.load_checkpoint(save_folder,
  166. tag=save_tag,
  167. load_optimizer_states=load_optimizer_states,
  168. load_lr_scheduler_states=load_lr_scheduler_states,
  169. load_module_only=load_module_only)
  170. compare_model_states(trained_model,
  171. loaded_model,
  172. compare_optimizer=load_optimizer_states,
  173. load_module_only=load_module_only)
  174. if load_optimizer_states:
  175. compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
  176. if load_lr_scheduler_states:
  177. compare_lr_scheduler_states(trained_model, loaded_model)