fused_optimizer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. '''
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. Copyright NVIDIA/apex
  4. This file is adapted from FP16_Optimizer in NVIDIA/apex
  5. '''
  6. import torch
  7. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  8. from deepspeed.runtime import DeepSpeedOptimizer
  9. from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm
  10. from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
  11. from deepspeed.utils import groups, logger, log_dist
  12. from deepspeed import comm as dist
  13. from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD
  14. from deepspeed.accelerator import get_accelerator
  15. class FP16_Optimizer(DeepSpeedOptimizer):
  16. """
  17. FP16 Optimizer for training fp16 models. Handles loss scaling.
  18. For usage example please see, TODO: DeepSpeed V2 Tutorial
  19. """
  20. def __init__(self,
  21. init_optimizer,
  22. deepspeed=None,
  23. static_loss_scale=1.0,
  24. dynamic_loss_scale=False,
  25. initial_dynamic_scale=2**32,
  26. dynamic_loss_args=None,
  27. verbose=True,
  28. mpu=None,
  29. clip_grad=0.0,
  30. fused_adam_legacy=False,
  31. has_moe_layers=False,
  32. timers=None):
  33. self.fused_adam_legacy = fused_adam_legacy
  34. self.timers = timers
  35. self.deepspeed = deepspeed
  36. self.has_moe_layers = has_moe_layers
  37. self.using_pipeline = self.deepspeed.pipeline_parallelism
  38. if not get_accelerator().is_available():
  39. raise SystemError("Cannot use fp16 without accelerator.")
  40. self.optimizer = init_optimizer
  41. # param flattened by groups
  42. self.fp16_groups = []
  43. self.fp16_groups_flat = []
  44. self.fp32_groups_flat = []
  45. self._global_grad_norm = 0.
  46. # loop to deal with groups
  47. for i, param_group in enumerate(self.optimizer.param_groups):
  48. # push this group to list before modify
  49. self.fp16_groups.append(param_group['params'])
  50. # init fp16 weight buffer, flattened
  51. self.fp16_groups_flat.append(
  52. _flatten_dense_tensors([p.clone().detach()
  53. for p in self.fp16_groups[i]]))
  54. # set model fp16 weight to slices of flattened buffer
  55. updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
  56. self.fp16_groups[i])
  57. for p, q in zip(self.fp16_groups[i], updated_params):
  58. p.data = q.data
  59. # init master weight, flattened
  60. self.fp32_groups_flat.append(
  61. self.fp16_groups_flat[i].clone().float().detach())
  62. # modify optimizer of have flat master weight
  63. self.fp32_groups_flat[
  64. i].requires_grad = True # keep this in case internal optimizer uses it
  65. param_group['params'] = [self.fp32_groups_flat[i]]
  66. # we may have a way of fusing dynamic scale. Do not support for now
  67. if dynamic_loss_scale:
  68. self.dynamic_loss_scale = True
  69. self.cur_iter = 0
  70. self.last_overflow_iter = -1
  71. self.scale_factor = 2
  72. if dynamic_loss_args is None:
  73. self.cur_scale = initial_dynamic_scale
  74. self.scale_window = 1000
  75. self.min_loss_scale = 1
  76. else:
  77. self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
  78. self.scale_window = dynamic_loss_args[SCALE_WINDOW]
  79. self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
  80. else:
  81. self.dynamic_loss_scale = False
  82. self.cur_iter = 0
  83. self.cur_scale = static_loss_scale
  84. self.verbose = verbose
  85. self.custom_loss_scaler = False
  86. self.external_loss_scale = None
  87. self.clip_grad = clip_grad
  88. self.norm_type = 2
  89. self.step_count = 0
  90. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  91. TORCH_MINOR = int(torch.__version__.split('.')[1])
  92. if TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
  93. self.clip_grad_norm = torch.nn.utils.clip_grad_norm
  94. else:
  95. self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
  96. #model parallel object
  97. self.mpu = mpu
  98. self.overflow = False
  99. self.overflow_checker = CheckOverflow(self.fp16_groups,
  100. mpu=self.mpu,
  101. deepspeed=deepspeed)
  102. self.initialize_optimizer_states()
  103. def initialize_optimizer_states(self):
  104. for i, group in enumerate(self.fp16_groups):
  105. self.fp32_groups_flat[i].grad = torch.zeros(
  106. self.fp32_groups_flat[i].size(),
  107. device=self.fp32_groups_flat[i].device)
  108. self.optimizer.step()
  109. for i, group in enumerate(self.fp16_groups):
  110. self.fp32_groups_flat[i].grad = None
  111. return
  112. def zero_grad(self, set_to_none=False):
  113. """
  114. Zero FP16 parameter grads.
  115. """
  116. # For speed, set model fp16 grad to None by default
  117. for group in self.fp16_groups:
  118. for p in group:
  119. if set_to_none:
  120. p.grad = None
  121. else:
  122. if p.grad is not None:
  123. p.grad.detach_()
  124. p.grad.zero_()
  125. def step_fused_adam(self, closure=None):
  126. """
  127. Not supporting closure.
  128. """
  129. # First compute norm for all group so we know if there is overflow
  130. grads_groups_flat = []
  131. norm_groups = []
  132. for i, group in enumerate(self.fp16_groups):
  133. grads_groups_flat.append(
  134. _flatten_dense_tensors([
  135. torch.zeros(p.size(),
  136. dtype=p.dtype,
  137. device=p.device) if p.grad is None else p.grad
  138. for p in group
  139. ]))
  140. norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
  141. self.overflow = self.overflow_checker.check_using_norm(norm_groups)
  142. prev_scale = self.cur_scale
  143. self._update_scale(self.overflow)
  144. if self.overflow:
  145. if self.verbose:
  146. logger.info(
  147. "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
  148. "scale: {}, reducing to {}".format(prev_scale,
  149. self.cur_scale))
  150. return self.overflow
  151. scaled_grad_norm = get_global_norm(norm_list=norm_groups)
  152. combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
  153. scaled_grad_norm,
  154. apply_scale=False)
  155. # Stash unscaled gradient norm
  156. self._global_grad_norm = scaled_grad_norm / self.cur_scale
  157. # norm is in fact norm*cur_scale
  158. self.optimizer.step(grads=[[g] for g in grads_groups_flat],
  159. output_params=[[p] for p in self.fp16_groups_flat],
  160. scale=combined_scale,
  161. grad_norms=norm_groups)
  162. # TODO: we probably don't need this? just to be safe
  163. for i in range(len(norm_groups)):
  164. updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
  165. self.fp16_groups[i])
  166. for p, q in zip(self.fp16_groups[i], updated_params):
  167. p.data = q.data
  168. return self.overflow
  169. def start_timers(self, name_list):
  170. if self.timers is not None:
  171. for name in name_list:
  172. self.timers(name).start()
  173. def stop_timers(self, name_list):
  174. if self.timers is not None:
  175. for name in name_list:
  176. self.timers(name).stop()
  177. def log_timers(self, name_list):
  178. if self.timers is not None:
  179. self.timers.log(name_list)
  180. def set_lr(self, lr):
  181. """Set the learning rate."""
  182. for param_group in self.optimizer.param_groups:
  183. param_group["lr"] = lr
  184. def get_lr(self):
  185. """Return the current learning rate."""
  186. return self.optimizer.param_groups[0]["lr"]
  187. def override_loss_scale(self, loss_scale):
  188. if loss_scale != self.external_loss_scale:
  189. logger.info(
  190. f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}'
  191. )
  192. self.custom_loss_scaler = True
  193. self.external_loss_scale = loss_scale
  194. def step(self, closure=None):
  195. """
  196. Not supporting closure.
  197. """
  198. if self.fused_adam_legacy:
  199. return self.step_fused_adam()
  200. COMPUTE_NORM = "compute_norm"
  201. OVERFLOW_CHECK = 'overflow_check'
  202. OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK]
  203. UNSCALE_AND_CLIP = 'unscale_and_clip'
  204. BASIC_STEP = 'basic_step'
  205. UPDATE_FP16 = 'update_fp16'
  206. STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16]
  207. # First determine if there is overflow.
  208. self.start_timers([OVERFLOW_CHECK])
  209. fp16_params = []
  210. for i, group in enumerate(self.fp16_groups):
  211. fp16_params.extend([p for p in group if p.grad is not None])
  212. self.overflow = self.overflow_checker.has_overflow(fp16_params)
  213. self.stop_timers([OVERFLOW_CHECK])
  214. prev_scale = self.cur_scale
  215. self._update_scale(self.overflow)
  216. if self.overflow:
  217. if self.verbose:
  218. log_dist(
  219. "Overflow detected. Skipping step. Attempted loss "
  220. f"scale: {prev_scale}, reducing to {self.cur_scale}",
  221. ranks=[0])
  222. # Clear gradients
  223. for i, group in enumerate(self.fp16_groups):
  224. for p in group:
  225. p.grad = None
  226. self.log_timers(OVERFLOW_TIMERS)
  227. return self.overflow
  228. grads_groups_flat = []
  229. for i, group in enumerate(self.fp16_groups):
  230. data_type = self.fp32_groups_flat[i].dtype
  231. grads_groups_flat.append(
  232. _flatten_dense_tensors([
  233. torch.zeros(p.size(),
  234. dtype=data_type,
  235. device=p.device)
  236. if p.grad is None else p.grad.to(data_type) for p in group
  237. ]))
  238. for p in group:
  239. p.grad = None
  240. self.fp32_groups_flat[i].grad = grads_groups_flat[i]
  241. self.start_timers([COMPUTE_NORM])
  242. all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
  243. self.stop_timers([COMPUTE_NORM])
  244. if self.has_moe_layers:
  245. all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm)
  246. scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])
  247. # Stash unscaled gradient norm
  248. self._global_grad_norm = scaled_global_grad_norm / self.cur_scale
  249. self.start_timers([UNSCALE_AND_CLIP])
  250. self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm)
  251. self.stop_timers([UNSCALE_AND_CLIP])
  252. self.start_timers([BASIC_STEP])
  253. self.optimizer.step()
  254. self.stop_timers([BASIC_STEP])
  255. #get rid of the fp32 gradients. Not needed anymore
  256. for group in self.fp32_groups_flat:
  257. group.grad = None
  258. self.start_timers([UPDATE_FP16])
  259. for i in range(len(self.fp16_groups)):
  260. updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i],
  261. self.fp16_groups[i])
  262. for p, q in zip(self.fp16_groups[i], updated_params):
  263. p.data.copy_(q.data)
  264. self.stop_timers([UPDATE_FP16])
  265. self.log_timers(STEP_TIMERS)
  266. self.step_count += 1
  267. return self.overflow
  268. def _get_norm_with_moe_layers(self, all_groups_norm):
  269. #all_groups_norm_old = all_groups_norm
  270. # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
  271. if self.using_pipeline:
  272. pg = self.deepspeed.mpu.get_data_parallel_group()
  273. else:
  274. pg = groups._get_data_parallel_group()
  275. scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg))
  276. scaled_norm_tensor = torch.tensor(scaled_norm,
  277. device=self.fp32_groups_flat[0].device,
  278. dtype=torch.float)
  279. dist.all_reduce(scaled_norm_tensor, group=pg)
  280. all_groups_norm = scaled_norm_tensor.item()
  281. #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
  282. return all_groups_norm
  283. def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
  284. # compute combined scale factor for this group
  285. combined_scale = self.cur_scale
  286. if self.clip_grad > 0.:
  287. # norm is in fact norm*scale
  288. clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
  289. if clip > 1:
  290. combined_scale = clip * self.cur_scale
  291. if apply_scale:
  292. for grad in grad_groups_flat:
  293. grad.data.mul_(1. / combined_scale)
  294. return combined_scale
  295. def backward(self, loss, create_graph=False, retain_graph=False):
  296. """
  297. :attr:`backward` performs the following steps:
  298. 1. fp32_loss = loss.float()
  299. 2. scaled_loss = fp32_loss*loss_scale
  300. 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
  301. """
  302. if self.custom_loss_scaler:
  303. scaled_loss = self.external_loss_scale * loss
  304. scaled_loss.backward()
  305. else:
  306. scaled_loss = (loss.float()) * self.cur_scale
  307. scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)
  308. def _update_scale(self, skip):
  309. if self.dynamic_loss_scale:
  310. prev_scale = self.cur_scale
  311. if skip:
  312. self.cur_scale = max(self.cur_scale / self.scale_factor,
  313. self.min_loss_scale)
  314. self.last_overflow_iter = self.cur_iter
  315. if self.verbose:
  316. logger.info(f"\nGrad overflow on iteration {self.cur_iter}")
  317. logger.info(
  318. f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}"
  319. )
  320. else:
  321. # Ensure self.scale_window updates since last overflow
  322. stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
  323. if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
  324. self.cur_scale *= self.scale_factor
  325. if self.verbose:
  326. logger.info(
  327. f"No Grad overflow for {self.scale_window} iterations")
  328. logger.info(
  329. f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}"
  330. )
  331. else:
  332. if skip:
  333. logger.info("Grad overflow on iteration: %s", self.cur_iter)
  334. logger.info("Using static loss scale of: %s", self.cur_scale)
  335. self.cur_iter += 1
  336. return
  337. # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
  338. def _get_state(self):
  339. return self.optimizer.state
  340. def _set_state(self, value):
  341. self.optimizer.state = value
  342. state = property(_get_state, _set_state)
  343. # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
  344. # (for example, to adjust the learning rate)
  345. def _get_param_groups(self):
  346. return self.optimizer.param_groups
  347. def _set_param_groups(self, value):
  348. self.optimizer.param_groups = value
  349. param_groups = property(_get_param_groups, _set_param_groups)
  350. def state_dict(self):
  351. """
  352. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
  353. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
  354. of the contained Pytorch optimizer.
  355. Example::
  356. checkpoint = {}
  357. checkpoint['model'] = model.state_dict()
  358. checkpoint['optimizer'] = optimizer.state_dict()
  359. torch.save(checkpoint, "saved.pth")
  360. """
  361. state_dict = {}
  362. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  363. state_dict['cur_scale'] = self.cur_scale
  364. state_dict['cur_iter'] = self.cur_iter
  365. if state_dict['dynamic_loss_scale']:
  366. state_dict['last_overflow_iter'] = self.last_overflow_iter
  367. state_dict['scale_factor'] = self.scale_factor
  368. state_dict['scale_window'] = self.scale_window
  369. state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
  370. state_dict['fp32_groups_flat'] = self.fp32_groups_flat
  371. state_dict[CLIP_GRAD] = self.clip_grad
  372. return state_dict
  373. # Refresh fp32 master params from fp16 copies
  374. def refresh_fp32_params(self):
  375. for current, saved in zip(self.fp32_groups_flat, self.fp16_groups_flat):
  376. current.data.copy_(saved.data)
  377. def load_state_dict(self, state_dict, load_optimizer_states=True):
  378. """
  379. Loads a state_dict created by an earlier call to state_dict().
  380. If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
  381. whose parameters in turn came from ``model``, it is expected that the user
  382. will call ``model.load_state_dict()`` before
  383. ``fp16_optimizer_instance.load_state_dict()`` is called.
  384. Example::
  385. model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
  386. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  387. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  388. ...
  389. checkpoint = torch.load("saved.pth")
  390. model.load_state_dict(checkpoint['model'])
  391. optimizer.load_state_dict(checkpoint['optimizer'])
  392. """
  393. # I think it should actually be ok to reload the optimizer before the model.
  394. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  395. self.cur_scale = state_dict['cur_scale']
  396. self.cur_iter = state_dict['cur_iter']
  397. if state_dict['dynamic_loss_scale']:
  398. self.last_overflow_iter = state_dict['last_overflow_iter']
  399. self.scale_factor = state_dict['scale_factor']
  400. self.scale_window = state_dict['scale_window']
  401. if load_optimizer_states:
  402. self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
  403. self.clip_grad = state_dict[CLIP_GRAD]
  404. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
  405. # The optimizer's hyperparameters and internal buffers are also up to date.
  406. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
  407. # out of date. There are two options.
  408. # 1: Refresh the master params from the model's fp16 params.
  409. # This requires less storage but incurs precision loss.
  410. # 2: Save and restore the fp32 master copies separately.
  411. # We choose option 2.
  412. #
  413. # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
  414. # of their associated parameters, because it's possible those buffers might not exist yet in
  415. # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
  416. # constructed in the same way as the one whose state_dict we are loading, the same master params
  417. # are guaranteed to exist, so we can just copy_() from the saved master params.
  418. for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
  419. current.data.copy_(saved.data)
  420. def __repr__(self):
  421. return repr(self.optimizer)
  422. # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
  423. def _get_loss_scale(self):
  424. if self.custom_loss_scaler:
  425. return self.external_loss_scale
  426. else:
  427. return self.cur_scale
  428. def _set_loss_scale(self, value):
  429. self.loss_scaler.cur_scale = value
  430. loss_scale = property(_get_loss_scale, _set_loss_scale)