fused_optimizer.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Copyright NVIDIA/apex
  6. This file is adapted from FP16_Optimizer in NVIDIA/apex
  7. """
  8. import torch
  9. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  10. from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
  11. from deepspeed.runtime.utils import get_global_norm, get_flattened_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers, is_model_parallel_parameter
  12. from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
  13. from deepspeed.utils import logger, log_dist
  14. from deepspeed.utils.torch import required_torch_version
  15. from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD
  16. from deepspeed.accelerator import get_accelerator
  17. from deepspeed.moe.utils import is_moe_param_group
  18. from deepspeed.runtime.constants import PIPE_REPLICATED
  19. from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
  20. OVERFLOW_CHECK_TIMER = 'overflow_check'
  21. COMPUTE_NORM_TIMER = 'compute_norm'
  22. UNSCALE_AND_CLIP_TIMER = 'unscale_and_clip'
  23. BASIC_STEP_TIMER = 'basic_step'
  24. UPDATE_FP16_TIMER = 'update_fp16'
  25. OVERFLOW_TIMERS = [COMPUTE_NORM_TIMER, OVERFLOW_CHECK_TIMER]
  26. STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP_TIMER, BASIC_STEP_TIMER, UPDATE_FP16_TIMER]
  27. class FP16_Optimizer(DeepSpeedOptimizer):
  28. """
  29. FP16 Optimizer for training fp16 models. Handles loss scaling.
  30. For usage example please see, TODO: DeepSpeed V2 Tutorial
  31. """
  32. def __init__(self,
  33. init_optimizer,
  34. deepspeed=None,
  35. static_loss_scale=1.0,
  36. dynamic_loss_scale=False,
  37. initial_dynamic_scale=2**32,
  38. dynamic_loss_args=None,
  39. verbose=True,
  40. mpu=None,
  41. clip_grad=0.0,
  42. fused_adam_legacy=False,
  43. has_moe_layers=False,
  44. timers=None):
  45. self.fused_adam_legacy = fused_adam_legacy
  46. self.timers = timers
  47. self.deepspeed = deepspeed
  48. self.has_moe_layers = has_moe_layers
  49. self.using_pipeline = self.deepspeed.pipeline_parallelism
  50. if not get_accelerator().is_available():
  51. raise SystemError("Cannot use fp16 without accelerator.")
  52. self.optimizer = init_optimizer
  53. # param flattened by groups
  54. self.fp16_groups = []
  55. self.fp16_groups_flat = []
  56. self.fp32_groups_flat = []
  57. self.flatten_grad_norm_mask_list = []
  58. self.has_executed_step = False
  59. self._global_grad_norm = 0.
  60. # loop to deal with groups
  61. for i, param_group in enumerate(self.optimizer.param_groups):
  62. # push this group to list before modify
  63. self.fp16_groups.append(param_group['params'])
  64. # init fp16 weight buffer, flattened
  65. self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
  66. # set model fp16 weight to slices of flattened buffer
  67. updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
  68. for p, q in zip(self.fp16_groups[i], updated_params):
  69. p.data = q.data
  70. # init master weight, flattened
  71. self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
  72. # modify optimizer of have flat master weight
  73. self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
  74. param_group['params'] = [self.fp32_groups_flat[i]]
  75. # we may have a way of fusing dynamic scale. Do not support for now
  76. if dynamic_loss_scale:
  77. self.dynamic_loss_scale = True
  78. self.cur_iter = 0
  79. self.last_overflow_iter = -1
  80. self.scale_factor = 2
  81. if dynamic_loss_args is None:
  82. self.cur_scale = initial_dynamic_scale
  83. self.scale_window = 1000
  84. self.min_loss_scale = 1
  85. else:
  86. self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
  87. self.scale_window = dynamic_loss_args[SCALE_WINDOW]
  88. self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
  89. else:
  90. self.dynamic_loss_scale = False
  91. self.cur_iter = 0
  92. self.cur_scale = static_loss_scale
  93. self.verbose = verbose
  94. self.custom_loss_scaler = False
  95. self.external_loss_scale = None
  96. self.clip_grad = clip_grad
  97. self.norm_type = 2
  98. if required_torch_version(max_version=0.4):
  99. self.clip_grad_norm = torch.nn.utils.clip_grad_norm
  100. else:
  101. self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
  102. #model parallel object
  103. self.mpu = mpu
  104. self.overflow = False
  105. self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed)
  106. self.initialize_optimizer_states()
  107. def initialize_optimizer_states(self):
  108. for i, group in enumerate(self.fp16_groups):
  109. self.fp32_groups_flat[i].grad = torch.zeros(self.fp32_groups_flat[i].size(),
  110. device=self.fp32_groups_flat[i].device)
  111. self.optimizer.step()
  112. for i, group in enumerate(self.fp16_groups):
  113. self.fp32_groups_flat[i].grad = None
  114. return
  115. def zero_grad(self, set_to_none=True):
  116. """
  117. Zero FP16 parameter grads.
  118. """
  119. # For speed, set model fp16 grad to None by default
  120. for group in self.fp16_groups:
  121. for p in group:
  122. if set_to_none:
  123. p.grad = None
  124. else:
  125. if p.grad is not None:
  126. p.grad.detach_()
  127. p.grad.zero_()
  128. def step_fused_adam(self, closure=None):
  129. """
  130. Not supporting closure.
  131. """
  132. # First compute norm for all group so we know if there is overflow
  133. grads_groups_flat = []
  134. norm_groups = []
  135. for i, group in enumerate(self.fp16_groups):
  136. grads_groups_flat.append(
  137. _flatten_dense_tensors([
  138. torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group
  139. ]))
  140. norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
  141. self.overflow = self.overflow_checker.check_using_norm(norm_groups)
  142. prev_scale = self.cur_scale
  143. self._update_scale(self.overflow)
  144. if self.overflow:
  145. if self.verbose:
  146. logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  147. "scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
  148. return self.overflow
  149. scaled_grad_norm = get_global_norm(norm_list=norm_groups)
  150. combined_scale = self.unscale_and_clip_grads(grads_groups_flat, scaled_grad_norm, apply_scale=False)
  151. # Stash unscaled gradient norm
  152. self._global_grad_norm = scaled_grad_norm / self.cur_scale
  153. # norm is in fact norm*cur_scale
  154. self.optimizer.step(grads=[[g] for g in grads_groups_flat],
  155. output_params=[[p] for p in self.fp16_groups_flat],
  156. scale=combined_scale,
  157. grad_norms=norm_groups)
  158. # TODO: we probably don't need this? just to be safe
  159. for i in range(len(norm_groups)):
  160. updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
  161. for p, q in zip(self.fp16_groups[i], updated_params):
  162. p.data = q.data
  163. return self.overflow
  164. def set_lr(self, lr):
  165. """Set the learning rate."""
  166. for param_group in self.optimizer.param_groups:
  167. param_group["lr"] = lr
  168. def get_lr(self):
  169. """Return the current learning rate."""
  170. return self.optimizer.param_groups[0]["lr"]
  171. def override_loss_scale(self, loss_scale):
  172. if loss_scale != self.external_loss_scale:
  173. logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
  174. self.custom_loss_scaler = True
  175. self.external_loss_scale = loss_scale
  176. def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank):
  177. # for filtering replicated tensors from tensor
  178. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
  179. return True
  180. if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p):
  181. return True
  182. def _get_norm_mask_idx(self, group):
  183. """The function preserves the parallel information for norm
  184. from unflattened gradients.
  185. Args:
  186. group (Iterable[Tensor] ): params group
  187. Returns:
  188. torch.Tensor: A 2D tensor containing index ranges for each group,
  189. where each row represents a [start index, end index].
  190. """
  191. group_mask_idx_list = []
  192. grad_flat_st_idx = 0
  193. grad_flat_en_idx = 0
  194. for p in group:
  195. grad_flat_en_idx = grad_flat_st_idx + p.numel()
  196. if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)):
  197. # merge range
  198. if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]:
  199. group_mask_idx_list[-1][-1] = grad_flat_en_idx
  200. else:
  201. group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx])
  202. grad_flat_st_idx = grad_flat_en_idx
  203. return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device_name())
  204. def step(self, closure=None):
  205. """
  206. Not supporting closure.
  207. """
  208. if self.fused_adam_legacy:
  209. return self.step_fused_adam()
  210. # First determine if there is overflow.
  211. self.timers(OVERFLOW_CHECK_TIMER).start()
  212. fp16_params = []
  213. for i, group in enumerate(self.fp16_groups):
  214. fp16_params.extend([p for p in group if p.grad is not None])
  215. self.overflow = self.overflow_checker.has_overflow(fp16_params)
  216. self.timers(OVERFLOW_CHECK_TIMER).stop()
  217. prev_scale = self.cur_scale
  218. self._update_scale(self.overflow)
  219. if self.overflow:
  220. if self.verbose:
  221. log_dist(
  222. "Overflow detected. Skipping step. Attempted loss "
  223. f"scale: {prev_scale}, reducing to {self.cur_scale}",
  224. ranks=[0])
  225. # Clear gradients
  226. for i, group in enumerate(self.fp16_groups):
  227. for p in group:
  228. p.grad = None
  229. self.timers.log(OVERFLOW_TIMERS)
  230. return self.overflow
  231. grads_groups_flat = []
  232. non_experts_grads_for_norm = []
  233. expert_grads_for_norm = {}
  234. assert len(self.fp16_groups) == len(self.optimizer.param_groups)
  235. for i, group in enumerate(self.fp16_groups):
  236. data_type = self.fp32_groups_flat[i].dtype
  237. grads_groups_flat.append(
  238. _flatten_dense_tensors([
  239. torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type)
  240. for p in group
  241. ]))
  242. self.fp32_groups_flat[i].grad = grads_groups_flat[i]
  243. param_group = self.optimizer.param_groups[i]
  244. # split expert and non_expert grads for norm
  245. if self.has_moe_layers and is_moe_param_group(param_group):
  246. if param_group['name'] not in expert_grads_for_norm:
  247. expert_grads_for_norm[param_group['name']] = []
  248. expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i])
  249. else:
  250. # retrieves the required mask for calculating the norm of flat_grad
  251. # perform this collect operation only once
  252. if not self.has_executed_step:
  253. cur_flat_grad_norm_mask = self._get_norm_mask_idx(group)
  254. self.flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask)
  255. non_experts_grads_for_norm.append(self.fp32_groups_flat[i])
  256. for p in group:
  257. p.grad = None
  258. self.timers(COMPUTE_NORM_TIMER).start()
  259. all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm,
  260. mpu=self.mpu,
  261. grad_norm_mask=self.flatten_grad_norm_mask_list)
  262. if self.has_moe_layers:
  263. all_groups_norm = get_norm_with_moe_layers(all_groups_norm,
  264. mpu=self.mpu,
  265. expert_tensors=expert_grads_for_norm,
  266. norm_type=self.norm_type)
  267. scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])
  268. self.timers(COMPUTE_NORM_TIMER).stop()
  269. # Stash unscaled gradient norm
  270. self._global_grad_norm = scaled_global_grad_norm / self.cur_scale
  271. self.timers(UNSCALE_AND_CLIP_TIMER).start()
  272. self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm)
  273. self.timers(UNSCALE_AND_CLIP_TIMER).stop()
  274. self.timers(BASIC_STEP_TIMER).start()
  275. self.optimizer.step()
  276. self.timers(BASIC_STEP_TIMER).stop()
  277. #get rid of the fp32 gradients. Not needed anymore
  278. for group in self.fp32_groups_flat:
  279. group.grad = None
  280. self.timers(UPDATE_FP16_TIMER).start()
  281. for i in range(len(self.fp16_groups)):
  282. updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i])
  283. for p, q in zip(self.fp16_groups[i], updated_params):
  284. p.data.copy_(q.data)
  285. self.has_executed_step = True
  286. self.timers(UPDATE_FP16_TIMER).stop()
  287. self.timers.log(STEP_TIMERS)
  288. return self.overflow
  289. def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
  290. # compute combined scale factor for this group
  291. combined_scale = self.cur_scale
  292. if self.clip_grad > 0.:
  293. # norm is in fact norm*scale
  294. clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
  295. if clip > 1:
  296. combined_scale = clip * self.cur_scale
  297. if apply_scale:
  298. for grad in grad_groups_flat:
  299. grad.data.mul_(1. / combined_scale)
  300. return combined_scale
  301. def backward(self, loss, create_graph=False, retain_graph=False):
  302. """
  303. :attr:`backward` performs the following steps:
  304. 1. fp32_loss = loss.float()
  305. 2. scaled_loss = fp32_loss*loss_scale
  306. 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
  307. """
  308. if self.custom_loss_scaler:
  309. scaled_loss = self.external_loss_scale * loss
  310. scaled_loss.backward()
  311. else:
  312. scaled_loss = (loss.float()) * self.cur_scale
  313. scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)
  314. def _update_scale(self, skip):
  315. if self.dynamic_loss_scale:
  316. prev_scale = self.cur_scale
  317. if skip:
  318. self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale)
  319. self.last_overflow_iter = self.cur_iter
  320. if self.verbose:
  321. logger.info(f"\nGrad overflow on iteration {self.cur_iter}")
  322. logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  323. else:
  324. # Ensure self.scale_window updates since last overflow
  325. stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
  326. if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
  327. self.cur_scale *= self.scale_factor
  328. if self.verbose:
  329. logger.info(f"No Grad overflow for {self.scale_window} iterations")
  330. logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}")
  331. else:
  332. if skip:
  333. logger.info("Grad overflow on iteration: %s", self.cur_iter)
  334. logger.info("Using static loss scale of: %s", self.cur_scale)
  335. self.cur_iter += 1
  336. return
  337. # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
  338. def _get_state(self):
  339. return self.optimizer.state
  340. def _set_state(self, value):
  341. self.optimizer.state = value
  342. state = property(_get_state, _set_state)
  343. # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
  344. # (for example, to adjust the learning rate)
  345. def _get_param_groups(self):
  346. return self.optimizer.param_groups
  347. def _set_param_groups(self, value):
  348. self.optimizer.param_groups = value
  349. param_groups = property(_get_param_groups, _set_param_groups)
  350. def state_dict(self):
  351. """
  352. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
  353. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
  354. of the contained Pytorch optimizer.
  355. Example::
  356. checkpoint = {}
  357. checkpoint['model'] = model.state_dict()
  358. checkpoint['optimizer'] = optimizer.state_dict()
  359. torch.save(checkpoint, "saved.pth")
  360. """
  361. state_dict = {}
  362. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  363. state_dict['cur_scale'] = self.cur_scale
  364. state_dict['cur_iter'] = self.cur_iter
  365. if state_dict['dynamic_loss_scale']:
  366. state_dict['last_overflow_iter'] = self.last_overflow_iter
  367. state_dict['scale_factor'] = self.scale_factor
  368. state_dict['scale_window'] = self.scale_window
  369. state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
  370. state_dict['fp32_groups_flat'] = self.fp32_groups_flat
  371. state_dict[CLIP_GRAD] = self.clip_grad
  372. return state_dict
  373. # Refresh fp32 master params from fp16 copies
  374. def refresh_fp32_params(self):
  375. for current, saved in zip(self.fp32_groups_flat, self.fp16_groups_flat):
  376. current.data.copy_(saved.data)
  377. def load_state_dict(self, state_dict, load_optimizer_states=True):
  378. """
  379. Loads a state_dict created by an earlier call to state_dict().
  380. If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
  381. whose parameters in turn came from ``model``, it is expected that the user
  382. will call ``model.load_state_dict()`` before
  383. ``fp16_optimizer_instance.load_state_dict()`` is called.
  384. Example::
  385. model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
  386. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  387. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  388. ...
  389. checkpoint = torch.load("saved.pth")
  390. model.load_state_dict(checkpoint['model'])
  391. optimizer.load_state_dict(checkpoint['optimizer'])
  392. """
  393. # I think it should actually be ok to reload the optimizer before the model.
  394. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  395. self.cur_scale = state_dict['cur_scale']
  396. self.cur_iter = state_dict['cur_iter']
  397. if state_dict['dynamic_loss_scale']:
  398. self.last_overflow_iter = state_dict['last_overflow_iter']
  399. self.scale_factor = state_dict['scale_factor']
  400. self.scale_window = state_dict['scale_window']
  401. if load_optimizer_states:
  402. self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
  403. self.clip_grad = state_dict[CLIP_GRAD]
  404. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
  405. # The optimizer's hyperparameters and internal buffers are also up to date.
  406. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
  407. # out of date. There are two options.
  408. # 1: Refresh the master params from the model's fp16 params.
  409. # This requires less storage but incurs precision loss.
  410. # 2: Save and restore the fp32 master copies separately.
  411. # We choose option 2.
  412. #
  413. # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
  414. # of their associated parameters, because it's possible those buffers might not exist yet in
  415. # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
  416. # constructed in the same way as the one whose state_dict we are loading, the same master params
  417. # are guaranteed to exist, so we can just copy_() from the saved master params.
  418. for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
  419. current.data.copy_(saved.data)
  420. def __repr__(self):
  421. return repr(self.optimizer)
  422. # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
  423. def _get_loss_scale(self):
  424. if self.custom_loss_scaler:
  425. return self.external_loss_scale
  426. else:
  427. return self.cur_scale
  428. def _set_loss_scale(self, value):
  429. self.loss_scaler.cur_scale = value
  430. loss_scale = property(_get_loss_scale, _set_loss_scale)