coalesced_collectives.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. batched collective operations for overhead amortization and better
  6. bandwidth utilization
  7. """
  8. import math
  9. from typing import List
  10. import torch
  11. from torch import Tensor
  12. from deepspeed import comm as dist
  13. from deepspeed.comm import ProcessGroup, all_to_all_single
  14. from deepspeed.accelerator import get_accelerator
  15. from deepspeed.utils import instrument_w_nvtx
  16. from deepspeed.ops import op_builder
  17. from deepspeed.utils import logger
  18. def _torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group=None, async_op=False, prof=False):
  19. return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor, input_tensor, group=group, async_op=False)
  20. quantizer_module = None
  21. @instrument_w_nvtx
  22. @torch.no_grad()
  23. def all_to_all_quant_reduce(tensors: List[Tensor], groups: {}) -> List[Tensor]:
  24. global quantizer_module
  25. if quantizer_module is None:
  26. quantizer_module = op_builder.QuantizerBuilder().load()
  27. local_world_size = get_accelerator().device_count()
  28. global_world_size = dist.get_world_size()
  29. num_nodes = global_world_size // local_world_size
  30. this_rank = dist.get_rank()
  31. intra_idx = int(this_rank / local_world_size)
  32. inter_idx = this_rank % local_world_size
  33. output_lst: List[Tensor] = [None] * len(tensors)
  34. for idx, tensor in enumerate(tensors):
  35. if tensor.dim() == 1:
  36. output_lst[idx] = reduce_scatter_coalesced([tensor])[0]
  37. elif tensor.numel() % (2 * global_world_size) != 0:
  38. # Due to the constraint of 2-stage all-to-all, the input tensor must be divisible by 2 * global_world_size
  39. # Otherwise, all-to-all cannot be performed because of shape mismatch.
  40. # See more at https://github.com/microsoft/DeepSpeed/pull/5056
  41. logger.warning(
  42. f"qgZ falls back to reduce_scatter because tensor size = {tensor.numel()} is not divisible by (2 * global_world_size) = {2 * global_world_size}. Please consider allocating a new world to enable qgZ"
  43. )
  44. output_lst[idx] = reduce_scatter_coalesced([tensor])[0]
  45. else:
  46. intra_quant_group = max(tensor.shape[0], tensor.shape[1], global_world_size)
  47. inter_quant_group = intra_quant_group // local_world_size
  48. intra_quant_int4, intra_q_scales = quantizer_module.swizzle_quant(tensor, intra_quant_group, 4,
  49. quantizer_module.Symmetric, 1, num_nodes,
  50. local_world_size)
  51. local_output = torch.empty_like(intra_quant_int4)
  52. scale_output = torch.empty_like(intra_q_scales)
  53. all_to_all_single(local_output, intra_quant_int4, group=groups[f'local_{intra_idx}'])
  54. all_to_all_single(scale_output, intra_q_scales, group=groups[f'local_{intra_idx}'])
  55. global_input_tensor, global_scales = quantizer_module.quantized_reduction(
  56. local_output, scale_output, intra_quant_group, inter_quant_group, 4, quantizer_module.Symmetric,
  57. local_world_size)
  58. global_output = torch.empty_like(global_input_tensor)
  59. global_scale_output = torch.empty_like(global_scales)
  60. all_to_all_single(global_output, global_input_tensor, group=groups[f'global_{inter_idx}'])
  61. all_to_all_single(global_scale_output, global_scales, group=groups[f'global_{inter_idx}'])
  62. final_output = quantizer_module.dequantize(global_output, global_scale_output, global_scale_output.numel(),
  63. 4, quantizer_module.Symmetric)
  64. assert final_output.numel(
  65. ) % num_nodes == 0, f"final_output.numel()={final_output.numel()} is not divisible by num_nodes={num_nodes}"
  66. output_lst[idx] = (sum(list(final_output.chunk(num_nodes))) / num_nodes).view(-1)
  67. return output_lst
  68. @instrument_w_nvtx
  69. @torch.no_grad()
  70. def reduce_scatter_coalesced(
  71. tensors: List[Tensor],
  72. group: ProcessGroup = None,
  73. ) -> List[Tensor]:
  74. """simultaneously reduce-scatter a list of tensors - this can be done more
  75. efficiently than individual reduce scatter calls
  76. TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL
  77. """
  78. this_rank = dist.get_rank(group)
  79. world_sz = dist.get_world_size(group)
  80. partition_lst_for_each_tensor = [None] * len(tensors)
  81. for tensor_idx, tensor in enumerate(tensors):
  82. flattened_tensor = tensor.view(-1)
  83. chunk_sz = math.ceil(tensor.numel() / world_sz)
  84. partition_lst_for_each_tensor[tensor_idx] = [
  85. flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz] for rank in range(0, world_sz)
  86. ]
  87. padded_partition_sz_for_each_tensor = tuple(math.ceil(t.numel() / world_sz) for t in tensors)
  88. if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
  89. # if there's only one tensor being reduced and we don't need to pad
  90. # we have an opportunity to avoid a memory allocation
  91. tensor_partition_flat_buffer = tensors[0].view(-1)
  92. else:
  93. # interleave tensor partitions such that the correct reduced partitions of each tensor
  94. # end up at each rank
  95. tensor_partitions_lst_with_padding = []
  96. for rank in range(world_sz):
  97. for tensor_idx in range(len(tensors)):
  98. # add tensor content
  99. tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
  100. tensor_partitions_lst_with_padding.append(tensor_chunk)
  101. # add padding if necessary
  102. padding_sz = padded_partition_sz_for_each_tensor[tensor_idx] - tensor_chunk.numel()
  103. if padding_sz > 0:
  104. tensor_partitions_lst_with_padding.append(
  105. torch.empty(padding_sz, dtype=tensor_chunk.dtype, device=tensor_chunk.device))
  106. tensor_partition_flat_buffer = instrument_w_nvtx(torch.cat)(tensor_partitions_lst_with_padding)
  107. tensor_partition_flat_buffer.div_(world_sz) # pre-divide
  108. tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(tensor_partition_flat_buffer, world_sz)
  109. # batched reduce-scatter call
  110. _torch_reduce_scatter_fn(tensor_partition_flat_buffer,
  111. tensor_partition_buffer_for_each_rank[this_rank],
  112. group=group)
  113. # reverse procedure of the interleaving done previously, done on the
  114. # result of the batched reduce-scatter
  115. output_lst: List[Tensor] = [None] * len(tensors)
  116. offset = 0
  117. for tensor_idx in range(len(tensors)):
  118. output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
  119. 0, offset, partition_lst_for_each_tensor[tensor_idx][this_rank].numel())
  120. offset += padded_partition_sz_for_each_tensor[tensor_idx]
  121. return output_lst