coalesced_collectives.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. batched collective operations for overhead amortization and better
  6. bandwidth utilization
  7. """
  8. import math
  9. from typing import List
  10. import torch
  11. from torch import Tensor
  12. from deepspeed import comm as dist
  13. # NOTE: Use torch.distributed's ProcessGroup class until we have our own.
  14. from torch.distributed import ProcessGroup, all_to_all_single
  15. from deepspeed.accelerator import get_accelerator
  16. from deepspeed.utils import instrument_w_nvtx
  17. from deepspeed.ops import op_builder
  18. def _torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group=None, async_op=False, prof=False):
  19. return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor, input_tensor, group=group, async_op=False)
  20. quantizer_module = None
  21. @instrument_w_nvtx
  22. @torch.no_grad()
  23. def all_to_all_quant_reduce(tensors: List[Tensor], groups: {}) -> List[Tensor]:
  24. global quantizer_module
  25. if quantizer_module is None:
  26. quantizer_module = op_builder.QuantizerBuilder().load()
  27. local_world_size = get_accelerator().device_count()
  28. global_world_size = dist.get_world_size()
  29. num_nodes = global_world_size // local_world_size
  30. this_rank = dist.get_rank()
  31. intra_idx = int(this_rank / local_world_size)
  32. inter_idx = this_rank % local_world_size
  33. output_lst: List[Tensor] = [None] * len(tensors)
  34. for idx, tensor in enumerate(tensors):
  35. if tensor.dim() == 1:
  36. intra_quant_group = global_world_size
  37. output_lst[idx] = reduce_scatter_coalesced([tensor])[0]
  38. continue
  39. else:
  40. intra_quant_group = max(tensor.shape[0], tensor.shape[1], global_world_size)
  41. inter_quant_group = intra_quant_group // local_world_size
  42. intra_quant_int4, intra_q_scales = quantizer_module.swizzle_quant(tensor, intra_quant_group, 4,
  43. quantizer_module.Symmetric, 1, num_nodes,
  44. local_world_size)
  45. local_output = torch.empty_like(intra_quant_int4)
  46. scale_output = torch.empty_like(intra_q_scales)
  47. all_to_all_single(local_output, intra_quant_int4, group=groups[f'local_{intra_idx}'])
  48. all_to_all_single(scale_output, intra_q_scales, group=groups[f'local_{intra_idx}'])
  49. global_input_tensor, global_scales = quantizer_module.quantized_reduction(
  50. local_output, scale_output, intra_quant_group, inter_quant_group, 4, quantizer_module.Symmetric,
  51. local_world_size)
  52. global_output = torch.empty_like(global_input_tensor)
  53. global_scale_output = torch.empty_like(global_scales)
  54. all_to_all_single(global_output, global_input_tensor, group=groups[f'global_{inter_idx}'])
  55. all_to_all_single(global_scale_output, global_scales, group=groups[f'global_{inter_idx}'])
  56. final_output = quantizer_module.dequantize(global_output, global_scale_output, global_scale_output.numel(),
  57. 4, quantizer_module.Symmetric)
  58. output_lst[idx] = (sum(list(final_output.chunk(num_nodes))) / num_nodes).view(-1)
  59. return output_lst
  60. @instrument_w_nvtx
  61. @torch.no_grad()
  62. def reduce_scatter_coalesced(
  63. tensors: List[Tensor],
  64. group: ProcessGroup = None,
  65. ) -> List[Tensor]:
  66. """simultaneously reduce-scatter a list of tensors - this can be done more
  67. efficiently than individual reduce scatter calls
  68. TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL
  69. """
  70. this_rank = dist.get_rank(group)
  71. world_sz = dist.get_world_size(group)
  72. partition_lst_for_each_tensor = [None] * len(tensors)
  73. for tensor_idx, tensor in enumerate(tensors):
  74. flattened_tensor = tensor.view(-1)
  75. chunk_sz = math.ceil(tensor.numel() / world_sz)
  76. partition_lst_for_each_tensor[tensor_idx] = [
  77. flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz] for rank in range(0, world_sz)
  78. ]
  79. padded_partition_sz_for_each_tensor = tuple(math.ceil(t.numel() / world_sz) for t in tensors)
  80. if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
  81. # if there's only one tensor being reduced and we don't need to pad
  82. # we have an opportunity to avoid a memory allocation
  83. tensor_partition_flat_buffer = tensors[0].view(-1)
  84. else:
  85. # interleave tensor partitions such that the correct reduced partitions of each tensor
  86. # end up at each rank
  87. tensor_partitions_lst_with_padding = []
  88. for rank in range(world_sz):
  89. for tensor_idx in range(len(tensors)):
  90. # add tensor content
  91. tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
  92. tensor_partitions_lst_with_padding.append(tensor_chunk)
  93. # add padding if necessary
  94. padding_sz = padded_partition_sz_for_each_tensor[tensor_idx] - tensor_chunk.numel()
  95. if padding_sz > 0:
  96. tensor_partitions_lst_with_padding.append(
  97. torch.empty(padding_sz, dtype=tensor_chunk.dtype, device=tensor_chunk.device))
  98. tensor_partition_flat_buffer = instrument_w_nvtx(torch.cat)(tensor_partitions_lst_with_padding)
  99. tensor_partition_flat_buffer.div_(world_sz) # pre-divide
  100. tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(tensor_partition_flat_buffer, world_sz)
  101. # batched reduce-scatter call
  102. _torch_reduce_scatter_fn(tensor_partition_flat_buffer,
  103. tensor_partition_buffer_for_each_rank[this_rank],
  104. group=group)
  105. # reverse procedure of the interleaving done previously, done on the
  106. # result of the batched reduce-scatter
  107. output_lst: List[Tensor] = [None] * len(tensors)
  108. offset = 0
  109. for tensor_idx in range(len(tensors)):
  110. output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
  111. 0, offset, partition_lst_for_each_tensor[tensor_idx][this_rank].numel())
  112. offset += padded_partition_sz_for_each_tensor[tensor_idx]
  113. return output_lst