test_mup_optimizers.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import deepspeed
  5. import torch
  6. import pytest
  7. from unit.common import DistributedTest
  8. from unit.simple_model import SimpleModel, random_dataloader
  9. from mup.shape import set_base_shapes
  10. @pytest.mark.parametrize("optimizer, expected_opt_class", [("MuAdam", torch.optim.Adam),
  11. ("MuAdamW", torch.optim.AdamW), ("MuSGD", torch.optim.SGD)]) # yapf: disable
  12. @pytest.mark.parametrize("zero_offload", [True, False]) # yapf: disable
  13. class TestMuPOptimizers(DistributedTest):
  14. world_size = 1
  15. reuse_dist_env = True
  16. def test(self, optimizer, expected_opt_class, zero_offload):
  17. config_dict = {
  18. "train_batch_size": 2,
  19. "steps_per_print": 1,
  20. "zero_allow_untested_optimizer": True,
  21. "optimizer": {
  22. "type": optimizer,
  23. "params": {
  24. "lr": 0.00015,
  25. }
  26. },
  27. "gradient_clipping": 1.0,
  28. "fp16": {
  29. "enabled": True
  30. },
  31. "zero_optimization": {
  32. "stage": 2,
  33. "cpu_offload": zero_offload
  34. }
  35. }
  36. hidden_dim = 10
  37. model = SimpleModel(hidden_dim)
  38. set_base_shapes(model, None)
  39. model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
  40. data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)
  41. for n, batch in enumerate(data_loader):
  42. loss = model(batch[0], batch[1])
  43. model.backward(loss)
  44. model.step()
  45. ds_optimizer = model.optimizer.optimizer
  46. assert isinstance(ds_optimizer, expected_opt_class)