loss_scaler.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  6. Licensed under the Apache License, Version 2.0 (the "License");
  7. you may not use this file except in compliance with the License.
  8. You may obtain a copy of the License at
  9. http://www.apache.org/licenses/LICENSE-2.0
  10. Unless required by applicable law or agreed to in writing, software
  11. distributed under the License is distributed on an "AS IS" BASIS,
  12. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. See the License for the specific language governing permissions and
  14. limitations under the License.
  15. Taken and modified for DeepSpeed from:
  16. https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
  17. Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
  18. """
  19. import torch
  20. from deepspeed import comm as dist
  21. from deepspeed.utils import logger
  22. INITIAL_LOSS_SCALE = 'init_scale'
  23. SCALE_WINDOW = 'scale_window'
  24. DELAYED_SHIFT = 'delayed_shift'
  25. CONSECUTIVE_HYSTERESIS = 'consecutive_hysteresis'
  26. MIN_LOSS_SCALE = 'min_scale'
  27. # item() is a recent addition, so this helps with backward compatibility.
  28. def to_python_float(t):
  29. if hasattr(t, 'item'):
  30. return t.item()
  31. return t[0]
  32. class LossScalerBase:
  33. """LossScalarBase
  34. Base class for a loss scaler
  35. """
  36. def __init__(self, cur_scale):
  37. self.cur_scale = cur_scale
  38. self.dynamic = False
  39. @property
  40. def loss_scale(self):
  41. return self.cur_scale
  42. def scale_gradient(self, module, grad_in, grad_out):
  43. return tuple(self.loss_scale * g for g in grad_in)
  44. def update_scale(self, overflow):
  45. pass
  46. def backward(self, loss, retain_graph=False):
  47. scaled_loss = loss * self.loss_scale
  48. scaled_loss.backward(retain_graph=retain_graph)
  49. # print(f'LossScalerBackward: {scaled_loss=}')
  50. class LossScaler(LossScalerBase):
  51. """
  52. Class that manages a static loss scale. This class is intended to interact with
  53. :class:`FP16_Optimizer`, and should not be directly manipulated by the user.
  54. Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
  55. :class:`FP16_Optimizer`'s constructor.
  56. Args:
  57. scale (float, optional, default=1.0): The loss scale.
  58. """
  59. def __init__(self, scale=1):
  60. super(LossScaler, self).__init__(scale)
  61. # `params` is a list / generator of torch.Variable
  62. def has_overflow(self, params):
  63. return False
  64. # `x` is a torch.Tensor
  65. def _has_inf_or_nan(x):
  66. return False
  67. class DynamicLossScaler(LossScalerBase):
  68. """
  69. Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
  70. indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
  71. :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
  72. operates, because the default options can be changed using the
  73. the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
  74. Loss scaling is designed to combat the problem of underflowing gradients encountered at long
  75. times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
  76. scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
  77. encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
  78. occurred.
  79. :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
  80. and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
  81. If a certain number of iterations occur without overflowing gradients detected,
  82. :class:`DynamicLossScaler` increases the loss scale once more.
  83. In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
  84. always using the highest loss scale possible without incurring overflow.
  85. Args:
  86. init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
  87. scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
  88. scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
  89. consecutive_hysteresis (bool, optional, default=False): Whether to refill hysteresis if we reach an iteration that doesn't overflow
  90. """
  91. def __init__(self,
  92. init_scale=2**32,
  93. scale_factor=2.,
  94. scale_window=1000,
  95. min_scale=1,
  96. delayed_shift=1,
  97. consecutive_hysteresis=False,
  98. raise_error_at_min_scale=True,
  99. dtype=torch.half):
  100. super(DynamicLossScaler, self).__init__(init_scale)
  101. self.cur_iter = 0
  102. self.last_overflow_iter = -1
  103. self.scale_factor = scale_factor
  104. self.scale_window = scale_window
  105. self.min_scale = min_scale
  106. self.delayed_shift = delayed_shift
  107. self.cur_hysteresis = delayed_shift
  108. self.consecutive_hysteresis = consecutive_hysteresis
  109. self.raise_error_at_min_scale = raise_error_at_min_scale
  110. self.dynamic = True
  111. self.dtype = dtype
  112. # `params` is a list / generator of torch.Variable
  113. def has_overflow_serial(self, params):
  114. for p in params:
  115. if p.grad is not None and self._has_inf_or_nan(p.grad.data):
  116. return True
  117. return False
  118. # `x` is a torch.Tensor
  119. def _has_inf_or_nan(x):
  120. try:
  121. # if x is half, the .float() incurs an additional deep copy, but it's necessary if
  122. # Pytorch's .sum() creates a one-element tensor of the same type as x
  123. # (which is true for some recent version of pytorch).
  124. cpu_sum = float(x.float().sum())
  125. # More efficient version that can be used if .sum() returns a Python scalar
  126. # cpu_sum = float(x.sum())
  127. except RuntimeError as instance:
  128. # We want to check if inst is actually an overflow exception.
  129. # RuntimeError could come from a different error.
  130. # If so, we still want the exception to propagate.
  131. if "value cannot be converted" not in instance.args[0]:
  132. raise
  133. return True
  134. else:
  135. if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:
  136. return True
  137. return False
  138. # `overflow` is boolean indicating whether the gradient overflowed
  139. def update_scale(self, overflow):
  140. if overflow:
  141. # self.cur_scale /= self.scale_factor
  142. if self.delayed_shift == 1 or self.cur_hysteresis == 1:
  143. if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale:
  144. raise Exception(
  145. "Current loss scale already at minimum - cannot decrease scale anymore. Exiting run.")
  146. else:
  147. next_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
  148. if dist.get_rank() == 0:
  149. overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
  150. if self.dtype == torch.half:
  151. overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, reducing to {int(next_scale)}"
  152. logger.info(overflow_msg)
  153. self.cur_scale = next_scale
  154. else:
  155. if dist.get_rank() == 0:
  156. overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
  157. if self.dtype == torch.half:
  158. overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, but hysteresis is {self.cur_hysteresis}. Reducing hysteresis to {self.cur_hysteresis-1}"
  159. logger.info(overflow_msg)
  160. self.cur_hysteresis -= 1
  161. self.last_overflow_iter = self.cur_iter
  162. else:
  163. if self.consecutive_hysteresis:
  164. if dist.get_rank() == 0:
  165. hysteresis_msg = f"Consecutive hysteresis is enabled. Restoring hysteresis to {self.delayed_shift}"
  166. logger.info(hysteresis_msg)
  167. self.cur_hysteresis = self.delayed_shift
  168. if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
  169. if not self.consecutive_hysteresis:
  170. self.cur_hysteresis = self.delayed_shift
  171. self.cur_scale *= self.scale_factor
  172. self.cur_iter += 1
  173. # Although loss scaling is only defined for fp16, yet for backwards compatibility
  174. # we still create a scaler for other dtypes (fp32, bf16) which does not perform any scaling.
  175. def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args):
  176. if dtype == torch.half and dynamic_scaling:
  177. if dynamic_loss_args is None:
  178. return DynamicLossScaler(dtype=dtype)
  179. return DynamicLossScaler(dtype=dtype, **dynamic_loss_args)
  180. loss_scale_value = static_loss_scale if dtype == torch.half else 1.0
  181. return LossScaler(scale=loss_scale_value)
  182. ##############################################################
  183. # Example usage below here -- assuming it's in a separate file
  184. ##############################################################
  185. """
  186. TO-DO separate out into an example.
  187. if __name__ == "__main__":
  188. import torch
  189. from torch.autograd import Variable
  190. from dynamic_loss_scaler import DynamicLossScaler
  191. # N is batch size; D_in is input dimension;
  192. # H is hidden dimension; D_out is output dimension.
  193. N, D_in, H, D_out = 64, 1000, 100, 10
  194. # Create random Tensors to hold inputs and outputs, and wrap them in Variables.
  195. x = Variable(torch.randn(N, D_in), requires_grad=False)
  196. y = Variable(torch.randn(N, D_out), requires_grad=False)
  197. w1 = Variable(torch.randn(D_in, H), requires_grad=True)
  198. w2 = Variable(torch.randn(H, D_out), requires_grad=True)
  199. parameters = [w1, w2]
  200. learning_rate = 1e-6
  201. optimizer = torch.optim.SGD(parameters, lr=learning_rate)
  202. loss_scaler = DynamicLossScaler()
  203. for t in range(500):
  204. y_pred = x.mm(w1).clamp(min=0).mm(w2)
  205. loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
  206. print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
  207. print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
  208. print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
  209. # Run backprop
  210. optimizer.zero_grad()
  211. loss.backward()
  212. # Check for overflow
  213. has_overflow = DynamicLossScaler.has_overflow(parameters)
  214. # If no overflow, unscale grad and update as usual
  215. if not has_overflow:
  216. for param in parameters:
  217. param.grad.data.mul_(1. / loss_scaler.loss_scale)
  218. optimizer.step()
  219. # Otherwise, don't do anything -- ie, skip iteration
  220. else:
  221. print('fp16 dynamic loss scale overflow!')
  222. # Update loss scale for next iteration
  223. loss_scaler.update_scale(has_overflow)
  224. """