123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- """
- batched collective operations for overhead amortization and better
- bandwidth utilization
- """
- import math
- from typing import List
- import torch
- from torch import Tensor
- from deepspeed import comm as dist
- # NOTE: Use torch.distributed's ProcessGroup class until we have our own.
- from torch.distributed import ProcessGroup, all_to_all_single
- from deepspeed.accelerator import get_accelerator
- from deepspeed.utils import instrument_w_nvtx
- from deepspeed.ops import op_builder
- def _torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group=None, async_op=False, prof=False):
- return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor, input_tensor, group=group, async_op=False)
- quantizer_module = None
- @instrument_w_nvtx
- @torch.no_grad()
- def all_to_all_quant_reduce(tensors: List[Tensor], groups: {}) -> List[Tensor]:
- global quantizer_module
- if quantizer_module is None:
- quantizer_module = op_builder.QuantizerBuilder().load()
- local_world_size = get_accelerator().device_count()
- global_world_size = dist.get_world_size()
- num_nodes = global_world_size // local_world_size
- this_rank = dist.get_rank()
- intra_idx = int(this_rank / local_world_size)
- inter_idx = this_rank % local_world_size
- output_lst: List[Tensor] = [None] * len(tensors)
- for idx, tensor in enumerate(tensors):
- if tensor.dim() == 1:
- intra_quant_group = global_world_size
- output_lst[idx] = reduce_scatter_coalesced([tensor])[0]
- continue
- else:
- intra_quant_group = max(tensor.shape[0], tensor.shape[1], global_world_size)
- inter_quant_group = intra_quant_group // local_world_size
- intra_quant_int4, intra_q_scales = quantizer_module.swizzle_quant(tensor, intra_quant_group, 4,
- quantizer_module.Symmetric, 1, num_nodes,
- local_world_size)
- local_output = torch.empty_like(intra_quant_int4)
- scale_output = torch.empty_like(intra_q_scales)
- all_to_all_single(local_output, intra_quant_int4, group=groups[f'local_{intra_idx}'])
- all_to_all_single(scale_output, intra_q_scales, group=groups[f'local_{intra_idx}'])
- global_input_tensor, global_scales = quantizer_module.quantized_reduction(
- local_output, scale_output, intra_quant_group, inter_quant_group, 4, quantizer_module.Symmetric,
- local_world_size)
- global_output = torch.empty_like(global_input_tensor)
- global_scale_output = torch.empty_like(global_scales)
- all_to_all_single(global_output, global_input_tensor, group=groups[f'global_{inter_idx}'])
- all_to_all_single(global_scale_output, global_scales, group=groups[f'global_{inter_idx}'])
- final_output = quantizer_module.dequantize(global_output, global_scale_output, global_scale_output.numel(),
- 4, quantizer_module.Symmetric)
- output_lst[idx] = (sum(list(final_output.chunk(num_nodes))) / num_nodes).view(-1)
- return output_lst
- @instrument_w_nvtx
- @torch.no_grad()
- def reduce_scatter_coalesced(
- tensors: List[Tensor],
- group: ProcessGroup = None,
- ) -> List[Tensor]:
- """simultaneously reduce-scatter a list of tensors - this can be done more
- efficiently than individual reduce scatter calls
- TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL
- """
- this_rank = dist.get_rank(group)
- world_sz = dist.get_world_size(group)
- partition_lst_for_each_tensor = [None] * len(tensors)
- for tensor_idx, tensor in enumerate(tensors):
- flattened_tensor = tensor.view(-1)
- chunk_sz = math.ceil(tensor.numel() / world_sz)
- partition_lst_for_each_tensor[tensor_idx] = [
- flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz] for rank in range(0, world_sz)
- ]
- padded_partition_sz_for_each_tensor = tuple(math.ceil(t.numel() / world_sz) for t in tensors)
- if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
- # if there's only one tensor being reduced and we don't need to pad
- # we have an opportunity to avoid a memory allocation
- tensor_partition_flat_buffer = tensors[0].view(-1)
- else:
- # interleave tensor partitions such that the correct reduced partitions of each tensor
- # end up at each rank
- tensor_partitions_lst_with_padding = []
- for rank in range(world_sz):
- for tensor_idx in range(len(tensors)):
- # add tensor content
- tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
- tensor_partitions_lst_with_padding.append(tensor_chunk)
- # add padding if necessary
- padding_sz = padded_partition_sz_for_each_tensor[tensor_idx] - tensor_chunk.numel()
- if padding_sz > 0:
- tensor_partitions_lst_with_padding.append(
- torch.empty(padding_sz, dtype=tensor_chunk.dtype, device=tensor_chunk.device))
- tensor_partition_flat_buffer = instrument_w_nvtx(torch.cat)(tensor_partitions_lst_with_padding)
- tensor_partition_flat_buffer.div_(world_sz) # pre-divide
- tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(tensor_partition_flat_buffer, world_sz)
- # batched reduce-scatter call
- _torch_reduce_scatter_fn(tensor_partition_flat_buffer,
- tensor_partition_buffer_for_each_rank[this_rank],
- group=group)
- # reverse procedure of the interleaving done previously, done on the
- # result of the batched reduce-scatter
- output_lst: List[Tensor] = [None] * len(tensors)
- offset = 0
- for tensor_idx in range(len(tensors)):
- output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
- 0, offset, partition_lst_for_each_tensor[tensor_idx][this_rank].numel())
- offset += padded_partition_sz_for_each_tensor[tensor_idx]
- return output_lst
|