tiling.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import deepspeed
  6. from deepspeed.runtime.utils import partition_uniform as partition
  7. def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=False):
  8. """Split a tensor along its last dimension. Adapted from Megatron-LM.
  9. Arguments:
  10. tensor: input tensor.
  11. partitions: list of partition sizes to supply to torch.split
  12. contiguous_split_chunks: If True, make each chunk contiguous
  13. in memory.
  14. """
  15. # Get the size and dimension.
  16. last_dim = tensor.dim() - 1
  17. # Split.
  18. tensor_list = torch.split(tensor, partitions, dim=last_dim)
  19. # Note: torch.split does not create contiguous tensors by default.
  20. if contiguous_split_chunks:
  21. return tuple(chunk.contiguous() for chunk in tensor_list)
  22. return tensor_list
  23. class TiledLinear(torch.nn.Module):
  24. def __init__(self,
  25. in_features,
  26. out_features,
  27. bias=True,
  28. in_splits=1,
  29. out_splits=1,
  30. input_is_already_split=False,
  31. combine_out_splits=True,
  32. linear_cls=torch.nn.Linear,
  33. init_linear=None,
  34. **kwargs):
  35. """A replacement for ``torch.nn.Linear`` that works with ZeRO-3 to reduce
  36. memory requirements via tiling.
  37. TiledLinear breaks the input and output dimensions of a linear layer
  38. into tiles that are processed in sequence. This class enables huge
  39. linear layers when combined with ZeRO-3 because inactive tiles can be
  40. partitioned and offloaded.
  41. .. note::
  42. We recommend using as few tiles as necessary. Tiling
  43. significantly reduces memory usage, but can reduce throughput
  44. for inexpensive layers. This due to the smaller kernels having
  45. less parallelism and lower arithmetic intensity, while
  46. introducing more frequent synchronization and communication.
  47. Args:
  48. in_features (int): See ``torch.nn.Linear``
  49. out_features (int): See ``torch.nn.Linear``
  50. bias (bool, optional): See ``torch.nn.Linear``
  51. in_splits (int, optional): The number of tiles along the input dimension. Defaults to 1.
  52. out_splits (int, optional): The number of tiles along the output dimension. Defaults to 1.
  53. input_is_already_split (bool, optional): If set to ``True``, assume that the ``input_`` in
  54. to ``forward()`` is already split into ``in_splits`` chunks. Defaults to ``False``.
  55. combine_out_splits (bool, optional): If set to ``False``, do not combine the ``out_splits`` outputs
  56. into a single tensor. Defaults to ``True``.
  57. linear_cls (class, optional): The underlying class to build individual tiles.
  58. Defaults to ``torch.nn.Linear``.
  59. init_linear (``torch.nn.Linear``, optional): If set, copy the parameters of
  60. ``init_linear``. Useful for debugging. Defaults to ``None``.
  61. kwargs (dict, optional): additional keyword arguments to provide to ``linear_cls()``.
  62. Raises:
  63. RuntimeError: ``in_splits`` must be within the range [1, in_features).
  64. RuntimeError: ``out_splits`` must be within the range of [1, out_features).
  65. """
  66. super().__init__()
  67. if (in_splits < 1) or (in_splits > in_features):
  68. raise RuntimeError('in splits must be in range [1, in_features].')
  69. if (out_splits < 1) or (out_splits > out_features):
  70. raise RuntimeError('out splits must be in range [1, out_features].')
  71. # global, not necessarily local
  72. self.in_features = in_features
  73. self.out_features = out_features
  74. self.use_bias = bias
  75. self.out_splits = out_splits
  76. self.in_splits = in_splits
  77. self.input_is_already_split = input_is_already_split
  78. self.combine_out_splits = combine_out_splits
  79. # Build partition-lists. These are CSR-style splits [0, part0, part1, ..., features]
  80. # For example, row_parts[p] gives the start of partition p and row_parts[p+1]
  81. # is the exclusive end.
  82. self.in_parts = partition(num_items=in_features, num_parts=in_splits)
  83. self.out_parts = partition(num_items=out_features, num_parts=out_splits)
  84. assert len(self.out_parts) == out_splits + 1
  85. assert len(self.in_parts) == in_splits + 1
  86. assert self.out_parts[0] == 0
  87. assert self.out_parts[out_splits] == out_features
  88. assert self.in_parts[in_splits] == in_features
  89. self.linears = torch.nn.ModuleList()
  90. for out_id in range(out_splits):
  91. self.linears.append(torch.nn.ModuleList())
  92. local_out_dim = self.out_parts[out_id + 1] - self.out_parts[out_id]
  93. for in_id in range(in_splits):
  94. #if input_size is split, we only need one bias
  95. local_bias = bias if in_id == (in_splits - 1) else False
  96. local_in_dim = self.in_parts[in_id + 1] - self.in_parts[in_id]
  97. local = linear_cls(local_in_dim, local_out_dim, bias=local_bias, **kwargs)
  98. self.linears[out_id].append(local)
  99. # Optionally initialize with a known tensor
  100. if init_linear is not None:
  101. self.copy_params_from(init_linear)
  102. def forward(self, input_):
  103. if self.in_splits > 1 and not self.input_is_already_split:
  104. input_parts = partition(input_.shape[-1], self.in_splits)
  105. split_sizes = [input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)]
  106. inputs = self._split_global_input(input_, split_sizes)
  107. elif self.in_splits > 1:
  108. inputs = input_
  109. assert len(
  110. inputs) == self.in_splits, f"Col splits {self.in_splits} does not match input splits {len(inputs)}"
  111. else:
  112. # no splits
  113. inputs = [input_]
  114. outputs = [None] * self.out_splits
  115. for out_id in range(self.out_splits):
  116. for in_id in range(self.in_splits):
  117. local_output = self.linears[out_id][in_id](inputs[in_id])
  118. outputs[out_id] = self._reduce_local_output(in_id=in_id,
  119. out_id=out_id,
  120. current_out=outputs[out_id],
  121. new_out=local_output)
  122. if self.combine_out_splits:
  123. return self._combine_output_splits(outputs)
  124. return outputs
  125. def _split_global_input(self, input, split_sizes):
  126. """Partition an input tensor along the last dimension, aligned with given splits.
  127. Subclasses should override this method to account for new input types.
  128. Args:
  129. input (List[Tensor]): The tensor to partition along the last dimension.
  130. split_sizes (List[int]): The size of each partition.
  131. Returns:
  132. List[Any]: A list of the chunks of ``input``.
  133. """
  134. return split_tensor_along_last_dim(input, split_sizes)
  135. def _reduce_local_output(self, in_id, out_id, current_out, new_out):
  136. """Reduce (sum) a new local result into the existing local results.
  137. Subclasses should override this method.
  138. For a given ``out_id``, this method is called ``in_id-1`` times. The first input
  139. split is a simple assignment.
  140. Args:
  141. in_id (int): The input split that produced ``new_out``.
  142. out_id (int): The output split that produced ``new_out``.
  143. current_out (Any): The reduced form of all previous ``out_id`` results.
  144. new_out (Any): The local result from forward (``in_id``, ``out_id``)e
  145. Returns:
  146. Any: The combined result of ``current_out`` and ``new_out``.
  147. """
  148. if current_out is None:
  149. #this clone is necessary to preserve auto grad
  150. #there is some issue with inplace update for outputs that are views
  151. return new_out.clone()
  152. else:
  153. return current_out + new_out
  154. def _combine_output_splits(self, outputs):
  155. """Join the splits of the output into a single result.
  156. Args:
  157. outputs (List[Any]): The reduced outputs for each output split.
  158. Returns:
  159. Any: The combined outputs.
  160. """
  161. assert len(outputs) == self.out_splits
  162. return torch.cat(outputs, dim=-1)
  163. @torch.no_grad()
  164. def copy_params_from(self, other):
  165. """Copy the weight and bias data from ``other``.
  166. This is especially useful for reproducible initialization and testing.
  167. Equivalent to:
  168. .. code-block:: python
  169. with torch.no_grad():
  170. self.weight.copy_(other.weight)
  171. if self.bias is not None:
  172. self.bias.copy_(other.bias)
  173. .. note::
  174. If ZeRO-3 is enabled, this is a collective operation and the
  175. updated parameters of data-parallel rank 0 will be visible on all
  176. ranks. See :class:`deepspeed.zero.GatheredParameters` for more
  177. information.
  178. Args:
  179. other (``torch.nn.Linear``): the linear layer to copy from.
  180. """
  181. assert hasattr(other, 'weight')
  182. assert other.weight.size() == (self.out_features, self.in_features)
  183. if self.use_bias:
  184. assert hasattr(other, 'bias')
  185. assert other.bias is not None
  186. assert other.bias.size() == (self.out_features, )
  187. else:
  188. assert other.bias is None
  189. for row in range(self.out_splits):
  190. rstart = self.out_parts[row]
  191. rstop = self.out_parts[row + 1]
  192. for col in range(self.in_splits):
  193. cstart = self.in_parts[col]
  194. cstop = self.in_parts[col + 1]
  195. local = self.linears[row][col]
  196. global_weight = other.weight[rstart:rstop, cstart:cstop]
  197. with deepspeed.zero.GatheredParameters(local.weight, modifier_rank=0):
  198. local.weight.copy_(global_weight)
  199. if local.bias is not None:
  200. with deepspeed.zero.GatheredParameters(local.bias, modifier_rank=0):
  201. local.bias.data.copy_(other.bias[rstart:rstop].data)
  202. class TiledLinearReturnBias(TiledLinear):
  203. """Wrapper for a Linear class that returns its own bias parameter, such as
  204. used by Megatron-LM.
  205. """
  206. def _reduce_local_output(self, in_id, out_id, current_out, new_out):
  207. """Reduces output tensors, but not the returned bias. """
  208. if current_out is not None:
  209. old_tensor, old_bias = current_out
  210. else:
  211. old_tensor, old_bias = None, None
  212. assert isinstance(new_out, tuple)
  213. assert len(new_out) == 2
  214. tensor, bias = new_out
  215. assert tensor is not None
  216. tensor = super()._reduce_local_output(in_id=in_id, out_id=out_id, current_out=old_tensor, new_out=tensor)
  217. if bias is None:
  218. bias = old_bias
  219. return tensor, bias
  220. def _combine_output_splits(self, outputs):
  221. # stack output tensors
  222. tensors = [o[0] for o in outputs]
  223. tensor = super()._combine_output_splits(tensors)
  224. # stack biases if applicable
  225. biases = [o[1] for o in outputs if o[1] is not None]
  226. if len(biases) > 0:
  227. bias = super()._combine_output_splits(biases)
  228. else:
  229. bias = None
  230. return tensor, bias