coalesced_collectives.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. """batched collective operations for overhead amortization and better
  3. bandwidth utilization"""
  4. import math
  5. from typing import List
  6. import torch
  7. from torch import Tensor
  8. from deepspeed import comm as dist
  9. # NOTE: Use torch.distributed's ProcessGroup class until we have our own.
  10. from torch.distributed import ProcessGroup
  11. import torch.nn.functional
  12. from deepspeed.utils import instrument_w_nvtx
  13. def _torch_reduce_scatter_fn(input_tensor: Tensor,
  14. output_tensor: Tensor,
  15. group=None,
  16. async_op=False,
  17. prof=False):
  18. return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor,
  19. input_tensor,
  20. group=group,
  21. async_op=async_op)
  22. @instrument_w_nvtx
  23. @torch.no_grad()
  24. def reduce_scatter_coalesced(
  25. tensors: List[Tensor],
  26. group: ProcessGroup = None,
  27. ) -> List[Tensor]:
  28. """simultaneously reduce-scatter a list of tensors - this can be done more
  29. efficiently than individual reduce scatter calls
  30. TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL
  31. """
  32. this_rank = dist.get_rank(group)
  33. world_sz = dist.get_world_size(group)
  34. partition_lst_for_each_tensor = [None] * len(tensors)
  35. for tensor_idx, tensor in enumerate(tensors):
  36. flattened_tensor = tensor.view(-1)
  37. chunk_sz = math.ceil(tensor.numel() / world_sz)
  38. partition_lst_for_each_tensor[tensor_idx] = [
  39. flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz]
  40. for rank in range(0,
  41. world_sz)
  42. ]
  43. padded_partition_sz_for_each_tensor = tuple(
  44. math.ceil(t.numel() / world_sz) for t in tensors)
  45. if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
  46. # if there's only one tensor being reduced and we don't need to pad
  47. # we have an opportunity to avoid a memory allocation
  48. tensor_partition_flat_buffer = tensors[0].view(-1)
  49. else:
  50. # interleave tensor partitions such that the correct reduced partitions of each tensor
  51. # end up at each rank
  52. tensor_partitions_lst_with_padding = []
  53. for rank in range(world_sz):
  54. for tensor_idx in range(len(tensors)):
  55. # add tensor content
  56. tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
  57. tensor_partitions_lst_with_padding.append(tensor_chunk)
  58. # add padding if necessary
  59. padding_sz = padded_partition_sz_for_each_tensor[
  60. tensor_idx] - tensor_chunk.numel()
  61. if padding_sz > 0:
  62. tensor_partitions_lst_with_padding.append(
  63. torch.empty(padding_sz,
  64. dtype=tensor_chunk.dtype,
  65. device=tensor_chunk.device))
  66. tensor_partition_flat_buffer = instrument_w_nvtx(
  67. torch.cat)(tensor_partitions_lst_with_padding)
  68. tensor_partition_flat_buffer.div_(world_sz) # pre-divide
  69. tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(
  70. tensor_partition_flat_buffer,
  71. world_sz)
  72. # batched reduce-scatter call
  73. _torch_reduce_scatter_fn(tensor_partition_flat_buffer,
  74. tensor_partition_buffer_for_each_rank[this_rank],
  75. group=group)
  76. # reverse procedure of the interleaving done previously, done on the
  77. # result of the batched reduce-scatter
  78. output_lst: List[Tensor] = [None] * len(tensors)
  79. offset = 0
  80. for tensor_idx in range(len(tensors)):
  81. output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
  82. 0,
  83. offset,
  84. partition_lst_for_each_tensor[tensor_idx][this_rank].numel())
  85. offset += padded_partition_sz_for_each_tensor[tensor_idx]
  86. return output_lst