fused_optimizer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Copyright NVIDIA/apex
  6. This file is adapted from FP16_Optimizer in NVIDIA/apex
  7. """
  8. import torch
  9. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  10. from deepspeed.runtime import DeepSpeedOptimizer
  11. from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version
  12. from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
  13. from deepspeed.utils import groups, logger, log_dist
  14. from deepspeed import comm as dist
  15. from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD
  16. from deepspeed.accelerator import get_accelerator
  17. OVERFLOW_CHECK_TIMER = 'overflow_check'
  18. COMPUTE_NORM_TIMER = 'compute_norm'
  19. UNSCALE_AND_CLIP_TIMER = 'unscale_and_clip'
  20. BASIC_STEP_TIMER = 'basic_step'
  21. UPDATE_FP16_TIMER = 'update_fp16'
  22. OVERFLOW_TIMERS = [COMPUTE_NORM_TIMER, OVERFLOW_CHECK_TIMER]
  23. STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP_TIMER, BASIC_STEP_TIMER, UPDATE_FP16_TIMER]
  24. class FP16_Optimizer(DeepSpeedOptimizer):
  25. """
  26. FP16 Optimizer for training fp16 models. Handles loss scaling.
  27. For usage example please see, TODO: DeepSpeed V2 Tutorial
  28. """
  29. def __init__(self,
  30. init_optimizer,
  31. deepspeed=None,
  32. static_loss_scale=1.0,
  33. dynamic_loss_scale=False,
  34. initial_dynamic_scale=2**32,
  35. dynamic_loss_args=None,
  36. verbose=True,
  37. mpu=None,
  38. clip_grad=0.0,
  39. fused_adam_legacy=False,
  40. has_moe_layers=False,
  41. timers=None):
  42. self.fused_adam_legacy = fused_adam_legacy
  43. self.timers = timers
  44. self.deepspeed = deepspeed
  45. self.has_moe_layers = has_moe_layers
  46. self.using_pipeline = self.deepspeed.pipeline_parallelism
  47. if not get_accelerator().is_available():
  48. raise SystemError("Cannot use fp16 without accelerator.")
  49. self.optimizer = init_optimizer
  50. # param flattened by groups
  51. self.fp16_groups = []
  52. self.fp16_groups_flat = []
  53. self.fp32_groups_flat = []
  54. self._global_grad_norm = 0.
  55. # loop to deal with groups
  56. for i, param_group in enumerate(self.optimizer.param_groups):
  57. # push this group to list before modify
  58. self.fp16_groups.append(param_group['params'])
  59. # init fp16 weight buffer, flattened
  60. self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
  61. # set model fp16 weight to slices of flattened buffer
  62. updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
  63. for p, q in zip(self.fp16_groups[i], updated_params):
  64. p.data = q.data
  65. # init master weight, flattened
  66. self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
  67. # modify optimizer of have flat master weight
  68. self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
  69. param_group['params'] = [self.fp32_groups_flat[i]]
  70. # we may have a way of fusing dynamic scale. Do not support for now
  71. if dynamic_loss_scale:
  72. self.dynamic_loss_scale = True
  73. self.cur_iter = 0
  74. self.last_overflow_iter = -1
  75. self.scale_factor = 2
  76. if dynamic_loss_args is None:
  77. self.cur_scale = initial_dynamic_scale
  78. self.scale_window = 1000
  79. self.min_loss_scale = 1
  80. else:
  81. self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
  82. self.scale_window = dynamic_loss_args[SCALE_WINDOW]
  83. self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
  84. else:
  85. self.dynamic_loss_scale = False
  86. self.cur_iter = 0
  87. self.cur_scale = static_loss_scale
  88. self.verbose = verbose
  89. self.custom_loss_scaler = False
  90. self.external_loss_scale = None
  91. self.clip_grad = clip_grad
  92. self.norm_type = 2
  93. self.step_count = 0
  94. if required_torch_version(max_version=0.4):
  95. self.clip_grad_norm = torch.nn.utils.clip_grad_norm
  96. else:
  97. self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
  98. #model parallel object
  99. self.mpu = mpu
  100. self.overflow = False
  101. self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed)
  102. self.initialize_optimizer_states()
  103. def initialize_optimizer_states(self):
  104. for i, group in enumerate(self.fp16_groups):
  105. self.fp32_groups_flat[i].grad = torch.zeros(self.fp32_groups_flat[i].size(),
  106. device=self.fp32_groups_flat[i].device)
  107. self.optimizer.step()
  108. for i, group in enumerate(self.fp16_groups):
  109. self.fp32_groups_flat[i].grad = None
  110. return
  111. def zero_grad(self, set_to_none=False):
  112. """
  113. Zero FP16 parameter grads.
  114. """
  115. # For speed, set model fp16 grad to None by default
  116. for group in self.fp16_groups:
  117. for p in group:
  118. if set_to_none:
  119. p.grad = None
  120. else:
  121. if p.grad is not None:
  122. p.grad.detach_()
  123. p.grad.zero_()
  124. def step_fused_adam(self, closure=None):
  125. """
  126. Not supporting closure.
  127. """
  128. # First compute norm for all group so we know if there is overflow
  129. grads_groups_flat = []
  130. norm_groups = []
  131. for i, group in enumerate(self.fp16_groups):
  132. grads_groups_flat.append(
  133. _flatten_dense_tensors([
  134. torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group
  135. ]))
  136. norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
  137. self.overflow = self.overflow_checker.check_using_norm(norm_groups)
  138. prev_scale = self.cur_scale
  139. self._update_scale(self.overflow)
  140. if self.overflow:
  141. if self.verbose:
  142. logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  143. "scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
  144. return self.overflow
  145. scaled_grad_norm = get_global_norm(norm_list=norm_groups)
  146. combined_scale = self.unscale_and_clip_grads(grads_groups_flat, scaled_grad_norm, apply_scale=False)
  147. # Stash unscaled gradient norm
  148. self._global_grad_norm = scaled_grad_norm / self.cur_scale
  149. # norm is in fact norm*cur_scale
  150. self.optimizer.step(grads=[[g] for g in grads_groups_flat],
  151. output_params=[[p] for p in self.fp16_groups_flat],
  152. scale=combined_scale,
  153. grad_norms=norm_groups)
  154. # TODO: we probably don't need this? just to be safe
  155. for i in range(len(norm_groups)):
  156. updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
  157. for p, q in zip(self.fp16_groups[i], updated_params):
  158. p.data = q.data
  159. return self.overflow
  160. def set_lr(self, lr):
  161. """Set the learning rate."""
  162. for param_group in self.optimizer.param_groups:
  163. param_group["lr"] = lr
  164. def get_lr(self):
  165. """Return the current learning rate."""
  166. return self.optimizer.param_groups[0]["lr"]
  167. def override_loss_scale(self, loss_scale):
  168. if loss_scale != self.external_loss_scale:
  169. logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
  170. self.custom_loss_scaler = True
  171. self.external_loss_scale = loss_scale
  172. def step(self, closure=None):
  173. """
  174. Not supporting closure.
  175. """
  176. if self.fused_adam_legacy:
  177. return self.step_fused_adam()
  178. # First determine if there is overflow.
  179. self.timers(OVERFLOW_CHECK_TIMER).start()
  180. fp16_params = []
  181. for i, group in enumerate(self.fp16_groups):
  182. fp16_params.extend([p for p in group if p.grad is not None])
  183. self.overflow = self.overflow_checker.has_overflow(fp16_params)
  184. self.timers(OVERFLOW_CHECK_TIMER).stop()
  185. prev_scale = self.cur_scale
  186. self._update_scale(self.overflow)
  187. if self.overflow:
  188. if self.verbose:
  189. log_dist(
  190. "Overflow detected. Skipping step. Attempted loss "
  191. f"scale: {prev_scale}, reducing to {self.cur_scale}",
  192. ranks=[0])
  193. # Clear gradients
  194. for i, group in enumerate(self.fp16_groups):
  195. for p in group:
  196. p.grad = None
  197. self.timers.log(OVERFLOW_TIMERS)
  198. return self.overflow
  199. grads_groups_flat = []
  200. for i, group in enumerate(self.fp16_groups):
  201. data_type = self.fp32_groups_flat[i].dtype
  202. grads_groups_flat.append(
  203. _flatten_dense_tensors([
  204. torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type)
  205. for p in group
  206. ]))
  207. for p in group:
  208. p.grad = None
  209. self.fp32_groups_flat[i].grad = grads_groups_flat[i]
  210. self.timers(COMPUTE_NORM_TIMER).start()
  211. all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
  212. self.timers(COMPUTE_NORM_TIMER).stop()
  213. if self.has_moe_layers:
  214. all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm)
  215. scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])
  216. # Stash unscaled gradient norm
  217. self._global_grad_norm = scaled_global_grad_norm / self.cur_scale
  218. self.timers(UNSCALE_AND_CLIP_TIMER).start()
  219. self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm)
  220. self.timers(UNSCALE_AND_CLIP_TIMER).stop()
  221. self.timers(BASIC_STEP_TIMER).start()
  222. self.optimizer.step()
  223. self.timers(BASIC_STEP_TIMER).stop()
  224. #get rid of the fp32 gradients. Not needed anymore
  225. for group in self.fp32_groups_flat:
  226. group.grad = None
  227. self.timers(UPDATE_FP16_TIMER).start()
  228. for i in range(len(self.fp16_groups)):
  229. updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i])
  230. for p, q in zip(self.fp16_groups[i], updated_params):
  231. p.data.copy_(q.data)
  232. self.timers(UPDATE_FP16_TIMER).stop()
  233. self.timers.log(STEP_TIMERS)
  234. self.step_count += 1
  235. return self.overflow
  236. def _get_norm_with_moe_layers(self, all_groups_norm):
  237. #all_groups_norm_old = all_groups_norm
  238. # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
  239. if self.using_pipeline:
  240. pg = self.deepspeed.mpu.get_data_parallel_group()
  241. else:
  242. pg = groups._get_data_parallel_group()
  243. scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg))
  244. scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float)
  245. dist.all_reduce(scaled_norm_tensor, group=pg)
  246. all_groups_norm = scaled_norm_tensor.item()
  247. #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
  248. return all_groups_norm
  249. def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
  250. # compute combined scale factor for this group
  251. combined_scale = self.cur_scale
  252. if self.clip_grad > 0.:
  253. # norm is in fact norm*scale
  254. clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
  255. if clip > 1:
  256. combined_scale = clip * self.cur_scale
  257. if apply_scale:
  258. for grad in grad_groups_flat:
  259. grad.data.mul_(1. / combined_scale)
  260. return combined_scale
  261. def backward(self, loss, create_graph=False, retain_graph=False):
  262. """
  263. :attr:`backward` performs the following steps:
  264. 1. fp32_loss = loss.float()
  265. 2. scaled_loss = fp32_loss*loss_scale
  266. 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
  267. """
  268. if self.custom_loss_scaler:
  269. scaled_loss = self.external_loss_scale * loss
  270. scaled_loss.backward()
  271. else:
  272. scaled_loss = (loss.float()) * self.cur_scale
  273. scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)
  274. def _update_scale(self, skip):
  275. if self.dynamic_loss_scale:
  276. prev_scale = self.cur_scale
  277. if skip:
  278. self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale)
  279. self.last_overflow_iter = self.cur_iter
  280. if self.verbose:
  281. logger.info(f"\nGrad overflow on iteration {self.cur_iter}")
  282. logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  283. else:
  284. # Ensure self.scale_window updates since last overflow
  285. stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
  286. if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
  287. self.cur_scale *= self.scale_factor
  288. if self.verbose:
  289. logger.info(f"No Grad overflow for {self.scale_window} iterations")
  290. logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  291. else:
  292. if skip:
  293. logger.info("Grad overflow on iteration: %s", self.cur_iter)
  294. logger.info("Using static loss scale of: %s", self.cur_scale)
  295. self.cur_iter += 1
  296. return
  297. # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
  298. def _get_state(self):
  299. return self.optimizer.state
  300. def _set_state(self, value):
  301. self.optimizer.state = value
  302. state = property(_get_state, _set_state)
  303. # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
  304. # (for example, to adjust the learning rate)
  305. def _get_param_groups(self):
  306. return self.optimizer.param_groups
  307. def _set_param_groups(self, value):
  308. self.optimizer.param_groups = value
  309. param_groups = property(_get_param_groups, _set_param_groups)
  310. def state_dict(self):
  311. """
  312. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
  313. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
  314. of the contained Pytorch optimizer.
  315. Example::
  316. checkpoint = {}
  317. checkpoint['model'] = model.state_dict()
  318. checkpoint['optimizer'] = optimizer.state_dict()
  319. torch.save(checkpoint, "saved.pth")
  320. """
  321. state_dict = {}
  322. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  323. state_dict['cur_scale'] = self.cur_scale
  324. state_dict['cur_iter'] = self.cur_iter
  325. if state_dict['dynamic_loss_scale']:
  326. state_dict['last_overflow_iter'] = self.last_overflow_iter
  327. state_dict['scale_factor'] = self.scale_factor
  328. state_dict['scale_window'] = self.scale_window
  329. state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
  330. state_dict['fp32_groups_flat'] = self.fp32_groups_flat
  331. state_dict[CLIP_GRAD] = self.clip_grad
  332. return state_dict
  333. # Refresh fp32 master params from fp16 copies
  334. def refresh_fp32_params(self):
  335. for current, saved in zip(self.fp32_groups_flat, self.fp16_groups_flat):
  336. current.data.copy_(saved.data)
  337. def load_state_dict(self, state_dict, load_optimizer_states=True):
  338. """
  339. Loads a state_dict created by an earlier call to state_dict().
  340. If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
  341. whose parameters in turn came from ``model``, it is expected that the user
  342. will call ``model.load_state_dict()`` before
  343. ``fp16_optimizer_instance.load_state_dict()`` is called.
  344. Example::
  345. model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
  346. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  347. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  348. ...
  349. checkpoint = torch.load("saved.pth")
  350. model.load_state_dict(checkpoint['model'])
  351. optimizer.load_state_dict(checkpoint['optimizer'])
  352. """
  353. # I think it should actually be ok to reload the optimizer before the model.
  354. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  355. self.cur_scale = state_dict['cur_scale']
  356. self.cur_iter = state_dict['cur_iter']
  357. if state_dict['dynamic_loss_scale']:
  358. self.last_overflow_iter = state_dict['last_overflow_iter']
  359. self.scale_factor = state_dict['scale_factor']
  360. self.scale_window = state_dict['scale_window']
  361. if load_optimizer_states:
  362. self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
  363. self.clip_grad = state_dict[CLIP_GRAD]
  364. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
  365. # The optimizer's hyperparameters and internal buffers are also up to date.
  366. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
  367. # out of date. There are two options.
  368. # 1: Refresh the master params from the model's fp16 params.
  369. # This requires less storage but incurs precision loss.
  370. # 2: Save and restore the fp32 master copies separately.
  371. # We choose option 2.
  372. #
  373. # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
  374. # of their associated parameters, because it's possible those buffers might not exist yet in
  375. # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
  376. # constructed in the same way as the one whose state_dict we are loading, the same master params
  377. # are guaranteed to exist, so we can just copy_() from the saved master params.
  378. for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
  379. current.data.copy_(saved.data)
  380. def __repr__(self):
  381. return repr(self.optimizer)
  382. # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
  383. def _get_loss_scale(self):
  384. if self.custom_loss_scaler:
  385. return self.external_loss_scale
  386. else:
  387. return self.cur_scale
  388. def _set_loss_scale(self, value):
  389. self.loss_scaler.cur_scale = value
  390. loss_scale = property(_get_loss_scale, _set_loss_scale)