linear.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #Linear Module to use with ZeRO Stage 3 to allow for parameter memory release
  2. #after the module execution during forward
  3. #Instead of saving variables using save_for_backward, we save variable ids
  4. #Allowing us to retrieve the variable without creating pointer to it
  5. #Which allows for underlying tensor to be garbage collected
  6. #When partitioned as needed by the Zero Stage 3 optimizer
  7. #TODO instead of patching Linear module, we could patch the ctx.save_for_backward
  8. #ctx.saved_tensors so that this approach works for all nn modules that are built upon
  9. #torch.nn.function. However the issue is that many modules uses C++ implementations
  10. #which does not have pytorch implementation. Eg torch.addmm which acts as a functional
  11. #when implemented outside of torch.autograd.Function
  12. import math
  13. import torch
  14. from torch import Tensor
  15. from torch.nn.parameter import Parameter
  16. from torch.nn import init
  17. from torch.nn.modules.module import Module
  18. from deepspeed.runtime.utils import noop_decorator
  19. tensor_map = {}
  20. def print_rank_0(message, debug=False, force=False):
  21. if torch.distributed.get_rank() == 0 and (debug or force):
  22. print(message)
  23. try:
  24. autocast_custom_fwd = torch.cuda.amp.custom_fwd
  25. autocast_custom_bwd = torch.cuda.amp.custom_bwd
  26. except (ImportError, AttributeError) as exp:
  27. autocast_custom_fwd = noop_decorator
  28. autocast_custom_bwd = noop_decorator
  29. class LinearFunctionForZeroStage3(torch.autograd.Function):
  30. # Note that both forward and backward are @staticmethods
  31. @staticmethod
  32. @autocast_custom_fwd
  33. # bias is an optional argument
  34. def forward(ctx, input, weight, bias=None):
  35. #print("In ZeRO Linear Function")
  36. weight_id = id(weight)
  37. bias_id = id(bias)
  38. #ctx.save_for_backward(input, weight, bias)
  39. ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id))
  40. tensor_map[weight_id] = weight
  41. tensor_map[bias_id] = bias
  42. if input.dim() == 2 and bias is not None:
  43. # fused op is marginally faster
  44. ret = torch.addmm(bias, input, weight.t())
  45. else:
  46. output = input.matmul(weight.t())
  47. if bias is not None:
  48. output += bias
  49. ret = output
  50. return ret
  51. # This function has only a single output, so it gets only one gradient
  52. @staticmethod
  53. @autocast_custom_bwd
  54. def backward(ctx, grad_output):
  55. # This is a pattern that is very convenient - at the top of backward
  56. # unpack saved_tensors and initialize all gradients w.r.t. inputs to
  57. # None. Thanks to the fact that additional trailing Nones are
  58. # ignored, the return statement is simple even when the function has
  59. # optional inputs.
  60. #input, weight, bias = ctx.saved_tensors
  61. input, weight_id, bias_id = ctx.saved_tensors
  62. weight = tensor_map[weight_id.item()]
  63. bias = tensor_map[bias_id.item()]
  64. grad_input = grad_weight = grad_bias = None
  65. #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
  66. # These needs_input_grad checks are optional and there only to
  67. # improve efficiency. If you want to make your code simpler, you can
  68. # skip them. Returning gradients for inputs that don't require it is
  69. # not an error.
  70. if ctx.needs_input_grad[0]:
  71. #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
  72. grad_input = grad_output.matmul(weight)
  73. #print(f"Computed grad input {grad_input.shape}")
  74. if ctx.needs_input_grad[1]:
  75. #print("Computing grad weight")
  76. dim = grad_output.dim()
  77. if dim > 2:
  78. grad_weight = grad_output.reshape(-1,
  79. grad_output.shape[-1]).t().matmul(
  80. input.reshape(-1,
  81. input.shape[-1]))
  82. else:
  83. grad_weight = grad_output.t().matmul(input)
  84. #print(f"Computed grad weight grad_weight {grad_weight.shape}")
  85. if bias is not None and ctx.needs_input_grad[2]:
  86. #print("Computing grad bias")
  87. grad_bias = grad_output.sum(0)
  88. #print("Done computing grad bias")
  89. #print("needs bias")
  90. #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
  91. return grad_input, grad_weight, grad_bias
  92. class LinearModuleForZeroStage3(Module):
  93. r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
  94. The weights are pre-transposed and stored as A^T instead of transposing during each
  95. forward. Memory savings proportional to the parameter size.
  96. Args:
  97. in_features: size of each input sample
  98. out_features: size of each output sample
  99. bias: If set to ``False``, the layer will not learn an additive bias.
  100. Default: ``True``
  101. Shape:
  102. - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
  103. additional dimensions and :math:`H_{in} = \text{in\_features}`
  104. - Output: :math:`(N, *, H_{out})` where all but the last dimension
  105. are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
  106. Attributes:
  107. weight: the learnable weights of the module of shape
  108. :math:`(\text{out\_features}, \text{in\_features})`. The values are
  109. initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  110. :math:`k = \frac{1}{\text{in\_features}}`
  111. bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
  112. If :attr:`bias` is ``True``, the values are initialized from
  113. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  114. :math:`k = \frac{1}{\text{in\_features}}`
  115. Examples::
  116. >>> m = nn.Linear(20, 30)
  117. >>> input = torch.randn(128, 20)
  118. >>> output = m(input)
  119. >>> print(output.size())
  120. torch.Size([128, 30])
  121. """
  122. __constants__ = ['in_features', 'out_features']
  123. in_features: int
  124. out_features: int
  125. weight: Tensor
  126. def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
  127. super(LinearModuleForZeroStage3, self).__init__()
  128. print("Building ZeRO module")
  129. self.in_features = in_features
  130. self.out_features = out_features
  131. self.weight = Parameter(torch.Tensor(out_features, in_features))
  132. if bias:
  133. self.bias = Parameter(torch.Tensor(out_features))
  134. else:
  135. self.register_parameter('bias', None)
  136. self.reset_parameters()
  137. def reset_parameters(self) -> None:
  138. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  139. if self.bias is not None:
  140. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  141. bound = 1 / math.sqrt(fan_in)
  142. init.uniform_(self.bias, -bound, bound)
  143. def forward(self, input: Tensor) -> Tensor:
  144. return LinearFunctionForZeroStage3.apply(input, self.weight, self.bias)
  145. def extra_repr(self) -> str:
  146. return 'in_features={}, out_features={}, bias={}'.format(
  147. self.in_features,
  148. self.out_features,
  149. self.bias is not None)