test_lr_schedulers.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import torch
  2. import deepspeed
  3. import argparse
  4. import pytest
  5. import json
  6. import os
  7. from common import distributed_test
  8. from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
  9. from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR
  10. from deepspeed.runtime.lr_schedules import WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS, TOTAL_NUM_STEPS
  11. from deepspeed.runtime.lr_schedules import CYCLE_MIN_LR, CYCLE_MAX_LR
  12. @pytest.mark.parametrize("scheduler_type,params",
  13. [(WARMUP_LR,
  14. {}),
  15. (WARMUP_DECAY_LR,
  16. {
  17. WARMUP_NUM_STEPS: 10,
  18. TOTAL_NUM_STEPS: 20
  19. }),
  20. (ONE_CYCLE,
  21. {
  22. CYCLE_MIN_LR: 0,
  23. CYCLE_MAX_LR: 0
  24. }),
  25. (LR_RANGE_TEST,
  26. {})])
  27. def test_get_lr_before_train(tmpdir, scheduler_type, params):
  28. config_dict = {
  29. "train_batch_size": 2,
  30. "steps_per_print": 1,
  31. "optimizer": {
  32. "type": "Adam",
  33. "params": {
  34. "lr": 0.00015
  35. },
  36. },
  37. "scheduler": {
  38. "type": scheduler_type,
  39. "params": params
  40. },
  41. "gradient_clipping": 1.0
  42. }
  43. args = args_from_dict(tmpdir, config_dict)
  44. hidden_dim = 10
  45. model = SimpleModel(hidden_dim, empty_grad=False)
  46. @distributed_test(world_size=[1])
  47. def _test_get_lr_before_train(args, model, hidden_dim):
  48. model, _, _, lr_scheduler = deepspeed.initialize(args=args,
  49. model=model,
  50. model_parameters=model.parameters())
  51. data_loader = random_dataloader(model=model,
  52. total_samples=50,
  53. hidden_dim=hidden_dim,
  54. device=model.device,
  55. dtype=torch.float)
  56. for n, batch in enumerate(data_loader):
  57. # get lr before training starts
  58. lr_scheduler.get_lr()
  59. loss = model(batch[0], batch[1])
  60. model.backward(loss)
  61. model.step()
  62. _test_get_lr_before_train(args=args, model=model, hidden_dim=hidden_dim)
  63. @pytest.mark.parametrize("warmup_num_steps", [10, 15, 19, 33])
  64. def test_lr_warmup_schedule(tmpdir, warmup_num_steps):
  65. config_dict = {
  66. "train_batch_size": 2,
  67. "steps_per_print": 1,
  68. "optimizer": {
  69. "type": "Adam",
  70. "params": {
  71. "lr": 0.00015
  72. },
  73. },
  74. "scheduler": {
  75. "type": WARMUP_LR,
  76. "params": {
  77. WARMUP_MIN_LR: 0.1,
  78. WARMUP_MAX_LR: 0.2,
  79. WARMUP_NUM_STEPS: warmup_num_steps
  80. }
  81. },
  82. "gradient_clipping": 1.0
  83. }
  84. total_num_steps = 2 * warmup_num_steps
  85. args = args_from_dict(tmpdir, config_dict)
  86. hidden_dim = 10
  87. model = SimpleModel(hidden_dim, empty_grad=False)
  88. @distributed_test(world_size=[1])
  89. def _test_lr_warmup_schedule(args, model, hidden_dim, schedule_params, num_steps):
  90. model, _, _, lr_scheduler = deepspeed.initialize(args=args,
  91. model=model,
  92. model_parameters=model.parameters())
  93. data_loader = random_dataloader(model=model,
  94. total_samples=num_steps * 2,
  95. hidden_dim=hidden_dim,
  96. device=model.device,
  97. dtype=torch.float)
  98. step_lrs = []
  99. for n, batch in enumerate(data_loader):
  100. loss = model(batch[0], batch[1])
  101. model.backward(loss)
  102. model.step()
  103. step_lrs.append(lr_scheduler.get_lr())
  104. # Verify initial lr
  105. assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]
  106. # Verify warmup completion
  107. warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
  108. warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
  109. assert step_lrs[warmup_num_steps] == warmup_max_lr
  110. # Verify post-warmup completion
  111. assert all([warmup_max_lr == lr for lr in step_lrs[warmup_num_steps:]])
  112. _test_lr_warmup_schedule(args=args,
  113. model=model,
  114. hidden_dim=hidden_dim,
  115. schedule_params=config_dict["scheduler"]["params"],
  116. num_steps=total_num_steps)
  117. @pytest.mark.parametrize("warmup_num_steps", [10, 15, 19, 33])
  118. def test_lr_warmup_decay_schedule(tmpdir, warmup_num_steps):
  119. config_dict = {
  120. "train_batch_size": 2,
  121. "steps_per_print": 1,
  122. "optimizer": {
  123. "type": "Adam",
  124. "params": {
  125. "lr": 0.00015
  126. },
  127. },
  128. "scheduler": {
  129. "type": WARMUP_DECAY_LR,
  130. "params": {
  131. WARMUP_MIN_LR: 0.1,
  132. WARMUP_MAX_LR: 0.2,
  133. WARMUP_NUM_STEPS: warmup_num_steps,
  134. TOTAL_NUM_STEPS: warmup_num_steps * 2
  135. }
  136. },
  137. "gradient_clipping": 1.0
  138. }
  139. args = args_from_dict(tmpdir, config_dict)
  140. hidden_dim = 10
  141. model = SimpleModel(hidden_dim, empty_grad=False)
  142. @distributed_test(world_size=[1])
  143. def _test_lr_warmup_decay_schedule(args,
  144. model,
  145. hidden_dim,
  146. schedule_params,
  147. num_steps):
  148. model, _, _, lr_scheduler = deepspeed.initialize(args=args,
  149. model=model,
  150. model_parameters=model.parameters())
  151. data_loader = random_dataloader(model=model,
  152. total_samples=num_steps * 2,
  153. hidden_dim=hidden_dim,
  154. device=model.device,
  155. dtype=torch.float)
  156. step_lrs = []
  157. for n, batch in enumerate(data_loader):
  158. loss = model(batch[0], batch[1])
  159. model.backward(loss)
  160. model.step()
  161. step_lrs.append(lr_scheduler.get_lr())
  162. # Verify initial lr
  163. assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]
  164. # Verify lr at warmup completion
  165. warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
  166. warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
  167. assert step_lrs[warmup_num_steps] == warmup_max_lr
  168. # Verify decay phase
  169. previous_lr = warmup_max_lr
  170. for lr in step_lrs[warmup_num_steps + 1:]:
  171. assert lr < previous_lr
  172. previous_lr = lr
  173. schedule_params = config_dict["scheduler"]["params"]
  174. total_num_steps = schedule_params[TOTAL_NUM_STEPS]
  175. _test_lr_warmup_decay_schedule(args=args,
  176. model=model,
  177. hidden_dim=hidden_dim,
  178. schedule_params=schedule_params,
  179. num_steps=total_num_steps)