launch.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. DeepSpeed launcher, this is similar to torch's distributed.launch but supports
  6. additional features such as arbitrary gpu exclusion.
  7. deepspeed.launcher.launch is intended to be run on a single worker node and
  8. will spawn several worker sub-processes depending on how many devices/ranks
  9. are on the worker.
  10. """
  11. import sys
  12. import subprocess
  13. import os
  14. import json
  15. import base64
  16. import time
  17. import signal
  18. import psutil
  19. from collections import defaultdict
  20. from typing import Dict
  21. from argparse import ArgumentParser, REMAINDER
  22. from deepspeed.accelerator import get_accelerator
  23. from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
  24. from ..nebula.constants import DLTS_POD_ENV_PATH
  25. from ..utils import logger, get_numactl_cmd
  26. from ..elasticity import is_torch_elastic_compatible
  27. from .constants import ELASTIC_TRAINING_ID_DEFAULT
  28. PID_FILE_BASEPATH = "/tmp"
  29. def parse_args():
  30. parser = ArgumentParser(description="DeepSpeed distributed training launch"
  31. " utility that creates multiple distributed"
  32. " processes on a single node")
  33. # Optional arguments for the launch helper
  34. parser.add_argument("--node_rank",
  35. type=int,
  36. default=0,
  37. help="The rank of the node for multi-node distributed "
  38. "training")
  39. parser.add_argument("--master_addr",
  40. default="127.0.0.1",
  41. type=str,
  42. help="Master node (rank 0)'s address, should be either"
  43. " the IP address or the hostname of node 0, for"
  44. " single node multi-proc training, the"
  45. " --master_addr can simply be 127.0.0.1")
  46. parser.add_argument("--master_port",
  47. default=TORCH_DISTRIBUTED_DEFAULT_PORT,
  48. type=int,
  49. help="Master node (rank 0)'s free port that needs to "
  50. "be used for communication during distributed "
  51. "training")
  52. parser.add_argument("--world_info", default="None", type=str, help="world info base64 encoded dictionary")
  53. parser.add_argument("--module",
  54. action="store_true",
  55. help="Change each process to interpret the launch "
  56. "script as a Python module, executing with the same "
  57. "behavior as 'python -m'.")
  58. parser.add_argument("--no_python",
  59. action="store_true",
  60. help="Skip prepending the training script with "
  61. "'python' - just execute it directly.")
  62. parser.add_argument("--enable_elastic_training", action="store_true", help="Enable elastic training support.")
  63. parser.add_argument("--min_elastic_nodes", type=int, default=-1, help="Min number of nodes in elastic training.")
  64. parser.add_argument("--max_elastic_nodes", type=int, default=-1, help="Max number of nodes in elastic training.")
  65. parser.add_argument("--no_local_rank",
  66. action="store_true",
  67. help="Do not pass local_rank as an argument when calling "
  68. "the user's training script.")
  69. parser.add_argument("--save_pid",
  70. type=int,
  71. default=0,
  72. help="main launching process pid, for internal pid tracking")
  73. parser.add_argument("--enable_each_rank_log",
  74. default="None",
  75. type=str,
  76. help="redirect the stdout and stderr from each rank into different log files")
  77. parser.add_argument("--bind_cores_to_rank",
  78. action="store_true",
  79. help="Bind each rank to different cores of the host. "
  80. "This improves host efficiency especially for CPU backend")
  81. parser.add_argument("--bind_core_list",
  82. type=str,
  83. default=None,
  84. help="List of cores to bind to with comma separated list of "
  85. "numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not "
  86. "specified, all cores on system would be used rank binding")
  87. # positional
  88. parser.add_argument("training_script",
  89. type=str,
  90. help="The full path to the single GPU training "
  91. "program/script to be launched in parallel, "
  92. "followed by all the arguments for the "
  93. "training script")
  94. # rest from the training program
  95. parser.add_argument('training_script_args', nargs=REMAINDER)
  96. return parser.parse_args()
  97. # Adapted from https://psutil.readthedocs.io/en/latest/#kill-process-tree
  98. def terminate_process_tree(pid):
  99. process = psutil.Process(pid)
  100. children = process.children(recursive=True)
  101. children.append(process)
  102. for child in children:
  103. try:
  104. child.terminate()
  105. except psutil.NoSuchProcess:
  106. pass
  107. gone, alive = psutil.wait_procs(children, timeout=30)
  108. for p in alive:
  109. p.kill()
  110. def main():
  111. args = parse_args()
  112. current_env = os.environ.copy()
  113. for k in current_env.keys():
  114. if "NCCL" in k:
  115. logger.info(f"{args.node_rank} {k}={current_env[k]}")
  116. if args.world_info == "None":
  117. raise ValueError("world_info can not be None")
  118. world_info = base64.urlsafe_b64decode(args.world_info)
  119. world_info = json.loads(world_info)
  120. logger.info(f"WORLD INFO DICT: {world_info}")
  121. node_list = list(world_info.keys())
  122. args.nnodes = len(node_list)
  123. local_node = node_list[args.node_rank]
  124. local_accelerator_ids = world_info[local_node]
  125. num_local_procs = len(local_accelerator_ids)
  126. logger.info(f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}")
  127. global_rank_mapping = defaultdict(list)
  128. curr_global_rank = 0
  129. dist_world_size = 0
  130. for node_id in node_list:
  131. gids = world_info[node_id]
  132. dist_world_size += len(gids)
  133. for gid in gids:
  134. global_rank_mapping[node_id].append(curr_global_rank)
  135. curr_global_rank += 1
  136. logger.info(f"global_rank_mapping={global_rank_mapping}")
  137. logger.info(f"dist_world_size={dist_world_size}")
  138. get_accelerator().set_visible_devices_envs(current_env, local_accelerator_ids)
  139. for env in get_accelerator().visible_devices_envs():
  140. logger.info(f"Setting {env}={current_env[env]}")
  141. # set PyTorch distributed related environmental variables
  142. current_env["MASTER_ADDR"] = args.master_addr
  143. current_env["MASTER_PORT"] = str(args.master_port)
  144. current_env["WORLD_SIZE"] = str(dist_world_size)
  145. current_env["CROSS_RANK"] = str(args.node_rank)
  146. current_env["CROSS_SIZE"] = str(args.nnodes)
  147. current_env["LOCAL_SIZE"] = str(num_local_procs)
  148. if args.save_pid:
  149. print(f"launcher pid: {os.getpid()}")
  150. pid_file = None
  151. if args.save_pid:
  152. launcher_pid = os.getpid()
  153. pid_file = os.path.join(PID_FILE_BASEPATH, f"{args.save_pid}.deepspeed")
  154. assert not os.path.isfile(pid_file), "pid file exists but shouldn't"
  155. with open(pid_file, 'w') as fd:
  156. fd.write(f"{launcher_pid}")
  157. if not is_torch_elastic_compatible():
  158. if args.enable_elastic_training:
  159. logger.info(f"Disabling elastic training support as \
  160. PyTorch version should be greater than 1.11.x")
  161. args.enable_elastic_training = False
  162. if os.path.exists(DLTS_POD_ENV_PATH):
  163. with open(DLTS_POD_ENV_PATH) as file:
  164. lines = file.readlines()
  165. lines = [line.rstrip() for line in lines]
  166. for line in lines:
  167. if line.startswith('export FC_TASKROLE_NAME') or line.startswith('export FC_TASK_INDEX'):
  168. key_val = line.split()[1]
  169. key, val = key_val.split('=')
  170. current_env[key] = val
  171. processes = []
  172. cmd = []
  173. if not args.enable_elastic_training:
  174. if args.enable_each_rank_log != "None":
  175. # prepare the log path and the file name prefix
  176. if os.path.isfile(args.enable_each_rank_log):
  177. raise ValueError(f"{args.enable_each_rank_log} should not be a file, it should be a directory.")
  178. if not os.path.exists(args.enable_each_rank_log):
  179. try:
  180. os.makedirs(args.enable_each_rank_log)
  181. except Exception as e:
  182. print(e)
  183. raise ValueError(f"unable to create directory {args.enable_each_rank_log} for each rank log.")
  184. log_name_prefix = time.strftime("%Y%m%d%H%M%S", time.localtime())
  185. for local_proc in range(0, num_local_procs):
  186. # each process's rank
  187. dist_rank = global_rank_mapping[local_node][local_proc]
  188. local_rank = dist_rank % num_local_procs
  189. current_env["RANK"] = str(dist_rank)
  190. current_env["LOCAL_RANK"] = str(local_rank)
  191. # spawn the processes
  192. cmd = []
  193. if args.bind_cores_to_rank:
  194. cores_per_rank, numactl_cmd = get_numactl_cmd(args.bind_core_list, num_local_procs, local_rank)
  195. current_env["OMP_NUM_THREADS"] = f"{cores_per_rank}"
  196. cmd = cmd + numactl_cmd
  197. if not args.no_python:
  198. cmd.append(sys.executable)
  199. cmd.append("-u")
  200. if args.module:
  201. cmd.append("-m")
  202. else:
  203. if args.module:
  204. raise ValueError("Don't use both the '--no_python' flag"
  205. " and the '--module' flag at the same time.")
  206. cmd.append(args.training_script)
  207. # A user may not want to pass local_rank as a keyword arg so we make this optional.
  208. if not args.no_local_rank:
  209. cmd.append(f"--local_rank={local_rank}")
  210. cmd += args.training_script_args
  211. if args.enable_each_rank_log != "None":
  212. log_file = os.path.join(args.enable_each_rank_log, f"{log_name_prefix}_rank{dist_rank}.log")
  213. log_fd = open(log_file, 'w')
  214. process = subprocess.Popen(cmd, env=current_env, stdout=log_fd, stderr=log_fd)
  215. else:
  216. process = subprocess.Popen(cmd, env=current_env)
  217. # logs the command from processes
  218. logger.info(f"process {process.pid} spawned with command: {cmd}")
  219. processes.append(process)
  220. else:
  221. from ..elasticity import DSElasticAgent
  222. from torch.distributed.elastic.rendezvous import RendezvousParameters
  223. from torch.distributed.elastic.agent.server.api import WorkerSpec
  224. import torch.distributed.elastic.rendezvous.registry as rdzv_registry
  225. from torch.distributed.elastic.multiprocessing import Std
  226. if args.min_elastic_nodes == -1:
  227. args.min_elastic_nodes = 1
  228. if args.max_elastic_nodes == -1:
  229. args.max_elastic_nodes = args.nnodes
  230. assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0, "Max and Min nodes should be positive"
  231. current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  232. # Get config and arguments
  233. cmd = []
  234. if not args.no_python:
  235. cmd = [sys.executable, "-u"]
  236. if args.module:
  237. cmd.append("-m")
  238. else:
  239. if args.module:
  240. raise ValueError("Don't use both the '--no_python' flag"
  241. " and the '--module' flag at the same time.")
  242. cmd.append(args.training_script)
  243. cmd += args.training_script_args
  244. cmd_args = cmd[1:]
  245. rdzv_configs: Dict[str, str] = {'timeout': 100}
  246. run_id = os.environ.get("ELASTIC_RUN_ID", ELASTIC_TRAINING_ID_DEFAULT)
  247. # Creating config for rendezvous class
  248. rdzv_parameters = RendezvousParameters(backend='c10d',
  249. endpoint=args.master_addr + ":" + str(args.master_port),
  250. run_id=run_id,
  251. min_nodes=args.min_elastic_nodes,
  252. max_nodes=args.max_elastic_nodes,
  253. **rdzv_configs)
  254. spec = WorkerSpec(
  255. role='trainer',
  256. local_world_size=num_local_procs,
  257. entrypoint=cmd[0],
  258. args=cmd[1:],
  259. rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
  260. max_restarts=100,
  261. monitor_interval=5,
  262. redirects=Std.from_str("0"),
  263. tee=Std.from_str("0"),
  264. master_addr=None,
  265. master_port=None,
  266. )
  267. agent = DSElasticAgent(spec, current_env)
  268. agent.run()
  269. sig_names = {2: "SIGINT", 15: "SIGTERM"}
  270. last_return_code = None
  271. def sigkill_handler(signum, frame):
  272. for process in processes:
  273. logger.info(f"Killing subprocess {process.pid}")
  274. try:
  275. terminate_process_tree(process.pid)
  276. except Exception:
  277. pass
  278. if last_return_code is not None:
  279. logger.error(f"{cmd} exits with return code = {last_return_code}")
  280. sys.exit(last_return_code)
  281. if signum in sig_names:
  282. logger.info(f"Main process received {sig_names[signum]}, exiting")
  283. if args.save_pid:
  284. if os.path.isfile(pid_file):
  285. os.remove(pid_file)
  286. sys.exit(1)
  287. # pass SIGINT/SIGTERM to children if the parent is being terminated
  288. signal.signal(signal.SIGINT, sigkill_handler)
  289. signal.signal(signal.SIGTERM, sigkill_handler)
  290. alive_processes = set(processes)
  291. while len(alive_processes):
  292. finished_processes = []
  293. for process in alive_processes:
  294. if process.poll() is None:
  295. # the process is still running
  296. continue
  297. else:
  298. if process.returncode != 0:
  299. last_return_code = process.returncode # for sigkill_handler
  300. sigkill_handler(signal.SIGTERM, None) # not coming back
  301. else:
  302. # exited cleanly
  303. logger.info(f"Process {process.pid} exits successfully.")
  304. finished_processes.append(process)
  305. alive_processes = set(alive_processes) - set(finished_processes)
  306. time.sleep(1)
  307. if __name__ == "__main__":
  308. main()