tiling.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import torch
  2. import deepspeed
  3. from deepspeed.runtime.utils import partition_uniform as partition
  4. def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=False):
  5. """Split a tensor along its last dimension. Adapted from Megatron-LM.
  6. Arguments:
  7. tensor: input tensor.
  8. partitions: list of partition sizes to supply to torch.split
  9. contiguous_split_chunks: If True, make each chunk contiguous
  10. in memory.
  11. """
  12. # Get the size and dimension.
  13. last_dim = tensor.dim() - 1
  14. # Split.
  15. tensor_list = torch.split(tensor, partitions, dim=last_dim)
  16. # Note: torch.split does not create contiguous tensors by default.
  17. if contiguous_split_chunks:
  18. return tuple(chunk.contiguous() for chunk in tensor_list)
  19. return tensor_list
  20. class TiledLinear(torch.nn.Module):
  21. def __init__(self,
  22. in_features,
  23. out_features,
  24. bias=True,
  25. in_splits=1,
  26. out_splits=1,
  27. input_is_already_split=False,
  28. combine_out_splits=True,
  29. linear_cls=torch.nn.Linear,
  30. init_linear=None,
  31. **kwargs):
  32. """A replacement for ``torch.nn.Linear`` that works with ZeRO-3 to reduce
  33. memory requirements via tiling.
  34. TiledLinear breaks the input and output dimensions of a linear layer
  35. into tiles that are processed in sequence. This class enables huge
  36. linear layers when combined with ZeRO-3 because inactive tiles can be
  37. partitioned and offloaded.
  38. .. note::
  39. We recommend using as few tiles as necessary. Tiling
  40. significantly reduces memory usage, but can reduce throughput
  41. for inexpensive layers. This due to the smaller kernels having
  42. less parallelism and lower arithmetic intensity, while
  43. introducing more frequent synchronization and communication.
  44. Args:
  45. in_features (int): See ``torch.nn.Linear``
  46. out_features (int): See ``torch.nn.Linear``
  47. bias (bool, optional): See ``torch.nn.Linear``
  48. in_splits (int, optional): The number of tiles along the input dimension. Defaults to 1.
  49. out_splits (int, optional): The number of tiles along the output dimension. Defaults to 1.
  50. input_is_already_split (bool, optional): If set to ``True``, assume that the ``input_`` in
  51. to ``forward()`` is already split into ``in_splits`` chunks. Defaults to ``False``.
  52. combine_out_splits (bool, optional): If set to ``False``, do not combine the ``out_splits`` outputs
  53. into a single tensor. Defaults to ``True``.
  54. linear_cls (class, optional): The underlying class to build individual tiles.
  55. Defaults to ``torch.nn.Linear``.
  56. init_linear (``torch.nn.Linear``, optional): If set, copy the parameters of
  57. ``init_linear``. Useful for debugging. Defaults to ``None``.
  58. kwargs (dict, optional): additional keyword arguments to provide to ``linear_cls()``.
  59. Raises:
  60. RuntimeError: ``in_splits`` must be within the range [1, in_features).
  61. RuntimeError: ``out_splits`` must be within the range of [1, out_features).
  62. """
  63. super().__init__()
  64. if (in_splits < 1) or (in_splits > in_features):
  65. raise RuntimeError('in splits must be in range [1, in_features].')
  66. if (out_splits < 1) or (out_splits > out_features):
  67. raise RuntimeError('out splits must be in range [1, out_features].')
  68. # global, not necessarily local
  69. self.in_features = in_features
  70. self.out_features = out_features
  71. self.use_bias = bias
  72. self.out_splits = out_splits
  73. self.in_splits = in_splits
  74. self.input_is_already_split = input_is_already_split
  75. self.combine_out_splits = combine_out_splits
  76. # Build partition-lists. These are CSR-style splits [0, part0, part1, ..., features]
  77. # For example, row_parts[p] gives the start of partition p and row_parts[p+1]
  78. # is the exclusive end.
  79. self.in_parts = partition(num_items=in_features, num_parts=in_splits)
  80. self.out_parts = partition(num_items=out_features, num_parts=out_splits)
  81. assert len(self.out_parts) == out_splits + 1
  82. assert len(self.in_parts) == in_splits + 1
  83. assert self.out_parts[0] == 0
  84. assert self.out_parts[out_splits] == out_features
  85. assert self.in_parts[in_splits] == in_features
  86. self.linears = torch.nn.ModuleList()
  87. for out_id in range(out_splits):
  88. self.linears.append(torch.nn.ModuleList())
  89. local_out_dim = self.out_parts[out_id + 1] - self.out_parts[out_id]
  90. for in_id in range(in_splits):
  91. #if input_size is split, we only need one bias
  92. local_bias = bias if in_id == (in_splits - 1) else False
  93. local_in_dim = self.in_parts[in_id + 1] - self.in_parts[in_id]
  94. local = linear_cls(local_in_dim,
  95. local_out_dim,
  96. bias=local_bias,
  97. **kwargs)
  98. self.linears[out_id].append(local)
  99. # Optionally initialize with a known tensor
  100. if init_linear is not None:
  101. self.copy_params_from(init_linear)
  102. def forward(self, input_):
  103. if self.in_splits > 1 and not self.input_is_already_split:
  104. input_parts = partition(input_.shape[-1], self.in_splits)
  105. split_sizes = [
  106. input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)
  107. ]
  108. inputs = self._split_global_input(input_, split_sizes)
  109. elif self.in_splits > 1:
  110. inputs = input_
  111. assert len(inputs) == self.in_splits, f"Col splits {self.in_splits} does not match input splits {len(inputs)}"
  112. else:
  113. # no splits
  114. inputs = [input_]
  115. outputs = [None] * self.out_splits
  116. for out_id in range(self.out_splits):
  117. for in_id in range(self.in_splits):
  118. local_output = self.linears[out_id][in_id](inputs[in_id])
  119. outputs[out_id] = self._reduce_local_output(in_id=in_id,
  120. out_id=out_id,
  121. current_out=outputs[out_id],
  122. new_out=local_output)
  123. if self.combine_out_splits:
  124. return self._combine_output_splits(outputs)
  125. return outputs
  126. def _split_global_input(self, input, split_sizes):
  127. """Partition an input tensor along the last dimension, aligned with given splits.
  128. Subclasses should override this method to account for new input types.
  129. Args:
  130. input (List[Tensor]): The tensor to partition along the last dimension.
  131. split_sizes (List[int]): The size of each partition.
  132. Returns:
  133. List[Any]: A list of the chunks of ``input``.
  134. """
  135. return split_tensor_along_last_dim(input, split_sizes)
  136. def _reduce_local_output(self, in_id, out_id, current_out, new_out):
  137. """Reduce (sum) a new local result into the existing local results.
  138. Subclasses should override this method.
  139. For a given ``out_id``, this method is called ``in_id-1`` times. The first input
  140. split is a simple assignment.
  141. Args:
  142. in_id (int): The input split that produced ``new_out``.
  143. out_id (int): The output split that produced ``new_out``.
  144. current_out (Any): The reduced form of all previous ``out_id`` results.
  145. new_out (Any): The local result from forward (``in_id``, ``out_id``)e
  146. Returns:
  147. Any: The combined result of ``current_out`` and ``new_out``.
  148. """
  149. if current_out is None:
  150. #this clone is necessary to preserve auto grad
  151. #there is some issue with inplace update for outputs that are views
  152. return new_out.clone()
  153. else:
  154. return current_out + new_out
  155. def _combine_output_splits(self, outputs):
  156. """Join the splits of the output into a single result.
  157. Args:
  158. outputs (List[Any]): The reduced outputs for each output split.
  159. Returns:
  160. Any: The combined outputs.
  161. """
  162. assert len(outputs) == self.out_splits
  163. return torch.cat(outputs, dim=-1)
  164. @torch.no_grad()
  165. def copy_params_from(self, other):
  166. """Copy the weight and bias data from ``other``.
  167. This is especially useful for reproducible initialization and testing.
  168. Equivalent to:
  169. .. code-block:: python
  170. with torch.no_grad():
  171. self.weight.copy_(other.weight)
  172. if self.bias is not None:
  173. self.bias.copy_(other.bias)
  174. .. note::
  175. If ZeRO-3 is enabled, this is a collective operation and the
  176. updated parameters of data-parallel rank 0 will be visible on all
  177. ranks. See :class:`deepspeed.zero.GatheredParameters` for more
  178. information.
  179. Args:
  180. other (``torch.nn.Linear``): the linear layer to copy from.
  181. """
  182. assert hasattr(other, 'weight')
  183. assert other.weight.size() == (self.out_features, self.in_features)
  184. if self.use_bias:
  185. assert hasattr(other, 'bias')
  186. assert other.bias is not None
  187. assert other.bias.size() == (self.out_features, )
  188. else:
  189. assert other.bias is None
  190. for row in range(self.out_splits):
  191. rstart = self.out_parts[row]
  192. rstop = self.out_parts[row + 1]
  193. for col in range(self.in_splits):
  194. cstart = self.in_parts[col]
  195. cstop = self.in_parts[col + 1]
  196. local = self.linears[row][col]
  197. global_weight = other.weight[rstart:rstop, cstart:cstop]
  198. with deepspeed.zero.GatheredParameters(local.weight, modifier_rank=0):
  199. local.weight.copy_(global_weight)
  200. if local.bias is not None:
  201. with deepspeed.zero.GatheredParameters(local.bias, modifier_rank=0):
  202. local.bias.data.copy_(other.bias[rstart:rstop].data)
  203. class TiledLinearReturnBias(TiledLinear):
  204. """Wrapper for a Linear class that returns its own bias parameter, such as
  205. used by Megatron-LM.
  206. """
  207. def _reduce_local_output(self, in_id, out_id, current_out, new_out):
  208. """Reduces output tensors, but not the returned bias. """
  209. if current_out is not None:
  210. old_tensor, old_bias = current_out
  211. else:
  212. old_tensor, old_bias = None, None
  213. assert isinstance(new_out, tuple)
  214. assert len(new_out) == 2
  215. tensor, bias = new_out
  216. assert tensor is not None
  217. tensor = super()._reduce_local_output(in_id=in_id,
  218. out_id=out_id,
  219. current_out=old_tensor,
  220. new_out=tensor)
  221. if bias is None:
  222. bias = old_bias
  223. return tensor, bias
  224. def _combine_output_splits(self, outputs):
  225. # stack output tensors
  226. tensors = [o[0] for o in outputs]
  227. tensor = super()._combine_output_splits(tensors)
  228. # stack biases if applicable
  229. biases = [o[1] for o in outputs if o[1] is not None]
  230. if len(biases) > 0:
  231. bias = super()._combine_output_splits(biases)
  232. else:
  233. bias = None
  234. return tensor, bias