adam.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import types
  5. import torch
  6. import numpy as np
  7. from deepspeed.accelerator import get_accelerator
  8. from deepspeed.runtime.utils import required_torch_version
  9. from deepspeed import comm as dist
  10. class OnebitAdam(torch.optim.Optimizer):
  11. """Implements the 1-bit Adam algorithm. Currently GPU-only.
  12. For usage example please see https://www.deepspeed.ai/tutorials/onebit-adam/
  13. For technical details please read https://arxiv.org/abs/2102.02888
  14. Arguments:
  15. params (iterable): iterable of parameters to optimize or dicts defining
  16. parameter groups.
  17. lr (float, optional): learning rate. (default: 1e-3)
  18. freeze_step (int, optional): Number of steps for warmup (uncompressed)
  19. stage before we start using compressed communication. (default 100000)
  20. betas (Tuple[float, float], optional): coefficients used for computing
  21. running averages of gradient and its square. (default: (0.9, 0.999))
  22. eps (float, optional): term added to the denominator to improve
  23. numerical stability. (default: 1e-8)
  24. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  25. amsgrad (boolean, optional): whether to use the AMSGrad variant of this
  26. algorithm from the paper `On the Convergence of Adam and Beyond`_
  27. (default: False) NOT SUPPORTED in 1-bit Adam!
  28. eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
  29. adds eps to the bias-corrected second moment estimate before
  30. evaluating square root instead of adding it to the square root of
  31. second moment estimate as in the original paper. (default: False)
  32. cuda_aware (boolean, required): Set True if the underlying MPI implementation
  33. supports CUDA-Aware communication. (default: False)
  34. comm_backend_name (string, optional): Set to 'mpi' if needed. (default: 'nccl')
  35. .. _Adam\\: A Method for Stochastic Optimization:
  36. https://arxiv.org/abs/1412.6980
  37. .. _On the Convergence of Adam and Beyond:
  38. https://openreview.net/forum?id=ryQu7f-RZ
  39. """
  40. def __init__(self,
  41. params,
  42. deepspeed=None,
  43. lr=1e-3,
  44. freeze_step=100000,
  45. bias_correction=True,
  46. betas=(0.9, 0.999),
  47. eps=1e-8,
  48. eps_inside_sqrt=False,
  49. weight_decay=0.,
  50. max_grad_norm=0.,
  51. amsgrad=False,
  52. cuda_aware=False,
  53. comm_backend_name='nccl'):
  54. if amsgrad:
  55. raise RuntimeError('1-bit Adam does not support the AMSGrad variant.')
  56. defaults = dict(lr=lr,
  57. bias_correction=bias_correction,
  58. betas=betas,
  59. eps=eps,
  60. weight_decay=weight_decay,
  61. max_grad_norm=max_grad_norm)
  62. super(OnebitAdam, self).__init__(params, defaults)
  63. self.eps_mode = 0 if eps_inside_sqrt else 1
  64. self.comm_time = 0.0
  65. self.step_time = 0.0
  66. self.ave_step = 1
  67. self.bk_time = 0.0
  68. self.deepspeed = deepspeed
  69. self.adam_freeze_key = False
  70. self.initialize = False
  71. self.freeze_step = freeze_step
  72. self.cuda_aware = cuda_aware
  73. self.using_pipeline = False
  74. self.comm_backend_name = comm_backend_name
  75. assert dist.is_initialized(), "Please initialize the torch distributed backend."
  76. # Empty initializer. Set handle based on the comm backend as follows.
  77. self.comm_backend_handle = None
  78. if self.comm_backend_name == 'nccl':
  79. assert (
  80. required_torch_version(min_version=1.8)
  81. ), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
  82. from deepspeed.runtime.comm.nccl import NcclBackend
  83. self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
  84. self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)
  85. elif self.comm_backend_name == 'mpi':
  86. from deepspeed.runtime.comm.mpi import MpiBackend
  87. self.comm_backend_handle = MpiBackend(cuda_aware)
  88. elif self.comm_backend_name == 'hccl':
  89. from deepspeed.runtime.comm.hccl import HcclBackend
  90. self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
  91. self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
  92. self.size = self.comm_backend_handle.size
  93. self.divider = int(self.size * 8 / np.gcd(self.size, 8))
  94. def step(self, closure=None, grads=None):
  95. """Performs a single optimization step.
  96. Arguments:
  97. closure (callable, optional): A closure that reevaluates the model
  98. and returns the loss.
  99. grads (list of tensors, optional): weight gradient to use for the
  100. optimizer update. If gradients have type torch.half, parameters
  101. are expected to be in type torch.float. (default: None)
  102. output params (list of tensors, optional): A reduced precision copy
  103. of the updated weights written out in addition to the regular
  104. updated weights. Have to be of same type as gradients. (default: None)
  105. scale (float, optional): factor to divide gradient tensor values
  106. by before applying to weights. (default: 1)
  107. """
  108. loss = None
  109. if closure is not None:
  110. loss = closure()
  111. gather_time = 0
  112. allgather_time = 0
  113. all_time = 0
  114. if self.adam_freeze_key is False:
  115. v_diff_buffer = 0.0
  116. if grads is None:
  117. grads_group = [None] * len(self.param_groups)
  118. # backward compatibility
  119. # assuming a list/generator of parameter means single group
  120. elif isinstance(grads, types.GeneratorType):
  121. grads_group = [grads]
  122. elif type(grads[0]) != list:
  123. grads_group = [grads]
  124. else:
  125. grads_group = grads
  126. for group, grads_this_group in zip(self.param_groups, grads_group):
  127. if grads_this_group is None:
  128. grads_this_group = [None] * len(group['params'])
  129. bias_correction = 1 if group['bias_correction'] else 0
  130. for p, grad in zip(group['params'], grads_this_group):
  131. if p.grad is None and grad is None:
  132. continue
  133. if grad is None:
  134. grad = p.grad.data
  135. if grad.is_sparse:
  136. raise RuntimeError('1-bit Adam does not support sparse gradients')
  137. state = self.state[p]
  138. # State initialization
  139. if len(state) == 0:
  140. state['step'] = 0
  141. # Exponential moving average of gradient values
  142. state['exp_avg'] = torch.zeros_like(p.data)
  143. # Exponential moving average of squared gradient values
  144. state['exp_avg_sq'] = torch.zeros_like(p.data)
  145. if not self.initialize or (self.adam_freeze_key and 'worker_error' not in state.keys()):
  146. state['tensor_size'] = torch.numel(p.data)
  147. state['corrected_tensor_size'] = state['tensor_size']
  148. if state['tensor_size'] % (self.size * self.divider) != 0:
  149. state['corrected_tensor_size'] += ((self.size * self.divider) - (state['tensor_size'] %
  150. (self.size * self.divider)))
  151. state['server_chunk_size'] = state['corrected_tensor_size'] // self.size
  152. get_accelerator().empty_cache()
  153. state['worker_error'] = torch.zeros(state['corrected_tensor_size'], device=p.device)
  154. state['server_error'] = torch.zeros(state['server_chunk_size'], device=p.device)
  155. get_accelerator().empty_cache()
  156. self.adam_freeze_key = True
  157. if not self.initialize and dist.get_rank() == 0:
  158. print("Cupy Buffers Initialized Successfully.")
  159. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  160. beta1, beta2 = group['betas']
  161. state['step'] += 1
  162. if self.adam_freeze_key is False:
  163. exp_avg.mul_(beta1).add_(1 - beta1, grad)
  164. exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
  165. grad = None
  166. if self.initialize:
  167. update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
  168. else:
  169. if 'non_freeze' in group.keys() and group['non_freeze'] is True:
  170. dist.all_reduce(grad)
  171. grad.mul_(1 / dist.get_world_size())
  172. exp_avg.mul_(beta1).add_(1 - beta1, grad)
  173. exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
  174. grad = None
  175. else:
  176. if self.initialize is True:
  177. exp_avg.mul_(beta1).add_(1 - beta1, grad)
  178. grad = None
  179. if self.size > 1:
  180. exp_avg.set_(
  181. self.comm_backend_handle.compressed_allreduce(exp_avg, state['worker_error'],
  182. state['server_error'],
  183. self.deepspeed.local_rank))
  184. # Because 1-bit compression cannot represent exact zero, it is required to
  185. # provide a momentum mask for those params that have constant exact zeros in their
  186. # momentums, otherwise the compression error would keep accumulating.
  187. # For example, for BERT pre-training seq 128, bert.embeddings.position_embeddings.weight
  188. # always have exact zeros in its momentum for row 129 to 512, because it only
  189. # learns up to seq length 128 while the model supports up to 512 seq length.
  190. # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
  191. if 'exp_avg_mask' in group:
  192. if exp_avg.device != group['exp_avg_mask'].device:
  193. group['exp_avg_mask'] = group['exp_avg_mask'].to(device=exp_avg.device)
  194. exp_avg.mul_(group['exp_avg_mask'])
  195. if self.initialize:
  196. update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
  197. if self.initialize:
  198. if group['weight_decay'] > 0.0:
  199. update += group['weight_decay'] * p.data
  200. with torch.no_grad():
  201. p.add_(-group['lr'] * update)
  202. if not self.initialize:
  203. print('Pop out errors', flush=True)
  204. state.pop('worker_error')
  205. state.pop('server_error')
  206. if not self.initialize:
  207. self.adam_freeze_key = False
  208. self.initialize = True
  209. print(f"Finished the initialization step at rank {dist.get_rank()}")
  210. return loss
  211. if self.adam_freeze_key is False:
  212. if state['step'] >= self.freeze_step:
  213. print('OnebitAdam - starting compressed communication')
  214. self.adam_freeze_key = True
  215. if self.using_pipeline:
  216. self.deepspeed.pipeline_enable_backward_allreduce = False
  217. else:
  218. self.deepspeed.enable_backward_allreduce = False
  219. return loss
  220. def load_state_dict(self, state_dict):
  221. """
  222. Overrides load_state_dict() to add special handling when loading checkpoints
  223. """
  224. # Because at different stage exp_avg_mask may change (e.g.,
  225. # BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask
  226. # in checkpoints but always use the one user provided in training script.
  227. # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
  228. # Thus here we keep the exp_avg_mask unchanged when loading checkpoint
  229. for i, group in enumerate(self.param_groups):
  230. if 'exp_avg_mask' in group:
  231. state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
  232. elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict['param_groups'][i]:
  233. state_dict['param_groups'][i].pop('exp_avg_mask')
  234. super().load_state_dict(state_dict)
  235. if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
  236. if dist.get_rank() == 0:
  237. print("Checkpoint loaded and OnebitAdam warmup stage starts/continues.")
  238. if self.adam_freeze_key is True:
  239. self.adam_freeze_key = False
  240. if self.using_pipeline:
  241. self.deepspeed.pipeline_enable_backward_allreduce = True
  242. else:
  243. self.deepspeed.enable_backward_allreduce = True
  244. else:
  245. if dist.get_rank() == 0:
  246. print("Checkpoint loaded and OnebitAdam compression stage starts/continues.")
  247. if self.adam_freeze_key is False:
  248. self.adam_freeze_key = True
  249. if self.using_pipeline:
  250. self.deepspeed.pipeline_enable_backward_allreduce = False
  251. else:
  252. self.deepspeed.enable_backward_allreduce = False
  253. # We reset the compression errors when loading checkpoints for 3 reasons:
  254. # 1) The worker and server error at each GPU are distinct, so in current implementation
  255. # only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors.
  256. # If we want to save them correctly we need O(num_gpu*model_size) memory in order to
  257. # gather all the error, which is a very large memory requirement. It's possible to save
  258. # them in a distributed way, but it will make the checkpoint saving/loading much more complicated.
  259. # 2) Even if we are able to save the compression errors correctly, you need to have the
  260. # exact same number of GPUs in order to load them correctly.
  261. # 3) We verified on BERT pre-training that occasionally resetting the compression error
  262. # at checkpoint loading does not affect the convergence.
  263. # However, please avoid frequent checkpoint loading which could break the error
  264. # compensation mechanism thus affect the convergence.
  265. for group in self.param_groups:
  266. for p in group['params']:
  267. if 'worker_error' in self.state[p]:
  268. self.state[p].pop('worker_error')
  269. if 'server_error' in self.state[p]:
  270. self.state[p].pop('server_error')