test_checkpointing.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. import torch
  2. import torch.distributed as dist
  3. import deepspeed
  4. from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
  5. from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
  6. from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
  7. from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
  8. from deepspeed.runtime.pipe.topology import *
  9. PipeTopo = PipeDataParallelTopology
  10. import argparse
  11. import pytest
  12. import json
  13. import os
  14. import numbers
  15. from common import distributed_test
  16. from simple_model import *
  17. def compare_deepspeed_states(saved_model, loaded_model):
  18. # These are compared in more depth in other places
  19. assert hasattr(loaded_model, 'module')
  20. assert saved_model.csr_tensor_module_names == loaded_model.csr_tensor_module_names
  21. assert saved_model.skipped_steps == loaded_model.skipped_steps
  22. assert saved_model.global_steps == loaded_model.global_steps
  23. def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
  24. compare_deepspeed_states(saved_model, loaded_model)
  25. for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
  26. assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
  27. if not compare_optimizer:
  28. return
  29. if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
  30. for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
  31. assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
  32. elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
  33. for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
  34. for p0, p1 in zip(partition0, partition1):
  35. assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
  36. elif isinstance(saved_model.optimizer, FP16_Optimizer):
  37. for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
  38. assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
  39. elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
  40. for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
  41. for p0, p1 in zip(params0, params1):
  42. assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
  43. elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
  44. pass
  45. else:
  46. assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}'
  47. def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
  48. saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer
  49. loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer
  50. for state0, state1 in zip(saved_optimizer.state.values(),
  51. loaded_optimizer.state.values()):
  52. for s0, s1 in zip(state0.values(), state1.values()):
  53. if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
  54. assert torch.equal(s0, s1)
  55. else:
  56. assert s0 == s1
  57. def compare_lr_scheduler_states(saved_model, loaded_model):
  58. assert hasattr(saved_model, 'lr_scheduler')
  59. assert hasattr(loaded_model, 'lr_scheduler')
  60. saved_scheduler = saved_model.lr_scheduler
  61. loaded_scheduler = loaded_model.lr_scheduler
  62. assert hasattr(saved_scheduler, 'state_dict')
  63. assert hasattr(loaded_scheduler, 'state_dict')
  64. saved_sd = saved_scheduler.state_dict()
  65. loaded_sd = loaded_scheduler.state_dict()
  66. print(f"saved_sd = {saved_sd}")
  67. print(f"loaded_sd = {loaded_sd}")
  68. assert saved_sd.keys() == loaded_sd.keys()
  69. for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
  70. if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
  71. assert state0 == state1
  72. def checkpoint_correctness_verification(args,
  73. model,
  74. hidden_dim,
  75. tmpdir,
  76. load_optimizer_states=False,
  77. load_lr_scheduler_states=False,
  78. fp16=True,
  79. train_batch=False):
  80. dtype = torch.half if fp16 else torch.float32
  81. ds_model, _, _, _ = deepspeed.initialize(args=args,
  82. model=model,
  83. model_parameters=model.parameters())
  84. data_loader = random_dataloader(model=ds_model,
  85. total_samples=50,
  86. hidden_dim=hidden_dim,
  87. device=ds_model.device,
  88. dtype=dtype)
  89. if train_batch:
  90. ds_model.set_dataloader(data_loader)
  91. for n, batch in enumerate(data_loader):
  92. loss = ds_model.train_batch()
  93. else:
  94. for n, batch in enumerate(data_loader):
  95. loss = ds_model(batch[0], batch[1])
  96. print(loss)
  97. ds_model.backward(loss)
  98. ds_model.step()
  99. trained_model = ds_model
  100. save_folder = os.path.join(tmpdir, 'saved_checkpoint')
  101. save_tag = '1'
  102. trained_model.save_checkpoint(save_folder, save_tag)
  103. loaded_model, _, _, _ = deepspeed.initialize(args=args,
  104. model=model,
  105. model_parameters=model.parameters())
  106. loaded_model.load_checkpoint(save_folder,
  107. save_tag,
  108. load_optimizer_states=load_optimizer_states,
  109. load_lr_scheduler_states=load_lr_scheduler_states)
  110. compare_model_states(trained_model, loaded_model)
  111. if load_optimizer_states:
  112. compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
  113. if load_lr_scheduler_states:
  114. compare_lr_scheduler_states(trained_model, loaded_model)
  115. @pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
  116. reason="lamb is not installed")
  117. def test_checkpoint_unfused_optimizer(tmpdir):
  118. config_dict = {
  119. "train_batch_size": 2,
  120. "steps_per_print": 1,
  121. "optimizer": {
  122. "type": "Lamb",
  123. "params": {
  124. "lr": 0.00015
  125. }
  126. },
  127. "gradient_clipping": 1.0,
  128. "fp16": {
  129. "enabled": True
  130. },
  131. "scheduler": {
  132. "type": "OneCycle",
  133. "params": {
  134. "cycle_first_step_size": 1000,
  135. "cycle_first_stair_count": 500,
  136. "cycle_second_step_size": 1000,
  137. "cycle_second_stair_count": 500,
  138. "decay_step_size": 1000,
  139. "cycle_min_lr": 0.0001,
  140. "cycle_max_lr": 0.0010,
  141. "decay_lr_rate": 0.001,
  142. "cycle_min_mom": 0.85,
  143. "cycle_max_mom": 0.99,
  144. "decay_mom_rate": 0.0
  145. }
  146. }
  147. }
  148. args = args_from_dict(tmpdir, config_dict)
  149. hidden_dim = 10
  150. model = SimpleModel(hidden_dim, empty_grad=False)
  151. @distributed_test(world_size=[2])
  152. def _test_checkpoint_unfused_optimizer(args,
  153. model,
  154. hidden_dim,
  155. load_optimizer_states):
  156. checkpoint_correctness_verification(args,
  157. model,
  158. hidden_dim,
  159. tmpdir,
  160. load_optimizer_states=load_optimizer_states)
  161. _test_checkpoint_unfused_optimizer(args=args,
  162. model=model,
  163. hidden_dim=hidden_dim,
  164. load_optimizer_states=True)
  165. _test_checkpoint_unfused_optimizer(args=args,
  166. model=model,
  167. hidden_dim=hidden_dim,
  168. load_optimizer_states=False)
  169. def test_checkpoint_fused_optimizer(tmpdir):
  170. config_dict = {
  171. "train_batch_size": 2,
  172. "steps_per_print": 1,
  173. "optimizer": {
  174. "type": "Adam",
  175. "params": {
  176. "lr": 0.00015,
  177. "betas": [0.8,
  178. 0.999],
  179. "eps": 1e-8,
  180. "weight_decay": 3e-7
  181. }
  182. },
  183. "fp16": {
  184. "enabled": True
  185. }
  186. }
  187. args = args_from_dict(tmpdir, config_dict)
  188. hidden_dim = 10
  189. model = SimpleModel(hidden_dim, empty_grad=False)
  190. @distributed_test(world_size=[2])
  191. def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
  192. checkpoint_correctness_verification(args,
  193. model,
  194. hidden_dim,
  195. tmpdir,
  196. load_optimizer_states=load_optimizer_states)
  197. _test_checkpoint_fused_optimizer(args=args,
  198. model=model,
  199. hidden_dim=hidden_dim,
  200. load_optimizer_states=True)
  201. _test_checkpoint_fused_optimizer(args=args,
  202. model=model,
  203. hidden_dim=hidden_dim,
  204. load_optimizer_states=False)
  205. @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
  206. [
  207. (1,
  208. False,
  209. 'Adam'),
  210. (2,
  211. False,
  212. 'Adam'),
  213. (2,
  214. True,
  215. 'deepspeed_adam'),
  216. ])
  217. def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
  218. if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
  219. pytest.skip("cpu-adam is not installed")
  220. config_dict = {
  221. "train_batch_size": 2,
  222. "steps_per_print": 1,
  223. "optimizer": {
  224. "type": adam_optimizer,
  225. "params": {
  226. "lr": 0.00015,
  227. "betas": [0.8,
  228. 0.999],
  229. "eps": 1e-8,
  230. "weight_decay": 3e-7
  231. }
  232. },
  233. "fp16": {
  234. "enabled": True
  235. },
  236. "zero_optimization": {
  237. "stage": zero_stage,
  238. "cpu_offload": use_cpu_offload
  239. }
  240. }
  241. args = args_from_dict(tmpdir, config_dict)
  242. hidden_dim = 10
  243. model = SimpleModel(hidden_dim, empty_grad=False)
  244. @distributed_test(world_size=[2])
  245. def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
  246. checkpoint_correctness_verification(args,
  247. model,
  248. hidden_dim,
  249. tmpdir,
  250. load_optimizer_states=load_optimizer_states)
  251. _test_checkpoint_zero_optimizer(args=args,
  252. model=model,
  253. hidden_dim=hidden_dim,
  254. load_optimizer_states=True)
  255. @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
  256. [
  257. (1,
  258. False,
  259. "Adam"),
  260. (2,
  261. False,
  262. "Adam"),
  263. (2,
  264. True,
  265. 'deepspeed_adam'),
  266. ])
  267. def test_checkpoint_zero_no_optimizer(tmpdir,
  268. zero_stage,
  269. use_cpu_offload,
  270. adam_optimizer):
  271. if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
  272. pytest.skip("cpu-adam is not installed")
  273. config_dict = {
  274. "train_batch_size": 2,
  275. "steps_per_print": 1,
  276. "optimizer": {
  277. "type": adam_optimizer,
  278. "params": {
  279. "lr": 0.00015,
  280. "betas": [0.8,
  281. 0.999],
  282. "eps": 1e-8,
  283. "weight_decay": 3e-7
  284. }
  285. },
  286. "fp16": {
  287. "enabled": True
  288. },
  289. "zero_optimization": {
  290. "stage": zero_stage,
  291. "cpu_offload": use_cpu_offload
  292. }
  293. }
  294. args = args_from_dict(tmpdir, config_dict)
  295. hidden_dim = 10
  296. model = SimpleModel(hidden_dim, empty_grad=False)
  297. @distributed_test(world_size=[2])
  298. def _test_checkpoint_zero_no_optimizer(args,
  299. model,
  300. hidden_dim,
  301. load_optimizer_states):
  302. checkpoint_correctness_verification(args,
  303. model,
  304. hidden_dim,
  305. tmpdir,
  306. load_optimizer_states=load_optimizer_states)
  307. _test_checkpoint_zero_no_optimizer(args=args,
  308. model=model,
  309. hidden_dim=hidden_dim,
  310. load_optimizer_states=False)
  311. @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
  312. [
  313. (0,
  314. False,
  315. 'Adam'),
  316. (1,
  317. False,
  318. 'Adam'),
  319. (2,
  320. False,
  321. 'Adam'),
  322. (2,
  323. True,
  324. 'deepspeed_adam'),
  325. ])
  326. def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
  327. if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
  328. pytest.skip("cpu-adam is not installed")
  329. config_dict = {
  330. "train_batch_size": 2,
  331. "steps_per_print": 1,
  332. "optimizer": {
  333. "type": adam_optimizer,
  334. "params": {
  335. "lr": 0.00015,
  336. "betas": [0.8,
  337. 0.999],
  338. "eps": 1e-8,
  339. "weight_decay": 3e-7
  340. }
  341. },
  342. "fp16": {
  343. "enabled": True
  344. },
  345. "zero_optimization": {
  346. "stage": zero_stage,
  347. "cpu_offload": use_cpu_offload
  348. },
  349. "scheduler": {
  350. "type": "WarmupLR",
  351. "params": {
  352. "warmup_min_lr": 0,
  353. "warmup_max_lr": 0.001,
  354. "warmup_num_steps": 1000
  355. }
  356. }
  357. }
  358. args = args_from_dict(tmpdir, config_dict)
  359. hidden_dim = 10
  360. model = SimpleModel(hidden_dim, empty_grad=False)
  361. @distributed_test(world_size=[2])
  362. def _test_checkpoint_lr_scheduler(args,
  363. model,
  364. hidden_dim,
  365. load_optimizer_states,
  366. load_lr_scheduler_states):
  367. checkpoint_correctness_verification(
  368. args,
  369. model,
  370. hidden_dim,
  371. tmpdir,
  372. load_optimizer_states=load_optimizer_states,
  373. load_lr_scheduler_states=load_lr_scheduler_states)
  374. _test_checkpoint_lr_scheduler(args=args,
  375. model=model,
  376. hidden_dim=hidden_dim,
  377. load_optimizer_states=False,
  378. load_lr_scheduler_states=True)
  379. @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
  380. [
  381. (0,
  382. False,
  383. 'Adam'),
  384. (1,
  385. False,
  386. 'Adam'),
  387. (2,
  388. False,
  389. 'Adam'),
  390. (2,
  391. True,
  392. 'deepspeed_adam'),
  393. ])
  394. def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
  395. if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
  396. pytest.skip("cpu-adam is not installed")
  397. config_dict = {
  398. "train_batch_size": 2,
  399. "steps_per_print": 1,
  400. "optimizer": {
  401. "type": adam_optimizer,
  402. "params": {
  403. "lr": 1e-5
  404. }
  405. },
  406. "fp16": {
  407. "enabled": True
  408. },
  409. "zero_optimization": {
  410. "stage": zero_stage,
  411. "cpu_offload": use_cpu_offload
  412. },
  413. "scheduler": {
  414. "type": "WarmupLR",
  415. "params": {
  416. "warmup_min_lr": 0,
  417. "warmup_max_lr": 0.001,
  418. "warmup_num_steps": 1000
  419. }
  420. },
  421. }
  422. args = args_from_dict(tmpdir, config_dict)
  423. hidden_dim = 10
  424. model = SimpleModel(hidden_dim, empty_grad=False)
  425. @distributed_test(world_size=[2])
  426. def _test_checkpoint_no_lr_scheduler(args,
  427. model,
  428. hidden_dim,
  429. load_optimizer_states,
  430. load_lr_scheduler_states):
  431. checkpoint_correctness_verification(
  432. args,
  433. model,
  434. hidden_dim,
  435. tmpdir,
  436. load_optimizer_states=load_optimizer_states,
  437. load_lr_scheduler_states=load_lr_scheduler_states)
  438. _test_checkpoint_no_lr_scheduler(args=args,
  439. model=model,
  440. hidden_dim=hidden_dim,
  441. load_optimizer_states=False,
  442. load_lr_scheduler_states=False)
  443. def test_checkpoint_fp32_optimizer(tmpdir):
  444. config_dict = {
  445. "train_batch_size": 2,
  446. "steps_per_print": 1,
  447. "optimizer": {
  448. "type": "Adam",
  449. "params": {
  450. "lr": 0.00015,
  451. "betas": [0.8,
  452. 0.999],
  453. "eps": 1e-8,
  454. "weight_decay": 3e-7
  455. }
  456. },
  457. "fp16": {
  458. "enabled": False
  459. }
  460. }
  461. args = args_from_dict(tmpdir, config_dict)
  462. hidden_dim = 10
  463. model = SimpleModel(hidden_dim, empty_grad=False)
  464. @distributed_test(world_size=[2])
  465. def _test_checkpoint_fp32_optimizer(args, model, hidden_dim):
  466. checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, fp16=False)
  467. _test_checkpoint_fp32_optimizer(args=args, model=model, hidden_dim=hidden_dim)
  468. @pytest.mark.parametrize("zero_stage", [0, 1])
  469. def test_checkpoint_pipe_engine(zero_stage, tmpdir, stages=2):
  470. config_dict = {
  471. "train_batch_size": 2,
  472. "train_micro_batch_size_per_gpu": 1,
  473. "steps_per_print": 1,
  474. "optimizer": {
  475. "type": "Adam",
  476. "params": {
  477. "lr": 1e-5
  478. }
  479. },
  480. "zero_optimization": {
  481. "stage": zero_stage
  482. },
  483. "fp16": {
  484. "enabled": zero_stage > 0
  485. },
  486. "scheduler": {
  487. "type": "OneCycle",
  488. "params": {
  489. "cycle_first_step_size": 1000,
  490. "cycle_first_stair_count": 500,
  491. "cycle_second_step_size": 1000,
  492. "cycle_second_stair_count": 500,
  493. "decay_step_size": 1000,
  494. "cycle_min_lr": 0.0001,
  495. "cycle_max_lr": 0.0010,
  496. "decay_lr_rate": 0.001,
  497. "cycle_min_mom": 0.85,
  498. "cycle_max_mom": 0.99,
  499. "decay_mom_rate": 0.0
  500. }
  501. }
  502. }
  503. @distributed_test(world_size=4)
  504. def _test(save_folder, num_stages):
  505. args = args_from_dict(tmpdir, config_dict)
  506. model = LinearStackPipe(num_stages=num_stages)
  507. checkpoint_correctness_verification(args=args,
  508. model=model,
  509. hidden_dim=model.hidden_dim,
  510. tmpdir=save_folder,
  511. fp16=config_dict['fp16']['enabled'],
  512. load_optimizer_states=True,
  513. load_lr_scheduler_states=True,
  514. train_batch=True)
  515. _test(tmpdir, num_stages=stages)
  516. @pytest.mark.parametrize("base_topo,test_topo",
  517. [
  518. (PipeTopo(num_pp=1,
  519. num_dp=4),
  520. PipeTopo(num_pp=4,
  521. num_dp=1)),
  522. (PipeTopo(num_pp=2,
  523. num_dp=2),
  524. PipeTopo(num_pp=2,
  525. num_dp=2)),
  526. (PipeTopo(num_pp=4,
  527. num_dp=1),
  528. PipeTopo(num_pp=2,
  529. num_dp=2)),
  530. ])
  531. def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir):
  532. @distributed_test(world_size=4)
  533. def _test(base_topo, test_topo, save_folder):
  534. base_model = LinearStackPipe(topology=base_topo)
  535. base_model.save_state_dict(save_folder)
  536. dist.barrier()
  537. test_model = LinearStackPipe(topology=test_topo)
  538. test_model.load_state_dir(save_folder)
  539. # Base and test can have different lengths, so make sure we map from the
  540. # smaller to larger model
  541. if len(base_model.forward_funcs) < len(test_model.forward_funcs):
  542. A = base_model
  543. B = test_model
  544. else:
  545. A = test_model
  546. B = base_model
  547. # Compare layers individually since partitions are different
  548. for idx, A_layer in enumerate(A.forward_funcs):
  549. if not hasattr(A_layer, 'parameters'):
  550. # Skip functionals, etc.
  551. continue
  552. # Find the corresponding layer in B
  553. global_idx = idx + A._local_start
  554. B_local_idx = global_idx - B._local_start
  555. B_layer = B.forward_funcs[B_local_idx]
  556. # Compare layer parameters
  557. for p0, p1 in zip(A_layer.parameters(), B_layer.parameters()):
  558. assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}"
  559. _test(base_topo, test_topo, save_folder=tmpdir)