unfused_optimizer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Copyright NVIDIA/apex
  6. This file is adapted from FP16_Optimizer in NVIDIA/apex
  7. """
  8. from deepspeed.moe.utils import split_params_grads_into_shared_and_expert_params
  9. import torch
  10. from torch._utils import _flatten_dense_tensors
  11. from deepspeed.runtime import DeepSpeedOptimizer
  12. from deepspeed.runtime.utils import get_global_norm, CheckOverflow, get_weight_norm, required_torch_version
  13. from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
  14. from deepspeed.utils import logger
  15. from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
  16. from deepspeed.accelerator import get_accelerator
  17. from deepspeed import comm as dist
  18. class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
  19. """
  20. FP16 Optimizer without weight fusion to support LAMB optimizer
  21. For usage example please see, TODO: DeepSpeed V2 Tutorial
  22. """
  23. def __init__(self,
  24. init_optimizer,
  25. deepspeed=None,
  26. static_loss_scale=1.0,
  27. dynamic_loss_scale=False,
  28. dynamic_loss_args=None,
  29. verbose=True,
  30. mpu=None,
  31. clip_grad=0.0,
  32. fused_lamb_legacy=False):
  33. self.fused_lamb_legacy = fused_lamb_legacy
  34. self._global_grad_norm = 0.
  35. if dist.get_rank() == 0:
  36. logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
  37. if not get_accelerator().is_available():
  38. raise SystemError("Cannot use fp16 without accelerator.")
  39. self.optimizer = init_optimizer
  40. # param groups
  41. self.fp16_groups = []
  42. self.fp32_groups = []
  43. # loop to deal with groups
  44. for i, param_group in enumerate(self.optimizer.param_groups):
  45. #fp16 weights that represents the actual model weights
  46. self.fp16_groups.append(param_group['params'])
  47. #creating a fp32 copy of the weights that will be updated first then
  48. #copied to fp16 weights
  49. fp32_group = [p.clone().float().detach() for p in param_group['params']]
  50. #in case the internal optimizer needs it
  51. for p in fp32_group:
  52. p.requires_grad = True
  53. #setting the param groups in the optimizer to point to fp32
  54. #note these are not the weights used by the model
  55. #the model uses the fp16 version that we added to fp16_group
  56. self.fp32_groups.append(fp32_group)
  57. param_group['params'] = self.fp32_groups[i]
  58. # we may have a way of fusing dynamic scale. Do not support for now
  59. if dynamic_loss_scale:
  60. self.dynamic_loss_scale = True
  61. self.cur_iter = 0
  62. self.last_overflow_iter = -1
  63. self.scale_factor = 2.0
  64. if dynamic_loss_args is None:
  65. self.cur_scale = 1.0 * 2**16
  66. self.scale_window = 1000
  67. self.min_loss_scale = 0.25
  68. else:
  69. self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
  70. self.scale_window = dynamic_loss_args[SCALE_WINDOW]
  71. self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
  72. else:
  73. self.dynamic_loss_scale = False
  74. self.cur_iter = 0
  75. self.cur_scale = static_loss_scale
  76. self.custom_loss_scaler = False
  77. self.external_loss_scale = None
  78. self.verbose = verbose
  79. self.clip_grad = clip_grad
  80. self.norm_type = 2
  81. if required_torch_version(max_version=0.4):
  82. self.clip_grad_norm = torch.nn.utils.clip_grad_norm
  83. else:
  84. self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
  85. self.mpu = mpu
  86. self.overflow = False
  87. self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed)
  88. self.initialize_optimizer_states()
  89. def zero_grad(self, set_to_none=False):
  90. """
  91. Zero FP16 parameter grads.
  92. """
  93. # FP32 grad should never exist outside of the step function
  94. # For speed, set model fp16 grad to None by default
  95. for group in self.fp16_groups:
  96. for p in group:
  97. if set_to_none:
  98. p.grad = None
  99. else:
  100. if p.grad is not None:
  101. p.grad.detach_()
  102. p.grad.zero_()
  103. def step_fused_lamb(self, closure=None):
  104. """
  105. Not supporting closure.
  106. """
  107. # First compute norm for all group so we know if there is overflow
  108. grads_groups_flat = []
  109. grads_groups = []
  110. norm_groups = []
  111. expert_norm_groups = []
  112. for i, group in enumerate(self.fp16_groups):
  113. grads = [
  114. torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group
  115. ]
  116. grads_groups.append(grads)
  117. grads_groups_flat.append(_flatten_dense_tensors(grads))
  118. grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params(group)
  119. norm_group_value = 0.0
  120. if len(grads_for_norm) > 0:
  121. norm_group_value = get_weight_norm(_flatten_dense_tensors(grads_for_norm), mpu=self.mpu)
  122. norm_groups.append(norm_group_value)
  123. expert_norm_group_value = 0.0
  124. if len(expert_grads_for_norm) > 0:
  125. expert_norm_group_value = get_weight_norm(_flatten_dense_tensors(expert_grads_for_norm), mpu=self.mpu)
  126. expert_norm_groups.append(expert_norm_group_value)
  127. self.overflow = self.overflow_checker.check_using_norm(norm_groups + expert_norm_groups)
  128. prev_scale = self.cur_scale
  129. self._update_scale(self.overflow)
  130. if self.overflow:
  131. if self.verbose:
  132. logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  133. "scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
  134. return self.overflow
  135. self._global_grad_norm = get_global_norm(norm_list=norm_groups)
  136. combined_scale = self.unscale_and_clip_grads(self._global_grad_norm, apply_scale=False)
  137. self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale)
  138. for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
  139. for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)):
  140. #remove the fp32 grad
  141. fp32_param.grad = None
  142. #copy data from fp32 to fp16
  143. fp16_param.data.copy_(fp32_param.data)
  144. return self.overflow
  145. def set_lr(self, lr):
  146. """Set the learning rate."""
  147. for param_group in self.optimizer.param_groups:
  148. param_group["lr"] = lr
  149. def get_lr(self):
  150. """Return the current learning rate."""
  151. return self.optimizer.param_groups[0]["lr"]
  152. def override_loss_scale(self, loss_scale):
  153. if loss_scale != self.external_loss_scale:
  154. logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
  155. self.custom_loss_scaler = True
  156. self.external_loss_scale = loss_scale
  157. def step(self, closure=None):
  158. """
  159. Not supporting closure.
  160. """
  161. if self.fused_lamb_legacy:
  162. return self.step_fused_lamb()
  163. self.overflow = self.overflow_checker.check()
  164. prev_scale = self.cur_scale
  165. self._update_scale(self.overflow)
  166. if self.overflow:
  167. if self.verbose:
  168. logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  169. "scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
  170. return self.overflow
  171. norm_groups = []
  172. for i, group in enumerate(self.fp16_groups):
  173. grads_for_norm, _ = split_params_grads_into_shared_and_expert_params(group)
  174. norm_group_value = 0.0
  175. if len(grads_for_norm) > 0:
  176. norm_group_value = get_weight_norm(grads_for_norm, mpu=self.mpu)
  177. norm_groups.append(norm_group_value)
  178. # copying gradients to fp32 to wor k with fp32 parameters
  179. for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]):
  180. if fp16_param.grad is None:
  181. fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device)
  182. else:
  183. fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)
  184. self._global_grad_norm = get_global_norm(norm_list=norm_groups)
  185. self.unscale_and_clip_grads(self._global_grad_norm)
  186. self.optimizer.step()
  187. for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
  188. for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)):
  189. #remove the fp32 grad
  190. fp32_param.grad = None
  191. #copy data from fp32 to fp16
  192. fp16_param.data.copy_(fp32_param.data)
  193. return self.overflow
  194. def unscale_and_clip_grads(self, total_norm, apply_scale=True):
  195. # compute combined scale factor for this group
  196. combined_scale = self.cur_scale
  197. if self.clip_grad > 0.:
  198. # norm is in fact norm*scale
  199. clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
  200. if clip > 1:
  201. combined_scale = clip * self.cur_scale
  202. if apply_scale:
  203. for group in self.fp32_groups:
  204. for param in group:
  205. if param.grad is not None:
  206. param.grad.data.mul_(1. / combined_scale)
  207. return combined_scale
  208. def backward(self, loss, create_graph=False, retain_graph=False):
  209. """
  210. :attr:`backward` performs the following steps:
  211. 1. fp32_loss = loss.float()
  212. 2. scaled_loss = fp32_loss*loss_scale
  213. 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
  214. """
  215. if self.custom_loss_scaler:
  216. scaled_loss = self.external_loss_scale * loss
  217. scaled_loss.backward()
  218. else:
  219. scaled_loss = (loss.float()) * self.cur_scale
  220. scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)
  221. def _update_scale(self, skip):
  222. if self.dynamic_loss_scale:
  223. prev_scale = self.cur_scale
  224. if skip:
  225. self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale)
  226. self.last_overflow_iter = self.cur_iter
  227. if self.verbose:
  228. logger.info("Grad overflow on iteration: %s", self.cur_iter)
  229. logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  230. else:
  231. # Ensure self.scale_window updates since last overflow
  232. stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
  233. if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
  234. self.cur_scale *= self.scale_factor
  235. if self.verbose:
  236. logger.info(f"No Grad overflow for {self.scale_window} iterations")
  237. logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  238. else:
  239. if skip:
  240. logger.info("Grad overflow on iteration %s", self.cur_iter)
  241. logger.info("Using static loss scale of %s", self.cur_scale)
  242. self.cur_iter += 1
  243. return
  244. # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
  245. def _get_state(self):
  246. return self.optimizer.state
  247. def _set_state(self, value):
  248. self.optimizer.state = value
  249. state = property(_get_state, _set_state)
  250. # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
  251. # (for example, to adjust the learning rate)
  252. def _get_param_groups(self):
  253. return self.optimizer.param_groups
  254. def _set_param_groups(self, value):
  255. self.optimizer.param_groups = value
  256. param_groups = property(_get_param_groups, _set_param_groups)
  257. # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
  258. def _get_loss_scale(self):
  259. if self.custom_loss_scaler:
  260. return self.external_loss_scale
  261. else:
  262. return self.cur_scale
  263. def _set_loss_scale(self, value):
  264. self.loss_scaler.cur_scale = value
  265. loss_scale = property(_get_loss_scale, _set_loss_scale)
  266. def state_dict(self):
  267. """
  268. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
  269. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
  270. of the contained Pytorch optimizer.
  271. Example::
  272. checkpoint = {}
  273. checkpoint['model'] = model.state_dict()
  274. checkpoint['optimizer'] = optimizer.state_dict()
  275. torch.save(checkpoint, "saved.pth")
  276. """
  277. state_dict = {}
  278. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  279. state_dict['cur_scale'] = self.cur_scale
  280. state_dict['cur_iter'] = self.cur_iter
  281. if state_dict['dynamic_loss_scale']:
  282. state_dict['last_overflow_iter'] = self.last_overflow_iter
  283. state_dict['scale_factor'] = self.scale_factor
  284. state_dict['scale_window'] = self.scale_window
  285. state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
  286. state_dict['fp32_groups'] = self.fp32_groups
  287. return state_dict
  288. # Refresh fp32 master params from fp16 copies
  289. def refresh_fp32_params(self):
  290. for current_group, saved_group in zip(self.fp32_groups, self.fp16_groups):
  291. for current, saved in zip(current_group, saved_group):
  292. current.data.copy_(saved.data)
  293. def load_state_dict(self, state_dict, load_optimizer_states=True):
  294. """
  295. Loads a state_dict created by an earlier call to state_dict().
  296. If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
  297. whose parameters in turn came from ``model``, it is expected that the user
  298. will call ``model.load_state_dict()`` before
  299. ``fp16_optimizer_instance.load_state_dict()`` is called.
  300. Example::
  301. model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
  302. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  303. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  304. ...
  305. checkpoint = torch.load("saved.pth")
  306. model.load_state_dict(checkpoint['model'])
  307. optimizer.load_state_dict(checkpoint['optimizer'])
  308. """
  309. # I think it should actually be ok to reload the optimizer before the model.
  310. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  311. self.cur_scale = state_dict['cur_scale']
  312. self.cur_iter = state_dict['cur_iter']
  313. if state_dict['dynamic_loss_scale']:
  314. self.last_overflow_iter = state_dict['last_overflow_iter']
  315. self.scale_factor = state_dict['scale_factor']
  316. self.scale_window = state_dict['scale_window']
  317. if load_optimizer_states:
  318. self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
  319. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
  320. # The optimizer's hyperparameters and internal buffers are also up to date.
  321. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
  322. # out of date. There are two options.
  323. # 1: Refresh the master params from the model's fp16 params.
  324. # This requires less storage but incurs precision loss.
  325. # 2: Save and restore the fp32 master copies separately.
  326. # We choose option 2.
  327. #
  328. # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
  329. # of their associated parameters, because it's possible those buffers might not exist yet in
  330. # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
  331. # constructed in the same way as the one whose state_dict we are loading, the same master params
  332. # are guaranteed to exist, so we can just copy_() from the saved master params.
  333. for current_group, saved_group in zip(self.fp32_groups, state_dict['fp32_groups']):
  334. for current, saved in zip(current_group, saved_group):
  335. current.data.copy_(saved.data)
  336. def __repr__(self):
  337. return repr(self.optimizer)
  338. def initialize_optimizer_states(self):
  339. for i, group in enumerate(self.fp16_groups):
  340. for param in group:
  341. param.grad = torch.zeros(param.size(),
  342. dtype=param.dtype,
  343. device=get_accelerator().current_device_name())
  344. for i, group in enumerate(self.fp32_groups):
  345. for param in group:
  346. param.grad = torch.zeros(param.size(),
  347. dtype=param.dtype,
  348. device=get_accelerator().current_device_name())
  349. self.optimizer.step()
  350. for i, group in enumerate(self.fp16_groups):
  351. for param in group:
  352. param.grad = None
  353. for i, group in enumerate(self.fp32_groups):
  354. for param in group:
  355. param.grad = None