test_dynamic_loss_scale.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import torch
  2. import deepspeed
  3. import argparse
  4. import pytest
  5. import json
  6. import os
  7. import numpy as np
  8. from common import distributed_test
  9. from simple_model import SimpleModel, args_from_dict
  10. lamb_available = pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
  11. reason="lamb is not installed")
  12. def run_model_step(model, gradient_list):
  13. for value in gradient_list:
  14. for p in model.parameters():
  15. p.grad = torch.empty_like(p, dtype=p.dtype)
  16. p.grad.fill_(value)
  17. model.step()
  18. def test_fused_no_overflow(tmpdir):
  19. config_dict = {
  20. "train_batch_size": 1,
  21. "steps_per_print": 1,
  22. "optimizer": {
  23. "type": "Adam",
  24. "params": {
  25. "lr": 0.00015
  26. }
  27. },
  28. "fp16": {
  29. "enabled": True,
  30. "loss_scale": 0,
  31. "initial_scale_power": 8,
  32. "loss_scale_window": 2
  33. }
  34. }
  35. args = args_from_dict(tmpdir, config_dict)
  36. @distributed_test(world_size=1)
  37. def _test_fused_no_overflow(args):
  38. hidden_dim = 1
  39. model = SimpleModel(hidden_dim, empty_grad=True)
  40. model, optim, _, _ = deepspeed.initialize(args=args,
  41. model=model,
  42. model_parameters=model.parameters())
  43. expected_loss_scale = 2**8
  44. expected_scale_window = 2
  45. # Ensure the dynamic loss scaler is correctly configured.
  46. assert optim.dynamic_loss_scale == True
  47. assert optim.cur_scale == expected_loss_scale
  48. assert optim.scale_window == expected_scale_window
  49. for i, value in enumerate(np.random.uniform(-0.1, 0.1, 10)):
  50. run_model_step(model, [value])
  51. assert optim.cur_scale == expected_loss_scale
  52. assert optim.cur_iter == (i + 1)
  53. if optim.cur_iter % expected_scale_window == 0:
  54. expected_loss_scale *= 2
  55. _test_fused_no_overflow(args)
  56. def test_fused_all_overflow(tmpdir):
  57. config_dict = {
  58. "train_batch_size": 1,
  59. "steps_per_print": 1,
  60. "optimizer": {
  61. "type": "Adam",
  62. "params": {
  63. "lr": 0.00015
  64. }
  65. },
  66. "fp16": {
  67. "enabled": True,
  68. "loss_scale": 0,
  69. "initial_scale_power": 4,
  70. "loss_scale_window": 2
  71. }
  72. }
  73. args = args_from_dict(tmpdir, config_dict)
  74. @distributed_test(world_size=1)
  75. def _test_fused_all_overflow(args):
  76. hidden_dim = 1
  77. model = SimpleModel(hidden_dim, empty_grad=True)
  78. model, optim, _, _ = deepspeed.initialize(args=args,
  79. model=model,
  80. model_parameters=model.parameters())
  81. expected_loss_scale = 2**4
  82. # Ensure the dynamic loss scaler is correctly configured.
  83. assert optim.dynamic_loss_scale == True
  84. assert optim.cur_scale == expected_loss_scale
  85. overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6
  86. for i, value in enumerate(overflow_gradients):
  87. run_model_step(model, [value])
  88. expected_loss_scale = max(expected_loss_scale / 2, 1)
  89. assert optim.cur_scale == expected_loss_scale
  90. assert optim.cur_iter == (i + 1)
  91. _test_fused_all_overflow(args)
  92. def test_fused_some_overflow(tmpdir):
  93. config_dict = {
  94. "train_batch_size": 1,
  95. "steps_per_print": 1,
  96. "optimizer": {
  97. "type": "Adam",
  98. "params": {
  99. "lr": 0.00015
  100. }
  101. },
  102. "fp16": {
  103. "enabled": True,
  104. "loss_scale": 0,
  105. "initial_scale_power": 8,
  106. "loss_scale_window": 2
  107. }
  108. }
  109. args = args_from_dict(tmpdir, config_dict)
  110. @distributed_test(world_size=1)
  111. def _test_fused_some_overflow(args):
  112. hidden_dim = 1
  113. model = SimpleModel(hidden_dim, empty_grad=True)
  114. model, optim, _, _ = deepspeed.initialize(args=args,
  115. model=model,
  116. model_parameters=model.parameters())
  117. expected_loss_scale = 2**8
  118. expected_scale_window = 2
  119. expected_iteration = 0
  120. # Ensure the dynamic loss scaler is correctly configured.
  121. assert optim.dynamic_loss_scale == True
  122. assert optim.cur_scale == expected_loss_scale
  123. assert optim.scale_window == expected_scale_window
  124. # Run model with overflows to decrease scale
  125. overflow_gradients = [float('inf'), float('nan')]
  126. expected_iteration += len(overflow_gradients)
  127. run_model_step(model, overflow_gradients)
  128. expected_loss_scale /= (2**len(overflow_gradients))
  129. assert optim.cur_scale == expected_loss_scale
  130. assert optim.cur_iter == expected_iteration
  131. # Run model scale_window + 1 times to increase scale once
  132. normal_gradients = np.random.uniform(-0.1, 0.1, expected_scale_window + 1)
  133. expected_iteration += len(normal_gradients)
  134. run_model_step(model, normal_gradients)
  135. expected_loss_scale *= 2
  136. assert optim.cur_scale == expected_loss_scale
  137. assert optim.cur_iter == expected_iteration
  138. # Run model with overflows to decrease scale
  139. overflow_gradients = [float('inf')]
  140. expected_iteration += len(overflow_gradients)
  141. run_model_step(model, overflow_gradients)
  142. expected_loss_scale /= (2**len(overflow_gradients))
  143. assert optim.cur_scale == expected_loss_scale
  144. assert optim.cur_iter == expected_iteration
  145. _test_fused_some_overflow(args)
  146. @lamb_available
  147. def test_unfused_no_overflow(tmpdir):
  148. config_dict = {
  149. "train_batch_size": 1,
  150. "steps_per_print": 1,
  151. "optimizer": {
  152. "type": "Lamb",
  153. "params": {
  154. "lr": 0.00015
  155. }
  156. },
  157. "fp16": {
  158. "enabled": True,
  159. "loss_scale": 0,
  160. "initial_scale_power": 8,
  161. "loss_scale_window": 2
  162. }
  163. }
  164. args = args_from_dict(tmpdir, config_dict)
  165. @distributed_test(world_size=1)
  166. def _test_unfused_no_overflow(args):
  167. hidden_dim = 1
  168. model = SimpleModel(hidden_dim, empty_grad=True)
  169. model, optim, _, _ = deepspeed.initialize(args=args,
  170. model=model,
  171. model_parameters=model.parameters())
  172. expected_loss_scale = 2**8
  173. expected_scale_window = 2
  174. # Ensure the dynamic loss scaler is correctly configured.
  175. assert optim.dynamic_loss_scale == True
  176. assert optim.cur_scale == expected_loss_scale
  177. assert optim.scale_window == expected_scale_window
  178. for i, value in enumerate(np.random.uniform(-0.1, 0.1, 10)):
  179. run_model_step(model, [value])
  180. assert optim.cur_scale == expected_loss_scale
  181. assert optim.cur_iter == (i + 1)
  182. if optim.cur_iter % expected_scale_window == 0:
  183. expected_loss_scale *= 2
  184. _test_unfused_no_overflow(args)
  185. @lamb_available
  186. def test_unfused_all_overflow(tmpdir):
  187. config_dict = {
  188. "train_batch_size": 1,
  189. "steps_per_print": 1,
  190. "optimizer": {
  191. "type": "Lamb",
  192. "params": {
  193. "lr": 0.00015
  194. }
  195. },
  196. "fp16": {
  197. "enabled": True,
  198. "loss_scale": 0,
  199. "initial_scale_power": 4,
  200. "loss_scale_window": 2,
  201. "min_loss_scale": 0.25
  202. }
  203. }
  204. args = args_from_dict(tmpdir, config_dict)
  205. @distributed_test(world_size=1)
  206. def _test_unfused_all_overflow(args):
  207. hidden_dim = 1
  208. model = SimpleModel(hidden_dim, empty_grad=True)
  209. model, optim, _, _ = deepspeed.initialize(args=args,
  210. model=model,
  211. model_parameters=model.parameters())
  212. expected_loss_scale = 2**4
  213. expected_min_loss_scale = 0.25
  214. # Ensure the dynamic loss scaler is correctly configured.
  215. assert optim.dynamic_loss_scale == True
  216. assert optim.cur_scale == expected_loss_scale
  217. assert optim.min_loss_scale == expected_min_loss_scale
  218. overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6
  219. for i, value in enumerate(overflow_gradients):
  220. run_model_step(model, [value])
  221. expected_loss_scale = max(expected_loss_scale / 2, expected_min_loss_scale)
  222. assert optim.cur_scale == expected_loss_scale
  223. assert optim.cur_iter == (i + 1)
  224. _test_unfused_all_overflow(args)
  225. @lamb_available
  226. def test_unfused_some_overflow(tmpdir):
  227. config_dict = {
  228. "train_batch_size": 1,
  229. "steps_per_print": 1,
  230. "optimizer": {
  231. "type": "Lamb",
  232. "params": {
  233. "lr": 0.00015
  234. }
  235. },
  236. "fp16": {
  237. "enabled": True,
  238. "loss_scale": 0,
  239. "initial_scale_power": 8,
  240. "loss_scale_window": 2
  241. }
  242. }
  243. args = args_from_dict(tmpdir, config_dict)
  244. @distributed_test(world_size=1)
  245. def _test_unfused_some_overflow(args):
  246. hidden_dim = 1
  247. model = SimpleModel(hidden_dim, empty_grad=True)
  248. model, optim, _, _ = deepspeed.initialize(args=args,
  249. model=model,
  250. model_parameters=model.parameters())
  251. expected_loss_scale = 2**8
  252. expected_scale_window = 2
  253. expected_iteration = 0
  254. # Ensure the dynamic loss scaler is correctly configured.
  255. assert optim.dynamic_loss_scale == True
  256. assert optim.cur_scale == expected_loss_scale
  257. assert optim.scale_window == expected_scale_window
  258. # Run model with overflows to decrease scale
  259. overflow_gradients = [float('inf'), float('nan')]
  260. expected_iteration += len(overflow_gradients)
  261. run_model_step(model, overflow_gradients)
  262. expected_loss_scale /= (2**len(overflow_gradients))
  263. assert optim.cur_scale == expected_loss_scale
  264. assert optim.cur_iter == expected_iteration
  265. # Run model scale_window + 1 times to increase scale once
  266. normal_gradients = np.random.uniform(-0.1, 0.1, expected_scale_window + 1)
  267. expected_iteration += len(normal_gradients)
  268. run_model_step(model, normal_gradients)
  269. expected_loss_scale *= 2
  270. assert optim.cur_scale == expected_loss_scale
  271. assert optim.cur_iter == expected_iteration
  272. # Run model with overflows to decrease scale
  273. overflow_gradients = [float('inf')]
  274. expected_iteration += len(overflow_gradients)
  275. run_model_step(model, overflow_gradients)
  276. expected_loss_scale /= (2**len(overflow_gradients))
  277. assert optim.cur_scale == expected_loss_scale
  278. assert optim.cur_iter == expected_iteration
  279. _test_unfused_some_overflow(args)