123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- from benchmarks.communication.utils import *
- from benchmarks.communication.constants import *
- import time
- def timed_all_to_all(input, output, args):
- if args.dist == 'torch':
- import torch.distributed as dist
- elif args.dist == 'deepspeed':
- import deepspeed.comm as dist
- sync_all()
- # Warmups, establish connections, etc.
- for i in range(args.warmups):
- dist.all_to_all_single(output, input, async_op=args.async_op)
- sync_all()
- # time the actual comm op trials times and average it
- pre = time.perf_counter()
- for i in range(args.trials):
- dist.all_to_all_single(output, input, async_op=args.async_op)
- sync_all()
- duration = time.perf_counter() - pre
- # maintain and clean performance data
- avg_duration = duration / args.trials
- size = input.element_size() * input.nelement()
- n = dist.get_world_size()
- tput, busbw = get_bw('all_to_all', size, avg_duration, args)
- tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
- desc = f'{input.nelement()}x{input.element_size()}'
- if not args.raw:
- size = convert_size(size)
- print_rank_0(
- f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
- def run_all_to_all(local_rank, args):
- if args.dist == 'torch':
- import torch.distributed as dist
- elif args.dist == 'deepspeed':
- import deepspeed.comm as dist
- world_size = dist.get_world_size()
- global_rank = dist.get_rank()
- # Prepare benchmark header
- print_header(args, 'all_to_all')
- if args.scan:
- M_LIST = []
- for x in (2**p for p in range(1, args.maxsize)):
- M_LIST.append(x)
- sync_all()
- # loop over various tensor sizes
- for M in M_LIST:
- global_rank = dist.get_rank()
- try:
- mat = torch.ones(world_size,
- M,
- dtype=getattr(torch,
- args.dtype)).cuda(local_rank)
- assert mat.numel() % world_size == 0, f"tensor cannot be divided in {world_size} chunks"
- sync_all()
- input = ((mat.mul_(float(global_rank))).view(-1))
- output = (mat.clone().view(-1))
- except RuntimeError as e:
- if 'out of memory' in str(e):
- if dist.get_rank() == 0:
- print('WARNING: Ran out of GPU memory. Exiting comm op.')
- sync_all()
- break
- sync_all()
- timed_all_to_all(input, output, args)
- else:
- # Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
- elements_per_gpu = max_numel(comm_op='all_to_all',
- dtype=getattr(torch,
- args.dtype),
- mem_factor=args.mem_factor,
- local_rank=local_rank,
- args=args)
- try:
- mat = torch.ones(elements_per_gpu,
- dtype=getattr(torch,
- args.dtype)).cuda(local_rank)
- assert mat.numel() % world_size == 0, f"tensor with {mat.numel()} elements cannot be divided in {world_size} chunks"
- input = ((mat.mul_(float(global_rank))).view(-1))
- # Delete original mat to avoid OOM
- del mat
- torch.cuda.empty_cache()
- output = torch.zeros(elements_per_gpu,
- dtype=getattr(torch,
- args.dtype)).cuda(local_rank)
- except RuntimeError as e:
- if 'out of memory' in str(e):
- if dist.get_rank() == 0:
- print(
- 'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
- )
- sync_all()
- return
- sync_all()
- if args.debug:
- for i in range(world_size):
- if i == global_rank:
- print(f"Before AllToAll Input List at rank {global_rank}: {input}")
- dist.barrier()
- timed_all_to_all(input, output, args)
- if args.debug:
- for i in range(world_size):
- if i == global_rank:
- print(f"AllToAll Results at rank {global_rank}: {output}")
- dist.barrier()
- if __name__ == "__main__":
- args = benchmark_parser().parse_args()
- rank = args.local_rank
- init_processes(local_rank=rank, args=args)
- run_all_to_all(local_rank=rank, args=args)
|