launch.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright 2020 The Microsoft DeepSpeed Team
  2. """
  3. DeepSpeed launcher, this is similar to torch.distributed.launch but supports
  4. additional features such as arbitrary gpu exclusion.
  5. deepspeed.launcher.launch is intended to be run on a single worker node and
  6. will spawn several worker sub-processes depending on how many devices/ranks
  7. are on the worker.
  8. """
  9. import sys
  10. import subprocess
  11. import os
  12. import json
  13. import base64
  14. import time
  15. import signal
  16. from collections import defaultdict
  17. from argparse import ArgumentParser, REMAINDER
  18. from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
  19. from ..utils import logger
  20. def parse_args():
  21. parser = ArgumentParser(description="DeepSpeed distributed training launch"
  22. " utility that creates multiple distributed"
  23. " processes on a single node")
  24. # Optional arguments for the launch helper
  25. parser.add_argument("--node_rank",
  26. type=int,
  27. default=0,
  28. help="The rank of the node for multi-node distributed "
  29. "training")
  30. parser.add_argument("--master_addr",
  31. default="127.0.0.1",
  32. type=str,
  33. help="Master node (rank 0)'s address, should be either"
  34. " the IP address or the hostname of node 0, for"
  35. " single node multi-proc training, the"
  36. " --master_addr can simply be 127.0.0.1")
  37. parser.add_argument("--master_port",
  38. default=TORCH_DISTRIBUTED_DEFAULT_PORT,
  39. type=int,
  40. help="Master node (rank 0)'s free port that needs to "
  41. "be used for communication during distributed "
  42. "training")
  43. parser.add_argument("--world_info",
  44. default="None",
  45. type=str,
  46. help="world info base64 encoded dictionary")
  47. # positional
  48. parser.add_argument("training_script",
  49. type=str,
  50. help="The full path to the single GPU training "
  51. "program/script to be launched in parallel, "
  52. "followed by all the arguments for the "
  53. "training script")
  54. # rest from the training program
  55. parser.add_argument('training_script_args', nargs=REMAINDER)
  56. return parser.parse_args()
  57. def main():
  58. args = parse_args()
  59. current_env = os.environ.copy()
  60. for k in current_env.keys():
  61. if "NCCL" in k:
  62. logger.info(f"{args.node_rank} {k}={current_env[k]}")
  63. if args.world_info == "None":
  64. raise ValueError("world_info can not be None")
  65. world_info = base64.urlsafe_b64decode(args.world_info)
  66. world_info = json.loads(world_info)
  67. logger.info(f"WORLD INFO DICT: {world_info}")
  68. node_list = list(world_info.keys())
  69. args.nnodes = len(node_list)
  70. local_node = node_list[args.node_rank]
  71. local_gpu_ids = world_info[local_node]
  72. num_local_procs = len(local_gpu_ids)
  73. logger.info(
  74. f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}"
  75. )
  76. global_rank_mapping = defaultdict(list)
  77. curr_global_rank = 0
  78. dist_world_size = 0
  79. for node_id in node_list:
  80. gids = world_info[node_id]
  81. dist_world_size += len(gids)
  82. for gid in gids:
  83. global_rank_mapping[node_id].append(curr_global_rank)
  84. curr_global_rank += 1
  85. logger.info(f"global_rank_mapping={global_rank_mapping}")
  86. logger.info(f"dist_world_size={dist_world_size}")
  87. current_env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, local_gpu_ids))
  88. logger.info(f"Setting CUDA_VISIBLE_DEVICES={current_env['CUDA_VISIBLE_DEVICES']}")
  89. # set PyTorch distributed related environmental variables
  90. current_env["MASTER_ADDR"] = args.master_addr
  91. current_env["MASTER_PORT"] = str(args.master_port)
  92. current_env["WORLD_SIZE"] = str(dist_world_size)
  93. processes = []
  94. cmd = []
  95. for local_rank in range(0, num_local_procs):
  96. # each process's rank
  97. dist_rank = global_rank_mapping[local_node][local_rank]
  98. current_env["RANK"] = str(dist_rank)
  99. current_env["LOCAL_RANK"] = str(local_rank)
  100. # spawn the processes
  101. cmd = [sys.executable,
  102. "-u",
  103. args.training_script,
  104. f"--local_rank={local_rank}"] + args.training_script_args
  105. process = subprocess.Popen(cmd, env=current_env)
  106. processes.append(process)
  107. sig_names = {2: "SIGINT", 15: "SIGTERM"}
  108. last_return_code = None
  109. def sigkill_handler(signum, frame):
  110. for process in processes:
  111. logger.info(f"Killing subprocess {process.pid}")
  112. try:
  113. process.kill()
  114. except Exception:
  115. pass
  116. if last_return_code is not None:
  117. logger.error(f"{cmd} exits with return code = {last_return_code}")
  118. sys.exit(last_return_code)
  119. if signum in sig_names:
  120. logger.info(f"Main process received {sig_names[signum]}, exiting")
  121. sys.exit(1)
  122. # pass SIGINT/SIGTERM to children if the parent is being terminated
  123. signal.signal(signal.SIGINT, sigkill_handler)
  124. signal.signal(signal.SIGTERM, sigkill_handler)
  125. alive_processes = set(processes)
  126. while len(alive_processes):
  127. finished_processes = []
  128. for process in alive_processes:
  129. if process.poll() is None:
  130. # the process is still running
  131. continue
  132. else:
  133. if process.returncode != 0:
  134. last_return_code = process.returncode # for sigkill_handler
  135. sigkill_handler(signal.SIGTERM, None) # not coming back
  136. else:
  137. # exited cleanly
  138. logger.info(f"Process {process.pid} exits successfully.")
  139. finished_processes.append(process)
  140. alive_processes = set(alive_processes) - set(finished_processes)
  141. time.sleep(1)
  142. if __name__ == "__main__":
  143. main()