fused_lamb.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Copyright NVIDIA/apex
  6. This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
  7. """
  8. import types
  9. import torch
  10. from deepspeed.ops.op_builder import FusedLambBuilder
  11. class FusedLamb(torch.optim.Optimizer):
  12. """Implements the LAMB algorithm. Currently GPU-only.
  13. LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
  14. https://arxiv.org/abs/1904.00962
  15. Arguments:
  16. params (iterable): iterable of parameters to optimize or dicts defining
  17. parameter groups.
  18. lr (float, optional): learning rate. (default: 1e-3)
  19. bias_correction (bool, optional): bias correction (default: True)
  20. betas (Tuple[float, float], optional): coefficients used for computing
  21. running averages of gradient and its square. (default: (0.9, 0.999))
  22. eps (float, optional): term added to the denominator to improve
  23. numerical stability. (default: 1e-8)
  24. eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
  25. adds eps to the bias-corrected second moment estimate before
  26. evaluating square root instead of adding it to the square root of
  27. second moment estimate as in the original paper. (default: False)
  28. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  29. max_grad_norm (float, optional): value used to clip global grad norm
  30. (default: 0.0)
  31. max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
  32. min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
  33. amsgrad (boolean, optional): NOT SUPPORTED in FusedLamb!
  34. """
  35. def __init__(self,
  36. params,
  37. lr=1e-3,
  38. bias_correction=True,
  39. betas=(0.9, 0.999),
  40. eps=1e-8,
  41. eps_inside_sqrt=False,
  42. weight_decay=0.,
  43. max_grad_norm=0.,
  44. max_coeff=10.0,
  45. min_coeff=0.01,
  46. amsgrad=False):
  47. self.fused_lamb_cuda = FusedLambBuilder().load()
  48. if amsgrad:
  49. raise RuntimeError('FusedLamb does not support the AMSGrad variant.')
  50. defaults = dict(lr=lr,
  51. bias_correction=bias_correction,
  52. betas=betas,
  53. eps=eps,
  54. weight_decay=weight_decay,
  55. max_grad_norm=max_grad_norm,
  56. max_coeff=max_coeff,
  57. min_coeff=min_coeff)
  58. super(FusedLamb, self).__init__(params, defaults)
  59. self.eps_mode = 0 if eps_inside_sqrt else 1
  60. self.lamb_coeffs = []
  61. def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
  62. """Performs a single optimization step.
  63. Arguments:
  64. closure (callable, optional): A closure that reevaluates the model
  65. and returns the loss.
  66. grads (list of tensors, optional): weight gradient to use for the
  67. optimizer update. If gradients have type torch.half, parameters
  68. are expected to be in type torch.float. (default: None)
  69. output params (list of tensors, optional): A reduced precision copy
  70. of the updated weights written out in addition to the regular
  71. updated weights. Have to be of same type as gradients. (default: None)
  72. scale (float, optional): factor to divide gradient tensor values
  73. by before applying to weights. (default: 1)
  74. """
  75. loss = None
  76. if closure is not None:
  77. loss = closure()
  78. if grads is None:
  79. grads_group = [None] * len(self.param_groups)
  80. # backward compatibility
  81. # assuming a list/generator of parameter means single group
  82. elif isinstance(grads, types.GeneratorType):
  83. grads_group = [grads]
  84. elif type(grads[0]) != list:
  85. grads_group = [grads]
  86. else:
  87. grads_group = grads
  88. if output_params is None:
  89. output_params_group = [None] * len(self.param_groups)
  90. elif isinstance(output_params, types.GeneratorType):
  91. output_params_group = [output_params]
  92. elif type(output_params[0]) != list:
  93. output_params_group = [output_params]
  94. else:
  95. output_params_group = output_params
  96. if grad_norms is None:
  97. grad_norms = [None] * len(self.param_groups)
  98. #remove the previous coeffs
  99. del self.lamb_coeffs[:]
  100. for group, grads_this_group, output_params_this_group, grad_norm_group in zip(
  101. self.param_groups, grads_group, output_params_group, grad_norms):
  102. if grads_this_group is None:
  103. grads_this_group = [None] * len(group['params'])
  104. if output_params_this_group is None:
  105. output_params_this_group = [None] * len(group['params'])
  106. if grad_norm_group is None:
  107. grad_norm_group = [None] * len(group['params'])
  108. elif not isinstance(grad_norm_group, list):
  109. grad_norm_group = [grad_norm_group]
  110. bias_correction = 1 if group['bias_correction'] else 0
  111. for p, grad, output_param, grad_norm in zip(group['params'], grads_this_group, output_params_this_group,
  112. grad_norm_group):
  113. # compute combined scale factor for this group
  114. combined_scale = scale
  115. if group['max_grad_norm'] > 0:
  116. # norm is in fact norm*scale
  117. clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
  118. if clip > 1:
  119. combined_scale = clip * scale
  120. #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
  121. if p.grad is None and grad is None:
  122. continue
  123. if grad is None:
  124. grad = p.grad.data
  125. if grad.is_sparse:
  126. raise RuntimeError('FusedLamb does not support sparse gradients')
  127. state = self.state[p]
  128. # State initialization
  129. if len(state) == 0:
  130. state['step'] = 0
  131. # Exponential moving average of gradient values
  132. state['exp_avg'] = torch.zeros_like(p.data)
  133. # Exponential moving average of squared gradient values
  134. state['exp_avg_sq'] = torch.zeros_like(p.data)
  135. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  136. beta1, beta2 = group['betas']
  137. max_coeff = group['max_coeff']
  138. min_coeff = group['min_coeff']
  139. state['step'] += 1
  140. out_p = torch.tensor([], dtype=torch.float) if output_param is None else output_param
  141. lamb_coeff = self.fused_lamb_cuda.lamb(p.data, out_p, exp_avg, exp_avg_sq, grad, group['lr'], beta1,
  142. beta2, max_coeff, min_coeff, group['eps'], combined_scale,
  143. state['step'], self.eps_mode, bias_correction,
  144. group['weight_decay'])
  145. self.lamb_coeffs.append(lamb_coeff)
  146. return loss
  147. def get_lamb_coeffs(self):
  148. lamb_coeffs = [lamb_coeff.item() for lamb_coeff in self.lamb_coeffs]
  149. return lamb_coeffs