comms_logging.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import math
  5. from deepspeed.utils import log_dist
  6. def get_caller_func(frame=3):
  7. import sys
  8. return sys._getframe(frame).f_code.co_name
  9. def print_rank_0(message):
  10. import deepspeed.comm as dist
  11. if dist.get_rank() == 0:
  12. print(message)
  13. # Helper function to pretty-print message sizes
  14. def convert_size(size_bytes):
  15. if size_bytes == 0:
  16. return "0B"
  17. size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
  18. i = int(math.floor(math.log(size_bytes, 1024)))
  19. p = math.pow(1024, i)
  20. s = round(size_bytes / p, 2)
  21. return "%s %s" % (s, size_name[i])
  22. # Helper function to calculate algbw and busbw.
  23. # See https://gist.github.com/jeffra/b5e80466b4c86be00ea3b6f130fb7a36 and https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md
  24. def calc_bw_log(comm_op, size, duration):
  25. import deepspeed.comm as dist
  26. n = dist.get_world_size()
  27. tput = 0
  28. busbw = 0
  29. if comm_op == "all_to_all_single":
  30. tput = (size / duration)
  31. busbw = (size / duration) * ((n - 1) / n)
  32. elif comm_op == "all_gather" or comm_op == "all_gather_into_tensor" or comm_op == "reduce_scatter" or comm_op == "reduce_scatter_tensor":
  33. size *= n
  34. tput = (size / duration)
  35. busbw = (size / duration) * ((n - 1) / n)
  36. elif comm_op == "all_reduce" or comm_op == "all_reduce_coalesced" or comm_op == "inference_all_reduce":
  37. tput = (size * 2 / duration)
  38. busbw = (size / duration) * (2 * (n - 1) / n)
  39. elif comm_op == "send" or comm_op == "recv" or comm_op == "isend" or comm_op == "irecv" or comm_op == "broadcast" or comm_op == "reduce" or comm_op == "gather" or comm_op == "scatter" or comm_op == "barrier":
  40. tput = (size / duration)
  41. busbw = tput
  42. else:
  43. print_rank_0("wrong comm_op specified") # noqa: F821
  44. exit(0)
  45. # convert to Gbps
  46. tput *= 8
  47. busbw *= 8
  48. tput /= 1e6
  49. busbw /= 1e6
  50. return tput, busbw
  51. class CommsLogger:
  52. def __init__(self):
  53. from deepspeed.comm.constants import COMMS_LOGGER_VERBOSE_DEFAULT, COMMS_LOGGER_DEBUG_DEFAULT, COMMS_LOGGER_PROF_OPS_DEFAULT, COMMS_LOGGER_PROF_ALL_DEFAULT, COMMS_LOGGER_ENABLED_DEFAULT
  54. self.comms_dict = {}
  55. self.verbose = COMMS_LOGGER_VERBOSE_DEFAULT
  56. self.debug = COMMS_LOGGER_DEBUG_DEFAULT
  57. self.prof_ops = COMMS_LOGGER_PROF_OPS_DEFAULT
  58. self.prof_all = COMMS_LOGGER_PROF_ALL_DEFAULT
  59. self.enabled = COMMS_LOGGER_ENABLED_DEFAULT
  60. def configure(self, comms_config):
  61. self.enabled = comms_config.comms_logger_enabled
  62. if self.enabled:
  63. self.verbose = comms_config.comms_logger.verbose
  64. self.debug = comms_config.comms_logger.debug
  65. self.prof_ops = comms_config.comms_logger.prof_ops
  66. self.prof_all = comms_config.comms_logger.prof_all
  67. # There are three settings for the op profiler:
  68. # - Global profiling (profile all comms)
  69. # - Op-type profiling (e.g. profile all all_reduce comms)
  70. # - Op profiling (e.g. profile a specific all_reduce op)
  71. def start_profiling_comms(self):
  72. self.prof_all = True
  73. def stop_profiling_comms(self):
  74. self.prof_all = True
  75. # E.g. start_profiling_op('all_reduce')
  76. def start_profiling_op(self, op_name_list):
  77. self.prof_ops = list(set(self.prof_ops) | set(op_name_list))
  78. def stop_profiling_op(self, op_name_list):
  79. self.prof_ops = [op for op in self.prof_ops if op not in op_name_list]
  80. # Add log entry
  81. def append(self, raw_name, record_name, latency, msg_size):
  82. algbw, busbw = calc_bw_log(raw_name, msg_size, latency)
  83. if record_name in self.comms_dict.keys():
  84. # If this comm_op has already been logged with this message size, just add to existing record
  85. if msg_size in self.comms_dict[record_name].keys():
  86. self.comms_dict[record_name][msg_size][0] += 1
  87. self.comms_dict[record_name][msg_size][1].append(latency)
  88. self.comms_dict[record_name][msg_size][2].append(algbw)
  89. self.comms_dict[record_name][msg_size][3].append(busbw)
  90. # If this is a new message size for this comm_op, add new record under existing comm_op
  91. else:
  92. self.comms_dict[record_name][msg_size] = [1, [latency], [algbw], [busbw]]
  93. else:
  94. # Create entirely new record
  95. self.comms_dict[record_name] = {msg_size: [1, [latency], [algbw], [busbw]]}
  96. # If verbose, print every comm op
  97. # TODO: Add to tensorboard
  98. if self.verbose:
  99. log_str = f"comm op: {record_name} | time (ms): {latency:.2f} | msg size: {convert_size(msg_size)} | algbw (Gbps): {algbw:.2f} | busbw (Gbps): {busbw:.2f}"
  100. log_dist(log_str, [0])
  101. # Print summary at end of iteration, epoch, or training
  102. def log_all(self, print_log=True, show_straggler=False):
  103. import torch
  104. from deepspeed.utils.timer import trim_mean
  105. import deepspeed.comm as dist
  106. from deepspeed.comm.reduce_op import ReduceOp
  107. if print_log:
  108. print(
  109. f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}"
  110. )
  111. for record_name in self.comms_dict.keys():
  112. if print_log:
  113. print(record_name)
  114. for msg_size, vals in sorted(self.comms_dict[record_name].items()):
  115. # vals[0] is the count for each msg size
  116. count = vals[0]
  117. # vals[1] is a list of latency records for each msg size
  118. total_lat = sum(vals[1])
  119. # vals[2] and vals[3] are the lists of algbw and busbw, respectively
  120. # Get rid of outliers when we print
  121. avg_lat = trim_mean(vals[1], 0.1)
  122. avg_algbw = trim_mean(vals[2], 0.1)
  123. avg_busbw = trim_mean(vals[3], 0.1)
  124. if print_log:
  125. print(
  126. f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}"
  127. )
  128. if show_straggler:
  129. if print_log:
  130. print("_______________________________")
  131. print("Breakdown with straggler effect")
  132. print("-------------------------------")
  133. print(
  134. f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total comm lat(ms)': <20}{'Total straggler(ms)': <20}{'Avg comm lat(ms)': <20}{'Avg straggler(ms)': <20}"
  135. )
  136. for record_name in self.comms_dict.keys():
  137. if print_log:
  138. print(record_name)
  139. for msg_size, vals in sorted(self.comms_dict[record_name].items()):
  140. # vals[0] is the count for each msg size
  141. count = vals[0]
  142. # vals[1] is a list of latency records for each msg size
  143. lats = torch.tensor(vals[1])
  144. min_lats = torch.tensor(vals[1])
  145. dist.all_reduce(min_lats, op=ReduceOp.MIN)
  146. total_lat = min_lats.sum().item()
  147. total_straggler = (lats - min_lats).sum().item()
  148. avg_lat = trim_mean(min_lats.tolist(), 0.1)
  149. avg_straggler = trim_mean((lats - min_lats).tolist(), 0.1)
  150. if print_log:
  151. print(
  152. f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{total_straggler: <20.2f}{avg_lat: <20.2f}{avg_straggler: <20.2f}"
  153. )