test_shared_weights.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import torch.nn as nn
  6. import deepspeed
  7. from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
  8. from unit.common import DistributedTest
  9. class ModelWithSharedWeights(nn.Module):
  10. def __init__(self):
  11. super().__init__()
  12. self.layer0 = nn.Linear(100, 100)
  13. self.layer1 = nn.Linear(200, 200)
  14. self.layer2 = nn.Linear(300, 300)
  15. # tie layer 1 and layer 2
  16. self.layer1.weight = self.layer2.weight
  17. class TestCheckpointSharedWeights(DistributedTest):
  18. world_size = 2
  19. def test_checkpoint_shared_weights(self, tmp_path):
  20. config = {
  21. "train_micro_batch_size_per_gpu": 2,
  22. "zero_allow_untested_optimizer": True,
  23. "zero_optimization": {
  24. "stage": 2
  25. },
  26. }
  27. model = ModelWithSharedWeights()
  28. optimizer = torch.optim.Adam(model.parameters())
  29. deepspeed_engine, _, _, _ = deepspeed.initialize(
  30. config=config,
  31. model=model,
  32. optimizer=optimizer,
  33. )
  34. filename = tmp_path / "checkpoint.pt"
  35. deepspeed_engine.save_checkpoint(filename, tag="checkpoint")
  36. model = ModelWithSharedWeights()
  37. state_dict = get_fp32_state_dict_from_zero_checkpoint(filename, tag="checkpoint")
  38. model.load_state_dict(state_dict, strict=True)