all_gather.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from benchmarks.communication.utils import *
  2. from benchmarks.communication.constants import *
  3. import time
  4. # Run all_gather and print metrics
  5. def timed_all_gather(input, output, args):
  6. if args.dist == 'torch':
  7. import torch.distributed as dist
  8. elif args.dist == 'deepspeed':
  9. import deepspeed.comm as dist
  10. sync_all()
  11. # Warmups, establish connections, etc.
  12. for i in range(args.warmups):
  13. # use all_gather_base if available
  14. if args.dist == 'torch':
  15. if hasattr(torch.distributed, "_all_gather_base"):
  16. dist._all_gather_base(output, input, group=None, async_op=args.async_op)
  17. else:
  18. output_tensors = list(
  19. torch.chunk(output_tensor,
  20. cdb.get_world_size(group)))
  21. dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
  22. elif args.dist == 'deepspeed':
  23. dist.allgather_fn(output, input, group=None, async_op=args.async_op)
  24. sync_all()
  25. # time the actual comm op trials times and average it
  26. pre = time.perf_counter()
  27. for i in range(args.trials):
  28. # use all_gather_base if available
  29. if args.dist == 'torch':
  30. if hasattr(torch.distributed, "_all_gather_base"):
  31. dist._all_gather_base(output, input, group=None, async_op=args.async_op)
  32. else:
  33. output_tensors = list(
  34. torch.chunk(output_tensor,
  35. cdb.get_world_size(group)))
  36. dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
  37. elif args.dist == 'deepspeed':
  38. dist.allgather_fn(output, input, group=None, async_op=args.async_op)
  39. sync_all()
  40. duration = time.perf_counter() - pre
  41. # maintain and clean performance data
  42. avg_duration = duration / args.trials
  43. size = input.element_size() * input.nelement()
  44. n = dist.get_world_size()
  45. tput, busbw = get_bw('all_gather', size, avg_duration, args)
  46. tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
  47. desc = f'{input.nelement()}x{input.element_size()}'
  48. if not args.raw:
  49. size = convert_size(size)
  50. print_rank_0(
  51. f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
  52. def run_all_gather(local_rank, args):
  53. if args.dist == 'torch':
  54. import torch.distributed as dist
  55. elif args.dist == 'deepspeed':
  56. import deepspeed.comm as dist
  57. # Prepare benchmark header
  58. print_header(args, 'all_gather')
  59. global_rank = dist.get_rank()
  60. world_size = dist.get_world_size()
  61. if args.scan:
  62. # Create list of message sizes
  63. M_LIST = []
  64. for x in (2**p for p in range(1, args.maxsize)):
  65. M_LIST.append(x)
  66. sync_all()
  67. # loop over various tensor sizes
  68. for M in M_LIST:
  69. global_rank = dist.get_rank()
  70. try:
  71. mat = torch.ones(world_size,
  72. M,
  73. dtype=getattr(torch,
  74. args.dtype)).cuda(local_rank)
  75. sync_all()
  76. input = ((mat.mul_(float(global_rank))).view(-1))
  77. # Delete original mat to avoid OOM
  78. del mat
  79. torch.cuda.empty_cache()
  80. output = torch.zeros(input.nelement() * world_size,
  81. dtype=getattr(torch,
  82. args.dtype)).cuda(local_rank)
  83. except RuntimeError as e:
  84. if 'out of memory' in str(e):
  85. if dist.get_rank() == 0:
  86. print('WARNING: Ran out of GPU memory. Exiting comm op.')
  87. sync_all()
  88. break
  89. sync_all()
  90. timed_all_gather(input, output, args)
  91. else:
  92. # all_gather_base saves memory
  93. if (args.dist == 'torch'
  94. and hasattr(torch.distributed,
  95. "_all_gather_base")) or (args.dist == 'deepspeed'
  96. and dist.has_allgather_base):
  97. mem_factor = args.mem_factor + 0.2
  98. else:
  99. mem_factor = args.mem_factor
  100. # Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
  101. sync_all()
  102. elements_per_gpu = max_numel(comm_op='all_gather',
  103. dtype=getattr(torch,
  104. args.dtype),
  105. mem_factor=mem_factor,
  106. local_rank=local_rank,
  107. args=args)
  108. try:
  109. mat = torch.ones(elements_per_gpu,
  110. dtype=getattr(torch,
  111. args.dtype)).cuda(local_rank)
  112. # multiply each GPU's tensor by the rank to ease debugging
  113. input = ((mat.mul_(float(global_rank))).view(-1))
  114. # Delete original mat to avoid OOM
  115. del mat
  116. torch.cuda.empty_cache()
  117. output = torch.zeros(elements_per_gpu * world_size,
  118. dtype=getattr(torch,
  119. args.dtype)).cuda(local_rank)
  120. except RuntimeError as e:
  121. if 'out of memory' in str(e):
  122. if dist.get_rank() == 0:
  123. print(
  124. 'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
  125. )
  126. sync_all()
  127. return
  128. sync_all()
  129. timed_all_gather(input, output, args)
  130. if __name__ == "__main__":
  131. args = benchmark_parser().parse_args()
  132. rank = args.local_rank
  133. init_processes(local_rank=rank, args=args)
  134. run_all_gather(local_rank=rank, args=args)