lamb.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import types
  5. import torch
  6. import numpy as np
  7. from deepspeed import comm as dist
  8. from deepspeed.runtime.utils import required_torch_version
  9. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  10. from deepspeed.accelerator import get_accelerator
  11. class OnebitLamb(torch.optim.Optimizer):
  12. """Implements the 1-bit Lamb algorithm. Currently GPU-only.
  13. For usage example please see https://www.deepspeed.ai/tutorials/onebit-lamb/
  14. For technical details please see our paper https://arxiv.org/abs/2104.06069.
  15. Arguments:
  16. params (iterable): iterable of parameters to optimize or dicts defining
  17. parameter groups.
  18. lr (float, optional): learning rate. (default: 1e-3)
  19. freeze_step (int, optional): Number of steps for warmup (uncompressed)
  20. stage before we start using compressed communication. (default 100000)
  21. betas (Tuple[float, float], optional): coefficients used for computing
  22. running averages of gradient and its square. (default: (0.9, 0.999))
  23. eps (float, optional): term added to the denominator to improve
  24. numerical stability. (default: 1e-8)
  25. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  26. max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
  27. min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
  28. amsgrad (boolean, optional): whether to use the AMSGrad variant of this
  29. algorithm from the paper `On the Convergence of Adam and Beyond`_
  30. (default: False) NOT SUPPORTED in 1-bit Lamb!
  31. eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
  32. adds eps to the bias-corrected second moment estimate before
  33. evaluating square root instead of adding it to the square root of
  34. second moment estimate as in the original paper. (default: False)
  35. cuda_aware (boolean, required): Set True if the underlying MPI implementation
  36. supports CUDA-Aware communication. (default: False)
  37. comm_backend_name (string, optional): Set to 'mpi' if needed. (default: 'nccl')
  38. coeff_beta (float, optional): coefficient used for computing
  39. running averages of lamb coefficient (default: 0.9) note that you may want to
  40. increase or decrease this beta depending on the freeze_step you choose, as
  41. 1/(1 - coeff_beta) should be smaller than or equal to freeze_step
  42. factor_max (float, optional): maximum value of scaling factor to the frozen lamb
  43. coefficient during compression stage (default: 4.0)
  44. factor_min (float, optional): minimum value of scaling factor to the frozen lamb
  45. coefficient during compression stage (default: 0.5)
  46. factor_threshold (float, optional): threshold of how much the scaling factor can
  47. fluctuate between steps (default: 0.1)
  48. .. _Large Batch Optimization for Deep Learning\\: Training BERT in 76 minutes:
  49. https://arxiv.org/abs/1904.00962
  50. .. _Adam\\: A Method for Stochastic Optimization:
  51. https://arxiv.org/abs/1412.6980
  52. .. _On the Convergence of Adam and Beyond:
  53. https://openreview.net/forum?id=ryQu7f-RZ
  54. """
  55. def __init__(self,
  56. params,
  57. deepspeed=None,
  58. lr=1e-3,
  59. freeze_step=100000,
  60. bias_correction=True,
  61. betas=(0.9, 0.999),
  62. eps=1e-8,
  63. eps_inside_sqrt=False,
  64. weight_decay=0.,
  65. max_grad_norm=0.,
  66. max_coeff=10.0,
  67. min_coeff=0.01,
  68. amsgrad=False,
  69. cuda_aware=False,
  70. comm_backend_name='nccl',
  71. coeff_beta=0.9,
  72. factor_max=4.0,
  73. factor_min=0.5,
  74. factor_threshold=0.1):
  75. if amsgrad:
  76. raise RuntimeError('1-bit Lamb does not support the AMSGrad variant.')
  77. defaults = dict(lr=lr,
  78. bias_correction=bias_correction,
  79. betas=betas,
  80. eps=eps,
  81. weight_decay=weight_decay,
  82. max_grad_norm=max_grad_norm,
  83. max_coeff=max_coeff,
  84. min_coeff=min_coeff)
  85. super(OnebitLamb, self).__init__(params, defaults)
  86. self.eps_mode = 0 if eps_inside_sqrt else 1
  87. self.deepspeed = deepspeed
  88. self.lamb_freeze_key = False
  89. self.initialize = False
  90. self.freeze_step = freeze_step
  91. self.cuda_aware = cuda_aware
  92. self.coeff_beta = coeff_beta
  93. self.factor_max = factor_max
  94. self.factor_min = factor_min
  95. self.factor_threshold = factor_threshold
  96. self.using_pipeline = False
  97. self.comm_backend_name = comm_backend_name
  98. assert dist.is_initialized(), "Please initialize the torch distributed backend."
  99. # Empty initializer. Set handle based on the comm backend as follows.
  100. self.comm_backend_handle = None
  101. if self.comm_backend_name == 'nccl':
  102. assert (
  103. required_torch_version(min_version=1.8)
  104. ), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
  105. from deepspeed.runtime.comm.nccl import NcclBackend
  106. self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
  107. self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)
  108. elif self.comm_backend_name == 'mpi':
  109. from deepspeed.runtime.comm.mpi import MpiBackend
  110. self.comm_backend_handle = MpiBackend(cuda_aware)
  111. elif self.comm_backend_name == 'hccl':
  112. from deepspeed.runtime.comm.hccl import HcclBackend
  113. self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
  114. self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
  115. self.size = self.comm_backend_handle.size
  116. self.divider = int(self.size * 8 / np.gcd(self.size, 8))
  117. self.exp_avg_flat = []
  118. self.dummy_exp_avg = {}
  119. self.corrected_tensor_sizes = []
  120. self.server_chunk_sizes = []
  121. self.worker_errors = []
  122. self.server_errors = []
  123. self.lamb_coeffs = []
  124. def step(self, closure=None, grads=None):
  125. """Performs a single optimization step.
  126. Arguments:
  127. closure (callable, optional): A closure that reevaluates the model
  128. and returns the loss.
  129. grads (list of tensors, optional): weight gradient to use for the
  130. optimizer update. If gradients have type torch.half, parameters
  131. are expected to be in type torch.float. (default: None)
  132. """
  133. loss = None
  134. if closure is not None:
  135. loss = closure()
  136. if grads is None:
  137. grads_group = [None] * len(self.param_groups)
  138. # backward compatibility
  139. # assuming a list/generator of parameter means single group
  140. elif isinstance(grads, types.GeneratorType):
  141. grads_group = [grads]
  142. elif type(grads[0]) != list:
  143. grads_group = [grads]
  144. else:
  145. grads_group = grads
  146. # remove the previous stats
  147. del self.lamb_coeffs[:]
  148. if self.lamb_freeze_key:
  149. exp_avg_last_step = []
  150. for group in self.param_groups:
  151. exp_avg_last_step.append([self.state[p]['exp_avg'].detach().clone() for p in group['params']])
  152. if 'scaling_coeff' not in self.state[self.param_groups[0]['params'][0]]:
  153. # Compute the scaling_coeff for each momentum at the end of warmup stage.
  154. # This is used to reduce compression error during compression stage.
  155. momentum_scales = []
  156. for group in self.param_groups:
  157. momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) /
  158. np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
  159. for p in group['params']])
  160. united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])
  161. for i, group in enumerate(self.param_groups):
  162. for j, p in enumerate(group['params']):
  163. self.state[p]['scaling_coeff'] = united_scale / momentum_scales[i][j]
  164. for group, grads_this_group in zip(self.param_groups, grads_group):
  165. if grads_this_group is None:
  166. grads_this_group = [None] * len(group['params'])
  167. bias_correction = 1 if group['bias_correction'] else 0
  168. for p, grad in zip(group['params'], grads_this_group):
  169. if p.grad is None and grad is None:
  170. continue
  171. if grad is None:
  172. grad = p.grad.data
  173. if grad.is_sparse:
  174. raise RuntimeError('1-bit Lamb does not support sparse gradients')
  175. state = self.state[p]
  176. # State initialization
  177. if len(state) == 0 or (len(state) == 1 and 'scaling_coeff' in state.keys()):
  178. state['step'] = 0
  179. state['lamb_coeff_freeze'] = 0.0
  180. state['last_factor'] = 1.0
  181. # Exponential moving average of gradient values
  182. state['exp_avg'] = torch.zeros_like(p.data)
  183. # Exponential moving average of squared gradient values
  184. state['exp_avg_sq'] = torch.zeros_like(p.data)
  185. state['exp_avg_sq_fresh'] = torch.zeros_like(p.data)
  186. if not self.initialize:
  187. self.lamb_freeze_key = True
  188. exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state[
  189. 'exp_avg_sq_fresh']
  190. beta1, beta2 = group['betas']
  191. max_coeff = group['max_coeff']
  192. min_coeff = group['min_coeff']
  193. state['step'] += 1
  194. if self.lamb_freeze_key is False:
  195. # warmup stage, baseline Lamb optimization
  196. exp_avg.mul_(beta1).add_(1 - beta1, grad)
  197. exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
  198. if state['step'] == self.freeze_step:
  199. exp_avg_sq_fresh.data = exp_avg_sq.detach().clone()
  200. grad = None
  201. if self.initialize:
  202. weight_norm = p.data.pow(2).sum().sqrt()
  203. update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
  204. if group['weight_decay'] > 0.0:
  205. update += group['weight_decay'] * p.data
  206. update_norm = update.pow(2).sum().sqrt()
  207. lamb_coeff = 1.0
  208. if weight_norm != 0 and update_norm != 0:
  209. lamb_coeff = (weight_norm / update_norm).item()
  210. if lamb_coeff > max_coeff:
  211. lamb_coeff = max_coeff
  212. if lamb_coeff < min_coeff:
  213. lamb_coeff = min_coeff
  214. if lamb_coeff != 1.0:
  215. state['lamb_coeff_freeze'] = self.coeff_beta * state['lamb_coeff_freeze'] + (
  216. 1 - self.coeff_beta) * lamb_coeff
  217. self.lamb_coeffs.append(lamb_coeff)
  218. with torch.no_grad():
  219. p.add_(-group['lr'] * lamb_coeff * update)
  220. else:
  221. # compression stage, update each momentum locally, then
  222. # communicate based on the compressed_allreduce below
  223. if self.initialize:
  224. exp_avg.mul_(beta1).add_(1 - beta1, grad)
  225. exp_avg.mul_(self.state[p]['scaling_coeff'])
  226. grad = None
  227. # init fused momentum
  228. if len(self.exp_avg_flat) == 0:
  229. momentum_groups = []
  230. tensor_size = 0
  231. for group in self.param_groups:
  232. for p in group['params']:
  233. momentum_groups.append(self.state[p]['exp_avg'])
  234. tensor_size += torch.numel(p.data)
  235. corrected_tensor_size = tensor_size
  236. if tensor_size % (self.size * self.divider) != 0:
  237. difference = ((self.size * self.divider) - (tensor_size % (self.size * self.divider)))
  238. corrected_tensor_size += difference
  239. self.dummy_exp_avg[0] = torch.zeros(difference, device=momentum_groups[0].data.device)
  240. momentum_groups.append(self.dummy_exp_avg[0])
  241. self.corrected_tensor_sizes.append(corrected_tensor_size)
  242. self.server_chunk_sizes.append(corrected_tensor_size // self.size)
  243. self.exp_avg_flat.append(_flatten_dense_tensors([p.detach().clone() for p in momentum_groups]))
  244. updated_params = _unflatten_dense_tensors(self.exp_avg_flat[0], momentum_groups)
  245. for p, q in zip(momentum_groups, updated_params):
  246. p.data = q.data
  247. if self.initialize and len(self.worker_errors) == 0:
  248. get_accelerator().empty_cache()
  249. for i in range(len(self.exp_avg_flat)):
  250. self.worker_errors.append(
  251. torch.zeros(self.corrected_tensor_sizes[i], device=self.exp_avg_flat[i].device))
  252. self.server_errors.append(torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device))
  253. get_accelerator().empty_cache()
  254. if self.lamb_freeze_key:
  255. if self.size > 1:
  256. for i in range(len(self.exp_avg_flat)):
  257. if not self.initialize:
  258. get_accelerator().empty_cache()
  259. self.worker_errors.append(
  260. torch.zeros(self.corrected_tensor_sizes[i], device=self.exp_avg_flat[i].device))
  261. self.server_errors.append(
  262. torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device))
  263. get_accelerator().empty_cache()
  264. if dist.get_rank() == 0:
  265. print("Cupy Buffers Initialized Successfully.")
  266. self.comm_backend_handle.compressed_allreduce(self.exp_avg_flat[i], self.worker_errors[0],
  267. self.server_errors[0], self.deepspeed.local_rank)
  268. if dist.get_rank() == 0:
  269. print('Pop out errors', flush=True)
  270. del self.worker_errors[:]
  271. del self.server_errors[:]
  272. else:
  273. self.comm_backend_handle.compressed_allreduce(self.exp_avg_flat[i], self.worker_errors[i],
  274. self.server_errors[i], self.deepspeed.local_rank)
  275. if self.lamb_freeze_key and self.initialize:
  276. for i, group in enumerate(self.param_groups):
  277. bias_correction = 1 if group['bias_correction'] else 0
  278. for j, p in enumerate(group['params']):
  279. state = self.state[p]
  280. exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state[
  281. 'exp_avg_sq_fresh']
  282. beta1, beta2 = group['betas']
  283. exp_avg.div_(self.state[p]['scaling_coeff'])
  284. # Because 1-bit compression cannot represent exact zero, it is required to
  285. # provide a momentum mask for those params that have constant exact zeros in their
  286. # momentums, otherwise the compression error would keep accumulating.
  287. # For example, for BERT pre-training seq 128, bert.embeddings.position_embeddings.weight
  288. # always have exact zeros in its momentum for row 129 to 512, because it only
  289. # learns up to seq length 128 while the model supports up to 512 seq length.
  290. # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py about how
  291. # to add this exp_avg_mask for BERT pre-training.)
  292. if 'exp_avg_mask' in group:
  293. if exp_avg.device != group['exp_avg_mask'].device:
  294. group['exp_avg_mask'] = group['exp_avg_mask'].to(device=exp_avg.device)
  295. exp_avg.mul_(group['exp_avg_mask'])
  296. grad_reconstruct = ((exp_avg - exp_avg_last_step[i][j] * beta1) / (1 - beta1))
  297. exp_avg_sq_fresh.mul_(beta2).addcmul_(1 - beta2, grad_reconstruct, grad_reconstruct)
  298. denom = exp_avg_sq.sqrt() + group['eps']
  299. update_prelim = exp_avg / denom
  300. if group['weight_decay'] > 0.0:
  301. update = update_prelim + group['weight_decay'] * p.data
  302. else:
  303. update = update_prelim
  304. lamb_coeff = 1.0
  305. update_norm = update.pow(2).sum().sqrt()
  306. denom_real = exp_avg_sq_fresh.sqrt() + group['eps']
  307. factor = (denom / denom_real).max().item()
  308. if group['weight_decay'] > 0.0:
  309. update_ratio = min(1.0, (update_prelim.pow(2).sum().sqrt() / update_norm).item())
  310. factor = factor * update_ratio + (1.0 - update_ratio)
  311. if factor > self.factor_max:
  312. factor = self.factor_max
  313. if factor < self.factor_min:
  314. factor = self.factor_min
  315. if factor > state['last_factor'] * (1.0 + self.factor_threshold):
  316. factor = state['last_factor'] * (1.0 + self.factor_threshold)
  317. if factor < state['last_factor'] * (1.0 - self.factor_threshold):
  318. factor = state['last_factor'] * (1.0 - self.factor_threshold)
  319. state['last_factor'] = factor
  320. lamb_coeff = state['lamb_coeff_freeze'] * factor
  321. self.lamb_coeffs.append(lamb_coeff)
  322. with torch.no_grad():
  323. p.add_(-group['lr'] * lamb_coeff * update)
  324. del exp_avg_last_step[:]
  325. exp_avg_last_step = None
  326. if not self.initialize:
  327. self.lamb_freeze_key = False
  328. self.initialize = True
  329. print(f"Finished the initialization step at rank {dist.get_rank()}")
  330. return loss
  331. if self.lamb_freeze_key is False:
  332. if state['step'] >= self.freeze_step:
  333. print('OnebitLamb - starting compressed communication')
  334. self.lamb_freeze_key = True
  335. if self.using_pipeline:
  336. self.deepspeed.pipeline_enable_backward_allreduce = False
  337. else:
  338. self.deepspeed.enable_backward_allreduce = False
  339. return loss
  340. def load_state_dict(self, state_dict):
  341. """
  342. Overrides load_state_dict() to add special handling when loading checkpoints
  343. """
  344. # Because at different stage exp_avg_mask may change (e.g.,
  345. # BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask
  346. # in checkpoints but always use the one user provided in training script.
  347. # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
  348. # Thus here we keep the exp_avg_mask unchanged when loading checkpoint
  349. for i, group in enumerate(self.param_groups):
  350. if 'exp_avg_mask' in group:
  351. state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
  352. elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict['param_groups'][i]:
  353. state_dict['param_groups'][i].pop('exp_avg_mask')
  354. super().load_state_dict(state_dict)
  355. # need to reset the fused momentum since loading states will break the linking
  356. del self.exp_avg_flat[:]
  357. self.dummy_exp_avg.clear()
  358. del self.corrected_tensor_sizes[:]
  359. del self.server_chunk_sizes[:]
  360. if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
  361. if dist.get_rank() == 0:
  362. print("Checkpoint loaded and OnebitLamb warmup stage starts/continues.")
  363. if self.lamb_freeze_key is True:
  364. self.lamb_freeze_key = False
  365. if self.using_pipeline:
  366. self.deepspeed.pipeline_enable_backward_allreduce = True
  367. else:
  368. self.deepspeed.enable_backward_allreduce = True
  369. for group in self.param_groups:
  370. for p in group['params']:
  371. self.state[p]['lamb_coeff_freeze'] = 0.0
  372. self.state[p]['last_factor'] = 1.0
  373. if 'scaling_coeff' in self.state[p]:
  374. self.state[p].pop('scaling_coeff')
  375. else:
  376. if dist.get_rank() == 0:
  377. print("Checkpoint loaded and OnebitLamb compression stage starts/continues.")
  378. if self.lamb_freeze_key is False:
  379. self.lamb_freeze_key = True
  380. if self.using_pipeline:
  381. self.deepspeed.pipeline_enable_backward_allreduce = False
  382. else:
  383. self.deepspeed.enable_backward_allreduce = False
  384. # We reset the compression errors when loading checkpoints for 3 reasons:
  385. # 1) The worker and server error at each GPU are distinct, so in current implementation
  386. # only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors.
  387. # If we want to save them correctly we need O(num_gpu*model_size) memory in order to
  388. # gather all the error, which is a very large memory requirement. It's possible to save
  389. # them in a distributed way, but it will make the checkpoint saving/loading much more complicated.
  390. # 2) Even if we are able to save the compression errors correctly, you need to have the
  391. # exact same number of GPUs in order to load them correctly.
  392. # 3) We verified on BERT pre-training that occasionally resetting the compression error
  393. # at checkpoint loading does not affect the convergence.
  394. # However, please avoid frequent checkpoint loading which could break the error
  395. # compensation mechanism thus affect the convergence.
  396. del self.worker_errors[:]
  397. del self.server_errors[:]
  398. def get_lamb_coeffs(self):
  399. return self.lamb_coeffs