all_reduce.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from benchmarks.communication.utils import *
  2. from benchmarks.communication.constants import *
  3. import time
  4. def timed_all_reduce(input, args):
  5. if args.dist == 'torch':
  6. import torch.distributed as dist
  7. elif args.dist == 'deepspeed':
  8. import deepspeed.comm as dist
  9. sync_all()
  10. # Warmups, establish connections, etc.
  11. for i in range(args.warmups):
  12. dist.all_reduce(input, async_op=args.async_op)
  13. sync_all()
  14. # time the actual comm op trials times and average it
  15. pre = time.perf_counter()
  16. for i in range(args.trials):
  17. dist.all_reduce(input, async_op=args.async_op)
  18. sync_all()
  19. duration = time.perf_counter() - pre
  20. # maintain and clean performance data
  21. avg_duration = duration / args.trials
  22. size = input.element_size() * input.nelement()
  23. n = dist.get_world_size()
  24. tput, busbw = get_bw('all_reduce', size, avg_duration, args)
  25. tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
  26. desc = f'{input.nelement()}x{input.element_size()}'
  27. if not args.raw:
  28. size = convert_size(size)
  29. print_rank_0(
  30. f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
  31. def run_all_reduce(local_rank, args):
  32. if args.dist == 'torch':
  33. import torch.distributed as dist
  34. elif args.dist == 'deepspeed':
  35. import deepspeed.comm as dist
  36. # Prepare benchmark header
  37. print_header(args, 'all_reduce')
  38. world_size = dist.get_world_size()
  39. global_rank = dist.get_rank()
  40. if args.scan:
  41. M_LIST = []
  42. for x in (2**p for p in range(1, args.maxsize)):
  43. M_LIST.append(x)
  44. sync_all()
  45. # loop over various tensor sizes
  46. for M in M_LIST:
  47. global_rank = dist.get_rank()
  48. try:
  49. mat = torch.ones(world_size,
  50. M,
  51. dtype=getattr(torch,
  52. args.dtype)).cuda(local_rank)
  53. sync_all()
  54. input = ((mat.mul_(float(global_rank))).view(-1))
  55. except RuntimeError as e:
  56. if 'out of memory' in str(e):
  57. if dist.get_rank() == 0:
  58. print('WARNING: Ran out of GPU memory. Exiting comm op.')
  59. sync_all()
  60. break
  61. sync_all()
  62. timed_all_reduce(input, args)
  63. else:
  64. # Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
  65. # Don't need output tensor, so we double mem_factor
  66. elements_per_gpu = max_numel(comm_op='all_reduce',
  67. dtype=getattr(torch,
  68. args.dtype),
  69. mem_factor=args.mem_factor * 2,
  70. local_rank=local_rank,
  71. args=args)
  72. try:
  73. mat = torch.ones(elements_per_gpu,
  74. dtype=getattr(torch,
  75. args.dtype)).cuda(local_rank)
  76. input = ((mat.mul_(float(global_rank))).view(-1))
  77. except RuntimeError as e:
  78. if 'out of memory' in str(e):
  79. if dist.get_rank() == 0:
  80. print(
  81. 'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
  82. )
  83. sync_all()
  84. return
  85. sync_all()
  86. timed_all_reduce(input, args)
  87. if __name__ == "__main__":
  88. args = benchmark_parser().parse_args()
  89. rank = args.local_rank
  90. init_processes(local_rank=rank, args=args)
  91. run_all_reduce(local_rank=rank, args=args)