fused_optimizer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Copyright NVIDIA/apex
  6. This file is adapted from FP16_Optimizer in NVIDIA/apex
  7. """
  8. import torch
  9. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  10. from deepspeed.runtime import DeepSpeedOptimizer
  11. from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version
  12. from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
  13. from deepspeed.utils import groups, logger, log_dist
  14. from deepspeed import comm as dist
  15. from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD
  16. from deepspeed.accelerator import get_accelerator
  17. OVERFLOW_CHECK_TIMER = 'overflow_check'
  18. COMPUTE_NORM_TIMER = 'compute_norm'
  19. UNSCALE_AND_CLIP_TIMER = 'unscale_and_clip'
  20. BASIC_STEP_TIMER = 'basic_step'
  21. UPDATE_FP16_TIMER = 'update_fp16'
  22. OVERFLOW_TIMERS = [COMPUTE_NORM_TIMER, OVERFLOW_CHECK_TIMER]
  23. STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP_TIMER, BASIC_STEP_TIMER, UPDATE_FP16_TIMER]
  24. class FP16_Optimizer(DeepSpeedOptimizer):
  25. """
  26. FP16 Optimizer for training fp16 models. Handles loss scaling.
  27. For usage example please see, TODO: DeepSpeed V2 Tutorial
  28. """
  29. def __init__(self,
  30. init_optimizer,
  31. deepspeed=None,
  32. static_loss_scale=1.0,
  33. dynamic_loss_scale=False,
  34. initial_dynamic_scale=2**32,
  35. dynamic_loss_args=None,
  36. verbose=True,
  37. mpu=None,
  38. clip_grad=0.0,
  39. fused_adam_legacy=False,
  40. has_moe_layers=False,
  41. timers=None):
  42. self.fused_adam_legacy = fused_adam_legacy
  43. self.timers = timers
  44. self.deepspeed = deepspeed
  45. self.has_moe_layers = has_moe_layers
  46. self.using_pipeline = self.deepspeed.pipeline_parallelism
  47. if not get_accelerator().is_available():
  48. raise SystemError("Cannot use fp16 without accelerator.")
  49. self.optimizer = init_optimizer
  50. # param flattened by groups
  51. self.fp16_groups = []
  52. self.fp16_groups_flat = []
  53. self.fp32_groups_flat = []
  54. self._global_grad_norm = 0.
  55. # loop to deal with groups
  56. for i, param_group in enumerate(self.optimizer.param_groups):
  57. # push this group to list before modify
  58. self.fp16_groups.append(param_group['params'])
  59. # init fp16 weight buffer, flattened
  60. self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
  61. # set model fp16 weight to slices of flattened buffer
  62. updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
  63. for p, q in zip(self.fp16_groups[i], updated_params):
  64. p.data = q.data
  65. # init master weight, flattened
  66. self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
  67. # modify optimizer of have flat master weight
  68. self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
  69. param_group['params'] = [self.fp32_groups_flat[i]]
  70. # we may have a way of fusing dynamic scale. Do not support for now
  71. if dynamic_loss_scale:
  72. self.dynamic_loss_scale = True
  73. self.cur_iter = 0
  74. self.last_overflow_iter = -1
  75. self.scale_factor = 2
  76. if dynamic_loss_args is None:
  77. self.cur_scale = initial_dynamic_scale
  78. self.scale_window = 1000
  79. self.min_loss_scale = 1
  80. else:
  81. self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
  82. self.scale_window = dynamic_loss_args[SCALE_WINDOW]
  83. self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
  84. else:
  85. self.dynamic_loss_scale = False
  86. self.cur_iter = 0
  87. self.cur_scale = static_loss_scale
  88. self.verbose = verbose
  89. self.custom_loss_scaler = False
  90. self.external_loss_scale = None
  91. self.clip_grad = clip_grad
  92. self.norm_type = 2
  93. if required_torch_version(max_version=0.4):
  94. self.clip_grad_norm = torch.nn.utils.clip_grad_norm
  95. else:
  96. self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
  97. #model parallel object
  98. self.mpu = mpu
  99. self.overflow = False
  100. self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed)
  101. self.initialize_optimizer_states()
  102. def initialize_optimizer_states(self):
  103. for i, group in enumerate(self.fp16_groups):
  104. self.fp32_groups_flat[i].grad = torch.zeros(self.fp32_groups_flat[i].size(),
  105. device=self.fp32_groups_flat[i].device)
  106. self.optimizer.step()
  107. for i, group in enumerate(self.fp16_groups):
  108. self.fp32_groups_flat[i].grad = None
  109. return
  110. def zero_grad(self, set_to_none=True):
  111. """
  112. Zero FP16 parameter grads.
  113. """
  114. # For speed, set model fp16 grad to None by default
  115. for group in self.fp16_groups:
  116. for p in group:
  117. if set_to_none:
  118. p.grad = None
  119. else:
  120. if p.grad is not None:
  121. p.grad.detach_()
  122. p.grad.zero_()
  123. def step_fused_adam(self, closure=None):
  124. """
  125. Not supporting closure.
  126. """
  127. # First compute norm for all group so we know if there is overflow
  128. grads_groups_flat = []
  129. norm_groups = []
  130. for i, group in enumerate(self.fp16_groups):
  131. grads_groups_flat.append(
  132. _flatten_dense_tensors([
  133. torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group
  134. ]))
  135. norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
  136. self.overflow = self.overflow_checker.check_using_norm(norm_groups)
  137. prev_scale = self.cur_scale
  138. self._update_scale(self.overflow)
  139. if self.overflow:
  140. if self.verbose:
  141. logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  142. "scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
  143. return self.overflow
  144. scaled_grad_norm = get_global_norm(norm_list=norm_groups)
  145. combined_scale = self.unscale_and_clip_grads(grads_groups_flat, scaled_grad_norm, apply_scale=False)
  146. # Stash unscaled gradient norm
  147. self._global_grad_norm = scaled_grad_norm / self.cur_scale
  148. # norm is in fact norm*cur_scale
  149. self.optimizer.step(grads=[[g] for g in grads_groups_flat],
  150. output_params=[[p] for p in self.fp16_groups_flat],
  151. scale=combined_scale,
  152. grad_norms=norm_groups)
  153. # TODO: we probably don't need this? just to be safe
  154. for i in range(len(norm_groups)):
  155. updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
  156. for p, q in zip(self.fp16_groups[i], updated_params):
  157. p.data = q.data
  158. return self.overflow
  159. def set_lr(self, lr):
  160. """Set the learning rate."""
  161. for param_group in self.optimizer.param_groups:
  162. param_group["lr"] = lr
  163. def get_lr(self):
  164. """Return the current learning rate."""
  165. return self.optimizer.param_groups[0]["lr"]
  166. def override_loss_scale(self, loss_scale):
  167. if loss_scale != self.external_loss_scale:
  168. logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
  169. self.custom_loss_scaler = True
  170. self.external_loss_scale = loss_scale
  171. def step(self, closure=None):
  172. """
  173. Not supporting closure.
  174. """
  175. if self.fused_adam_legacy:
  176. return self.step_fused_adam()
  177. # First determine if there is overflow.
  178. self.timers(OVERFLOW_CHECK_TIMER).start()
  179. fp16_params = []
  180. for i, group in enumerate(self.fp16_groups):
  181. fp16_params.extend([p for p in group if p.grad is not None])
  182. self.overflow = self.overflow_checker.has_overflow(fp16_params)
  183. self.timers(OVERFLOW_CHECK_TIMER).stop()
  184. prev_scale = self.cur_scale
  185. self._update_scale(self.overflow)
  186. if self.overflow:
  187. if self.verbose:
  188. log_dist(
  189. "Overflow detected. Skipping step. Attempted loss "
  190. f"scale: {prev_scale}, reducing to {self.cur_scale}",
  191. ranks=[0])
  192. # Clear gradients
  193. for i, group in enumerate(self.fp16_groups):
  194. for p in group:
  195. p.grad = None
  196. self.timers.log(OVERFLOW_TIMERS)
  197. return self.overflow
  198. grads_groups_flat = []
  199. for i, group in enumerate(self.fp16_groups):
  200. data_type = self.fp32_groups_flat[i].dtype
  201. grads_groups_flat.append(
  202. _flatten_dense_tensors([
  203. torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type)
  204. for p in group
  205. ]))
  206. for p in group:
  207. p.grad = None
  208. self.fp32_groups_flat[i].grad = grads_groups_flat[i]
  209. self.timers(COMPUTE_NORM_TIMER).start()
  210. all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
  211. self.timers(COMPUTE_NORM_TIMER).stop()
  212. if self.has_moe_layers:
  213. all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm)
  214. scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])
  215. # Stash unscaled gradient norm
  216. self._global_grad_norm = scaled_global_grad_norm / self.cur_scale
  217. self.timers(UNSCALE_AND_CLIP_TIMER).start()
  218. self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm)
  219. self.timers(UNSCALE_AND_CLIP_TIMER).stop()
  220. self.timers(BASIC_STEP_TIMER).start()
  221. self.optimizer.step()
  222. self.timers(BASIC_STEP_TIMER).stop()
  223. #get rid of the fp32 gradients. Not needed anymore
  224. for group in self.fp32_groups_flat:
  225. group.grad = None
  226. self.timers(UPDATE_FP16_TIMER).start()
  227. for i in range(len(self.fp16_groups)):
  228. updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i])
  229. for p, q in zip(self.fp16_groups[i], updated_params):
  230. p.data.copy_(q.data)
  231. self.timers(UPDATE_FP16_TIMER).stop()
  232. self.timers.log(STEP_TIMERS)
  233. return self.overflow
  234. def _get_norm_with_moe_layers(self, all_groups_norm):
  235. #all_groups_norm_old = all_groups_norm
  236. # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
  237. if self.using_pipeline:
  238. pg = self.deepspeed.mpu.get_data_parallel_group()
  239. else:
  240. pg = groups._get_data_parallel_group()
  241. scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg))
  242. scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float)
  243. dist.all_reduce(scaled_norm_tensor, group=pg)
  244. all_groups_norm = scaled_norm_tensor.item()
  245. #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
  246. return all_groups_norm
  247. def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
  248. # compute combined scale factor for this group
  249. combined_scale = self.cur_scale
  250. if self.clip_grad > 0.:
  251. # norm is in fact norm*scale
  252. clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
  253. if clip > 1:
  254. combined_scale = clip * self.cur_scale
  255. if apply_scale:
  256. for grad in grad_groups_flat:
  257. grad.data.mul_(1. / combined_scale)
  258. return combined_scale
  259. def backward(self, loss, create_graph=False, retain_graph=False):
  260. """
  261. :attr:`backward` performs the following steps:
  262. 1. fp32_loss = loss.float()
  263. 2. scaled_loss = fp32_loss*loss_scale
  264. 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
  265. """
  266. if self.custom_loss_scaler:
  267. scaled_loss = self.external_loss_scale * loss
  268. scaled_loss.backward()
  269. else:
  270. scaled_loss = (loss.float()) * self.cur_scale
  271. scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)
  272. def _update_scale(self, skip):
  273. if self.dynamic_loss_scale:
  274. prev_scale = self.cur_scale
  275. if skip:
  276. self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale)
  277. self.last_overflow_iter = self.cur_iter
  278. if self.verbose:
  279. logger.info(f"\nGrad overflow on iteration {self.cur_iter}")
  280. logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  281. else:
  282. # Ensure self.scale_window updates since last overflow
  283. stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
  284. if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
  285. self.cur_scale *= self.scale_factor
  286. if self.verbose:
  287. logger.info(f"No Grad overflow for {self.scale_window} iterations")
  288. logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  289. else:
  290. if skip:
  291. logger.info("Grad overflow on iteration: %s", self.cur_iter)
  292. logger.info("Using static loss scale of: %s", self.cur_scale)
  293. self.cur_iter += 1
  294. return
  295. # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
  296. def _get_state(self):
  297. return self.optimizer.state
  298. def _set_state(self, value):
  299. self.optimizer.state = value
  300. state = property(_get_state, _set_state)
  301. # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
  302. # (for example, to adjust the learning rate)
  303. def _get_param_groups(self):
  304. return self.optimizer.param_groups
  305. def _set_param_groups(self, value):
  306. self.optimizer.param_groups = value
  307. param_groups = property(_get_param_groups, _set_param_groups)
  308. def state_dict(self):
  309. """
  310. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
  311. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
  312. of the contained Pytorch optimizer.
  313. Example::
  314. checkpoint = {}
  315. checkpoint['model'] = model.state_dict()
  316. checkpoint['optimizer'] = optimizer.state_dict()
  317. torch.save(checkpoint, "saved.pth")
  318. """
  319. state_dict = {}
  320. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  321. state_dict['cur_scale'] = self.cur_scale
  322. state_dict['cur_iter'] = self.cur_iter
  323. if state_dict['dynamic_loss_scale']:
  324. state_dict['last_overflow_iter'] = self.last_overflow_iter
  325. state_dict['scale_factor'] = self.scale_factor
  326. state_dict['scale_window'] = self.scale_window
  327. state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
  328. state_dict['fp32_groups_flat'] = self.fp32_groups_flat
  329. state_dict[CLIP_GRAD] = self.clip_grad
  330. return state_dict
  331. # Refresh fp32 master params from fp16 copies
  332. def refresh_fp32_params(self):
  333. for current, saved in zip(self.fp32_groups_flat, self.fp16_groups_flat):
  334. current.data.copy_(saved.data)
  335. def load_state_dict(self, state_dict, load_optimizer_states=True):
  336. """
  337. Loads a state_dict created by an earlier call to state_dict().
  338. If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
  339. whose parameters in turn came from ``model``, it is expected that the user
  340. will call ``model.load_state_dict()`` before
  341. ``fp16_optimizer_instance.load_state_dict()`` is called.
  342. Example::
  343. model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
  344. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  345. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  346. ...
  347. checkpoint = torch.load("saved.pth")
  348. model.load_state_dict(checkpoint['model'])
  349. optimizer.load_state_dict(checkpoint['optimizer'])
  350. """
  351. # I think it should actually be ok to reload the optimizer before the model.
  352. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  353. self.cur_scale = state_dict['cur_scale']
  354. self.cur_iter = state_dict['cur_iter']
  355. if state_dict['dynamic_loss_scale']:
  356. self.last_overflow_iter = state_dict['last_overflow_iter']
  357. self.scale_factor = state_dict['scale_factor']
  358. self.scale_window = state_dict['scale_window']
  359. if load_optimizer_states:
  360. self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
  361. self.clip_grad = state_dict[CLIP_GRAD]
  362. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
  363. # The optimizer's hyperparameters and internal buffers are also up to date.
  364. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
  365. # out of date. There are two options.
  366. # 1: Refresh the master params from the model's fp16 params.
  367. # This requires less storage but incurs precision loss.
  368. # 2: Save and restore the fp32 master copies separately.
  369. # We choose option 2.
  370. #
  371. # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
  372. # of their associated parameters, because it's possible those buffers might not exist yet in
  373. # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
  374. # constructed in the same way as the one whose state_dict we are loading, the same master params
  375. # are guaranteed to exist, so we can just copy_() from the saved master params.
  376. for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
  377. current.data.copy_(saved.data)
  378. def __repr__(self):
  379. return repr(self.optimizer)
  380. # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
  381. def _get_loss_scale(self):
  382. if self.custom_loss_scaler:
  383. return self.external_loss_scale
  384. else:
  385. return self.cur_scale
  386. def _set_loss_scale(self, value):
  387. self.loss_scaler.cur_scale = value
  388. loss_scale = property(_get_loss_scale, _set_loss_scale)