fused_lamb.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. '''
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. Copyright NVIDIA/apex
  4. This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
  5. '''
  6. import types
  7. import torch
  8. from ..op_builder import FusedLambBuilder
  9. class FusedLamb(torch.optim.Optimizer):
  10. """Implements the LAMB algorithm. Currently GPU-only.
  11. LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
  12. https://arxiv.org/abs/1904.00962
  13. Arguments:
  14. params (iterable): iterable of parameters to optimize or dicts defining
  15. parameter groups.
  16. lr (float, optional): learning rate. (default: 1e-3)
  17. bias_correction (bool, optional): bias correction (default: True)
  18. betas (Tuple[float, float], optional): coefficients used for computing
  19. running averages of gradient and its square. (default: (0.9, 0.999))
  20. eps (float, optional): term added to the denominator to improve
  21. numerical stability. (default: 1e-8)
  22. eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
  23. adds eps to the bias-corrected second moment estimate before
  24. evaluating square root instead of adding it to the square root of
  25. second moment estimate as in the original paper. (default: False)
  26. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  27. max_grad_norm (float, optional): value used to clip global grad norm
  28. (default: 0.0)
  29. max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
  30. min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
  31. amsgrad (boolean, optional): NOT SUPPORTED in FusedLamb!
  32. """
  33. def __init__(self,
  34. params,
  35. lr=1e-3,
  36. bias_correction=True,
  37. betas=(0.9,
  38. 0.999),
  39. eps=1e-8,
  40. eps_inside_sqrt=False,
  41. weight_decay=0.,
  42. max_grad_norm=0.,
  43. max_coeff=10.0,
  44. min_coeff=0.01,
  45. amsgrad=False):
  46. self.fused_lamb_cuda = FusedLambBuilder().load()
  47. if amsgrad:
  48. raise RuntimeError('FusedLamb does not support the AMSGrad variant.')
  49. defaults = dict(lr=lr,
  50. bias_correction=bias_correction,
  51. betas=betas,
  52. eps=eps,
  53. weight_decay=weight_decay,
  54. max_grad_norm=max_grad_norm,
  55. max_coeff=max_coeff,
  56. min_coeff=min_coeff)
  57. super(FusedLamb, self).__init__(params, defaults)
  58. self.eps_mode = 0 if eps_inside_sqrt else 1
  59. self.lamb_coeffs = []
  60. def step(self,
  61. closure=None,
  62. grads=None,
  63. output_params=None,
  64. scale=1.,
  65. grad_norms=None):
  66. """Performs a single optimization step.
  67. Arguments:
  68. closure (callable, optional): A closure that reevaluates the model
  69. and returns the loss.
  70. grads (list of tensors, optional): weight gradient to use for the
  71. optimizer update. If gradients have type torch.half, parameters
  72. are expected to be in type torch.float. (default: None)
  73. output params (list of tensors, optional): A reduced precision copy
  74. of the updated weights written out in addition to the regular
  75. updated weights. Have to be of same type as gradients. (default: None)
  76. scale (float, optional): factor to divide gradient tensor values
  77. by before applying to weights. (default: 1)
  78. """
  79. loss = None
  80. if closure is not None:
  81. loss = closure()
  82. if grads is None:
  83. grads_group = [None] * len(self.param_groups)
  84. # backward compatibility
  85. # assuming a list/generator of parameter means single group
  86. elif isinstance(grads, types.GeneratorType):
  87. grads_group = [grads]
  88. elif type(grads[0]) != list:
  89. grads_group = [grads]
  90. else:
  91. grads_group = grads
  92. if output_params is None:
  93. output_params_group = [None] * len(self.param_groups)
  94. elif isinstance(output_params, types.GeneratorType):
  95. output_params_group = [output_params]
  96. elif type(output_params[0]) != list:
  97. output_params_group = [output_params]
  98. else:
  99. output_params_group = output_params
  100. if grad_norms is None:
  101. grad_norms = [None] * len(self.param_groups)
  102. #remove the previous coeffs
  103. del self.lamb_coeffs[:]
  104. for group, grads_this_group, output_params_this_group, grad_norm_group in zip(self.param_groups, grads_group, output_params_group, grad_norms):
  105. if grads_this_group is None:
  106. grads_this_group = [None] * len(group['params'])
  107. if output_params_this_group is None:
  108. output_params_this_group = [None] * len(group['params'])
  109. if grad_norm_group is None:
  110. grad_norm_group = [None] * len(group['params'])
  111. elif not isinstance(grad_norm_group, list):
  112. grad_norm_group = [grad_norm_group]
  113. bias_correction = 1 if group['bias_correction'] else 0
  114. for p, grad, output_param, grad_norm in zip(group['params'], grads_this_group, output_params_this_group, grad_norm_group):
  115. # compute combined scale factor for this group
  116. combined_scale = scale
  117. if group['max_grad_norm'] > 0:
  118. # norm is in fact norm*scale
  119. clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
  120. if clip > 1:
  121. combined_scale = clip * scale
  122. #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
  123. if p.grad is None and grad is None:
  124. continue
  125. if grad is None:
  126. grad = p.grad.data
  127. if grad.is_sparse:
  128. raise RuntimeError('FusedLamb does not support sparse gradients')
  129. state = self.state[p]
  130. # State initialization
  131. if len(state) == 0:
  132. state['step'] = 0
  133. # Exponential moving average of gradient values
  134. state['exp_avg'] = torch.zeros_like(p.data)
  135. # Exponential moving average of squared gradient values
  136. state['exp_avg_sq'] = torch.zeros_like(p.data)
  137. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  138. beta1, beta2 = group['betas']
  139. max_coeff = group['max_coeff']
  140. min_coeff = group['min_coeff']
  141. state['step'] += 1
  142. out_p = torch.tensor(
  143. [],
  144. dtype=torch.float) if output_param is None else output_param
  145. lamb_coeff = self.fused_lamb_cuda.lamb(p.data,
  146. out_p,
  147. exp_avg,
  148. exp_avg_sq,
  149. grad,
  150. group['lr'],
  151. beta1,
  152. beta2,
  153. max_coeff,
  154. min_coeff,
  155. group['eps'],
  156. combined_scale,
  157. state['step'],
  158. self.eps_mode,
  159. bias_correction,
  160. group['weight_decay'])
  161. self.lamb_coeffs.append(lamb_coeff)
  162. return loss
  163. def get_lamb_coeffs(self):
  164. lamb_coeffs = [lamb_coeff.item() for lamb_coeff in self.lamb_coeffs]
  165. return lamb_coeffs