test_sparse.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import deepspeed
  5. from unit.common import DistributedTest
  6. from unit.simple_model import *
  7. import pytest
  8. class TestSparseCheckpoint(DistributedTest):
  9. world_size = 2
  10. @pytest.mark.parametrize(["to_save_model_has_embedding", "to_save_model_sparse"], [
  11. [False, False],
  12. [True, False],
  13. [True, True],
  14. ])
  15. @pytest.mark.parametrize(["destination_has_embedding", "destination_sparse"], [
  16. [False, False],
  17. [True, False],
  18. [True, True],
  19. ])
  20. def test_non_strict_load_sparse(self, tmpdir, to_save_model_has_embedding, to_save_model_sparse,
  21. destination_has_embedding, destination_sparse):
  22. class ModelNoEmbedding(torch.nn.Module):
  23. def __init__(self):
  24. super().__init__()
  25. self.linear = torch.nn.Linear(3, 1)
  26. def forward(self, x):
  27. return self.linear(x)
  28. class ModelEmbedding(torch.nn.Module):
  29. def __init__(self):
  30. super().__init__()
  31. self.emb = torch.nn.Embedding(10, 3)
  32. self.linear = torch.nn.Linear(3, 1)
  33. def forward(self, x, offsets):
  34. return self.linear(self.emb(x, offsets))
  35. if to_save_model_has_embedding:
  36. model_to_save = ModelEmbedding()
  37. else:
  38. model_to_save = ModelNoEmbedding()
  39. if destination_has_embedding:
  40. model_destination = ModelEmbedding()
  41. else:
  42. model_destination = ModelNoEmbedding()
  43. engine_to_save, _, _, _ = deepspeed.initialize(model=model_to_save,
  44. config={
  45. "train_batch_size": 2,
  46. "sparse_gradients": to_save_model_sparse
  47. })
  48. engine_destination, _, _, _ = deepspeed.initialize(model=model_destination,
  49. config={
  50. "train_batch_size": 2,
  51. "sparse_gradients": destination_sparse
  52. })
  53. save_folder = os.path.join(tmpdir, 'saved_checkpoint')
  54. save_tag = '1'
  55. engine_to_save.save_checkpoint(save_folder, tag=save_tag)
  56. is_sparse_destination = isinstance(model_destination, ModelEmbedding) and destination_sparse
  57. if isinstance(model_destination, ModelEmbedding) and model_destination.emb.sparse:
  58. assert "emb.weight" in engine_destination.sparse_tensor_module_names
  59. engine_destination.load_checkpoint(save_folder,
  60. tag=save_tag,
  61. load_module_strict=False,
  62. load_optimizer_states=False,
  63. load_lr_scheduler_states=False,
  64. load_module_only=False)
  65. if isinstance(model_destination, ModelEmbedding) and isinstance(model_to_save, ModelEmbedding):
  66. assert engine_destination.sparse_tensor_module_names == engine_to_save.sparse_tensor_module_names
  67. elif isinstance(model_destination, ModelEmbedding):
  68. assert not is_sparse_destination or "emb.weight" in engine_destination.sparse_tensor_module_names
  69. else:
  70. assert len(engine_destination.sparse_tensor_module_names) == 0