test_moe_checkpoint.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
  5. from deepspeed.runtime.utils import required_torch_version
  6. from unit.common import DistributedTest
  7. from unit.simple_model import *
  8. from unit.checkpoint.common import checkpoint_correctness_verification
  9. import pytest
  10. class TestMoECheckpoint(DistributedTest):
  11. world_size = 4
  12. @pytest.mark.parametrize("ep_size", [4])
  13. def test_checkpoint_moe(self, tmpdir, ep_size):
  14. if not required_torch_version(min_version=1.8):
  15. pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")
  16. config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}}
  17. hidden_dim = 16
  18. models = [SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size, ep_size=ep_size) for _ in range(2)]
  19. optimizers = [torch.optim.AdamW(params=model.parameters()) for model in models]
  20. checkpoint_correctness_verification(config_dict,
  21. models=models,
  22. hidden_dim=hidden_dim,
  23. tmpdir=tmpdir,
  24. load_optimizer_states=True,
  25. load_lr_scheduler_states=False,
  26. fp16=config_dict["fp16"]["enabled"],
  27. empty_tag=True,
  28. base_optimizers=optimizers,
  29. seq_dataloader=True)
  30. @pytest.mark.parametrize("ep_size, load_optim_states", [(4, True), (4, False), (2, True), (2, False)])
  31. def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states):
  32. if not required_torch_version(min_version=1.8):
  33. pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")
  34. config_dict = {
  35. "train_batch_size": 8,
  36. "steps_per_print": 1,
  37. "optimizer": {
  38. "type": 'Adam',
  39. "params": {
  40. "lr": 0.00015,
  41. "betas": [0.8, 0.999],
  42. "eps": 1e-8,
  43. "weight_decay": 3e-7
  44. }
  45. },
  46. "fp16": {
  47. "enabled": True,
  48. "initial_scale_power": 8
  49. },
  50. "zero_optimization": {
  51. "stage": 2,
  52. }
  53. }
  54. hidden_dim = 16
  55. models = [SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size, ep_size=ep_size) for _ in range(2)]
  56. # param group must have a random unique name (for now)
  57. # TODO: clean-up this requirement, the unique name should not be required here
  58. param_groups = [{'params': [p for p in model.parameters()], 'name': 'random-unique-name'} for model in models]
  59. params = [split_params_into_different_moe_groups_for_optimizer(group) for group in param_groups]
  60. optimizers = [torch.optim.AdamW(params=param) for param in params]
  61. checkpoint_correctness_verification(config_dict,
  62. models=models,
  63. hidden_dim=hidden_dim,
  64. tmpdir=tmpdir,
  65. load_optimizer_states=load_optim_states,
  66. load_lr_scheduler_states=False,
  67. fp16=config_dict["fp16"]["enabled"],
  68. empty_tag=True,
  69. base_optimizers=optimizers,
  70. seq_dataloader=True)