test_mics_optimizer.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
  5. # SPDX-License-Identifier: Apache-2.0
  6. import deepspeed
  7. from deepspeed.runtime.utils import required_torch_version
  8. from unit.common import DistributedTest
  9. from unit.simple_model import *
  10. from unit.checkpoint.common import *
  11. import pytest
  12. if not required_torch_version(max_version=2.0):
  13. pytest.skip("Skipping until we resolve problems with torch 2.1", allow_module_level=True)
  14. class TestMiCSCheckpoint(DistributedTest):
  15. world_size = 4
  16. def _toy_model_config(self, shard_size):
  17. config_dict = {
  18. "train_micro_batch_size_per_gpu": 2,
  19. "steps_per_print": 1,
  20. "optimizer": {
  21. "type": 'Adam',
  22. "params": {
  23. "lr": 0.00015,
  24. "betas": [0.8, 0.999],
  25. "eps": 1e-8,
  26. "weight_decay": 3e-7
  27. }
  28. },
  29. "fp16": {
  30. "enabled": True,
  31. "initial_scale_power": 8
  32. },
  33. "wall_clock_breakdown": True,
  34. "zero_optimization": {
  35. "stage": 3,
  36. "mics_shard_size": shard_size
  37. }
  38. }
  39. hidden_dim = 10
  40. with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
  41. models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
  42. return config_dict, hidden_dim, models
  43. @pytest.mark.parametrize('shard_size', [1, 2, 4])
  44. def test_load_optimizer_state(self, tmpdir, shard_size):
  45. config_dict, hidden_dim, models = self._toy_model_config(shard_size)
  46. checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=True)
  47. @pytest.mark.parametrize('shard_size', [1, 2, 4])
  48. def test_not_load_optimizer_state(self, tmpdir, shard_size):
  49. config_dict, hidden_dim, models = self._toy_model_config(shard_size)
  50. checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=False)
  51. @pytest.mark.parametrize('shard_size', [1, 2, 4])
  52. def test_load_module_only(self, tmpdir, shard_size):
  53. config_dict, hidden_dim, models = self._toy_model_config(shard_size)
  54. checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True)
  55. @pytest.mark.parametrize('shard_size', [1, 2, 4])
  56. def test_save_checkpoint_on_first_partition_group(self, tmpdir, shard_size):
  57. config_dict, _, models = self._toy_model_config(shard_size)
  58. ds_engine, _, _, _ = deepspeed.initialize(config=config_dict,
  59. model=models[0],
  60. model_parameters=models[0].parameters(),
  61. optimizer=None)
  62. ds_engine.save_checkpoint(tmpdir)
  63. if ds_engine.global_rank < shard_size:
  64. assert ds_engine.save_non_zero_checkpoint == True
  65. else:
  66. assert ds_engine.save_non_zero_checkpoint == False