123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- import torch
- import deepspeed
- from deepspeed.runtime.utils import partition_uniform as partition
- def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=False):
- """Split a tensor along its last dimension. Adapted from Megatron-LM.
- Arguments:
- tensor: input tensor.
- partitions: list of partition sizes to supply to torch.split
- contiguous_split_chunks: If True, make each chunk contiguous
- in memory.
- """
- # Get the size and dimension.
- last_dim = tensor.dim() - 1
- # Split.
- tensor_list = torch.split(tensor, partitions, dim=last_dim)
- # Note: torch.split does not create contiguous tensors by default.
- if contiguous_split_chunks:
- return tuple(chunk.contiguous() for chunk in tensor_list)
- return tensor_list
- class TiledLinear(torch.nn.Module):
- def __init__(self,
- in_features,
- out_features,
- bias=True,
- in_splits=1,
- out_splits=1,
- input_is_already_split=False,
- combine_out_splits=True,
- linear_cls=torch.nn.Linear,
- init_linear=None,
- **kwargs):
- """A replacement for ``torch.nn.Linear`` that works with ZeRO-3 to reduce
- memory requirements via tiling.
- TiledLinear breaks the input and output dimensions of a linear layer
- into tiles that are processed in sequence. This class enables huge
- linear layers when combined with ZeRO-3 because inactive tiles can be
- partitioned and offloaded.
- .. note::
- We recommend using as few tiles as necessary. Tiling
- significantly reduces memory usage, but can reduce throughput
- for inexpensive layers. This due to the smaller kernels having
- less parallelism and lower arithmetic intensity, while
- introducing more frequent synchronization and communication.
- Args:
- in_features (int): See ``torch.nn.Linear``
- out_features (int): See ``torch.nn.Linear``
- bias (bool, optional): See ``torch.nn.Linear``
- in_splits (int, optional): The number of tiles along the input dimension. Defaults to 1.
- out_splits (int, optional): The number of tiles along the output dimension. Defaults to 1.
- input_is_already_split (bool, optional): If set to ``True``, assume that the ``input_`` in
- to ``forward()`` is already split into ``in_splits`` chunks. Defaults to ``False``.
- combine_out_splits (bool, optional): If set to ``False``, do not combine the ``out_splits`` outputs
- into a single tensor. Defaults to ``True``.
- linear_cls (class, optional): The underlying class to build individual tiles.
- Defaults to ``torch.nn.Linear``.
- init_linear (``torch.nn.Linear``, optional): If set, copy the parameters of
- ``init_linear``. Useful for debugging. Defaults to ``None``.
- kwargs (dict, optional): additional keyword arguments to provide to ``linear_cls()``.
- Raises:
- RuntimeError: ``in_splits`` must be within the range [1, in_features).
- RuntimeError: ``out_splits`` must be within the range of [1, out_features).
- """
- super().__init__()
- if (in_splits < 1) or (in_splits > in_features):
- raise RuntimeError('in splits must be in range [1, in_features].')
- if (out_splits < 1) or (out_splits > out_features):
- raise RuntimeError('out splits must be in range [1, out_features].')
- # global, not necessarily local
- self.in_features = in_features
- self.out_features = out_features
- self.use_bias = bias
- self.out_splits = out_splits
- self.in_splits = in_splits
- self.input_is_already_split = input_is_already_split
- self.combine_out_splits = combine_out_splits
- # Build partition-lists. These are CSR-style splits [0, part0, part1, ..., features]
- # For example, row_parts[p] gives the start of partition p and row_parts[p+1]
- # is the exclusive end.
- self.in_parts = partition(num_items=in_features, num_parts=in_splits)
- self.out_parts = partition(num_items=out_features, num_parts=out_splits)
- assert len(self.out_parts) == out_splits + 1
- assert len(self.in_parts) == in_splits + 1
- assert self.out_parts[0] == 0
- assert self.out_parts[out_splits] == out_features
- assert self.in_parts[in_splits] == in_features
- self.linears = torch.nn.ModuleList()
- for out_id in range(out_splits):
- self.linears.append(torch.nn.ModuleList())
- local_out_dim = self.out_parts[out_id + 1] - self.out_parts[out_id]
- for in_id in range(in_splits):
- #if input_size is split, we only need one bias
- local_bias = bias if in_id == (in_splits - 1) else False
- local_in_dim = self.in_parts[in_id + 1] - self.in_parts[in_id]
- local = linear_cls(local_in_dim,
- local_out_dim,
- bias=local_bias,
- **kwargs)
- self.linears[out_id].append(local)
- # Optionally initialize with a known tensor
- if init_linear is not None:
- self.copy_params_from(init_linear)
- def forward(self, input_):
- if self.in_splits > 1 and not self.input_is_already_split:
- input_parts = partition(input_.shape[-1], self.in_splits)
- split_sizes = [
- input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)
- ]
- inputs = self._split_global_input(input_, split_sizes)
- elif self.in_splits > 1:
- inputs = input_
- assert len(inputs) == self.in_splits, f"Col splits {self.in_splits} does not match input splits {len(inputs)}"
- else:
- # no splits
- inputs = [input_]
- outputs = [None] * self.out_splits
- for out_id in range(self.out_splits):
- for in_id in range(self.in_splits):
- local_output = self.linears[out_id][in_id](inputs[in_id])
- outputs[out_id] = self._reduce_local_output(in_id=in_id,
- out_id=out_id,
- current_out=outputs[out_id],
- new_out=local_output)
- if self.combine_out_splits:
- return self._combine_output_splits(outputs)
- return outputs
- def _split_global_input(self, input, split_sizes):
- """Partition an input tensor along the last dimension, aligned with given splits.
- Subclasses should override this method to account for new input types.
- Args:
- input (List[Tensor]): The tensor to partition along the last dimension.
- split_sizes (List[int]): The size of each partition.
- Returns:
- List[Any]: A list of the chunks of ``input``.
- """
- return split_tensor_along_last_dim(input, split_sizes)
- def _reduce_local_output(self, in_id, out_id, current_out, new_out):
- """Reduce (sum) a new local result into the existing local results.
- Subclasses should override this method.
- For a given ``out_id``, this method is called ``in_id-1`` times. The first input
- split is a simple assignment.
- Args:
- in_id (int): The input split that produced ``new_out``.
- out_id (int): The output split that produced ``new_out``.
- current_out (Any): The reduced form of all previous ``out_id`` results.
- new_out (Any): The local result from forward (``in_id``, ``out_id``)e
- Returns:
- Any: The combined result of ``current_out`` and ``new_out``.
- """
- if current_out is None:
- #this clone is necessary to preserve auto grad
- #there is some issue with inplace update for outputs that are views
- return new_out.clone()
- else:
- return current_out + new_out
- def _combine_output_splits(self, outputs):
- """Join the splits of the output into a single result.
- Args:
- outputs (List[Any]): The reduced outputs for each output split.
- Returns:
- Any: The combined outputs.
- """
- assert len(outputs) == self.out_splits
- return torch.cat(outputs, dim=-1)
- @torch.no_grad()
- def copy_params_from(self, other):
- """Copy the weight and bias data from ``other``.
- This is especially useful for reproducible initialization and testing.
- Equivalent to:
- .. code-block:: python
- with torch.no_grad():
- self.weight.copy_(other.weight)
- if self.bias is not None:
- self.bias.copy_(other.bias)
- .. note::
- If ZeRO-3 is enabled, this is a collective operation and the
- updated parameters of data-parallel rank 0 will be visible on all
- ranks. See :class:`deepspeed.zero.GatheredParameters` for more
- information.
- Args:
- other (``torch.nn.Linear``): the linear layer to copy from.
- """
- assert hasattr(other, 'weight')
- assert other.weight.size() == (self.out_features, self.in_features)
- if self.use_bias:
- assert hasattr(other, 'bias')
- assert other.bias is not None
- assert other.bias.size() == (self.out_features, )
- else:
- assert other.bias is None
- for row in range(self.out_splits):
- rstart = self.out_parts[row]
- rstop = self.out_parts[row + 1]
- for col in range(self.in_splits):
- cstart = self.in_parts[col]
- cstop = self.in_parts[col + 1]
- local = self.linears[row][col]
- global_weight = other.weight[rstart:rstop, cstart:cstop]
- with deepspeed.zero.GatheredParameters(local.weight, modifier_rank=0):
- local.weight.copy_(global_weight)
- if local.bias is not None:
- with deepspeed.zero.GatheredParameters(local.bias, modifier_rank=0):
- local.bias.data.copy_(other.bias[rstart:rstop].data)
- class TiledLinearReturnBias(TiledLinear):
- """Wrapper for a Linear class that returns its own bias parameter, such as
- used by Megatron-LM.
- """
- def _reduce_local_output(self, in_id, out_id, current_out, new_out):
- """Reduces output tensors, but not the returned bias. """
- if current_out is not None:
- old_tensor, old_bias = current_out
- else:
- old_tensor, old_bias = None, None
- assert isinstance(new_out, tuple)
- assert len(new_out) == 2
- tensor, bias = new_out
- assert tensor is not None
- tensor = super()._reduce_local_output(in_id=in_id,
- out_id=out_id,
- current_out=old_tensor,
- new_out=tensor)
- if bias is None:
- bias = old_bias
- return tensor, bias
- def _combine_output_splits(self, outputs):
- # stack output tensors
- tensors = [o[0] for o in outputs]
- tensor = super()._combine_output_splits(tensors)
- # stack biases if applicable
- biases = [o[1] for o in outputs if o[1] is not None]
- if len(biases) > 0:
- bias = super()._combine_output_splits(biases)
- else:
- bias = None
- return tensor, bias
|