123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- '''Copyright The Microsoft DeepSpeed Team'''
- import os
- import torch
- import inspect
- from deepspeed.utils import get_caller_func
- def older_torch():
- '''
- Helper to lookup torch version. For versions less than 1.8, torch.dist
- used torch.distributed.group.WORLD as the default group argument instead of None.
- See more details at: https://github.com/pytorch/pytorch/pull/48767
- '''
- TORCH_MAJOR = int(torch.__version__.split('.')[0])
- TORCH_MINOR = int(torch.__version__.split('.')[1])
- if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
- return True
- else:
- return False
- def has_allgather_base():
- '''
- Helper to check if torch.distributed has _all_gather_base
- '''
- return hasattr(torch.distributed, "_all_gather_base")
- def has_reduce_scatter_base():
- '''
- Helper to check if torch.distributed has _reduce_scatter_base
- '''
- return hasattr(torch.distributed, "_reduce_scatter_base")
- def get_local_rank_from_launcher():
- # DeepSpeed launcher will set it so get from there
- rank = os.environ.get('LOCAL_RANK')
- if rank is None:
- rank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK')
- # Make it a single process job and set rank to 0
- if rank is None:
- rank = 0
- return int(rank)
- def get_world_rank_from_launcher():
- # DeepSpeed launcher will set it so get from there
- rank = os.environ.get('RANK')
- if rank is None:
- rank = os.environ.get('OMPI_COMM_WORLD_RANK')
- # Make it a single process job and set rank to 0
- if rank is None:
- rank = 0
- return int(rank)
- def get_world_size_from_launcher():
- # DeepSpeed launcher will set it so get from there
- size = os.environ.get('WORLD_SIZE')
- rank = os.environ.get('RANK')
- if size is None:
- size = os.environ.get('OMPI_COMM_WORLD_SIZE')
- # Make it a single process job and set size to 1
- if size is None:
- size = 1
- if rank == 0:
- print(f"set world size to {size}")
- return int(size)
- def get_default_args(func):
- signature = inspect.signature(func)
- return {
- k: v.default
- for k,
- v in signature.parameters.items() if v.default is not inspect.Parameter.empty
- }
- # We need this hacky function since torch doesn't consistently name or place the input tensor args
- def get_tensor_position(func):
- sig_params = inspect.signature(func).parameters
- arg = None
- # most colls
- if 'tensor' in sig_params:
- arg = 'tensor'
- # reduce scatter coll
- elif 'input_list' in sig_params:
- arg = 'input_list'
- # all_to_all and torch multiGPU colls
- elif 'input_tensor_list' in sig_params:
- arg = 'input_tensor_list'
- if arg is None:
- return -1
- else:
- return list(sig_params).index(arg)
- def get_tensor_kwarg(func, kwargs):
- func_args = get_default_args(func)
- func_args.update(kwargs)
- arg = None
- if 'tensor' in func_args:
- arg = func_args['tensor']
- elif 'input_list' in func_args:
- arg = func_args['input_list']
- elif 'input_tensor_list' in func_args:
- arg = func_args['input_tensor_list']
- return arg
- def get_msg_size_from_args(func, *args, **kwargs):
- # 3 cases:
- # - tensor arg is in args
- # - tensor arg is in kwargs
- # - tensor arg is not present (e.g. barrier)
- tensor_arg_position = -1
- tensor_arg = None
- # check if tensor arg is in args
- if len(args) > 0:
- tensor_arg_position = get_tensor_position(func)
- if tensor_arg_position > -1:
- tensor_arg = args[get_tensor_position(func)]
- # check if tensor arg is in kwargs
- if tensor_arg is None and len(kwargs) > 0:
- tensor_arg = get_tensor_kwarg(func, kwargs)
- # if tensor arg is not present, no data is being transmitted
- if tensor_arg is None:
- return 0
- else:
- # Sum of tensor sizes for list colls such as torch's all_to_all
- # NOTE: msg_size for list colls will not be the actual size transmitted by a given MPI/NCCL call within the coll op. Instead, it's the total amount of data transmitted.
- if type(tensor_arg) is list:
- return sum(x.element_size() * x.nelement() for x in tensor_arg)
- else:
- return tensor_arg.element_size() * tensor_arg.nelement()
- def get_debug_log_name(func_args, debug):
- if debug:
- return func_args['log_name'] + ' | [Caller Func: ' + get_caller_func() + ']'
- else:
- return func_args['log_name']
|