test_other_optimizer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import deepspeed
  5. from deepspeed.ops.op_builder import FusedLambBuilder
  6. from unit.common import DistributedTest
  7. from unit.simple_model import *
  8. from unit.checkpoint.common import checkpoint_correctness_verification
  9. import pytest
  10. class TestOtherOptimizerCheckpoint(DistributedTest):
  11. world_size = 2
  12. @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
  13. def test_checkpoint_unfused_optimizer(self, tmpdir):
  14. config_dict = {
  15. "train_batch_size": 2,
  16. "steps_per_print": 1,
  17. "optimizer": {
  18. "type": "Lamb",
  19. "params": {
  20. "lr": 0.00015
  21. }
  22. },
  23. "gradient_clipping": 1.0,
  24. "fp16": {
  25. "enabled": True
  26. },
  27. "scheduler": {
  28. "type": "OneCycle",
  29. "params": {
  30. "cycle_first_step_size": 1000,
  31. "cycle_first_stair_count": 500,
  32. "cycle_second_step_size": 1000,
  33. "cycle_second_stair_count": 500,
  34. "decay_step_size": 1000,
  35. "cycle_min_lr": 0.0001,
  36. "cycle_max_lr": 0.0010,
  37. "decay_lr_rate": 0.001,
  38. "cycle_min_mom": 0.85,
  39. "cycle_max_mom": 0.99,
  40. "decay_mom_rate": 0.0
  41. }
  42. }
  43. }
  44. args = args_from_dict(tmpdir, config_dict)
  45. hidden_dim = 10
  46. models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
  47. # Load & verify optimizer states
  48. checkpoint_correctness_verification(config_dict,
  49. models=models,
  50. hidden_dim=hidden_dim,
  51. tmpdir=tmpdir,
  52. load_optimizer_states=True)
  53. # Ignore optimizer states
  54. checkpoint_correctness_verification(config_dict,
  55. models=models,
  56. hidden_dim=hidden_dim,
  57. tmpdir=tmpdir,
  58. load_optimizer_states=False)
  59. def test_checkpoint_fused_optimizer(self, tmpdir):
  60. config_dict = {
  61. "train_batch_size": 2,
  62. "steps_per_print": 1,
  63. "optimizer": {
  64. "type": "Adam",
  65. "params": {
  66. "lr": 0.00015,
  67. "betas": [0.8, 0.999],
  68. "eps": 1e-8,
  69. "weight_decay": 3e-7
  70. }
  71. },
  72. "fp16": {
  73. "enabled": True
  74. }
  75. }
  76. args = args_from_dict(tmpdir, config_dict)
  77. hidden_dim = 10
  78. models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
  79. # Load & verify optimizer states
  80. checkpoint_correctness_verification(config_dict,
  81. models=models,
  82. hidden_dim=hidden_dim,
  83. tmpdir=tmpdir,
  84. load_optimizer_states=True)
  85. # Ignore optimizer states
  86. checkpoint_correctness_verification(config_dict,
  87. models=models,
  88. hidden_dim=hidden_dim,
  89. tmpdir=tmpdir,
  90. load_optimizer_states=False)
  91. def test_checkpoint_fp32_optimizer(self, tmpdir):
  92. config_dict = {
  93. "train_batch_size": 2,
  94. "steps_per_print": 1,
  95. "optimizer": {
  96. "type": "Adam",
  97. "params": {
  98. "lr": 0.00015,
  99. "betas": [0.8, 0.999],
  100. "eps": 1e-8,
  101. "weight_decay": 3e-7
  102. }
  103. },
  104. "fp16": {
  105. "enabled": False
  106. }
  107. }
  108. args = args_from_dict(tmpdir, config_dict)
  109. hidden_dim = 10
  110. models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
  111. checkpoint_correctness_verification(config_dict,
  112. models=models,
  113. hidden_dim=hidden_dim,
  114. tmpdir=tmpdir,
  115. fp16=False)