unfused_optimizer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. '''
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. Copyright NVIDIA/apex
  4. This file is adapted from FP16_Optimizer in NVIDIA/apex
  5. '''
  6. from deepspeed.moe.utils import split_params_grads_into_shared_and_expert_params
  7. import torch
  8. from torch._utils import _flatten_dense_tensors
  9. from deepspeed.runtime import DeepSpeedOptimizer
  10. from deepspeed.runtime.utils import get_global_norm, CheckOverflow, get_weight_norm
  11. from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
  12. from deepspeed.utils import logger
  13. from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
  14. from deepspeed.accelerator import get_accelerator
  15. from deepspeed import comm as dist
  16. class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
  17. """
  18. FP16 Optimizer without weight fusion to support LAMB optimizer
  19. For usage example please see, TODO: DeepSpeed V2 Tutorial
  20. """
  21. def __init__(self,
  22. init_optimizer,
  23. deepspeed=None,
  24. static_loss_scale=1.0,
  25. dynamic_loss_scale=False,
  26. dynamic_loss_args=None,
  27. verbose=True,
  28. mpu=None,
  29. clip_grad=0.0,
  30. fused_lamb_legacy=False):
  31. self.fused_lamb_legacy = fused_lamb_legacy
  32. self._global_grad_norm = 0.
  33. if dist.get_rank() == 0:
  34. logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
  35. if not get_accelerator().is_available():
  36. raise SystemError("Cannot use fp16 without accelerator.")
  37. self.optimizer = init_optimizer
  38. # param groups
  39. self.fp16_groups = []
  40. self.fp32_groups = []
  41. # loop to deal with groups
  42. for i, param_group in enumerate(self.optimizer.param_groups):
  43. #fp16 weights that represents the actual model weights
  44. self.fp16_groups.append(param_group['params'])
  45. #creating a fp32 copy of the weights that will be updated first then
  46. #copied to fp16 weights
  47. fp32_group = [p.clone().float().detach() for p in param_group['params']]
  48. #in case the internal optimizer needs it
  49. for p in fp32_group:
  50. p.requires_grad = True
  51. #setting the param groups in the optimizer to point to fp32
  52. #note these are not the weights used by the model
  53. #the model uses the fp16 version that we added to fp16_group
  54. self.fp32_groups.append(fp32_group)
  55. param_group['params'] = self.fp32_groups[i]
  56. # we may have a way of fusing dynamic scale. Do not support for now
  57. if dynamic_loss_scale:
  58. self.dynamic_loss_scale = True
  59. self.cur_iter = 0
  60. self.last_overflow_iter = -1
  61. self.scale_factor = 2.0
  62. if dynamic_loss_args is None:
  63. self.cur_scale = 1.0 * 2**16
  64. self.scale_window = 1000
  65. self.min_loss_scale = 0.25
  66. else:
  67. self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
  68. self.scale_window = dynamic_loss_args[SCALE_WINDOW]
  69. self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
  70. else:
  71. self.dynamic_loss_scale = False
  72. self.cur_iter = 0
  73. self.cur_scale = static_loss_scale
  74. self.custom_loss_scaler = False
  75. self.external_loss_scale = None
  76. self.verbose = verbose
  77. self.clip_grad = clip_grad
  78. self.norm_type = 2
  79. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  80. TORCH_MINOR = int(torch.__version__.split('.')[1])
  81. if TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
  82. self.clip_grad_norm = torch.nn.utils.clip_grad_norm
  83. else:
  84. self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
  85. self.mpu = mpu
  86. self.overflow = False
  87. self.overflow_checker = CheckOverflow(self.fp16_groups,
  88. mpu=self.mpu,
  89. deepspeed=deepspeed)
  90. self.initialize_optimizer_states()
  91. def zero_grad(self, set_to_none=False):
  92. """
  93. Zero FP16 parameter grads.
  94. """
  95. # FP32 grad should never exist outside of the step function
  96. # For speed, set model fp16 grad to None by default
  97. for group in self.fp16_groups:
  98. for p in group:
  99. if set_to_none:
  100. p.grad = None
  101. else:
  102. if p.grad is not None:
  103. p.grad.detach_()
  104. p.grad.zero_()
  105. def step_fused_lamb(self, closure=None):
  106. """
  107. Not supporting closure.
  108. """
  109. # First compute norm for all group so we know if there is overflow
  110. grads_groups_flat = []
  111. grads_groups = []
  112. norm_groups = []
  113. expert_norm_groups = []
  114. for i, group in enumerate(self.fp16_groups):
  115. grads = [
  116. torch.zeros(p.size(),
  117. dtype=p.dtype,
  118. device=p.device) if p.grad is None else p.grad for p in group
  119. ]
  120. grads_groups.append(grads)
  121. grads_groups_flat.append(_flatten_dense_tensors(grads))
  122. grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params(group)
  123. norm_group_value = 0.0
  124. if len(grads_for_norm) > 0:
  125. norm_group_value = get_weight_norm(
  126. _flatten_dense_tensors(grads_for_norm),
  127. mpu=self.mpu)
  128. norm_groups.append(norm_group_value)
  129. expert_norm_group_value = 0.0
  130. if len(expert_grads_for_norm) > 0:
  131. expert_norm_group_value = get_weight_norm(
  132. _flatten_dense_tensors(expert_grads_for_norm),
  133. mpu=self.mpu)
  134. expert_norm_groups.append(expert_norm_group_value)
  135. self.overflow = self.overflow_checker.check_using_norm(norm_groups +
  136. expert_norm_groups)
  137. prev_scale = self.cur_scale
  138. self._update_scale(self.overflow)
  139. if self.overflow:
  140. if self.verbose:
  141. logger.info(
  142. "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  143. "scale: {}, reducing to {}".format(prev_scale,
  144. self.cur_scale))
  145. return self.overflow
  146. self._global_grad_norm = get_global_norm(norm_list=norm_groups)
  147. combined_scale = self.unscale_and_clip_grads(self._global_grad_norm,
  148. apply_scale=False)
  149. self.optimizer.step(grads=grads_groups,
  150. output_params=self.fp16_groups,
  151. scale=combined_scale)
  152. for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
  153. for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)):
  154. #remove the fp32 grad
  155. fp32_param.grad = None
  156. #copy data from fp32 to fp16
  157. fp16_param.data.copy_(fp32_param.data)
  158. return self.overflow
  159. def set_lr(self, lr):
  160. """Set the learning rate."""
  161. for param_group in self.optimizer.param_groups:
  162. param_group["lr"] = lr
  163. def get_lr(self):
  164. """Return the current learning rate."""
  165. return self.optimizer.param_groups[0]["lr"]
  166. def override_loss_scale(self, loss_scale):
  167. if loss_scale != self.external_loss_scale:
  168. logger.info(
  169. f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}'
  170. )
  171. self.custom_loss_scaler = True
  172. self.external_loss_scale = loss_scale
  173. def step(self, closure=None):
  174. """
  175. Not supporting closure.
  176. """
  177. if self.fused_lamb_legacy:
  178. return self.step_fused_lamb()
  179. self.overflow = self.overflow_checker.check()
  180. prev_scale = self.cur_scale
  181. self._update_scale(self.overflow)
  182. if self.overflow:
  183. if self.verbose:
  184. logger.info(
  185. "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  186. "scale: {}, reducing to {}".format(prev_scale,
  187. self.cur_scale))
  188. return self.overflow
  189. norm_groups = []
  190. for i, group in enumerate(self.fp16_groups):
  191. grads_for_norm, _ = split_params_grads_into_shared_and_expert_params(group)
  192. norm_group_value = 0.0
  193. if len(grads_for_norm) > 0:
  194. norm_group_value = get_weight_norm(grads_for_norm, mpu=self.mpu)
  195. norm_groups.append(norm_group_value)
  196. # copying gradients to fp32 to wor k with fp32 parameters
  197. for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]):
  198. if fp16_param.grad is None:
  199. fp32_param.grad = torch.zeros(fp16_param.size(),
  200. dtype=fp32_param.dtype,
  201. device=fp32_param.device)
  202. else:
  203. fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)
  204. self._global_grad_norm = get_global_norm(norm_list=norm_groups)
  205. self.unscale_and_clip_grads(self._global_grad_norm)
  206. self.optimizer.step()
  207. for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
  208. for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)):
  209. #remove the fp32 grad
  210. fp32_param.grad = None
  211. #copy data from fp32 to fp16
  212. fp16_param.data.copy_(fp32_param.data)
  213. return self.overflow
  214. def unscale_and_clip_grads(self, total_norm, apply_scale=True):
  215. # compute combined scale factor for this group
  216. combined_scale = self.cur_scale
  217. if self.clip_grad > 0.:
  218. # norm is in fact norm*scale
  219. clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
  220. if clip > 1:
  221. combined_scale = clip * self.cur_scale
  222. if apply_scale:
  223. for group in self.fp32_groups:
  224. for param in group:
  225. if param.grad is not None:
  226. param.grad.data.mul_(1. / combined_scale)
  227. return combined_scale
  228. def backward(self, loss, create_graph=False, retain_graph=False):
  229. """
  230. :attr:`backward` performs the following steps:
  231. 1. fp32_loss = loss.float()
  232. 2. scaled_loss = fp32_loss*loss_scale
  233. 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
  234. """
  235. if self.custom_loss_scaler:
  236. scaled_loss = self.external_loss_scale * loss
  237. scaled_loss.backward()
  238. else:
  239. scaled_loss = (loss.float()) * self.cur_scale
  240. scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)
  241. def _update_scale(self, skip):
  242. if self.dynamic_loss_scale:
  243. prev_scale = self.cur_scale
  244. if skip:
  245. self.cur_scale = max(self.cur_scale / self.scale_factor,
  246. self.min_loss_scale)
  247. self.last_overflow_iter = self.cur_iter
  248. if self.verbose:
  249. logger.info("Grad overflow on iteration: %s", self.cur_iter)
  250. logger.info(
  251. f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}"
  252. )
  253. else:
  254. # Ensure self.scale_window updates since last overflow
  255. stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
  256. if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
  257. self.cur_scale *= self.scale_factor
  258. if self.verbose:
  259. logger.info(
  260. f"No Grad overflow for {self.scale_window} iterations")
  261. logger.info(
  262. f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}"
  263. )
  264. else:
  265. if skip:
  266. logger.info("Grad overflow on iteration %s", self.cur_iter)
  267. logger.info("Using static loss scale of %s", self.cur_scale)
  268. self.cur_iter += 1
  269. return
  270. # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
  271. def _get_state(self):
  272. return self.optimizer.state
  273. def _set_state(self, value):
  274. self.optimizer.state = value
  275. state = property(_get_state, _set_state)
  276. # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
  277. # (for example, to adjust the learning rate)
  278. def _get_param_groups(self):
  279. return self.optimizer.param_groups
  280. def _set_param_groups(self, value):
  281. self.optimizer.param_groups = value
  282. param_groups = property(_get_param_groups, _set_param_groups)
  283. # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
  284. def _get_loss_scale(self):
  285. if self.custom_loss_scaler:
  286. return self.external_loss_scale
  287. else:
  288. return self.cur_scale
  289. def _set_loss_scale(self, value):
  290. self.loss_scaler.cur_scale = value
  291. loss_scale = property(_get_loss_scale, _set_loss_scale)
  292. def state_dict(self):
  293. """
  294. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
  295. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
  296. of the contained Pytorch optimizer.
  297. Example::
  298. checkpoint = {}
  299. checkpoint['model'] = model.state_dict()
  300. checkpoint['optimizer'] = optimizer.state_dict()
  301. torch.save(checkpoint, "saved.pth")
  302. """
  303. state_dict = {}
  304. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  305. state_dict['cur_scale'] = self.cur_scale
  306. state_dict['cur_iter'] = self.cur_iter
  307. if state_dict['dynamic_loss_scale']:
  308. state_dict['last_overflow_iter'] = self.last_overflow_iter
  309. state_dict['scale_factor'] = self.scale_factor
  310. state_dict['scale_window'] = self.scale_window
  311. state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
  312. state_dict['fp32_groups'] = self.fp32_groups
  313. return state_dict
  314. # Refresh fp32 master params from fp16 copies
  315. def refresh_fp32_params(self):
  316. for current_group, saved_group in zip(self.fp32_groups, self.fp16_groups):
  317. for current, saved in zip(current_group, saved_group):
  318. current.data.copy_(saved.data)
  319. def load_state_dict(self, state_dict, load_optimizer_states=True):
  320. """
  321. Loads a state_dict created by an earlier call to state_dict().
  322. If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
  323. whose parameters in turn came from ``model``, it is expected that the user
  324. will call ``model.load_state_dict()`` before
  325. ``fp16_optimizer_instance.load_state_dict()`` is called.
  326. Example::
  327. model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
  328. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  329. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  330. ...
  331. checkpoint = torch.load("saved.pth")
  332. model.load_state_dict(checkpoint['model'])
  333. optimizer.load_state_dict(checkpoint['optimizer'])
  334. """
  335. # I think it should actually be ok to reload the optimizer before the model.
  336. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  337. self.cur_scale = state_dict['cur_scale']
  338. self.cur_iter = state_dict['cur_iter']
  339. if state_dict['dynamic_loss_scale']:
  340. self.last_overflow_iter = state_dict['last_overflow_iter']
  341. self.scale_factor = state_dict['scale_factor']
  342. self.scale_window = state_dict['scale_window']
  343. if load_optimizer_states:
  344. self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
  345. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
  346. # The optimizer's hyperparameters and internal buffers are also up to date.
  347. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
  348. # out of date. There are two options.
  349. # 1: Refresh the master params from the model's fp16 params.
  350. # This requires less storage but incurs precision loss.
  351. # 2: Save and restore the fp32 master copies separately.
  352. # We choose option 2.
  353. #
  354. # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
  355. # of their associated parameters, because it's possible those buffers might not exist yet in
  356. # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
  357. # constructed in the same way as the one whose state_dict we are loading, the same master params
  358. # are guaranteed to exist, so we can just copy_() from the saved master params.
  359. for current_group, saved_group in zip(self.fp32_groups, state_dict['fp32_groups']):
  360. for current, saved in zip(current_group, saved_group):
  361. current.data.copy_(saved.data)
  362. def __repr__(self):
  363. return repr(self.optimizer)
  364. def initialize_optimizer_states(self):
  365. for i, group in enumerate(self.fp16_groups):
  366. for param in group:
  367. param.grad = torch.zeros(param.size(),
  368. dtype=param.dtype,
  369. device=get_accelerator().current_device_name())
  370. for i, group in enumerate(self.fp32_groups):
  371. for param in group:
  372. param.grad = torch.zeros(param.size(),
  373. dtype=param.dtype,
  374. device=get_accelerator().current_device_name())
  375. self.optimizer.step()
  376. for i, group in enumerate(self.fp16_groups):
  377. for param in group:
  378. param.grad = None
  379. for i, group in enumerate(self.fp32_groups):
  380. for param in group:
  381. param.grad = None