utils.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import os
  3. import torch
  4. import inspect
  5. from deepspeed.utils import get_caller_func
  6. def older_torch():
  7. '''
  8. Helper to lookup torch version. For versions less than 1.8, torch.dist
  9. used torch.distributed.group.WORLD as the default group argument instead of None.
  10. See more details at: https://github.com/pytorch/pytorch/pull/48767
  11. '''
  12. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  13. TORCH_MINOR = int(torch.__version__.split('.')[1])
  14. if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
  15. return True
  16. else:
  17. return False
  18. def has_allgather_base():
  19. '''
  20. Helper to check if torch.distributed has _all_gather_base
  21. '''
  22. return hasattr(torch.distributed, "_all_gather_base")
  23. def has_reduce_scatter_base():
  24. '''
  25. Helper to check if torch.distributed has _reduce_scatter_base
  26. '''
  27. return hasattr(torch.distributed, "_reduce_scatter_base")
  28. def get_local_rank_from_launcher():
  29. # DeepSpeed launcher will set it so get from there
  30. rank = os.environ.get('LOCAL_RANK')
  31. if rank is None:
  32. rank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK')
  33. # Make it a single process job and set rank to 0
  34. if rank is None:
  35. rank = 0
  36. return int(rank)
  37. def get_world_rank_from_launcher():
  38. # DeepSpeed launcher will set it so get from there
  39. rank = os.environ.get('RANK')
  40. if rank is None:
  41. rank = os.environ.get('OMPI_COMM_WORLD_RANK')
  42. # Make it a single process job and set rank to 0
  43. if rank is None:
  44. rank = 0
  45. return int(rank)
  46. def get_world_size_from_launcher():
  47. # DeepSpeed launcher will set it so get from there
  48. size = os.environ.get('WORLD_SIZE')
  49. rank = os.environ.get('RANK')
  50. if size is None:
  51. size = os.environ.get('OMPI_COMM_WORLD_SIZE')
  52. # Make it a single process job and set size to 1
  53. if size is None:
  54. size = 1
  55. if rank == 0:
  56. print(f"set world size to {size}")
  57. return int(size)
  58. def get_default_args(func):
  59. signature = inspect.signature(func)
  60. return {
  61. k: v.default
  62. for k,
  63. v in signature.parameters.items() if v.default is not inspect.Parameter.empty
  64. }
  65. # We need this hacky function since torch doesn't consistently name or place the input tensor args
  66. def get_tensor_position(func):
  67. sig_params = inspect.signature(func).parameters
  68. arg = None
  69. # most colls
  70. if 'tensor' in sig_params:
  71. arg = 'tensor'
  72. # reduce scatter coll
  73. elif 'input_list' in sig_params:
  74. arg = 'input_list'
  75. # all_to_all and torch multiGPU colls
  76. elif 'input_tensor_list' in sig_params:
  77. arg = 'input_tensor_list'
  78. if arg is None:
  79. return -1
  80. else:
  81. return list(sig_params).index(arg)
  82. def get_tensor_kwarg(func, kwargs):
  83. func_args = get_default_args(func)
  84. func_args.update(kwargs)
  85. arg = None
  86. if 'tensor' in func_args:
  87. arg = func_args['tensor']
  88. elif 'input_list' in func_args:
  89. arg = func_args['input_list']
  90. elif 'input_tensor_list' in func_args:
  91. arg = func_args['input_tensor_list']
  92. return arg
  93. def get_msg_size_from_args(func, *args, **kwargs):
  94. # 3 cases:
  95. # - tensor arg is in args
  96. # - tensor arg is in kwargs
  97. # - tensor arg is not present (e.g. barrier)
  98. tensor_arg_position = -1
  99. tensor_arg = None
  100. # check if tensor arg is in args
  101. if len(args) > 0:
  102. tensor_arg_position = get_tensor_position(func)
  103. if tensor_arg_position > -1:
  104. tensor_arg = args[get_tensor_position(func)]
  105. # check if tensor arg is in kwargs
  106. if tensor_arg is None and len(kwargs) > 0:
  107. tensor_arg = get_tensor_kwarg(func, kwargs)
  108. # if tensor arg is not present, no data is being transmitted
  109. if tensor_arg is None:
  110. return 0
  111. else:
  112. # Sum of tensor sizes for list colls such as torch's all_to_all
  113. # NOTE: msg_size for list colls will not be the actual size transmitted by a given MPI/NCCL call within the coll op. Instead, it's the total amount of data transmitted.
  114. if type(tensor_arg) is list:
  115. return sum(x.element_size() * x.nelement() for x in tensor_arg)
  116. else:
  117. return tensor_arg.element_size() * tensor_arg.nelement()
  118. def get_debug_log_name(func_args, debug):
  119. if debug:
  120. return func_args['log_name'] + ' | [Caller Func: ' + get_caller_func() + ']'
  121. else:
  122. return func_args['log_name']