linear.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. #Linear Module to use with ZeRO Stage 3 to allow for parameter memory release
  5. #after the module execution during forward
  6. #Instead of saving variables using save_for_backward, we save variable ids
  7. #Allowing us to retrieve the variable without creating pointer to it
  8. #Which allows for underlying tensor to be garbage collected
  9. #When partitioned as needed by the Zero Stage 3 optimizer
  10. #TODO instead of patching Linear module, we could patch the ctx.save_for_backward
  11. #ctx.saved_tensors so that this approach works for all nn modules that are built upon
  12. #torch.nn.function. However the issue is that many modules uses C++ implementations
  13. #which does not have pytorch implementation. Eg torch.addmm which acts as a functional
  14. #when implemented outside of torch.autograd.Function
  15. import math
  16. import torch
  17. from torch import Tensor
  18. from torch.nn.parameter import Parameter
  19. from torch.nn import init
  20. from torch.nn.modules.module import Module
  21. from deepspeed.runtime.utils import noop_decorator
  22. from deepspeed import comm as dist
  23. from deepspeed.accelerator import get_accelerator
  24. def print_rank_0(message, debug=False, force=False):
  25. if dist.get_rank() == 0 and (debug or force):
  26. print(message)
  27. try:
  28. autocast_custom_fwd = get_accelerator().amp().custom_fwd
  29. autocast_custom_bwd = get_accelerator().amp().custom_bwd
  30. except (ImportError, AttributeError) as exp:
  31. autocast_custom_fwd = noop_decorator
  32. autocast_custom_bwd = noop_decorator
  33. class LinearFunctionForZeroStage3(torch.autograd.Function):
  34. # Note that both forward and backward are @staticmethods
  35. @staticmethod
  36. @autocast_custom_fwd
  37. # bias is an optional argument
  38. def forward(ctx, input, weight, bias=None):
  39. ctx.save_for_backward(input, weight, bias)
  40. if input.dim() == 2 and bias is not None:
  41. # fused op is marginally faster
  42. ret = torch.addmm(bias, input, weight.t())
  43. else:
  44. output = input.matmul(weight.t())
  45. if bias is not None:
  46. output += bias
  47. ret = output
  48. return ret
  49. # This function has only a single output, so it gets only one gradient
  50. @staticmethod
  51. @autocast_custom_bwd
  52. def backward(ctx, grad_output):
  53. # This is a pattern that is very convenient - at the top of backward
  54. # unpack saved_tensors and initialize all gradients w.r.t. inputs to
  55. # None. Thanks to the fact that additional trailing Nones are
  56. # ignored, the return statement is simple even when the function has
  57. # optional inputs.
  58. input, weight, bias = ctx.saved_tensors
  59. grad_input = grad_weight = grad_bias = None
  60. #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
  61. # These needs_input_grad checks are optional and there only to
  62. # improve efficiency. If you want to make your code simpler, you can
  63. # skip them. Returning gradients for inputs that don't require it is
  64. # not an error.
  65. if ctx.needs_input_grad[0]:
  66. #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
  67. grad_input = grad_output.matmul(weight)
  68. #print(f"Computed grad input {grad_input.shape}")
  69. if ctx.needs_input_grad[1]:
  70. #print("Computing grad weight")
  71. dim = grad_output.dim()
  72. if dim > 2:
  73. grad_weight = grad_output.reshape(-1,
  74. grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))
  75. else:
  76. grad_weight = grad_output.t().matmul(input)
  77. #print(f"Computed grad weight grad_weight {grad_weight.shape}")
  78. if bias is not None and ctx.needs_input_grad[2]:
  79. #print("Computing grad bias")
  80. if dim > 2:
  81. grad_bias = grad_output.sum([i for i in range(dim - 1)])
  82. else:
  83. grad_bias = grad_output.sum(0)
  84. #print("Done computing grad bias")
  85. #print("needs bias")
  86. #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
  87. return grad_input, grad_weight, grad_bias
  88. def zero3_linear_wrap(input, weight, bias=None):
  89. if bias is None:
  90. return LinearFunctionForZeroStage3.apply(input, weight)
  91. else:
  92. return LinearFunctionForZeroStage3.apply(input, weight, bias)
  93. class LinearModuleForZeroStage3(Module):
  94. r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
  95. The weights are pre-transposed and stored as A^T instead of transposing during each
  96. forward. Memory savings proportional to the parameter size.
  97. Args:
  98. in_features: size of each input sample
  99. out_features: size of each output sample
  100. bias: If set to ``False``, the layer will not learn an additive bias.
  101. Default: ``True``
  102. Shape:
  103. - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
  104. additional dimensions and :math:`H_{in} = \text{in\_features}`
  105. - Output: :math:`(N, *, H_{out})` where all but the last dimension
  106. are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
  107. Attributes:
  108. weight: the learnable weights of the module of shape
  109. :math:`(\text{out\_features}, \text{in\_features})`. The values are
  110. initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  111. :math:`k = \frac{1}{\text{in\_features}}`
  112. bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
  113. If :attr:`bias` is ``True``, the values are initialized from
  114. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  115. :math:`k = \frac{1}{\text{in\_features}}`
  116. Examples::
  117. >>> m = nn.Linear(20, 30)
  118. >>> input = torch.randn(128, 20)
  119. >>> output = m(input)
  120. >>> print(output.size())
  121. torch.Size([128, 30])
  122. """
  123. __constants__ = ['in_features', 'out_features']
  124. in_features: int
  125. out_features: int
  126. weight: Tensor
  127. def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
  128. super(LinearModuleForZeroStage3, self).__init__()
  129. print("Building ZeRO module")
  130. self.in_features = in_features
  131. self.out_features = out_features
  132. self.weight = Parameter(torch.Tensor(out_features, in_features))
  133. if bias:
  134. self.bias = Parameter(torch.Tensor(out_features))
  135. else:
  136. self.register_parameter('bias', None)
  137. self.reset_parameters()
  138. def reset_parameters(self) -> None:
  139. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  140. if self.bias is not None:
  141. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  142. bound = 1 / math.sqrt(fan_in)
  143. init.uniform_(self.bias, -bound, bound)
  144. def forward(self, input: Tensor) -> Tensor:
  145. return LinearFunctionForZeroStage3.apply(input, self.weight, self.bias)
  146. def extra_repr(self) -> str:
  147. return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias
  148. is not None)