test_adamw.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import deepspeed
  5. import torch
  6. import pytest
  7. from deepspeed.ops.adam import FusedAdam
  8. from deepspeed.ops.adam import DeepSpeedCPUAdam
  9. from unit.common import DistributedTest
  10. from unit.simple_model import SimpleModel
  11. from deepspeed.accelerator import get_accelerator
  12. if torch.half not in get_accelerator().supported_dtypes():
  13. pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
  14. # yapf: disable
  15. #'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer
  16. adam_configs = [["AdamW", False, False, False, (FusedAdam, True)],
  17. ["AdamW", False, True, False, (torch.optim.AdamW, None)],
  18. ["AdamW", True, False, False, (DeepSpeedCPUAdam, True)],
  19. ["AdamW", True, True, False, (torch.optim.AdamW, None)],
  20. ["AdamW", False, False, True, (FusedAdam, True)],
  21. ["AdamW", False, True, True, (torch.optim.AdamW, None)],
  22. ["AdamW", True, False, True, (DeepSpeedCPUAdam, True)],
  23. ["AdamW", True, True, True, (torch.optim.AdamW, None)],
  24. ["Adam", False, False, False, (FusedAdam, False)],
  25. ["Adam", False, True, False, (torch.optim.Adam, None)],
  26. ["Adam", True, False, False, (DeepSpeedCPUAdam, False)],
  27. ["Adam", True, True, False, (torch.optim.Adam, None)],
  28. ["Adam", False, False, True, (FusedAdam, True)],
  29. ["Adam", False, True, True, (torch.optim.AdamW, None)],
  30. ["Adam", True, False, True, (DeepSpeedCPUAdam, True)],
  31. ["Adam", True, True, True, (torch.optim.AdamW, None)]]
  32. @pytest.mark.parametrize(
  33. 'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer',
  34. adam_configs)
  35. class TestAdamConfigs(DistributedTest):
  36. world_size = 1
  37. reuse_dist_env = True
  38. def test(self,
  39. optimizer,
  40. zero_offload,
  41. torch_adam,
  42. adam_w_mode,
  43. resulting_optimizer):
  44. config_dict = {
  45. "train_batch_size": 2,
  46. "steps_per_print": 1,
  47. "optimizer": {
  48. "type": optimizer,
  49. "params": {
  50. "lr": 0.00015,
  51. "torch_adam": torch_adam,
  52. "adam_w_mode": adam_w_mode
  53. }
  54. },
  55. "gradient_clipping": 1.0,
  56. "fp16": {
  57. "enabled": True
  58. },
  59. "zero_optimization": {
  60. "stage": 2,
  61. "cpu_offload": zero_offload
  62. }
  63. }
  64. model = SimpleModel(10)
  65. model, _, _, _ = deepspeed.initialize(config=config_dict,
  66. model=model,
  67. model_parameters=model.parameters())
  68. # get base optimizer under zero
  69. ds_optimizer = model.optimizer.optimizer
  70. opt_class, adam_w_mode = resulting_optimizer
  71. assert isinstance(ds_optimizer, opt_class)
  72. if adam_w_mode in [True, False]:
  73. assert ds_optimizer.adam_w_mode == adam_w_mode