test_lr_schedulers.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import deepspeed
  6. import pytest
  7. from unit.common import DistributedTest
  8. from unit.simple_model import SimpleModel, random_dataloader
  9. from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, LR_RANGE_TEST_MIN_LR, LR_RANGE_TEST_STEP_RATE, LR_RANGE_TEST_STEP_SIZE, LR_RANGE_TEST_STAIRCASE
  10. from deepspeed.runtime.lr_schedules import WARMUP_LR, WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS, WARMUP_TYPE, WARMUP_LOG_RATE, WARMUP_LINEAR_RATE
  11. from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE
  12. from deepspeed.runtime.lr_schedules import CYCLE_MIN_MOM, CYCLE_MAX_MOM, DECAY_MOM_RATE
  13. from deepspeed.runtime.lr_schedules import WARMUP_DECAY_LR, TOTAL_NUM_STEPS
  14. def _verify_continuous_decrease(values):
  15. for i in range(len(values) - 1):
  16. assert values[i] > values[i + 1]
  17. def _verify_continuous_increase(values):
  18. for i in range(len(values) - 1):
  19. assert values[i] < values[i + 1]
  20. def _verify_staircase_increase(values, step_size):
  21. num_values = len(values)
  22. for i in range(0, num_values, step_size):
  23. j = min(i + step_size, num_values)
  24. assert all([values[i] == v for v in values[i:j]])
  25. @pytest.mark.parametrize("scheduler_type,params", [(WARMUP_LR, {}),
  26. (WARMUP_DECAY_LR, {
  27. WARMUP_NUM_STEPS: 10,
  28. TOTAL_NUM_STEPS: 20
  29. }), (ONE_CYCLE, {
  30. CYCLE_MIN_LR: 0,
  31. CYCLE_MAX_LR: 0.1
  32. }), (LR_RANGE_TEST, {})])
  33. class TestGetLrBeforeTrain(DistributedTest):
  34. world_size = 1
  35. def test(self, scheduler_type, params):
  36. config_dict = {
  37. "train_batch_size": 2,
  38. "steps_per_print": 1,
  39. "optimizer": {
  40. "type": "Adam",
  41. "params": {
  42. "lr": 0.00015
  43. },
  44. },
  45. "scheduler": {
  46. "type": scheduler_type,
  47. "params": params
  48. },
  49. "gradient_clipping": 1.0
  50. }
  51. hidden_dim = 10
  52. model = SimpleModel(hidden_dim, empty_grad=False)
  53. model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
  54. model=model,
  55. model_parameters=model.parameters())
  56. data_loader = random_dataloader(model=model,
  57. total_samples=50,
  58. hidden_dim=hidden_dim,
  59. device=model.device,
  60. dtype=torch.float)
  61. for n, batch in enumerate(data_loader):
  62. # get lr before training starts
  63. lr_scheduler.get_lr()
  64. loss = model(batch[0], batch[1])
  65. model.backward(loss)
  66. model.step()
  67. @pytest.mark.parametrize("warmup_num_steps", [10, 15, 19, 33])
  68. @pytest.mark.parametrize("warmup_type", [WARMUP_LOG_RATE, WARMUP_LINEAR_RATE])
  69. class TestLrSchedule(DistributedTest):
  70. world_size = 1
  71. def test_lr_warmup_schedule(self, warmup_num_steps, warmup_type):
  72. config_dict = {
  73. "train_batch_size": 2,
  74. "steps_per_print": 1,
  75. "optimizer": {
  76. "type": "Adam",
  77. "params": {
  78. "lr": 0.00015
  79. },
  80. },
  81. "scheduler": {
  82. "type": WARMUP_LR,
  83. "params": {
  84. WARMUP_MIN_LR: 0.1,
  85. WARMUP_MAX_LR: 0.2,
  86. WARMUP_NUM_STEPS: warmup_num_steps,
  87. WARMUP_TYPE: warmup_type,
  88. }
  89. },
  90. "gradient_clipping": 1.0
  91. }
  92. schedule_params = config_dict["scheduler"]["params"]
  93. total_num_steps = 2 * warmup_num_steps
  94. hidden_dim = 10
  95. model = SimpleModel(hidden_dim, empty_grad=False)
  96. model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
  97. model=model,
  98. model_parameters=model.parameters())
  99. data_loader = random_dataloader(model=model,
  100. total_samples=total_num_steps * 2,
  101. hidden_dim=hidden_dim,
  102. device=model.device,
  103. dtype=torch.float)
  104. step_lrs = []
  105. for n, batch in enumerate(data_loader):
  106. loss = model(batch[0], batch[1])
  107. model.backward(loss)
  108. model.step()
  109. step_lrs.append(lr_scheduler.get_lr())
  110. # Verify initial lr
  111. assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]
  112. # Verify warmup completion
  113. warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
  114. warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
  115. assert step_lrs[warmup_num_steps] == warmup_max_lr
  116. # Verify post-warmup completion
  117. assert all([warmup_max_lr == lr for lr in step_lrs[warmup_num_steps:]])
  118. def test_lr_warmup_decay_schedule(self, warmup_num_steps, warmup_type):
  119. config_dict = {
  120. "train_batch_size": 2,
  121. "steps_per_print": 1,
  122. "optimizer": {
  123. "type": "Adam",
  124. "params": {
  125. "lr": 0.00015
  126. },
  127. },
  128. "scheduler": {
  129. "type": WARMUP_DECAY_LR,
  130. "params": {
  131. WARMUP_MIN_LR: 0.1,
  132. WARMUP_MAX_LR: 0.2,
  133. WARMUP_NUM_STEPS: warmup_num_steps,
  134. TOTAL_NUM_STEPS: warmup_num_steps * 2,
  135. WARMUP_TYPE: warmup_type
  136. }
  137. },
  138. "gradient_clipping": 1.0
  139. }
  140. schedule_params = config_dict["scheduler"]["params"]
  141. total_num_steps = schedule_params[TOTAL_NUM_STEPS]
  142. hidden_dim = 10
  143. model = SimpleModel(hidden_dim, empty_grad=False)
  144. model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
  145. model=model,
  146. model_parameters=model.parameters())
  147. data_loader = random_dataloader(model=model,
  148. total_samples=total_num_steps * 2,
  149. hidden_dim=hidden_dim,
  150. device=model.device,
  151. dtype=torch.float)
  152. step_lrs = []
  153. for n, batch in enumerate(data_loader):
  154. loss = model(batch[0], batch[1])
  155. model.backward(loss)
  156. model.step()
  157. step_lrs.append(lr_scheduler.get_lr())
  158. # Verify initial lr
  159. assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]
  160. # Verify lr at warmup completion
  161. warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
  162. warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
  163. assert step_lrs[warmup_num_steps] == warmup_max_lr
  164. # Verify decay phase
  165. previous_lr = warmup_max_lr
  166. for lr in step_lrs[warmup_num_steps + 1:]:
  167. assert lr < previous_lr
  168. previous_lr = lr
  169. @pytest.mark.parametrize("scheduler_type,params", [(WARMUP_LR, {}),
  170. (WARMUP_DECAY_LR, {
  171. WARMUP_NUM_STEPS: 5,
  172. TOTAL_NUM_STEPS: 10
  173. }),
  174. (ONE_CYCLE, {
  175. CYCLE_MIN_LR: 0,
  176. CYCLE_MAX_LR: 0.1,
  177. CYCLE_FIRST_STEP_SIZE: 5,
  178. DECAY_STEP_SIZE: 5
  179. }),
  180. (LR_RANGE_TEST, {
  181. LR_RANGE_TEST_MIN_LR: 1e-4,
  182. LR_RANGE_TEST_STEP_SIZE: 1
  183. })])
  184. class TestSchedulerOptimizerParity(DistributedTest):
  185. world_size = 1
  186. def test(self, scheduler_type, params):
  187. config_dict = {
  188. "train_batch_size": 2,
  189. "steps_per_print": 1,
  190. "optimizer": {
  191. "type": "Adam",
  192. "params": {
  193. "lr": 0.00015
  194. },
  195. },
  196. "scheduler": {
  197. "type": scheduler_type,
  198. "params": params
  199. },
  200. "gradient_clipping": 1.0
  201. }
  202. hidden_dim = 10
  203. model = SimpleModel(hidden_dim, empty_grad=False)
  204. model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
  205. model=model,
  206. model_parameters=model.parameters())
  207. data_loader = random_dataloader(model=model,
  208. total_samples=50,
  209. hidden_dim=hidden_dim,
  210. device=model.device,
  211. dtype=torch.float)
  212. for n, batch in enumerate(data_loader):
  213. loss = model(batch[0], batch[1])
  214. model.backward(loss)
  215. model.step()
  216. assert lr_scheduler.get_lr() == model.get_lr()
  217. @pytest.mark.parametrize("min_lr, step_rate, step_size, staircase",
  218. [(1e-4, 1e-5, 1, True),
  219. (1e-5, 1e-5, 1, False),
  220. (1e-4, 1e-3, 10, True),
  221. (1e-3, 1e-3, 10, False),
  222. (1e-2, 1e-2, 19, True),
  223. (1e-2, 1e-2, 19, False)
  224. ])# yapf: disable
  225. class TestLrRange(DistributedTest):
  226. world_size = 1
  227. def test(self, min_lr, step_rate, step_size, staircase):
  228. config_dict = {
  229. "train_batch_size": 2,
  230. "steps_per_print": 1,
  231. "optimizer": {
  232. "type": "Adam",
  233. "params": {
  234. "lr": 0.00015
  235. },
  236. },
  237. "scheduler": {
  238. "type": LR_RANGE_TEST,
  239. "params": {
  240. LR_RANGE_TEST_MIN_LR: min_lr,
  241. LR_RANGE_TEST_STEP_RATE: step_rate,
  242. LR_RANGE_TEST_STEP_SIZE: step_size,
  243. LR_RANGE_TEST_STAIRCASE: staircase
  244. }
  245. },
  246. "gradient_clipping": 1.0
  247. }
  248. hidden_dim = 10
  249. model = SimpleModel(hidden_dim, empty_grad=False)
  250. model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
  251. model=model,
  252. model_parameters=model.parameters())
  253. data_loader = random_dataloader(model=model,
  254. total_samples=max(50, step_size * 2),
  255. hidden_dim=hidden_dim,
  256. device=model.device,
  257. dtype=torch.float)
  258. step_lrs = []
  259. for _, batch in enumerate(data_loader):
  260. step_lrs.extend(lr_scheduler.get_lr())
  261. loss = model(batch[0], batch[1])
  262. model.backward(loss)
  263. model.step()
  264. # Verify starting lr
  265. assert step_lrs[0] == min_lr
  266. if staircase:
  267. # Verify staircase increasing lr
  268. _verify_staircase_increase(step_lrs, step_size)
  269. else:
  270. # Verify continuous increasing lr
  271. _verify_continuous_increase(step_lrs)
  272. class TestOneCycle(DistributedTest):
  273. world_size = 1
  274. @pytest.mark.parametrize("min_lr, max_lr, decay_rate, cycle_step_size, decay_step_size",
  275. [
  276. (1e-5, 1e-2, 1e-3, 10, 10),
  277. (1e-3, 1e-1, 0, 21, 21),
  278. (1e-5, 1e-2, 1e-3, 10, 10),
  279. (1e-3, 1e-1, 1e-1, 21, 21),
  280. (1e-5, 1e-1, 0, 10, 0),
  281. ]) # yapf: disable
  282. def test_lr(self, min_lr, max_lr, decay_rate, cycle_step_size, decay_step_size):
  283. config_dict = {
  284. "train_batch_size": 2,
  285. "steps_per_print": 1,
  286. "optimizer": {
  287. "type": "Adam",
  288. "params": {
  289. "lr": 0.00015
  290. },
  291. },
  292. "scheduler": {
  293. "type": ONE_CYCLE,
  294. "params": {
  295. CYCLE_MIN_LR: min_lr,
  296. CYCLE_MAX_LR: max_lr,
  297. DECAY_LR_RATE: decay_rate,
  298. CYCLE_FIRST_STEP_SIZE: cycle_step_size,
  299. DECAY_STEP_SIZE: decay_step_size
  300. }
  301. },
  302. "gradient_clipping": 1.0
  303. }
  304. hidden_dim = 10
  305. model = SimpleModel(hidden_dim, empty_grad=False)
  306. model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
  307. model=model,
  308. model_parameters=model.parameters())
  309. data_loader = random_dataloader(model=model,
  310. total_samples=max(50, cycle_step_size * 3),
  311. hidden_dim=hidden_dim,
  312. device=model.device,
  313. dtype=torch.float)
  314. step_lrs = []
  315. for _, batch in enumerate(data_loader):
  316. step_lrs.extend(lr_scheduler.get_lr())
  317. loss = model(batch[0], batch[1])
  318. model.backward(loss)
  319. model.step()
  320. # Verify starting lr
  321. assert step_lrs[0] == min_lr
  322. # Verify peak lr
  323. assert step_lrs[cycle_step_size] == max_lr
  324. # Verify increasing phase
  325. _verify_continuous_increase(step_lrs[:cycle_step_size])
  326. # Verify decreasing phase
  327. _verify_continuous_decrease(step_lrs[cycle_step_size:(cycle_step_size * 2)])
  328. # Verify decay phase
  329. if decay_rate > 0:
  330. _verify_continuous_decrease(step_lrs[(cycle_step_size * 2):])
  331. @pytest.mark.parametrize("min_mom, max_mom, decay_rate, step_size",
  332. [
  333. (0.08, 0.09, 1e-3, 10),
  334. (0.08, 0.09, 0, 21),
  335. (0.08, 0.09, 1e-3, 10),
  336. (0.08, 0.09, 0, 21),
  337. ]) # yapf: disable
  338. def test_mom(self, min_mom, max_mom, decay_rate, step_size):
  339. config_dict = {
  340. "train_batch_size": 2,
  341. "steps_per_print": 1,
  342. "optimizer": {
  343. "type": "Adam",
  344. "params": {
  345. "lr": 0.00015
  346. },
  347. },
  348. "scheduler": {
  349. "type": ONE_CYCLE,
  350. "params": {
  351. CYCLE_MIN_LR: 1e-3,
  352. CYCLE_MAX_LR: 1e-2,
  353. CYCLE_MIN_MOM: min_mom,
  354. CYCLE_MAX_MOM: max_mom,
  355. DECAY_MOM_RATE: decay_rate,
  356. CYCLE_FIRST_STEP_SIZE: step_size,
  357. DECAY_STEP_SIZE: step_size
  358. }
  359. },
  360. "gradient_clipping": 1.0
  361. }
  362. hidden_dim = 10
  363. model = SimpleModel(hidden_dim, empty_grad=False)
  364. model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
  365. model=model,
  366. model_parameters=model.parameters())
  367. data_loader = random_dataloader(model=model,
  368. total_samples=max(50, step_size * 3),
  369. hidden_dim=hidden_dim,
  370. device=model.device,
  371. dtype=torch.float)
  372. step_moms = []
  373. for _, batch in enumerate(data_loader):
  374. step_moms.append(lr_scheduler.get_mom())
  375. loss = model(batch[0], batch[1])
  376. model.backward(loss)
  377. model.step()
  378. # Verify starting lr
  379. assert step_moms[0][0][0] == max_mom
  380. # Verify peak lr
  381. assert step_moms[step_size][0][0] == min_mom
  382. # Verify decreasing phase
  383. _verify_continuous_decrease(step_moms[:step_size])
  384. # Verify increasing phase
  385. _verify_continuous_increase(step_moms[step_size:(step_size * 2)])
  386. # Verify decay phase
  387. if decay_rate > 0:
  388. _verify_continuous_increase(step_moms[(step_size * 2):])