test_mics_optimizer.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
  5. # SPDX-License-Identifier: Apache-2.0
  6. import deepspeed
  7. from unit.common import DistributedTest
  8. from unit.simple_model import *
  9. from unit.checkpoint.common import *
  10. import pytest
  11. class TestMiCSCheckpoint(DistributedTest):
  12. world_size = 4
  13. def _toy_model_config(self, shard_size):
  14. config_dict = {
  15. "train_micro_batch_size_per_gpu": 2,
  16. "steps_per_print": 1,
  17. "optimizer": {
  18. "type": 'Adam',
  19. "params": {
  20. "lr": 0.00015,
  21. "betas": [0.8, 0.999],
  22. "eps": 1e-8,
  23. "weight_decay": 3e-7
  24. }
  25. },
  26. "fp16": {
  27. "enabled": True,
  28. "initial_scale_power": 8
  29. },
  30. "wall_clock_breakdown": True,
  31. "zero_optimization": {
  32. "stage": 3,
  33. "mics_shard_size": shard_size
  34. }
  35. }
  36. hidden_dim = 10
  37. with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
  38. models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
  39. return config_dict, hidden_dim, models
  40. @pytest.mark.parametrize('shard_size', [1, 2, 4])
  41. def test_load_optimizer_state(self, tmpdir, shard_size):
  42. config_dict, hidden_dim, models = self._toy_model_config(shard_size)
  43. checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=True)
  44. @pytest.mark.parametrize('shard_size', [1, 2, 4])
  45. def test_not_load_optimizer_state(self, tmpdir, shard_size):
  46. config_dict, hidden_dim, models = self._toy_model_config(shard_size)
  47. checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=False)
  48. @pytest.mark.parametrize('shard_size', [1, 2, 4])
  49. def test_load_module_only(self, tmpdir, shard_size):
  50. config_dict, hidden_dim, models = self._toy_model_config(shard_size)
  51. checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True)
  52. @pytest.mark.parametrize('shard_size', [1, 2, 4])
  53. def test_save_checkpoint_on_first_partition_group(self, tmpdir, shard_size):
  54. config_dict, _, models = self._toy_model_config(shard_size)
  55. ds_engine, _, _, _ = deepspeed.initialize(config=config_dict,
  56. model=models[0],
  57. model_parameters=models[0].parameters(),
  58. optimizer=None)
  59. ds_engine.save_checkpoint(tmpdir)
  60. if ds_engine.global_rank < shard_size:
  61. assert ds_engine.save_non_zero_checkpoint == True
  62. else:
  63. assert ds_engine.save_non_zero_checkpoint == False