utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import inspect
  6. from deepspeed.utils import get_caller_func
  7. def get_local_rank_from_launcher():
  8. # DeepSpeed launcher will set it so get from there
  9. rank = os.environ.get('LOCAL_RANK')
  10. if rank is None:
  11. rank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK')
  12. # Make it a single process job and set rank to 0
  13. if rank is None:
  14. rank = 0
  15. return int(rank)
  16. def get_world_rank_from_launcher():
  17. # DeepSpeed launcher will set it so get from there
  18. rank = os.environ.get('RANK')
  19. if rank is None:
  20. rank = os.environ.get('OMPI_COMM_WORLD_RANK')
  21. # Make it a single process job and set rank to 0
  22. if rank is None:
  23. rank = 0
  24. return int(rank)
  25. def get_world_size_from_launcher():
  26. # DeepSpeed launcher will set it so get from there
  27. size = os.environ.get('WORLD_SIZE')
  28. rank = os.environ.get('RANK')
  29. if size is None:
  30. size = os.environ.get('OMPI_COMM_WORLD_SIZE')
  31. # Make it a single process job and set size to 1
  32. if size is None:
  33. size = 1
  34. if rank == 0:
  35. print(f"set world size to {size}")
  36. return int(size)
  37. def get_default_args(func):
  38. signature = inspect.signature(func)
  39. return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
  40. # We need this hacky function since torch doesn't consistently name or place the input tensor args
  41. def get_tensor_position(func):
  42. sig_params = inspect.signature(func).parameters
  43. arg = None
  44. # most colls
  45. if 'tensor' in sig_params:
  46. arg = 'tensor'
  47. # all_reduce_coalesced coll
  48. elif 'tensors' in sig_params:
  49. arg = 'tensors'
  50. # reduce scatter coll
  51. elif 'input_list' in sig_params:
  52. arg = 'input_list'
  53. # all_to_all and torch multiGPU colls
  54. elif 'input_tensor_list' in sig_params:
  55. arg = 'input_tensor_list'
  56. if arg is None:
  57. return -1
  58. else:
  59. return list(sig_params).index(arg)
  60. def get_tensor_kwarg(func, kwargs):
  61. func_args = get_default_args(func)
  62. func_args.update(kwargs)
  63. arg = None
  64. if 'tensor' in func_args:
  65. arg = func_args['tensor']
  66. elif 'tensors' in func_args:
  67. arg = func_args['tensors']
  68. elif 'input_list' in func_args:
  69. arg = func_args['input_list']
  70. elif 'input_tensor_list' in func_args:
  71. arg = func_args['input_tensor_list']
  72. return arg
  73. def get_msg_size_from_args(func, *args, **kwargs):
  74. # 3 cases:
  75. # - tensor arg is in args
  76. # - tensor arg is in kwargs
  77. # - tensor arg is not present (e.g. barrier)
  78. tensor_arg_position = -1
  79. tensor_arg = None
  80. # check if tensor arg is in args
  81. if len(args) > 0:
  82. tensor_arg_position = get_tensor_position(func)
  83. if tensor_arg_position > -1:
  84. tensor_arg = args[get_tensor_position(func)]
  85. # check if tensor arg is in kwargs
  86. if tensor_arg is None and len(kwargs) > 0:
  87. tensor_arg = get_tensor_kwarg(func, kwargs)
  88. # if tensor arg is not present, no data is being transmitted
  89. if tensor_arg is None:
  90. return 0
  91. else:
  92. # Sum of tensor sizes for list colls such as torch's all_to_all
  93. # NOTE: msg_size for list colls will not be the actual size transmitted by a given MPI/NCCL call within the coll op. Instead, it's the total amount of data transmitted.
  94. if type(tensor_arg) is list:
  95. return sum(x.element_size() * x.nelement() for x in tensor_arg)
  96. else:
  97. return tensor_arg.element_size() * tensor_arg.nelement()
  98. def get_debug_log_name(func_args, debug):
  99. if debug:
  100. return func_args['log_name'] + ' | [Caller Func: ' + get_caller_func() + ']'
  101. else:
  102. return func_args['log_name']