test_fp16.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. import torch
  2. import apex
  3. import deepspeed
  4. import argparse
  5. import pytest
  6. import json
  7. import os
  8. from common import distributed_test
  9. from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
  10. lamb_available = pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
  11. reason="lamb is not installed")
  12. @lamb_available
  13. def test_lamb_fp32_grad_clip(tmpdir):
  14. config_dict = {
  15. "train_batch_size": 2,
  16. "steps_per_print": 1,
  17. "optimizer": {
  18. "type": "Lamb",
  19. "params": {
  20. "lr": 0.00015
  21. }
  22. },
  23. "gradient_clipping": 1.0
  24. }
  25. args = args_from_dict(tmpdir, config_dict)
  26. hidden_dim = 10
  27. model = SimpleModel(hidden_dim, empty_grad=False)
  28. @distributed_test(world_size=[1, 2])
  29. def _test_lamb_fp32_grad_clip(args, model, hidden_dim):
  30. model, _, _,_ = deepspeed.initialize(args=args,
  31. model=model,
  32. model_parameters=model.parameters())
  33. data_loader = random_dataloader(model=model,
  34. total_samples=50,
  35. hidden_dim=hidden_dim,
  36. device=model.device,
  37. dtype=torch.float)
  38. for n, batch in enumerate(data_loader):
  39. loss = model(batch[0], batch[1])
  40. model.backward(loss)
  41. model.step()
  42. _test_lamb_fp32_grad_clip(args=args, model=model, hidden_dim=hidden_dim)
  43. @lamb_available
  44. def test_lamb_fp16_basic(tmpdir):
  45. config_dict = {
  46. "train_batch_size": 2,
  47. "steps_per_print": 1,
  48. "optimizer": {
  49. "type": "Lamb",
  50. "params": {
  51. "lr": 0.00015
  52. }
  53. },
  54. "gradient_clipping": 1.0,
  55. "fp16": {
  56. "enabled": True
  57. }
  58. }
  59. args = args_from_dict(tmpdir, config_dict)
  60. hidden_dim = 10
  61. model = SimpleModel(hidden_dim, empty_grad=False)
  62. @distributed_test(world_size=[1, 2])
  63. def _test_lamb_fp16_basic(args, model, hidden_dim):
  64. model, _, _,_ = deepspeed.initialize(args=args,
  65. model=model,
  66. model_parameters=model.parameters())
  67. data_loader = random_dataloader(model=model,
  68. total_samples=50,
  69. hidden_dim=hidden_dim,
  70. device=model.device)
  71. for n, batch in enumerate(data_loader):
  72. loss = model(batch[0], batch[1])
  73. model.backward(loss)
  74. model.step()
  75. _test_lamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim)
  76. @lamb_available
  77. def test_lamb_fp16_empty_grad(tmpdir):
  78. config_dict = {
  79. "train_batch_size": 2,
  80. "steps_per_print": 1,
  81. "optimizer": {
  82. "type": "Lamb",
  83. "params": {
  84. "lr": 0.00015
  85. }
  86. },
  87. "gradient_clipping": 1.0,
  88. "fp16": {
  89. "enabled": True
  90. }
  91. }
  92. args = args_from_dict(tmpdir, config_dict)
  93. hidden_dim = 10
  94. model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank)
  95. @distributed_test(world_size=[2])
  96. def _test_lamb_fp16_empty_grad(args, model, hidden_dim):
  97. model, _, _,_ = deepspeed.initialize(args=args,
  98. model=model,
  99. model_parameters=model.parameters())
  100. data_loader = random_dataloader(model=model,
  101. total_samples=50,
  102. hidden_dim=hidden_dim,
  103. device=model.device)
  104. for n, batch in enumerate(data_loader):
  105. loss = model(batch[0], batch[1])
  106. model.backward(loss)
  107. model.step()
  108. _test_lamb_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
  109. def test_adam_fp32_empty_grad(tmpdir):
  110. config_dict = {
  111. "train_batch_size": 2,
  112. "steps_per_print": 1,
  113. "optimizer": {
  114. "type": "Adam",
  115. "params": {
  116. "lr": 0.00015
  117. }
  118. },
  119. "gradient_clipping": 1.0,
  120. "fp16": {
  121. "enabled": False
  122. }
  123. }
  124. args = args_from_dict(tmpdir, config_dict)
  125. hidden_dim = 10
  126. model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank)
  127. @distributed_test(world_size=[2])
  128. def _test_adam_fp32_empty_grad(args, model, hidden_dim):
  129. model, _, _,_ = deepspeed.initialize(args=args,
  130. model=model,
  131. model_parameters=model.parameters())
  132. data_loader = random_dataloader(model=model,
  133. total_samples=50,
  134. hidden_dim=hidden_dim,
  135. device=model.device,
  136. dtype=torch.float)
  137. for n, batch in enumerate(data_loader):
  138. loss = model(batch[0], batch[1])
  139. model.backward(loss)
  140. model.step()
  141. _test_adam_fp32_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
  142. def test_adamw_fp16_basic(tmpdir):
  143. config_dict = {
  144. "train_batch_size": 1,
  145. "steps_per_print": 1,
  146. "fp16": {
  147. "enabled": True
  148. }
  149. }
  150. args = args_from_dict(tmpdir, config_dict)
  151. hidden_dim = 10
  152. model = SimpleModel(hidden_dim, empty_grad=False)
  153. @distributed_test(world_size=[1])
  154. def _test_adamw_fp16_basic(args, model, hidden_dim):
  155. optimizer = torch.optim.AdamW(params=model.parameters())
  156. model, _, _,_ = deepspeed.initialize(args=args,
  157. model=model,
  158. optimizer=optimizer)
  159. data_loader = random_dataloader(model=model,
  160. total_samples=50,
  161. hidden_dim=hidden_dim,
  162. device=model.device)
  163. for n, batch in enumerate(data_loader):
  164. loss = model(batch[0], batch[1])
  165. model.backward(loss)
  166. model.step()
  167. _test_adamw_fp16_basic(args=args, model=model, hidden_dim=hidden_dim)
  168. def test_adamw_fp16_empty_grad(tmpdir):
  169. config_dict = {
  170. "train_batch_size": 1,
  171. "steps_per_print": 1,
  172. "fp16": {
  173. "enabled": True
  174. }
  175. }
  176. args = args_from_dict(tmpdir, config_dict)
  177. hidden_dim = 10
  178. model = SimpleModel(hidden_dim, empty_grad=True)
  179. @distributed_test(world_size=[1])
  180. def _test_adamw_fp16_empty_grad(args, model, hidden_dim):
  181. optimizer = torch.optim.AdamW(params=model.parameters())
  182. model, _, _,_ = deepspeed.initialize(args=args,
  183. model=model,
  184. optimizer=optimizer)
  185. data_loader = random_dataloader(model=model,
  186. total_samples=50,
  187. hidden_dim=hidden_dim,
  188. device=model.device)
  189. for n, batch in enumerate(data_loader):
  190. loss = model(batch[0], batch[1])
  191. model.backward(loss)
  192. model.step()
  193. _test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
  194. @pytest.mark.parametrize('zero_stage, use_cpu_offload',
  195. [
  196. (1,
  197. False),
  198. (2,
  199. False),
  200. (2,
  201. True),
  202. ])
  203. def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offload):
  204. if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
  205. pytest.skip("cpu-adam is not installed")
  206. config_dict = {
  207. "train_batch_size": 1,
  208. "steps_per_print": 1,
  209. "optimizer": {
  210. "type": "Adam",
  211. "params": {
  212. "lr": 0.00015
  213. }
  214. },
  215. "scheduler": {
  216. "type": "OneCycle",
  217. "params": {
  218. "cycle_first_step_size": 16000,
  219. "cycle_first_stair_count": 8000,
  220. "decay_step_size": 16000,
  221. "cycle_min_lr": 1e-06,
  222. "cycle_max_lr": 3e-05,
  223. "decay_lr_rate": 1e-07,
  224. "cycle_min_mom": 0.85,
  225. "cycle_max_mom": 0.99,
  226. "decay_mom_rate": 0.0
  227. }
  228. },
  229. "fp16": {
  230. "enabled": True
  231. },
  232. "zero_optimization": {
  233. "stage": zero_stage,
  234. "cpu_offload": use_cpu_offload
  235. }
  236. }
  237. args = args_from_dict(tmpdir, config_dict)
  238. hidden_dim = 10
  239. model = SimpleModel(hidden_dim, empty_grad=True)
  240. @distributed_test(world_size=[1])
  241. def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim):
  242. model, _, _,_ = deepspeed.initialize(args=args,
  243. model=model,
  244. model_parameters=model.parameters())
  245. data_loader = random_dataloader(model=model,
  246. total_samples=50,
  247. hidden_dim=hidden_dim,
  248. device=model.device)
  249. for n, batch in enumerate(data_loader):
  250. loss = model(batch[0], batch[1])
  251. model.backward(loss)
  252. model.step()
  253. _test_adam_fp16_zero_onecycle_compatibility(args=args,
  254. model=model,
  255. hidden_dim=hidden_dim)
  256. @pytest.mark.parametrize('zero_stage, use_cpu_offload',
  257. [
  258. (1,
  259. False),
  260. (2,
  261. False),
  262. (2,
  263. True),
  264. ])
  265. def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
  266. if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
  267. pytest.skip("cpu-adam is not installed")
  268. config_dict = {
  269. "train_batch_size": 4,
  270. "steps_per_print": 1,
  271. "optimizer": {
  272. "type": "Adam",
  273. "params": {
  274. "lr": 0.00015
  275. }
  276. },
  277. "fp16": {
  278. "enabled": True,
  279. "loss_scale": 138.
  280. },
  281. "zero_optimization": {
  282. "stage": zero_stage,
  283. "cpu_offload": use_cpu_offload
  284. }
  285. }
  286. args = args_from_dict(tmpdir, config_dict)
  287. @distributed_test(world_size=2)
  288. def _test_zero_static_scale(args):
  289. hidden_dim = 10
  290. model = SimpleModel(hidden_dim, empty_grad=True)
  291. model, optim, _,_ = deepspeed.initialize(args=args,
  292. model=model,
  293. model_parameters=model.parameters())
  294. # Ensure the static scaler is configured.
  295. assert optim.dynamic_loss_scale == False
  296. assert optim.loss_scaler.loss_scale == 138.
  297. # Now make sure things work..
  298. data_loader = random_dataloader(model=model,
  299. total_samples=10,
  300. hidden_dim=hidden_dim,
  301. device=model.device)
  302. for n, batch in enumerate(data_loader):
  303. loss = model(batch[0], batch[1])
  304. model.backward(loss)
  305. model.step()
  306. _test_zero_static_scale(args)
  307. def test_zero_static_scale_deprecated_format(tmpdir):
  308. config_dict = {
  309. "train_batch_size": 4,
  310. "steps_per_print": 1,
  311. "optimizer": {
  312. "type": "Adam",
  313. "params": {
  314. "lr": 0.00015
  315. }
  316. },
  317. "fp16": {
  318. "enabled": True,
  319. "loss_scale": 138.
  320. },
  321. "zero_optimization": True
  322. }
  323. args = args_from_dict(tmpdir, config_dict)
  324. @distributed_test(world_size=2)
  325. def _test_zero_static_scale(args):
  326. hidden_dim = 10
  327. model = SimpleModel(hidden_dim, empty_grad=True)
  328. model, optim, _,_ = deepspeed.initialize(args=args,
  329. model=model,
  330. model_parameters=model.parameters())
  331. # Ensure the static scaler is configured.
  332. assert optim.dynamic_loss_scale == False
  333. assert optim.loss_scaler.loss_scale == 138.
  334. # Now make sure things work..
  335. data_loader = random_dataloader(model=model,
  336. total_samples=10,
  337. hidden_dim=hidden_dim,
  338. device=model.device)
  339. for n, batch in enumerate(data_loader):
  340. loss = model(batch[0], batch[1])
  341. model.backward(loss)
  342. model.step()
  343. _test_zero_static_scale(args)
  344. @pytest.mark.parametrize('zero_stage, use_cpu_offload',
  345. [
  346. (1,
  347. False),
  348. (2,
  349. False),
  350. (2,
  351. True),
  352. ])
  353. def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload):
  354. if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
  355. pytest.skip("cpu-adam is not installed")
  356. config_dict = {
  357. "train_batch_size": 4,
  358. "steps_per_print": 1,
  359. "fp16": {
  360. "enabled": True,
  361. },
  362. "zero_optimization": {
  363. "stage": zero_stage,
  364. "cpu_offload": use_cpu_offload
  365. },
  366. "zero_allow_untested_optimizer": False
  367. }
  368. args = args_from_dict(tmpdir, config_dict)
  369. @distributed_test(world_size=[1])
  370. def _test_zero_allow_untested_optimizer(args):
  371. hidden_dim = 10
  372. model = SimpleModel(hidden_dim, empty_grad=True)
  373. optimizer = SimpleOptimizer(model.parameters())
  374. with pytest.raises(AssertionError):
  375. model, optim, _,_ = deepspeed.initialize(args=args,
  376. model=model,
  377. optimizer=optimizer,
  378. model_parameters=model.parameters())
  379. _test_zero_allow_untested_optimizer(args)
  380. @pytest.mark.parametrize('zero_stage, use_cpu_offload',
  381. [
  382. (1,
  383. False),
  384. (2,
  385. False),
  386. (2,
  387. True),
  388. ])
  389. def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload):
  390. if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
  391. pytest.skip("cpu-adam is not installed")
  392. config_dict = {
  393. "train_micro_batch_size_per_gpu": 1,
  394. "gradient_accumulation_steps": 1,
  395. "fp16": {
  396. "enabled": True,
  397. "initial_scale_power": 8
  398. },
  399. "optimizer": {
  400. "type": "Adam",
  401. "params": {
  402. "lr": 0.00015
  403. }
  404. },
  405. "zero_optimization": {
  406. "stage": zero_stage,
  407. "cpu_offload": use_cpu_offload,
  408. "reduce_bucket_size": 100,
  409. "allgather_bucket_size": 100
  410. }
  411. }
  412. args = args_from_dict(tmpdir, config_dict)
  413. @distributed_test(world_size=[3])
  414. def _test_zero_empty_partition(args):
  415. hidden_dim = 1
  416. model = SimpleModel(hidden_dim)
  417. # Ensure model has 2 parameters, to cause empty partition with DP=3
  418. assert len(list(model.parameters())) == 2
  419. model, _, _, _ = deepspeed.initialize(args=args,
  420. model=model,
  421. model_parameters=model.parameters())
  422. # Now make sure things work..
  423. data_loader = random_dataloader(model=model,
  424. total_samples=1,
  425. hidden_dim=hidden_dim,
  426. device=model.device)
  427. for n, batch in enumerate(data_loader):
  428. loss = model(batch[0], batch[1])
  429. model.backward(loss)
  430. model.step()
  431. _test_zero_empty_partition(args)
  432. def test_adam_amp_basic(tmpdir):
  433. config_dict = {"train_batch_size": 1, "steps_per_print": 1, "amp": {"enabled": True}}
  434. args = args_from_dict(tmpdir, config_dict)
  435. hidden_dim = 10
  436. model = SimpleModel(hidden_dim, empty_grad=False)
  437. @distributed_test(world_size=[1])
  438. def _test_adam_amp_basic(args, model, hidden_dim):
  439. optimizer = torch.optim.Adam(params=model.parameters())
  440. model, _, _,_ = deepspeed.initialize(args=args,
  441. model=model,
  442. optimizer=optimizer)
  443. data_loader = random_dataloader(model=model,
  444. total_samples=50,
  445. hidden_dim=hidden_dim,
  446. device=model.device)
  447. for n, batch in enumerate(data_loader):
  448. loss = model(batch[0], batch[1])
  449. model.backward(loss)
  450. model.step()
  451. _test_adam_amp_basic(args=args, model=model, hidden_dim=hidden_dim)
  452. @lamb_available
  453. def test_lamb_amp_basic(tmpdir):
  454. config_dict = {
  455. "train_batch_size": 2,
  456. "steps_per_print": 1,
  457. "optimizer": {
  458. "type": "Lamb",
  459. "params": {
  460. "lr": 0.00015
  461. }
  462. },
  463. "gradient_clipping": 1.0,
  464. "amp": {
  465. "enabled": True,
  466. }
  467. }
  468. args = args_from_dict(tmpdir, config_dict)
  469. hidden_dim = 10
  470. model = SimpleModel(hidden_dim, empty_grad=False)
  471. @distributed_test(world_size=[1, 2])
  472. def _test_lamb_amp_basic(args, model, hidden_dim):
  473. model, _, _,_ = deepspeed.initialize(args=args,
  474. model=model,
  475. model_parameters=model.parameters())
  476. data_loader = random_dataloader(model=model,
  477. total_samples=50,
  478. hidden_dim=hidden_dim,
  479. device=model.device)
  480. for n, batch in enumerate(data_loader):
  481. loss = model(batch[0], batch[1])
  482. model.backward(loss)
  483. model.step()
  484. _test_lamb_amp_basic(args=args, model=model, hidden_dim=hidden_dim)
  485. def test_adam_amp_o2(tmpdir):
  486. config_dict = {
  487. "train_batch_size": 2,
  488. "steps_per_print": 1,
  489. "optimizer": {
  490. "type": "Adam",
  491. "params": {
  492. "lr": 0.00015
  493. }
  494. },
  495. "gradient_clipping": 1.0,
  496. "amp": {
  497. "enabled": True,
  498. "opt_level": "O2"
  499. }
  500. }
  501. args = args_from_dict(tmpdir, config_dict)
  502. hidden_dim = 10
  503. model = SimpleModel(hidden_dim, empty_grad=False)
  504. @distributed_test(world_size=[1, 2])
  505. def _test_adam_amp_o2(args, model, hidden_dim):
  506. model, _, _,_ = deepspeed.initialize(args=args,
  507. model=model,
  508. model_parameters=model.parameters())
  509. data_loader = random_dataloader(model=model,
  510. total_samples=50,
  511. hidden_dim=hidden_dim,
  512. device=model.device)
  513. for n, batch in enumerate(data_loader):
  514. loss = model(batch[0], batch[1])
  515. model.backward(loss)
  516. model.step()
  517. _test_adam_amp_o2(args=args, model=model, hidden_dim=hidden_dim)
  518. def test_adam_amp_o2_empty_grad(tmpdir):
  519. config_dict = {
  520. "train_batch_size": 2,
  521. "steps_per_print": 1,
  522. "optimizer": {
  523. "type": "Adam",
  524. "params": {
  525. "lr": 0.00015
  526. }
  527. },
  528. "gradient_clipping": 1.0,
  529. "amp": {
  530. "enabled": True,
  531. "opt_level": "O2"
  532. }
  533. }
  534. args = args_from_dict(tmpdir, config_dict)
  535. hidden_dim = 10
  536. model = SimpleModel(hidden_dim, empty_grad=False, rank=args.local_rank)
  537. @distributed_test(world_size=[2])
  538. def _test_adam_amp_o2_empty_grad(args, model, hidden_dim):
  539. model, _, _,_ = deepspeed.initialize(args=args,
  540. model=model,
  541. model_parameters=model.parameters())
  542. data_loader = random_dataloader(model=model,
  543. total_samples=50,
  544. hidden_dim=hidden_dim,
  545. device=model.device)
  546. for n, batch in enumerate(data_loader):
  547. loss = model(batch[0], batch[1])
  548. model.backward(loss)
  549. model.step()
  550. _test_adam_amp_o2_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
  551. @pytest.mark.parametrize('zero_stage, optimizer_constructor',
  552. [(1,
  553. apex.optimizers.FusedAdam),
  554. (2,
  555. torch.optim.Adam),
  556. (2,
  557. apex.optimizers.FusedAdam)])
  558. def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_constructor):
  559. config_dict = {
  560. "train_batch_size": 2,
  561. "steps_per_print": 1,
  562. "fp16": {
  563. "enabled": True
  564. },
  565. "zero_optimization": {
  566. "stage": zero_stage
  567. }
  568. }
  569. args = args_from_dict(tmpdir, config_dict)
  570. hidden_dim = 10
  571. model = SimpleModel(hidden_dim, empty_grad=False)
  572. @distributed_test(world_size=[1])
  573. def _test_zero_supported_client_optimizer(args, model, optimizer_constructor):
  574. client_optimizer = optimizer_constructor(params=model.parameters())
  575. model, _, _, _ = deepspeed.initialize(args=args,
  576. model=model,
  577. optimizer=client_optimizer)
  578. _test_zero_supported_client_optimizer(args=args,
  579. model=model,
  580. optimizer_constructor=optimizer_constructor)
  581. def test_zero2_reduce_scatter_off(tmpdir):
  582. config_dict = {
  583. "train_batch_size": 2,
  584. "steps_per_print": 1,
  585. "optimizer": {
  586. "type": "Adam",
  587. "params": {
  588. "lr": 0.00015
  589. }
  590. },
  591. "gradient_clipping": 1.0,
  592. "zero_optimization": {
  593. "stage": 2,
  594. "contiguous_gradients": True,
  595. "allgather_bucket_size": 2000000000,
  596. "reduce_bucket_size": 200000000,
  597. "overlap_comm": False,
  598. "reduce_scatter": False
  599. },
  600. "fp16": {
  601. "enabled": True
  602. }
  603. }
  604. args = args_from_dict(tmpdir, config_dict)
  605. hidden_dim = 10
  606. model = SimpleModel(hidden_dim, rank=args.local_rank)
  607. @distributed_test(world_size=[2])
  608. def _helper(args, model, hidden_dim):
  609. model, _, _,_ = deepspeed.initialize(args=args,
  610. model=model,
  611. model_parameters=model.parameters())
  612. data_loader = random_dataloader(model=model,
  613. total_samples=50,
  614. hidden_dim=hidden_dim,
  615. device=model.device)
  616. for n, batch in enumerate(data_loader):
  617. loss = model(batch[0], batch[1])
  618. model.backward(loss)
  619. model.step()
  620. _helper(args=args, model=model, hidden_dim=hidden_dim)