loss_scaler.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright 2019 The Microsoft DeepSpeed Team
  2. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #Taken and modified for DeepSpeed from:
  16. # https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
  17. #Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
  18. import torch
  19. INITIAL_LOSS_SCALE = 'init_scale'
  20. SCALE_WINDOW = 'scale_window'
  21. DELAYED_SHIFT = 'delayed_shift'
  22. MIN_LOSS_SCALE = 'min_scale'
  23. # item() is a recent addition, so this helps with backward compatibility.
  24. def to_python_float(t):
  25. if hasattr(t, 'item'):
  26. return t.item()
  27. return t[0]
  28. class LossScalerBase:
  29. """LossScalarBase
  30. Base class for a loss scaler
  31. """
  32. def __init__(self, cur_scale):
  33. self.cur_scale = cur_scale
  34. self.dynamic = False
  35. @property
  36. def loss_scale(self):
  37. return self.cur_scale
  38. def scale_gradient(self, module, grad_in, grad_out):
  39. return tuple(self.loss_scale * g for g in grad_in)
  40. def update_scale(self, overflow):
  41. pass
  42. def backward(self, loss, retain_graph=False):
  43. scaled_loss = loss * self.loss_scale
  44. scaled_loss.backward(retain_graph=retain_graph)
  45. class LossScaler(LossScalerBase):
  46. """
  47. Class that manages a static loss scale. This class is intended to interact with
  48. :class:`FP16_Optimizer`, and should not be directly manipulated by the user.
  49. Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
  50. :class:`FP16_Optimizer`'s constructor.
  51. Args:
  52. scale (float, optional, default=1.0): The loss scale.
  53. """
  54. def __init__(self, scale=1):
  55. super(LossScaler, self).__init__(scale)
  56. # `params` is a list / generator of torch.Variable
  57. def has_overflow(self, params):
  58. return False
  59. # `x` is a torch.Tensor
  60. def _has_inf_or_nan(x):
  61. return False
  62. class DynamicLossScaler(LossScalerBase):
  63. """
  64. Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
  65. indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
  66. :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
  67. operates, because the default options can be changed using the
  68. the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
  69. Loss scaling is designed to combat the problem of underflowing gradients encountered at long
  70. times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
  71. scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
  72. encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
  73. occurred.
  74. :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
  75. and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
  76. If a certain number of iterations occur without overflowing gradients detected,
  77. :class:`DynamicLossScaler` increases the loss scale once more.
  78. In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
  79. always using the highest loss scale possible without incurring overflow.
  80. Args:
  81. init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
  82. scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
  83. scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
  84. """
  85. def __init__(self,
  86. init_scale=2**32,
  87. scale_factor=2.,
  88. scale_window=1000,
  89. min_scale=1,
  90. delayed_shift=1,
  91. consecutive_hysteresis=False,
  92. raise_error_at_min_scale=True):
  93. super(DynamicLossScaler, self).__init__(init_scale)
  94. self.cur_iter = 0
  95. self.last_overflow_iter = -1
  96. self.scale_factor = scale_factor
  97. self.scale_window = scale_window
  98. self.min_scale = min_scale
  99. self.delayed_shift = delayed_shift
  100. self.cur_hysteresis = delayed_shift
  101. self.consecutive_hysteresis = consecutive_hysteresis
  102. self.raise_error_at_min_scale = raise_error_at_min_scale
  103. self.dynamic = True
  104. # `params` is a list / generator of torch.Variable
  105. def has_overflow_serial(self, params):
  106. for p in params:
  107. if p.grad is not None and self._has_inf_or_nan(p.grad.data):
  108. return True
  109. return False
  110. # `x` is a torch.Tensor
  111. def _has_inf_or_nan(x):
  112. try:
  113. # if x is half, the .float() incurs an additional deep copy, but it's necessary if
  114. # Pytorch's .sum() creates a one-element tensor of the same type as x
  115. # (which is true for some recent version of pytorch).
  116. cpu_sum = float(x.float().sum())
  117. # More efficient version that can be used if .sum() returns a Python scalar
  118. # cpu_sum = float(x.sum())
  119. except RuntimeError as instance:
  120. # We want to check if inst is actually an overflow exception.
  121. # RuntimeError could come from a different error.
  122. # If so, we still want the exception to propagate.
  123. if "value cannot be converted" not in instance.args[0]:
  124. raise
  125. return True
  126. else:
  127. if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:
  128. return True
  129. return False
  130. # `overflow` is boolean indicating whether the gradient overflowed
  131. def update_scale(self, overflow):
  132. if overflow:
  133. # self.cur_scale /= self.scale_factor
  134. if self.delayed_shift == 1 or self.cur_hysteresis == 1:
  135. if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale:
  136. raise Exception(
  137. "Current loss scale already at minimum - cannot decrease scale anymore. Exiting run."
  138. )
  139. self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
  140. else:
  141. self.cur_hysteresis -= 1
  142. self.last_overflow_iter = self.cur_iter
  143. else:
  144. if self.consecutive_hysteresis:
  145. self.cur_hysteresis = self.delayed_shift
  146. if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
  147. if not self.consecutive_hysteresis:
  148. self.cur_hysteresis = self.delayed_shift
  149. self.cur_scale *= self.scale_factor
  150. self.cur_iter += 1
  151. # Although loss scaling is only defined for fp16, yet for backwards compatibility
  152. # we still create a scaler for other dtypes (fp32, bf16) which does not perform any scaling.
  153. def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args):
  154. if dtype == torch.half and dynamic_scaling:
  155. if dynamic_loss_args is None:
  156. return DynamicLossScaler()
  157. return DynamicLossScaler(**dynamic_loss_args)
  158. loss_scale_value = static_loss_scale if dtype == torch.half else 1.0
  159. return LossScaler(scale=loss_scale_value)
  160. ##############################################################
  161. # Example usage below here -- assuming it's in a separate file
  162. ##############################################################
  163. """
  164. TO-DO separate out into an example.
  165. if __name__ == "__main__":
  166. import torch
  167. from torch.autograd import Variable
  168. from dynamic_loss_scaler import DynamicLossScaler
  169. # N is batch size; D_in is input dimension;
  170. # H is hidden dimension; D_out is output dimension.
  171. N, D_in, H, D_out = 64, 1000, 100, 10
  172. # Create random Tensors to hold inputs and outputs, and wrap them in Variables.
  173. x = Variable(torch.randn(N, D_in), requires_grad=False)
  174. y = Variable(torch.randn(N, D_out), requires_grad=False)
  175. w1 = Variable(torch.randn(D_in, H), requires_grad=True)
  176. w2 = Variable(torch.randn(H, D_out), requires_grad=True)
  177. parameters = [w1, w2]
  178. learning_rate = 1e-6
  179. optimizer = torch.optim.SGD(parameters, lr=learning_rate)
  180. loss_scaler = DynamicLossScaler()
  181. for t in range(500):
  182. y_pred = x.mm(w1).clamp(min=0).mm(w2)
  183. loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
  184. print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
  185. print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
  186. print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
  187. # Run backprop
  188. optimizer.zero_grad()
  189. loss.backward()
  190. # Check for overflow
  191. has_overflow = DynamicLossScaler.has_overflow(parameters)
  192. # If no overflow, unscale grad and update as usual
  193. if not has_overflow:
  194. for param in parameters:
  195. param.grad.data.mul_(1. / loss_scaler.loss_scale)
  196. optimizer.step()
  197. # Otherwise, don't do anything -- ie, skip iteration
  198. else:
  199. print('fp16 dynamic loss scale overflow!')
  200. # Update loss scale for next iteration
  201. loss_scaler.update_scale(has_overflow)
  202. """