test_checkpointing.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. import torch
  2. import deepspeed
  3. from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
  4. from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
  5. from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
  6. from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
  7. import argparse
  8. import pytest
  9. import json
  10. import os
  11. import numbers
  12. from common import distributed_test
  13. from simple_model import SimpleModel, random_dataloader, args_from_dict
  14. def compare_deepspeed_states(saved_model, loaded_model):
  15. # These are compared in more depth in other places
  16. assert hasattr(loaded_model, 'module')
  17. assert saved_model.csr_tensor_module_names == loaded_model.csr_tensor_module_names
  18. assert saved_model.skipped_steps == loaded_model.skipped_steps
  19. assert saved_model.global_steps == loaded_model.global_steps
  20. def compare_model_states(saved_model, loaded_model):
  21. compare_deepspeed_states(saved_model, loaded_model)
  22. for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
  23. assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
  24. if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
  25. for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
  26. assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
  27. elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
  28. for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
  29. for p0, p1 in zip(partition0, partition1):
  30. assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
  31. elif isinstance(saved_model.optimizer, FP16_Optimizer):
  32. for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
  33. assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
  34. elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
  35. for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
  36. for p0, p1 in zip(params0, params1):
  37. assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
  38. elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
  39. pass
  40. else:
  41. assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}'
  42. def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
  43. saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer
  44. loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer
  45. for state0, state1 in zip(saved_optimizer.state.values(),
  46. loaded_optimizer.state.values()):
  47. for s0, s1 in zip(state0.values(), state1.values()):
  48. if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
  49. assert torch.equal(s0, s1)
  50. else:
  51. assert s0 == s1
  52. def compare_lr_scheduler_states(saved_model, loaded_model):
  53. assert hasattr(saved_model, 'lr_scheduler')
  54. assert hasattr(loaded_model, 'lr_scheduler')
  55. saved_scheduler = saved_model.lr_scheduler
  56. loaded_scheduler = loaded_model.lr_scheduler
  57. assert hasattr(saved_scheduler, 'state_dict')
  58. assert hasattr(loaded_scheduler, 'state_dict')
  59. saved_sd = saved_scheduler.state_dict()
  60. loaded_sd = loaded_scheduler.state_dict()
  61. print(f"saved_sd = {saved_sd}")
  62. print(f"loaded_sd = {loaded_sd}")
  63. assert saved_sd.keys() == loaded_sd.keys()
  64. for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
  65. if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
  66. assert state0 == state1
  67. def checkpoint_correctness_verification(args,
  68. model,
  69. hidden_dim,
  70. tmpdir,
  71. load_optimizer_states=False,
  72. load_lr_scheduler_states=False,
  73. fp16=True):
  74. dtype = torch.half if fp16 else torch.float32
  75. ds_model, _, _, _ = deepspeed.initialize(args=args,
  76. model=model,
  77. model_parameters=model.parameters())
  78. data_loader = random_dataloader(model=ds_model,
  79. total_samples=50,
  80. hidden_dim=hidden_dim,
  81. device=ds_model.device,
  82. dtype=dtype)
  83. for n, batch in enumerate(data_loader):
  84. loss = ds_model(batch[0], batch[1])
  85. ds_model.backward(loss)
  86. ds_model.step()
  87. trained_model = ds_model
  88. save_folder = os.path.join(tmpdir, 'saved_checkpoint')
  89. save_tag = '1'
  90. trained_model.save_checkpoint(save_folder, save_tag)
  91. loaded_model, _, _, _ = deepspeed.initialize(args=args,
  92. model=model,
  93. model_parameters=model.parameters())
  94. loaded_model.load_checkpoint(save_folder,
  95. save_tag,
  96. load_optimizer_states=load_optimizer_states,
  97. load_lr_scheduler_states=load_lr_scheduler_states)
  98. compare_model_states(trained_model, loaded_model)
  99. if load_optimizer_states:
  100. compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
  101. if load_lr_scheduler_states:
  102. compare_lr_scheduler_states(trained_model, loaded_model)
  103. def test_checkpoint_unfused_optimizer(tmpdir):
  104. config_dict = {
  105. "train_batch_size": 2,
  106. "steps_per_print": 1,
  107. "optimizer": {
  108. "type": "Lamb",
  109. "params": {
  110. "lr": 0.00015
  111. }
  112. },
  113. "gradient_clipping": 1.0,
  114. "fp16": {
  115. "enabled": True
  116. },
  117. "scheduler": {
  118. "type": "OneCycle",
  119. "params": {
  120. "cycle_first_step_size": 1000,
  121. "cycle_first_stair_count": 500,
  122. "cycle_second_step_size": 1000,
  123. "cycle_second_stair_count": 500,
  124. "decay_step_size": 1000,
  125. "cycle_min_lr": 0.0001,
  126. "cycle_max_lr": 0.0010,
  127. "decay_lr_rate": 0.001,
  128. "cycle_min_mom": 0.85,
  129. "cycle_max_mom": 0.99,
  130. "decay_mom_rate": 0.0
  131. }
  132. }
  133. }
  134. args = args_from_dict(tmpdir, config_dict)
  135. hidden_dim = 10
  136. model = SimpleModel(hidden_dim, empty_grad=False)
  137. @distributed_test(world_size=[2])
  138. def _test_checkpoint_unfused_optimizer(args,
  139. model,
  140. hidden_dim,
  141. load_optimizer_states):
  142. checkpoint_correctness_verification(args,
  143. model,
  144. hidden_dim,
  145. tmpdir,
  146. load_optimizer_states=load_optimizer_states)
  147. _test_checkpoint_unfused_optimizer(args=args,
  148. model=model,
  149. hidden_dim=hidden_dim,
  150. load_optimizer_states=True)
  151. _test_checkpoint_unfused_optimizer(args=args,
  152. model=model,
  153. hidden_dim=hidden_dim,
  154. load_optimizer_states=False)
  155. def test_checkpoint_fused_optimizer(tmpdir):
  156. config_dict = {
  157. "train_batch_size": 2,
  158. "steps_per_print": 1,
  159. "optimizer": {
  160. "type": "Adam",
  161. "params": {
  162. "lr": 0.00015,
  163. "betas": [0.8,
  164. 0.999],
  165. "eps": 1e-8,
  166. "weight_decay": 3e-7
  167. }
  168. },
  169. "fp16": {
  170. "enabled": True
  171. }
  172. }
  173. args = args_from_dict(tmpdir, config_dict)
  174. hidden_dim = 10
  175. model = SimpleModel(hidden_dim, empty_grad=False)
  176. @distributed_test(world_size=[2])
  177. def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
  178. checkpoint_correctness_verification(args,
  179. model,
  180. hidden_dim,
  181. tmpdir,
  182. load_optimizer_states=load_optimizer_states)
  183. _test_checkpoint_fused_optimizer(args=args,
  184. model=model,
  185. hidden_dim=hidden_dim,
  186. load_optimizer_states=True)
  187. _test_checkpoint_fused_optimizer(args=args,
  188. model=model,
  189. hidden_dim=hidden_dim,
  190. load_optimizer_states=False)
  191. @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
  192. [
  193. (1,
  194. False,
  195. 'Adam'),
  196. (2,
  197. False,
  198. 'Adam'),
  199. (2,
  200. True,
  201. 'deepspeed_adam'),
  202. ])
  203. def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
  204. config_dict = {
  205. "train_batch_size": 2,
  206. "steps_per_print": 1,
  207. "optimizer": {
  208. "type": adam_optimizer,
  209. "params": {
  210. "lr": 0.00015,
  211. "betas": [0.8,
  212. 0.999],
  213. "eps": 1e-8,
  214. "weight_decay": 3e-7
  215. }
  216. },
  217. "fp16": {
  218. "enabled": True
  219. },
  220. "zero_optimization": {
  221. "stage": zero_stage,
  222. "cpu_offload": use_cpu_offload
  223. }
  224. }
  225. args = args_from_dict(tmpdir, config_dict)
  226. hidden_dim = 10
  227. model = SimpleModel(hidden_dim, empty_grad=False)
  228. @distributed_test(world_size=[2])
  229. def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
  230. checkpoint_correctness_verification(args,
  231. model,
  232. hidden_dim,
  233. tmpdir,
  234. load_optimizer_states=load_optimizer_states)
  235. _test_checkpoint_zero_optimizer(args=args,
  236. model=model,
  237. hidden_dim=hidden_dim,
  238. load_optimizer_states=True)
  239. @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
  240. [
  241. (1,
  242. False,
  243. "Adam"),
  244. (2,
  245. False,
  246. "Adam"),
  247. (2,
  248. True,
  249. 'deepspeed_adam'),
  250. ])
  251. def test_checkpoint_zero_no_optimizer(tmpdir,
  252. zero_stage,
  253. use_cpu_offload,
  254. adam_optimizer):
  255. config_dict = {
  256. "train_batch_size": 2,
  257. "steps_per_print": 1,
  258. "optimizer": {
  259. "type": adam_optimizer,
  260. "params": {
  261. "lr": 0.00015,
  262. "betas": [0.8,
  263. 0.999],
  264. "eps": 1e-8,
  265. "weight_decay": 3e-7
  266. }
  267. },
  268. "fp16": {
  269. "enabled": True
  270. },
  271. "zero_optimization": {
  272. "stage": zero_stage,
  273. "cpu_offload": use_cpu_offload
  274. }
  275. }
  276. args = args_from_dict(tmpdir, config_dict)
  277. hidden_dim = 10
  278. model = SimpleModel(hidden_dim, empty_grad=False)
  279. @distributed_test(world_size=[2])
  280. def _test_checkpoint_zero_no_optimizer(args,
  281. model,
  282. hidden_dim,
  283. load_optimizer_states):
  284. checkpoint_correctness_verification(args,
  285. model,
  286. hidden_dim,
  287. tmpdir,
  288. load_optimizer_states=load_optimizer_states)
  289. _test_checkpoint_zero_no_optimizer(args=args,
  290. model=model,
  291. hidden_dim=hidden_dim,
  292. load_optimizer_states=False)
  293. @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
  294. [
  295. (0,
  296. False,
  297. 'Adam'),
  298. (1,
  299. False,
  300. 'Adam'),
  301. (2,
  302. False,
  303. 'Adam'),
  304. (2,
  305. True,
  306. 'deepspeed_adam'),
  307. ])
  308. def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
  309. config_dict = {
  310. "train_batch_size": 2,
  311. "steps_per_print": 1,
  312. "optimizer": {
  313. "type": adam_optimizer,
  314. "params": {
  315. "lr": 0.00015,
  316. "betas": [0.8,
  317. 0.999],
  318. "eps": 1e-8,
  319. "weight_decay": 3e-7
  320. }
  321. },
  322. "fp16": {
  323. "enabled": True
  324. },
  325. "zero_optimization": {
  326. "stage": zero_stage,
  327. "cpu_offload": use_cpu_offload
  328. },
  329. "scheduler": {
  330. "type": "WarmupLR",
  331. "params": {
  332. "warmup_min_lr": 0,
  333. "warmup_max_lr": 0.001,
  334. "warmup_num_steps": 1000
  335. }
  336. }
  337. }
  338. args = args_from_dict(tmpdir, config_dict)
  339. hidden_dim = 10
  340. model = SimpleModel(hidden_dim, empty_grad=False)
  341. @distributed_test(world_size=[2])
  342. def _test_checkpoint_lr_scheduler(args,
  343. model,
  344. hidden_dim,
  345. load_optimizer_states,
  346. load_lr_scheduler_states):
  347. checkpoint_correctness_verification(
  348. args,
  349. model,
  350. hidden_dim,
  351. tmpdir,
  352. load_optimizer_states=load_optimizer_states,
  353. load_lr_scheduler_states=load_lr_scheduler_states)
  354. _test_checkpoint_lr_scheduler(args=args,
  355. model=model,
  356. hidden_dim=hidden_dim,
  357. load_optimizer_states=False,
  358. load_lr_scheduler_states=True)
  359. @pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
  360. [
  361. (0,
  362. False,
  363. 'Adam'),
  364. (1,
  365. False,
  366. 'Adam'),
  367. (2,
  368. False,
  369. 'Adam'),
  370. (2,
  371. True,
  372. 'deepspeed_adam'),
  373. ])
  374. def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
  375. config_dict = {
  376. "train_batch_size": 2,
  377. "steps_per_print": 1,
  378. "optimizer": {
  379. "type": adam_optimizer,
  380. "params": {
  381. "lr": 1e-5
  382. }
  383. },
  384. "fp16": {
  385. "enabled": True
  386. },
  387. "zero_optimization": {
  388. "stage": zero_stage,
  389. "cpu_offload": use_cpu_offload
  390. },
  391. "scheduler": {
  392. "type": "WarmupLR",
  393. "params": {
  394. "warmup_min_lr": 0,
  395. "warmup_max_lr": 0.001,
  396. "warmup_num_steps": 1000
  397. }
  398. },
  399. }
  400. args = args_from_dict(tmpdir, config_dict)
  401. hidden_dim = 10
  402. model = SimpleModel(hidden_dim, empty_grad=False)
  403. @distributed_test(world_size=[2])
  404. def _test_checkpoint_no_lr_scheduler(args,
  405. model,
  406. hidden_dim,
  407. load_optimizer_states,
  408. load_lr_scheduler_states):
  409. checkpoint_correctness_verification(
  410. args,
  411. model,
  412. hidden_dim,
  413. tmpdir,
  414. load_optimizer_states=load_optimizer_states,
  415. load_lr_scheduler_states=load_lr_scheduler_states)
  416. _test_checkpoint_no_lr_scheduler(args=args,
  417. model=model,
  418. hidden_dim=hidden_dim,
  419. load_optimizer_states=False,
  420. load_lr_scheduler_states=False)
  421. def test_checkpoint_fp32_optimizer(tmpdir):
  422. config_dict = {
  423. "train_batch_size": 2,
  424. "steps_per_print": 1,
  425. "optimizer": {
  426. "type": "Adam",
  427. "params": {
  428. "lr": 0.00015,
  429. "betas": [0.8,
  430. 0.999],
  431. "eps": 1e-8,
  432. "weight_decay": 3e-7
  433. }
  434. },
  435. "fp16": {
  436. "enabled": False
  437. }
  438. }
  439. args = args_from_dict(tmpdir, config_dict)
  440. hidden_dim = 10
  441. model = SimpleModel(hidden_dim, empty_grad=False)
  442. @distributed_test(world_size=[2])
  443. def _test_checkpoint_fp32_optimizer(args, model, hidden_dim):
  444. checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, fp16=False)
  445. _test_checkpoint_fp32_optimizer(args=args, model=model, hidden_dim=hidden_dim)