broadcast.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import torch
  2. from benchmarks.communication.utils import *
  3. from benchmarks.communication.constants import *
  4. import time
  5. def timed_broadcast(input, args):
  6. if args.dist == 'torch':
  7. import torch.distributed as dist
  8. elif args.dist == 'deepspeed':
  9. import deepspeed.comm as dist
  10. sync_all()
  11. # Warmups, establish connections, etc.
  12. for i in range(args.warmups):
  13. dist.broadcast(input, 0, async_op=args.async_op)
  14. sync_all()
  15. # time the actual comm op trials times and average it
  16. pre = time.perf_counter()
  17. for i in range(args.trials):
  18. dist.broadcast(input, 0, async_op=args.async_op)
  19. sync_all()
  20. duration = time.perf_counter() - pre
  21. # maintain and clean performance data
  22. avg_duration = duration / args.trials
  23. size = input.element_size() * input.nelement()
  24. n = dist.get_world_size()
  25. tput, busbw = get_bw('broadcast', size, avg_duration, args)
  26. tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
  27. desc = f'{input.nelement()}x{input.element_size()}'
  28. if not args.raw:
  29. size = convert_size(size)
  30. print_rank_0(
  31. f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
  32. def run_broadcast(local_rank, args):
  33. if args.dist == 'torch':
  34. import torch.distributed as dist
  35. elif args.dist == 'deepspeed':
  36. import deepspeed.comm as dist
  37. # Prepare benchmark header
  38. print_header(args, 'broadcast')
  39. world_size = dist.get_world_size()
  40. global_rank = dist.get_rank()
  41. if args.scan:
  42. M_LIST = []
  43. for x in (2**p for p in range(1, args.maxsize)):
  44. M_LIST.append(x)
  45. sync_all()
  46. # loop over various tensor sizes
  47. for M in M_LIST:
  48. global_rank = dist.get_rank()
  49. try:
  50. mat = torch.ones(world_size,
  51. M,
  52. dtype=getattr(torch,
  53. args.dtype)).cuda(local_rank)
  54. sync_all()
  55. input = ((mat.mul_(float(global_rank))).view(-1))
  56. except RuntimeError as e:
  57. if 'out of memory' in str(e):
  58. if dist.get_rank() == 0:
  59. print('WARNING: Ran out of GPU memory. Exiting comm op.')
  60. sync_all()
  61. break
  62. sync_all()
  63. timed_broadcast(input, args)
  64. else:
  65. # Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
  66. # Don't need output tensor, so we double mem_factor
  67. elements_per_gpu = max_numel(comm_op='broadcast',
  68. dtype=getattr(torch,
  69. args.dtype),
  70. mem_factor=args.mem_factor * 2,
  71. local_rank=local_rank,
  72. args=args)
  73. try:
  74. mat = torch.ones(elements_per_gpu,
  75. dtype=getattr(torch,
  76. args.dtype)).cuda(local_rank)
  77. input = ((mat.mul_(float(global_rank))).view(-1))
  78. except RuntimeError as e:
  79. if 'out of memory' in str(e):
  80. if dist.get_rank() == 0:
  81. print(
  82. 'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
  83. )
  84. sync_all()
  85. return
  86. sync_all()
  87. timed_broadcast(input, args)
  88. if __name__ == "__main__":
  89. args = benchmark_parser().parse_args()
  90. rank = args.local_rank
  91. init_processes(local_rank=rank, args=args)
  92. run_broadcast(local_rank=rank, args=args)