all_to_all.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from benchmarks.communication.utils import *
  2. from benchmarks.communication.constants import *
  3. import time
  4. def timed_all_to_all(input, output, args):
  5. if args.dist == 'torch':
  6. import torch.distributed as dist
  7. elif args.dist == 'deepspeed':
  8. import deepspeed.comm as dist
  9. sync_all()
  10. # Warmups, establish connections, etc.
  11. for i in range(args.warmups):
  12. dist.all_to_all_single(output, input, async_op=args.async_op)
  13. sync_all()
  14. # time the actual comm op trials times and average it
  15. pre = time.perf_counter()
  16. for i in range(args.trials):
  17. dist.all_to_all_single(output, input, async_op=args.async_op)
  18. sync_all()
  19. duration = time.perf_counter() - pre
  20. # maintain and clean performance data
  21. avg_duration = duration / args.trials
  22. size = input.element_size() * input.nelement()
  23. n = dist.get_world_size()
  24. tput, busbw = get_bw('all_to_all', size, avg_duration, args)
  25. tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
  26. desc = f'{input.nelement()}x{input.element_size()}'
  27. if not args.raw:
  28. size = convert_size(size)
  29. print_rank_0(
  30. f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
  31. def run_all_to_all(local_rank, args):
  32. if args.dist == 'torch':
  33. import torch.distributed as dist
  34. elif args.dist == 'deepspeed':
  35. import deepspeed.comm as dist
  36. world_size = dist.get_world_size()
  37. global_rank = dist.get_rank()
  38. # Prepare benchmark header
  39. print_header(args, 'all_to_all')
  40. if args.scan:
  41. M_LIST = []
  42. for x in (2**p for p in range(1, args.maxsize)):
  43. M_LIST.append(x)
  44. sync_all()
  45. # loop over various tensor sizes
  46. for M in M_LIST:
  47. global_rank = dist.get_rank()
  48. try:
  49. mat = torch.ones(world_size,
  50. M,
  51. dtype=getattr(torch,
  52. args.dtype)).cuda(local_rank)
  53. assert mat.numel() % world_size == 0, f"tensor cannot be divided in {world_size} chunks"
  54. sync_all()
  55. input = ((mat.mul_(float(global_rank))).view(-1))
  56. output = (mat.clone().view(-1))
  57. except RuntimeError as e:
  58. if 'out of memory' in str(e):
  59. if dist.get_rank() == 0:
  60. print('WARNING: Ran out of GPU memory. Exiting comm op.')
  61. sync_all()
  62. break
  63. sync_all()
  64. timed_all_to_all(input, output, args)
  65. else:
  66. # Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
  67. elements_per_gpu = max_numel(comm_op='all_to_all',
  68. dtype=getattr(torch,
  69. args.dtype),
  70. mem_factor=args.mem_factor,
  71. local_rank=local_rank,
  72. args=args)
  73. try:
  74. mat = torch.ones(elements_per_gpu,
  75. dtype=getattr(torch,
  76. args.dtype)).cuda(local_rank)
  77. assert mat.numel() % world_size == 0, f"tensor with {mat.numel()} elements cannot be divided in {world_size} chunks"
  78. input = ((mat.mul_(float(global_rank))).view(-1))
  79. # Delete original mat to avoid OOM
  80. del mat
  81. torch.cuda.empty_cache()
  82. output = torch.zeros(elements_per_gpu,
  83. dtype=getattr(torch,
  84. args.dtype)).cuda(local_rank)
  85. except RuntimeError as e:
  86. if 'out of memory' in str(e):
  87. if dist.get_rank() == 0:
  88. print(
  89. 'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
  90. )
  91. sync_all()
  92. return
  93. sync_all()
  94. if args.debug:
  95. for i in range(world_size):
  96. if i == global_rank:
  97. print(f"Before AllToAll Input List at rank {global_rank}: {input}")
  98. dist.barrier()
  99. timed_all_to_all(input, output, args)
  100. if args.debug:
  101. for i in range(world_size):
  102. if i == global_rank:
  103. print(f"AllToAll Results at rank {global_rank}: {output}")
  104. dist.barrier()
  105. if __name__ == "__main__":
  106. args = benchmark_parser().parse_args()
  107. rank = args.local_rank
  108. init_processes(local_rank=rank, args=args)
  109. run_all_to_all(local_rank=rank, args=args)