comms_logging.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import math
  3. from deepspeed.utils import log_dist
  4. def get_caller_func(frame=3):
  5. import sys
  6. return sys._getframe(frame).f_code.co_name
  7. # Helper function to pretty-print message sizes
  8. def convert_size(size_bytes):
  9. if size_bytes == 0:
  10. return "0B"
  11. size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
  12. i = int(math.floor(math.log(size_bytes, 1024)))
  13. p = math.pow(1024, i)
  14. s = round(size_bytes / p, 2)
  15. return "%s %s" % (s, size_name[i])
  16. # Helper function to calculate algbw and busbw.
  17. # See https://gist.github.com/jeffra/b5e80466b4c86be00ea3b6f130fb7a36 and https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md
  18. def calc_bw_log(comm_op, size, duration):
  19. import deepspeed.comm as dist
  20. n = dist.get_world_size()
  21. tput = 0
  22. busbw = 0
  23. if comm_op == "all_to_all_single":
  24. tput = (size / duration)
  25. busbw = (size / duration) * ((n - 1) / n)
  26. elif comm_op == "all_gather" or comm_op == "all_gather_base" or comm_op == "reduce_scatter" or comm_op == "reduce_scatter_base":
  27. size *= n
  28. tput = (size / duration)
  29. busbw = (size / duration) * ((n - 1) / n)
  30. elif comm_op == "all_reduce":
  31. tput = (size * 2 / duration)
  32. busbw = (size / duration) * (2 * (n - 1) / n)
  33. elif comm_op == "send" or comm_op == "recv" or comm_op == "isend" or comm_op == "irecv" or comm_op == "broadcast" or comm_op == "reduce" or comm_op == "gather" or comm_op == "scatter" or comm_op == "barrier":
  34. tput = (size / duration)
  35. busbw = tput
  36. else:
  37. print_rank_0("wrong comm_op specified") # noqa: F821
  38. exit(0)
  39. # convert to Gbps
  40. tput *= 8
  41. busbw *= 8
  42. tput /= 1e6
  43. busbw /= 1e6
  44. return tput, busbw
  45. class CommsLogger:
  46. def __init__(self):
  47. from deepspeed.comm.constants import COMMS_LOGGER_VERBOSE_DEFAULT, COMMS_LOGGER_DEBUG_DEFAULT, COMMS_LOGGER_PROF_OPS_DEFAULT, COMMS_LOGGER_PROF_ALL_DEFAULT, COMMS_LOGGER_ENABLED_DEFAULT
  48. self.comms_dict = {}
  49. self.verbose = COMMS_LOGGER_VERBOSE_DEFAULT
  50. self.debug = COMMS_LOGGER_DEBUG_DEFAULT
  51. self.prof_ops = COMMS_LOGGER_PROF_OPS_DEFAULT
  52. self.prof_all = COMMS_LOGGER_PROF_ALL_DEFAULT
  53. self.enabled = COMMS_LOGGER_ENABLED_DEFAULT
  54. def configure(self, comms_config):
  55. self.enabled = comms_config.comms_logger_enabled
  56. if self.enabled:
  57. self.verbose = comms_config.comms_logger.verbose
  58. self.debug = comms_config.comms_logger.debug
  59. self.prof_ops = comms_config.comms_logger.prof_ops
  60. self.prof_all = comms_config.comms_logger.prof_all
  61. # There are three settings for the op profiler:
  62. # - Global profiling (profile all comms)
  63. # - Op-type profiling (e.g. profile all all_reduce comms)
  64. # - Op profiling (e.g. profile a specific all_reduce op)
  65. def start_profiling_comms(self):
  66. self.prof_all = True
  67. def stop_profiling_comms(self):
  68. self.prof_all = True
  69. # E.g. start_profiling_op('all_reduce')
  70. def start_profiling_op(self, op_name_list):
  71. self.prof_ops = list(set(self.prof_ops) | set(op_name_list))
  72. def stop_profiling_op(self, op_name_list):
  73. self.prof_ops = [op for op in self.prof_ops if op not in op_name_list]
  74. # Add log entry
  75. def append(self, raw_name, record_name, latency, msg_size):
  76. import deepspeed.comm as dist
  77. algbw, busbw = calc_bw_log(raw_name, msg_size, latency)
  78. if record_name in self.comms_dict.keys():
  79. # If this comm_op has already been logged with this message size, just add to existing record
  80. if msg_size in self.comms_dict[record_name].keys():
  81. self.comms_dict[record_name][msg_size][0] += 1
  82. self.comms_dict[record_name][msg_size][1].append(latency)
  83. self.comms_dict[record_name][msg_size][2].append(algbw)
  84. self.comms_dict[record_name][msg_size][3].append(busbw)
  85. # If this is a new message size for this comm_op, add new record under existing comm_op
  86. else:
  87. self.comms_dict[record_name][msg_size] = [1, [latency], [algbw], [busbw]]
  88. else:
  89. # Create entirely new record
  90. self.comms_dict[record_name] = {msg_size: [1, [latency], [algbw], [busbw]]}
  91. # If verbose, print every comm op
  92. # TODO: Add to tensorboard
  93. if self.verbose:
  94. n = dist.get_world_size()
  95. log_str = f"rank={dist.get_rank()} | comm op: " + record_name + " | time (ms): {:.2f}".format(
  96. latency)
  97. log_str += " | msg size: " + convert_size(msg_size)
  98. log_str += " | algbw (Gbps): {:.2f} ".format(algbw)
  99. log_str += " | busbw (Gbps): {:.2f} ".format(busbw)
  100. log_dist(log_str, [0])
  101. # Print summary at end of iteration, epoch, or training
  102. def log_all(self):
  103. from deepspeed.utils.timer import trim_mean
  104. print(
  105. f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}"
  106. )
  107. for record_name in self.comms_dict.keys():
  108. print(record_name)
  109. for msg_size, vals in sorted(self.comms_dict[record_name].items()):
  110. # vals[0] is the count for each msg size
  111. count = vals[0]
  112. # vals[1] is a list of latency records for each msg size
  113. total_lat = sum(vals[1])
  114. # vals[2] and vals[3] are the lists of algbw and busbw, respectively
  115. # Get rid of outliers when we print
  116. avg_lat = trim_mean(vals[1], 0.1)
  117. avg_algbw = trim_mean(vals[2], 0.1)
  118. avg_busbw = trim_mean(vals[3], 0.1)
  119. print(
  120. f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}"
  121. )