pt2pt.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from benchmarks.communication.utils import *
  2. from benchmarks.communication.constants import *
  3. import time
  4. def timed_pt2pt(input, args):
  5. if args.dist == 'torch':
  6. import torch.distributed as dist
  7. elif args.dist == 'deepspeed':
  8. import deepspeed.comm as dist
  9. sync_all()
  10. # Warmups, establish connections, etc.
  11. for i in range(args.warmups):
  12. if dist.get_rank() == 0:
  13. if args.async_op:
  14. dist.isend(input, 1)
  15. else:
  16. dist.send(input, 1)
  17. if dist.get_rank() == 1:
  18. if args.async_op:
  19. dist.irecv(input, src=0)
  20. else:
  21. dist.recv(input, src=0)
  22. sync_all()
  23. # time the actual comm op trials times and average it
  24. pre = time.perf_counter()
  25. for i in range(args.trials):
  26. if dist.get_rank() == 0:
  27. if args.async_op:
  28. dist.isend(input, 1)
  29. else:
  30. dist.send(input, 1)
  31. if dist.get_rank() == 1:
  32. if args.async_op:
  33. dist.irecv(input, src=0)
  34. else:
  35. dist.recv(input, src=0)
  36. sync_all()
  37. duration = time.perf_counter() - pre
  38. # maintain and clean performance data
  39. avg_duration = duration / args.trials
  40. size = input.element_size() * input.nelement()
  41. n = dist.get_world_size()
  42. tput, busbw = get_bw('pt2pt', size, avg_duration, args)
  43. tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
  44. desc = f'{input.nelement()}x{input.element_size()}'
  45. if not args.raw:
  46. size = convert_size(size)
  47. print_rank_0(
  48. f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
  49. def run_pt2pt(local_rank, args):
  50. if args.dist == 'torch':
  51. import torch.distributed as dist
  52. elif args.dist == 'deepspeed':
  53. import deepspeed.comm as dist
  54. # Prepare benchmark header
  55. print_header(args, 'pt2pt')
  56. global_rank = dist.get_rank()
  57. world_size = dist.get_world_size()
  58. if args.scan:
  59. # Create list of message sizes
  60. M_LIST = []
  61. for x in (2**p for p in range(1, args.maxsize)):
  62. M_LIST.append(x)
  63. sync_all()
  64. # loop over various tensor sizes
  65. for M in M_LIST:
  66. global_rank = dist.get_rank()
  67. try:
  68. mat = torch.ones(world_size,
  69. M,
  70. dtype=getattr(torch,
  71. args.dtype)).cuda(local_rank)
  72. sync_all()
  73. input = ((mat.mul_(float(global_rank))).view(-1))
  74. except RuntimeError as e:
  75. if 'out of memory' in str(e):
  76. if dist.get_rank() == 0:
  77. print('WARNING: Ran out of GPU memory. Exiting comm op.')
  78. sync_all()
  79. break
  80. sync_all()
  81. timed_pt2pt(input, args)
  82. else:
  83. # Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
  84. # Don't need output tensor, so double mem_factor
  85. elements_per_gpu = max_numel(comm_op='pt2pt',
  86. dtype=getattr(torch,
  87. args.dtype),
  88. mem_factor=args.mem_factor * 2,
  89. local_rank=local_rank,
  90. args=args)
  91. try:
  92. mat = torch.ones(elements_per_gpu,
  93. dtype=getattr(torch,
  94. args.dtype)).cuda(local_rank)
  95. input = ((mat.mul_(float(global_rank))).view(-1))
  96. except RuntimeError as e:
  97. if 'out of memory' in str(e):
  98. if dist.get_rank() == 0:
  99. print(
  100. 'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
  101. )
  102. sync_all()
  103. return
  104. sync_all()
  105. timed_pt2pt(input, args)
  106. if __name__ == "__main__":
  107. args = benchmark_parser().parse_args()
  108. rank = args.local_rank
  109. init_processes(local_rank=rank, args=args)
  110. run_pt2pt(local_rank=rank, args=args)