123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- '''Copyright The Microsoft DeepSpeed Team'''
- """batched collective operations for overhead amortization and better
- bandwidth utilization"""
- import math
- from typing import List
- import torch
- from torch import Tensor
- from deepspeed import comm as dist
- # NOTE: Use torch.distributed's ProcessGroup class until we have our own.
- from torch.distributed import ProcessGroup
- import torch.nn.functional
- from deepspeed.utils import instrument_w_nvtx
- def _torch_reduce_scatter_fn(input_tensor: Tensor,
- output_tensor: Tensor,
- group=None,
- async_op=False,
- prof=False):
- return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor,
- input_tensor,
- group=group,
- async_op=async_op)
- @instrument_w_nvtx
- @torch.no_grad()
- def reduce_scatter_coalesced(
- tensors: List[Tensor],
- group: ProcessGroup = None,
- ) -> List[Tensor]:
- """simultaneously reduce-scatter a list of tensors - this can be done more
- efficiently than individual reduce scatter calls
- TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL
- """
- this_rank = dist.get_rank(group)
- world_sz = dist.get_world_size(group)
- partition_lst_for_each_tensor = [None] * len(tensors)
- for tensor_idx, tensor in enumerate(tensors):
- flattened_tensor = tensor.view(-1)
- chunk_sz = math.ceil(tensor.numel() / world_sz)
- partition_lst_for_each_tensor[tensor_idx] = [
- flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz]
- for rank in range(0,
- world_sz)
- ]
- padded_partition_sz_for_each_tensor = tuple(
- math.ceil(t.numel() / world_sz) for t in tensors)
- if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
- # if there's only one tensor being reduced and we don't need to pad
- # we have an opportunity to avoid a memory allocation
- tensor_partition_flat_buffer = tensors[0].view(-1)
- else:
- # interleave tensor partitions such that the correct reduced partitions of each tensor
- # end up at each rank
- tensor_partitions_lst_with_padding = []
- for rank in range(world_sz):
- for tensor_idx in range(len(tensors)):
- # add tensor content
- tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
- tensor_partitions_lst_with_padding.append(tensor_chunk)
- # add padding if necessary
- padding_sz = padded_partition_sz_for_each_tensor[
- tensor_idx] - tensor_chunk.numel()
- if padding_sz > 0:
- tensor_partitions_lst_with_padding.append(
- torch.empty(padding_sz,
- dtype=tensor_chunk.dtype,
- device=tensor_chunk.device))
- tensor_partition_flat_buffer = instrument_w_nvtx(
- torch.cat)(tensor_partitions_lst_with_padding)
- tensor_partition_flat_buffer.div_(world_sz) # pre-divide
- tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(
- tensor_partition_flat_buffer,
- world_sz)
- # batched reduce-scatter call
- _torch_reduce_scatter_fn(tensor_partition_flat_buffer,
- tensor_partition_buffer_for_each_rank[this_rank],
- group=group)
- # reverse procedure of the interleaving done previously, done on the
- # result of the batched reduce-scatter
- output_lst: List[Tensor] = [None] * len(tensors)
- offset = 0
- for tensor_idx in range(len(tensors)):
- output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
- 0,
- offset,
- partition_lst_for_each_tensor[tensor_idx][this_rank].numel())
- offset += padded_partition_sz_for_each_tensor[tensor_idx]
- return output_lst
|