unfused_optimizer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Copyright NVIDIA/apex
  6. This file is adapted from FP16_Optimizer in NVIDIA/apex
  7. """
  8. from deepspeed.moe.utils import split_params_grads_into_shared_and_expert_params
  9. import torch
  10. from torch._utils import _flatten_dense_tensors
  11. from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
  12. from deepspeed.runtime.utils import get_global_norm, CheckOverflow, get_weight_norm
  13. from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
  14. from deepspeed.utils import logger
  15. from deepspeed.utils.torch import required_torch_version
  16. from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
  17. from deepspeed.accelerator import get_accelerator
  18. from deepspeed import comm as dist
  19. class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
  20. """
  21. FP16 Optimizer without weight fusion to support LAMB optimizer
  22. For usage example please see, TODO: DeepSpeed V2 Tutorial
  23. """
  24. def __init__(self,
  25. init_optimizer,
  26. deepspeed=None,
  27. static_loss_scale=1.0,
  28. dynamic_loss_scale=False,
  29. dynamic_loss_args=None,
  30. verbose=True,
  31. mpu=None,
  32. clip_grad=0.0,
  33. fused_lamb_legacy=False):
  34. self.fused_lamb_legacy = fused_lamb_legacy
  35. self._global_grad_norm = 0.
  36. if dist.get_rank() == 0:
  37. logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
  38. if not get_accelerator().is_available():
  39. raise SystemError("Cannot use fp16 without accelerator.")
  40. self.optimizer = init_optimizer
  41. # param groups
  42. self.fp16_groups = []
  43. self.fp32_groups = []
  44. # loop to deal with groups
  45. for i, param_group in enumerate(self.optimizer.param_groups):
  46. #fp16 weights that represents the actual model weights
  47. self.fp16_groups.append(param_group['params'])
  48. #creating a fp32 copy of the weights that will be updated first then
  49. #copied to fp16 weights
  50. fp32_group = [p.clone().float().detach() for p in param_group['params']]
  51. #in case the internal optimizer needs it
  52. for p in fp32_group:
  53. p.requires_grad = True
  54. #setting the param groups in the optimizer to point to fp32
  55. #note these are not the weights used by the model
  56. #the model uses the fp16 version that we added to fp16_group
  57. self.fp32_groups.append(fp32_group)
  58. param_group['params'] = self.fp32_groups[i]
  59. # we may have a way of fusing dynamic scale. Do not support for now
  60. if dynamic_loss_scale:
  61. self.dynamic_loss_scale = True
  62. self.cur_iter = 0
  63. self.last_overflow_iter = -1
  64. self.scale_factor = 2.0
  65. if dynamic_loss_args is None:
  66. self.cur_scale = 1.0 * 2**16
  67. self.scale_window = 1000
  68. self.min_loss_scale = 0.25
  69. else:
  70. self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
  71. self.scale_window = dynamic_loss_args[SCALE_WINDOW]
  72. self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
  73. else:
  74. self.dynamic_loss_scale = False
  75. self.cur_iter = 0
  76. self.cur_scale = static_loss_scale
  77. self.custom_loss_scaler = False
  78. self.external_loss_scale = None
  79. self.verbose = verbose
  80. self.clip_grad = clip_grad
  81. self.norm_type = 2
  82. if required_torch_version(max_version=0.4):
  83. self.clip_grad_norm = torch.nn.utils.clip_grad_norm
  84. else:
  85. self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
  86. self.mpu = mpu
  87. self.overflow = False
  88. self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed)
  89. self.initialize_optimizer_states()
  90. def zero_grad(self, set_to_none=True):
  91. """
  92. Zero FP16 parameter grads.
  93. """
  94. # FP32 grad should never exist outside of the step function
  95. # For speed, set model fp16 grad to None by default
  96. for group in self.fp16_groups:
  97. for p in group:
  98. if set_to_none:
  99. p.grad = None
  100. else:
  101. if p.grad is not None:
  102. p.grad.detach_()
  103. p.grad.zero_()
  104. def step_fused_lamb(self, closure=None):
  105. """
  106. Not supporting closure.
  107. """
  108. # First compute norm for all group so we know if there is overflow
  109. grads_groups_flat = []
  110. grads_groups = []
  111. norm_groups = []
  112. expert_norm_groups = []
  113. for i, group in enumerate(self.fp16_groups):
  114. grads = [
  115. torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group
  116. ]
  117. grads_groups.append(grads)
  118. grads_groups_flat.append(_flatten_dense_tensors(grads))
  119. grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params(group)
  120. norm_group_value = 0.0
  121. if len(grads_for_norm) > 0:
  122. norm_group_value = get_weight_norm(_flatten_dense_tensors(grads_for_norm), mpu=self.mpu)
  123. norm_groups.append(norm_group_value)
  124. expert_norm_group_value = 0.0
  125. if len(expert_grads_for_norm) > 0:
  126. expert_norm_group_value = get_weight_norm(_flatten_dense_tensors(expert_grads_for_norm), mpu=self.mpu)
  127. expert_norm_groups.append(expert_norm_group_value)
  128. self.overflow = self.overflow_checker.check_using_norm(norm_groups + expert_norm_groups)
  129. prev_scale = self.cur_scale
  130. self._update_scale(self.overflow)
  131. if self.overflow:
  132. if self.verbose:
  133. logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  134. "scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
  135. return self.overflow
  136. self._global_grad_norm = get_global_norm(norm_list=norm_groups)
  137. combined_scale = self.unscale_and_clip_grads(self._global_grad_norm, apply_scale=False)
  138. self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale)
  139. for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
  140. for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)):
  141. #remove the fp32 grad
  142. fp32_param.grad = None
  143. #copy data from fp32 to fp16
  144. fp16_param.data.copy_(fp32_param.data)
  145. return self.overflow
  146. def set_lr(self, lr):
  147. """Set the learning rate."""
  148. for param_group in self.optimizer.param_groups:
  149. param_group["lr"] = lr
  150. def get_lr(self):
  151. """Return the current learning rate."""
  152. return self.optimizer.param_groups[0]["lr"]
  153. def override_loss_scale(self, loss_scale):
  154. if loss_scale != self.external_loss_scale:
  155. logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
  156. self.custom_loss_scaler = True
  157. self.external_loss_scale = loss_scale
  158. def step(self, closure=None):
  159. """
  160. Not supporting closure.
  161. """
  162. if self.fused_lamb_legacy:
  163. return self.step_fused_lamb()
  164. self.overflow = self.overflow_checker.check()
  165. prev_scale = self.cur_scale
  166. self._update_scale(self.overflow)
  167. if self.overflow:
  168. if self.verbose:
  169. logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  170. "scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
  171. return self.overflow
  172. norm_groups = []
  173. for i, group in enumerate(self.fp16_groups):
  174. grads_for_norm, _ = split_params_grads_into_shared_and_expert_params(group)
  175. norm_group_value = 0.0
  176. if len(grads_for_norm) > 0:
  177. norm_group_value = get_weight_norm(grads_for_norm, mpu=self.mpu)
  178. norm_groups.append(norm_group_value)
  179. # copying gradients to fp32 to work with fp32 parameters
  180. for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]):
  181. if fp16_param.grad is None:
  182. fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device)
  183. else:
  184. fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)
  185. self._global_grad_norm = get_global_norm(norm_list=norm_groups)
  186. self.unscale_and_clip_grads(self._global_grad_norm)
  187. self.optimizer.step()
  188. for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
  189. for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)):
  190. #remove the fp32 grad
  191. fp32_param.grad = None
  192. #copy data from fp32 to fp16
  193. fp16_param.data.copy_(fp32_param.data)
  194. return self.overflow
  195. def unscale_and_clip_grads(self, total_norm, apply_scale=True):
  196. # compute combined scale factor for this group
  197. combined_scale = self.cur_scale
  198. if self.clip_grad > 0.:
  199. # norm is in fact norm*scale
  200. clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
  201. if clip > 1:
  202. combined_scale = clip * self.cur_scale
  203. if apply_scale:
  204. for group in self.fp32_groups:
  205. for param in group:
  206. if param.grad is not None:
  207. param.grad.data.mul_(1. / combined_scale)
  208. return combined_scale
  209. def backward(self, loss, create_graph=False, retain_graph=False):
  210. """
  211. :attr:`backward` performs the following steps:
  212. 1. fp32_loss = loss.float()
  213. 2. scaled_loss = fp32_loss*loss_scale
  214. 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
  215. """
  216. if self.custom_loss_scaler:
  217. scaled_loss = self.external_loss_scale * loss
  218. scaled_loss.backward()
  219. else:
  220. scaled_loss = (loss.float()) * self.cur_scale
  221. scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)
  222. def _update_scale(self, skip):
  223. if self.dynamic_loss_scale:
  224. prev_scale = self.cur_scale
  225. if skip:
  226. self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale)
  227. self.last_overflow_iter = self.cur_iter
  228. if self.verbose:
  229. logger.info("Grad overflow on iteration: %s", self.cur_iter)
  230. logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  231. else:
  232. # Ensure self.scale_window updates since last overflow
  233. stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
  234. if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
  235. self.cur_scale *= self.scale_factor
  236. if self.verbose:
  237. logger.info(f"No Grad overflow for {self.scale_window} iterations")
  238. logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  239. else:
  240. if skip:
  241. logger.info("Grad overflow on iteration %s", self.cur_iter)
  242. logger.info("Using static loss scale of %s", self.cur_scale)
  243. self.cur_iter += 1
  244. return
  245. # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
  246. def _get_state(self):
  247. return self.optimizer.state
  248. def _set_state(self, value):
  249. self.optimizer.state = value
  250. state = property(_get_state, _set_state)
  251. # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
  252. # (for example, to adjust the learning rate)
  253. def _get_param_groups(self):
  254. return self.optimizer.param_groups
  255. def _set_param_groups(self, value):
  256. self.optimizer.param_groups = value
  257. param_groups = property(_get_param_groups, _set_param_groups)
  258. # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
  259. def _get_loss_scale(self):
  260. if self.custom_loss_scaler:
  261. return self.external_loss_scale
  262. else:
  263. return self.cur_scale
  264. def _set_loss_scale(self, value):
  265. self.loss_scaler.cur_scale = value
  266. loss_scale = property(_get_loss_scale, _set_loss_scale)
  267. def state_dict(self):
  268. """
  269. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
  270. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
  271. of the contained Pytorch optimizer.
  272. Example::
  273. checkpoint = {}
  274. checkpoint['model'] = model.state_dict()
  275. checkpoint['optimizer'] = optimizer.state_dict()
  276. torch.save(checkpoint, "saved.pth")
  277. """
  278. state_dict = {}
  279. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  280. state_dict['cur_scale'] = self.cur_scale
  281. state_dict['cur_iter'] = self.cur_iter
  282. if state_dict['dynamic_loss_scale']:
  283. state_dict['last_overflow_iter'] = self.last_overflow_iter
  284. state_dict['scale_factor'] = self.scale_factor
  285. state_dict['scale_window'] = self.scale_window
  286. state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
  287. state_dict['fp32_groups'] = self.fp32_groups
  288. return state_dict
  289. # Refresh fp32 master params from fp16 copies
  290. def refresh_fp32_params(self):
  291. for current_group, saved_group in zip(self.fp32_groups, self.fp16_groups):
  292. for current, saved in zip(current_group, saved_group):
  293. current.data.copy_(saved.data)
  294. def load_state_dict(self, state_dict, load_optimizer_states=True):
  295. """
  296. Loads a state_dict created by an earlier call to state_dict().
  297. If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
  298. whose parameters in turn came from ``model``, it is expected that the user
  299. will call ``model.load_state_dict()`` before
  300. ``fp16_optimizer_instance.load_state_dict()`` is called.
  301. Example::
  302. model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
  303. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  304. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  305. ...
  306. checkpoint = torch.load("saved.pth")
  307. model.load_state_dict(checkpoint['model'])
  308. optimizer.load_state_dict(checkpoint['optimizer'])
  309. """
  310. # I think it should actually be ok to reload the optimizer before the model.
  311. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  312. self.cur_scale = state_dict['cur_scale']
  313. self.cur_iter = state_dict['cur_iter']
  314. if state_dict['dynamic_loss_scale']:
  315. self.last_overflow_iter = state_dict['last_overflow_iter']
  316. self.scale_factor = state_dict['scale_factor']
  317. self.scale_window = state_dict['scale_window']
  318. if load_optimizer_states:
  319. self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
  320. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
  321. # The optimizer's hyperparameters and internal buffers are also up to date.
  322. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
  323. # out of date. There are two options.
  324. # 1: Refresh the master params from the model's fp16 params.
  325. # This requires less storage but incurs precision loss.
  326. # 2: Save and restore the fp32 master copies separately.
  327. # We choose option 2.
  328. #
  329. # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
  330. # of their associated parameters, because it's possible those buffers might not exist yet in
  331. # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
  332. # constructed in the same way as the one whose state_dict we are loading, the same master params
  333. # are guaranteed to exist, so we can just copy_() from the saved master params.
  334. for current_group, saved_group in zip(self.fp32_groups, state_dict['fp32_groups']):
  335. for current, saved in zip(current_group, saved_group):
  336. current.data.copy_(saved.data)
  337. def __repr__(self):
  338. return repr(self.optimizer)
  339. def initialize_optimizer_states(self):
  340. for i, group in enumerate(self.fp16_groups):
  341. for param in group:
  342. param.grad = torch.zeros(param.size(),
  343. dtype=param.dtype,
  344. device=get_accelerator().current_device_name())
  345. for i, group in enumerate(self.fp32_groups):
  346. for param in group:
  347. param.grad = torch.zeros(param.size(),
  348. dtype=param.dtype,
  349. device=get_accelerator().current_device_name())
  350. self.optimizer.step()
  351. for i, group in enumerate(self.fp16_groups):
  352. for param in group:
  353. param.grad = None
  354. for i, group in enumerate(self.fp32_groups):
  355. for param in group:
  356. param.grad = None