utils.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import torch
  2. import os
  3. import math
  4. import argparse
  5. from benchmarks.communication.constants import *
  6. global dist
  7. def init_torch_distributed(backend):
  8. global dist
  9. import torch.distributed as dist
  10. torch.distributed.init_process_group(backend)
  11. local_rank = int(os.environ['LOCAL_RANK'])
  12. torch.cuda.set_device(local_rank)
  13. def init_deepspeed_comm(backend):
  14. global dist
  15. import deepspeed
  16. import deepspeed.comm as dist
  17. deepspeed.init_distributed(dist_backend=backend)
  18. local_rank = int(os.environ['LOCAL_RANK'])
  19. torch.cuda.set_device(local_rank)
  20. def init_processes(local_rank, args):
  21. if args.dist == 'deepspeed':
  22. init_deepspeed_comm(args.backend)
  23. elif args.dist == 'torch':
  24. init_torch_distributed(args.backend)
  25. else:
  26. print_rank_0(f"distributed framework {args.dist} not supported")
  27. exit(0)
  28. def print_rank_0(message):
  29. if dist.get_rank() == 0:
  30. print(message)
  31. def print_header(args, comm_op):
  32. if comm_op == 'pt2pt':
  33. world_size = 2
  34. else:
  35. world_size = dist.get_world_size()
  36. tput = f'Throughput ({args.bw_unit})'
  37. busbw = f'BusBW ({args.bw_unit})'
  38. header = f"\n---- Performance of {comm_op} on {world_size} devices ---------------------------------------------------------\n"
  39. duration_str = 'Duration'
  40. if args.raw:
  41. duration_str += ' (us)'
  42. header += f"{'Size (Bytes)':20s} {'Description':25s} {duration_str:20s} {tput:20s} {busbw:20s}\n"
  43. header += "----------------------------------------------------------------------------------------------------"
  44. print_rank_0(header)
  45. def get_bw(comm_op, size, duration, args):
  46. n = dist.get_world_size()
  47. tput = 0
  48. busbw = 0
  49. if comm_op == "all_to_all":
  50. tput = (size / duration)
  51. busbw = (size / duration) * ((n - 1) / n)
  52. elif comm_op == "all_gather":
  53. size *= n
  54. tput = (size / duration)
  55. busbw = (size / duration) * ((n - 1) / n)
  56. elif comm_op == "all_reduce":
  57. tput = (size * 2 / duration)
  58. busbw = (size / duration) * (2 * (n - 1) / n)
  59. elif comm_op == "pt2pt" or comm_op == "broadcast":
  60. tput = (size / duration)
  61. busbw = tput
  62. else:
  63. print_rank_0("wrong comm_op specified")
  64. exit(0)
  65. if args.bw_unit == 'Gbps':
  66. tput *= 8
  67. busbw *= 8
  68. return tput, busbw
  69. def get_metric_strings(args, tput, busbw, duration):
  70. duration_ms = duration * 1e3
  71. duration_us = duration * 1e6
  72. tput = f'{tput / 1e9:.3f}'
  73. busbw = f'{busbw /1e9:.3f}'
  74. if duration_us < 1e3 or args.raw:
  75. duration = f'{duration_us:.3f}'
  76. if not args.raw:
  77. duration += ' us'
  78. else:
  79. duration = f'{duration_ms:.3f} ms'
  80. return tput, busbw, duration
  81. def sync_all():
  82. torch.cuda.synchronize()
  83. dist.barrier()
  84. def max_numel(comm_op, dtype, mem_factor, local_rank, args):
  85. dtype_size = _element_size(dtype)
  86. max_memory_per_gpu = torch.cuda.get_device_properties(
  87. local_rank).total_memory * mem_factor
  88. if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast':
  89. elements_per_gpu = int(max_memory_per_gpu // dtype_size)
  90. elif comm_op == 'all_gather':
  91. # all_gather performance is lower for non-powers of two, and the output buffer size scales with world size
  92. # Therefore, divide by world size and round down to nearest power of 2
  93. elements_per_gpu = int(max_memory_per_gpu // dtype_size // dist.get_world_size())
  94. elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
  95. elif comm_op == 'all_to_all':
  96. # Number of elements must be divisible by world_size
  97. # all_to_all performance is lower for non-powers of two. Round down like all_gather.
  98. elements_per_gpu = int(max_memory_per_gpu // dtype_size)
  99. elements_per_gpu = int(dist.get_world_size() *
  100. round(elements_per_gpu / dist.get_world_size()))
  101. elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
  102. else:
  103. print(f"This communication operation: {comm_op} is not supported yet")
  104. exit(0)
  105. return elements_per_gpu
  106. # Helper function to pretty-print message sizes
  107. def convert_size(size_bytes):
  108. if size_bytes == 0:
  109. return "0B"
  110. size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
  111. i = int(math.floor(math.log(size_bytes, 1024)))
  112. p = math.pow(1024, i)
  113. s = round(size_bytes / p, 2)
  114. return "%s %s" % (s, size_name[i])
  115. # Copied from torch. Need to add the func here for old torch compatibility.
  116. def _element_size(dtype):
  117. """
  118. Returns the element size for a dtype, in bytes
  119. """
  120. if not isinstance(dtype, torch.dtype):
  121. raise RuntimeError(f'expected torch.dtype, but got {type(dtype)}')
  122. if dtype.is_complex:
  123. return torch.finfo(dtype).bits >> 2
  124. elif dtype.is_floating_point:
  125. return torch.finfo(dtype).bits >> 3
  126. elif dtype == torch.bool:
  127. # NOTE: torch.bool is not supported in torch.iinfo()
  128. return 1
  129. else:
  130. return torch.iinfo(dtype).bits >> 3
  131. def benchmark_parser():
  132. parser = argparse.ArgumentParser()
  133. parser.add_argument("--local_rank", type=int)
  134. parser.add_argument("--trials",
  135. type=int,
  136. default=DEFAULT_TRIALS,
  137. help='Number of timed iterations')
  138. parser.add_argument("--warmups",
  139. type=int,
  140. default=DEFAULT_WARMUPS,
  141. help='Number of warmup (non-timed) iterations')
  142. parser.add_argument("--maxsize",
  143. type=int,
  144. default=24,
  145. help='Max message size as a power of 2')
  146. parser.add_argument("--async-op",
  147. action="store_true",
  148. help='Enables non-blocking communication')
  149. parser.add_argument("--bw-unit",
  150. type=str,
  151. default=DEFAULT_UNIT,
  152. choices=['Gbps',
  153. 'GBps'])
  154. parser.add_argument("--backend",
  155. type=str,
  156. default=DEFAULT_BACKEND,
  157. choices=['nccl'],
  158. help='Communication library to use')
  159. parser.add_argument("--dist",
  160. type=str,
  161. default=DEFAULT_DIST,
  162. choices=['deepspeed',
  163. 'torch'],
  164. help='Distributed DL framework to use')
  165. parser.add_argument("--scan",
  166. action="store_true",
  167. help='Enables scanning all message sizes')
  168. parser.add_argument("--raw",
  169. action="store_true",
  170. help='Print the message size and latency without units')
  171. parser.add_argument("--all-reduce", action="store_true", help='Run all_reduce')
  172. parser.add_argument("--all-gather", action="store_true", help='Run all_gather')
  173. parser.add_argument("--all-to-all", action="store_true", help='Run all_to_all')
  174. parser.add_argument("--pt2pt", action="store_true", help='Run pt2pt')
  175. parser.add_argument("--broadcast", action="store_true", help='Run broadcast')
  176. parser.add_argument("--dtype",
  177. type=str,
  178. default=DEFAULT_TYPE,
  179. help='PyTorch tensor dtype')
  180. parser.add_argument(
  181. "--mem-factor",
  182. type=float,
  183. default=.4,
  184. help='Proportion of max available GPU memory to use for single-size evals')
  185. parser.add_argument("--debug",
  186. action="store_true",
  187. help='Enables all_to_all debug prints')
  188. return parser