123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- import torch
- import os
- import math
- import argparse
- from benchmarks.communication.constants import *
- global dist
- def init_torch_distributed(backend):
- global dist
- import torch.distributed as dist
- torch.distributed.init_process_group(backend)
- local_rank = int(os.environ['LOCAL_RANK'])
- torch.cuda.set_device(local_rank)
- def init_deepspeed_comm(backend):
- global dist
- import deepspeed
- import deepspeed.comm as dist
- deepspeed.init_distributed(dist_backend=backend)
- local_rank = int(os.environ['LOCAL_RANK'])
- torch.cuda.set_device(local_rank)
- def init_processes(local_rank, args):
- if args.dist == 'deepspeed':
- init_deepspeed_comm(args.backend)
- elif args.dist == 'torch':
- init_torch_distributed(args.backend)
- else:
- print_rank_0(f"distributed framework {args.dist} not supported")
- exit(0)
- def print_rank_0(message):
- if dist.get_rank() == 0:
- print(message)
- def print_header(args, comm_op):
- if comm_op == 'pt2pt':
- world_size = 2
- else:
- world_size = dist.get_world_size()
- tput = f'Throughput ({args.bw_unit})'
- busbw = f'BusBW ({args.bw_unit})'
- header = f"\n---- Performance of {comm_op} on {world_size} devices ---------------------------------------------------------\n"
- duration_str = 'Duration'
- if args.raw:
- duration_str += ' (us)'
- header += f"{'Size (Bytes)':20s} {'Description':25s} {duration_str:20s} {tput:20s} {busbw:20s}\n"
- header += "----------------------------------------------------------------------------------------------------"
- print_rank_0(header)
- def get_bw(comm_op, size, duration, args):
- n = dist.get_world_size()
- tput = 0
- busbw = 0
- if comm_op == "all_to_all":
- tput = (size / duration)
- busbw = (size / duration) * ((n - 1) / n)
- elif comm_op == "all_gather":
- size *= n
- tput = (size / duration)
- busbw = (size / duration) * ((n - 1) / n)
- elif comm_op == "all_reduce":
- tput = (size * 2 / duration)
- busbw = (size / duration) * (2 * (n - 1) / n)
- elif comm_op == "pt2pt" or comm_op == "broadcast":
- tput = (size / duration)
- busbw = tput
- else:
- print_rank_0("wrong comm_op specified")
- exit(0)
- if args.bw_unit == 'Gbps':
- tput *= 8
- busbw *= 8
- return tput, busbw
- def get_metric_strings(args, tput, busbw, duration):
- duration_ms = duration * 1e3
- duration_us = duration * 1e6
- tput = f'{tput / 1e9:.3f}'
- busbw = f'{busbw /1e9:.3f}'
- if duration_us < 1e3 or args.raw:
- duration = f'{duration_us:.3f}'
- if not args.raw:
- duration += ' us'
- else:
- duration = f'{duration_ms:.3f} ms'
- return tput, busbw, duration
- def sync_all():
- torch.cuda.synchronize()
- dist.barrier()
- def max_numel(comm_op, dtype, mem_factor, local_rank, args):
- dtype_size = _element_size(dtype)
- max_memory_per_gpu = torch.cuda.get_device_properties(
- local_rank).total_memory * mem_factor
- if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast':
- elements_per_gpu = int(max_memory_per_gpu // dtype_size)
- elif comm_op == 'all_gather':
- # all_gather performance is lower for non-powers of two, and the output buffer size scales with world size
- # Therefore, divide by world size and round down to nearest power of 2
- elements_per_gpu = int(max_memory_per_gpu // dtype_size // dist.get_world_size())
- elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
- elif comm_op == 'all_to_all':
- # Number of elements must be divisible by world_size
- # all_to_all performance is lower for non-powers of two. Round down like all_gather.
- elements_per_gpu = int(max_memory_per_gpu // dtype_size)
- elements_per_gpu = int(dist.get_world_size() *
- round(elements_per_gpu / dist.get_world_size()))
- elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
- else:
- print(f"This communication operation: {comm_op} is not supported yet")
- exit(0)
- return elements_per_gpu
- # Helper function to pretty-print message sizes
- def convert_size(size_bytes):
- if size_bytes == 0:
- return "0B"
- size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
- i = int(math.floor(math.log(size_bytes, 1024)))
- p = math.pow(1024, i)
- s = round(size_bytes / p, 2)
- return "%s %s" % (s, size_name[i])
- # Copied from torch. Need to add the func here for old torch compatibility.
- def _element_size(dtype):
- """
- Returns the element size for a dtype, in bytes
- """
- if not isinstance(dtype, torch.dtype):
- raise RuntimeError(f'expected torch.dtype, but got {type(dtype)}')
- if dtype.is_complex:
- return torch.finfo(dtype).bits >> 2
- elif dtype.is_floating_point:
- return torch.finfo(dtype).bits >> 3
- elif dtype == torch.bool:
- # NOTE: torch.bool is not supported in torch.iinfo()
- return 1
- else:
- return torch.iinfo(dtype).bits >> 3
- def benchmark_parser():
- parser = argparse.ArgumentParser()
- parser.add_argument("--local_rank", type=int)
- parser.add_argument("--trials",
- type=int,
- default=DEFAULT_TRIALS,
- help='Number of timed iterations')
- parser.add_argument("--warmups",
- type=int,
- default=DEFAULT_WARMUPS,
- help='Number of warmup (non-timed) iterations')
- parser.add_argument("--maxsize",
- type=int,
- default=24,
- help='Max message size as a power of 2')
- parser.add_argument("--async-op",
- action="store_true",
- help='Enables non-blocking communication')
- parser.add_argument("--bw-unit",
- type=str,
- default=DEFAULT_UNIT,
- choices=['Gbps',
- 'GBps'])
- parser.add_argument("--backend",
- type=str,
- default=DEFAULT_BACKEND,
- choices=['nccl'],
- help='Communication library to use')
- parser.add_argument("--dist",
- type=str,
- default=DEFAULT_DIST,
- choices=['deepspeed',
- 'torch'],
- help='Distributed DL framework to use')
- parser.add_argument("--scan",
- action="store_true",
- help='Enables scanning all message sizes')
- parser.add_argument("--raw",
- action="store_true",
- help='Print the message size and latency without units')
- parser.add_argument("--all-reduce", action="store_true", help='Run all_reduce')
- parser.add_argument("--all-gather", action="store_true", help='Run all_gather')
- parser.add_argument("--all-to-all", action="store_true", help='Run all_to_all')
- parser.add_argument("--pt2pt", action="store_true", help='Run pt2pt')
- parser.add_argument("--broadcast", action="store_true", help='Run broadcast')
- parser.add_argument("--dtype",
- type=str,
- default=DEFAULT_TYPE,
- help='PyTorch tensor dtype')
- parser.add_argument(
- "--mem-factor",
- type=float,
- default=.4,
- help='Proportion of max available GPU memory to use for single-size evals')
- parser.add_argument("--debug",
- action="store_true",
- help='Enables all_to_all debug prints')
- return parser
|