runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # Copyright 2020 The Microsoft DeepSpeed Team
  2. """
  3. DeepSpeed runner is the main front-end to launching multi-worker
  4. training jobs with DeepSpeed. By default this uses pdsh to parallel
  5. ssh into multiple worker nodes and launch all the necessary processes
  6. per rank for training.
  7. """
  8. import os
  9. import sys
  10. import json
  11. import base64
  12. import argparse
  13. import subprocess
  14. import collections
  15. from copy import deepcopy
  16. import torch.cuda
  17. from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner
  18. from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER
  19. from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
  20. from ..utils import logger
  21. from ..autotuning import Autotuner
  22. DLTS_HOSTFILE = "/job/hostfile"
  23. EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", 'UCX']
  24. DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env"
  25. DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.']
  26. PDSH_MAX_FAN_OUT = 1024
  27. def parse_args(args=None):
  28. parser = argparse.ArgumentParser(
  29. description="DeepSpeed runner to help launch distributed "
  30. "multi-node/multi-gpu training jobs.")
  31. parser.add_argument("-H",
  32. "--hostfile",
  33. type=str,
  34. default=DLTS_HOSTFILE,
  35. help="Hostfile path (in MPI style) that defines the "
  36. "resource pool available to the job (e.g., "
  37. "worker-0 slots=4)")
  38. parser.add_argument("-i",
  39. "--include",
  40. type=str,
  41. default="",
  42. help='''Specify hardware resources to use during execution.
  43. String format is
  44. NODE_SPEC[@NODE_SPEC ...],
  45. where
  46. NODE_SPEC=NAME[:SLOT[,SLOT ...]].
  47. If :SLOT is omitted, include all slots on that host.
  48. Example: -i "worker-0@worker-1:0,2" will use all slots
  49. on worker-0 and slots [0, 2] on worker-1.
  50. ''')
  51. parser.add_argument("-e",
  52. "--exclude",
  53. type=str,
  54. default="",
  55. help='''Specify hardware resources to NOT use during execution.
  56. Mutually exclusive with --include. Resource formatting
  57. is the same as --include.
  58. Example: -e "worker-1:0" will use all available
  59. resources except slot 0 on worker-1.
  60. ''')
  61. parser.add_argument("--num_nodes",
  62. type=int,
  63. default=-1,
  64. help="Total number of worker nodes to run on, this will use "
  65. "the top N hosts from the given hostfile.")
  66. parser.add_argument("--num_gpus",
  67. type=int,
  68. default=-1,
  69. help="Max number of GPUs to use on each node, will use "
  70. "[0:N) GPU ids on each node.")
  71. parser.add_argument("--master_port",
  72. default=TORCH_DISTRIBUTED_DEFAULT_PORT,
  73. type=int,
  74. help="(optional) Port used by PyTorch distributed for "
  75. "communication during training.")
  76. parser.add_argument("--master_addr",
  77. default="",
  78. type=str,
  79. help="(optional) IP address of node 0, will be "
  80. "inferred via 'hostname -I' if not specified.")
  81. parser.add_argument("--launcher",
  82. default=PDSH_LAUNCHER,
  83. type=str,
  84. help="(optional) choose launcher backend for multi-node "
  85. "training. Options currently include PDSH, OpenMPI, MVAPICH.")
  86. parser.add_argument("--launcher_args",
  87. default="",
  88. type=str,
  89. help="(optional) pass launcher specific arguments as a "
  90. "single quoted argument.")
  91. parser.add_argument("--force_multi",
  92. action="store_true",
  93. help="Force multi-node launcher mode, helps in cases where user "
  94. "wants to launch on single remote node.")
  95. parser.add_argument(
  96. "--autotuning",
  97. default="",
  98. choices=["tune",
  99. "run"],
  100. type=str,
  101. help="Run DeepSpeed autotuner to discover optimal configuration parameters "
  102. "before running job.")
  103. parser.add_argument("user_script",
  104. type=str,
  105. help="User script to launch, followed by any required "
  106. "arguments.")
  107. parser.add_argument('user_args', nargs=argparse.REMAINDER)
  108. return parser.parse_args(args=args)
  109. def fetch_hostfile(hostfile_path):
  110. if not os.path.isfile(hostfile_path):
  111. logger.warning("Unable to find hostfile, will proceed with training "
  112. "with local resources only.")
  113. return None
  114. # e.g., worker-0 slots=16
  115. with open(hostfile_path, 'r') as fd:
  116. resource_pool = collections.OrderedDict()
  117. for line in fd.readlines():
  118. line = line.strip()
  119. if line == '':
  120. # skip empty lines
  121. continue
  122. try:
  123. hostname, slots = line.split()
  124. _, slot_count = slots.split("=")
  125. slot_count = int(slot_count)
  126. except ValueError as err:
  127. logger.error("Hostfile is not formatted correctly, unable to "
  128. "proceed with training.")
  129. raise err
  130. if hostname in resource_pool:
  131. logger.error("Hostfile contains duplicate hosts, unable to "
  132. "proceed with training.")
  133. raise ValueError(f"host {hostname} is already defined")
  134. resource_pool[hostname] = slot_count
  135. return resource_pool
  136. def parse_resource_filter(host_info, include_str="", exclude_str=""):
  137. '''Parse an inclusion or exclusion string and filter a hostfile dictionary.
  138. String format is NODE_SPEC[@NODE_SPEC ...], where
  139. NODE_SPEC = NAME[:SLOT[,SLOT ...]].
  140. If :SLOT is omitted, include/exclude all slots on that host.
  141. Examples:
  142. include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and
  143. slots [0, 2] on worker-1.
  144. exclude_str="worker-1:0" will use all available resources except
  145. slot 0 on worker-1.
  146. '''
  147. # Constants that define our syntax
  148. NODE_SEP = '@'
  149. SLOT_LIST_START = ':'
  150. SLOT_SEP = ','
  151. # Ensure include/exclude are mutually exclusive
  152. if (include_str != "") and (exclude_str != ""):
  153. raise ValueError('include_str and exclude_str are mutually exclusive.')
  154. # no-op
  155. if (include_str == "") and (exclude_str == ""):
  156. return host_info
  157. # Either build from scratch or remove items
  158. filtered_hosts = dict()
  159. if include_str:
  160. parse_str = include_str
  161. if exclude_str != "":
  162. filtered_hosts = deepcopy(host_info)
  163. parse_str = exclude_str
  164. # foreach node in the list
  165. for node_config in parse_str.split(NODE_SEP):
  166. # Node can either be alone or node:slot,slot,slot
  167. if SLOT_LIST_START in node_config:
  168. hostname, slots = node_config.split(SLOT_LIST_START)
  169. slots = [int(x) for x in slots.split(SLOT_SEP)]
  170. # sanity checks
  171. if hostname not in host_info:
  172. raise ValueError(f"Hostname '{hostname}' not found in hostfile")
  173. for slot in slots:
  174. if slot not in host_info[hostname]:
  175. raise ValueError(f"No slot '{slot}' specified on host '{hostname}'")
  176. # If include string, build the list from here
  177. if include_str:
  178. filtered_hosts[hostname] = slots
  179. elif exclude_str:
  180. for slot in slots:
  181. logger.info(f'removing {slot} from {hostname}')
  182. filtered_hosts[hostname].remove(slot)
  183. # User just specified the whole node
  184. else:
  185. hostname = node_config
  186. # sanity check hostname
  187. if hostname not in host_info:
  188. raise ValueError(f"Hostname '{hostname}' not found in hostfile")
  189. if include_str:
  190. filtered_hosts[hostname] = host_info[hostname]
  191. elif exclude_str:
  192. filtered_hosts[hostname] = []
  193. # Post-processing to remove duplicates and empty nodes
  194. del_keys = []
  195. for hostname in filtered_hosts:
  196. # Remove duplicates
  197. filtered_hosts[hostname] = list(set(filtered_hosts[hostname]))
  198. # Remove empty hosts
  199. if len(filtered_hosts[hostname]) == 0:
  200. del_keys.append(hostname)
  201. for name in del_keys:
  202. del filtered_hosts[name]
  203. # Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure
  204. # we map ranks to nodes correctly by maintaining host_info ordering.
  205. ordered_hosts = collections.OrderedDict()
  206. for host in host_info:
  207. if host in filtered_hosts:
  208. ordered_hosts[host] = filtered_hosts[host]
  209. return ordered_hosts
  210. def parse_inclusion_exclusion(resource_pool, inclusion, exclusion):
  211. active_resources = collections.OrderedDict()
  212. for hostname, slots in resource_pool.items():
  213. active_resources[hostname] = list(range(slots))
  214. return parse_resource_filter(active_resources,
  215. include_str=inclusion,
  216. exclude_str=exclusion)
  217. def encode_world_info(world_info):
  218. world_info_json = json.dumps(world_info).encode('utf-8')
  219. world_info_base64 = base64.urlsafe_b64encode(world_info_json).decode('utf-8')
  220. return world_info_base64
  221. def run_autotuning(args, active_resources):
  222. tuner = Autotuner(args, active_resources)
  223. logger.info("[Start] Running autotuning")
  224. tuner.tune()
  225. tuner.print_tuning_results()
  226. logger.info("[End] Running autotuning")
  227. if args.autotuning == "run":
  228. tuner.run_after_tuning()
  229. def main(args=None):
  230. args = parse_args(args)
  231. resource_pool = fetch_hostfile(args.hostfile)
  232. # respect CUDA_VISIBLE_DEVICES for a single node and no explicit resource filters
  233. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
  234. if not resource_pool and len(cuda_visible_devices):
  235. detected_str = f"Detected CUDA_VISIBLE_DEVICES={cuda_visible_devices}"
  236. if len(args.include) or len(
  237. args.exclude) or args.num_nodes > 1 or args.num_gpus > 0:
  238. print(
  239. f"{detected_str} but ignoring it because one or several of --include/--exclude/--num_gpus/--num_nodes cl args were used. If you want to use CUDA_VISIBLE_DEVICES don't pass any of these arguments to deepspeed."
  240. )
  241. else:
  242. args.include = f"localhost:{cuda_visible_devices}"
  243. print(f"{detected_str}: setting --include={args.include}")
  244. del os.environ["CUDA_VISIBLE_DEVICES"]
  245. if args.num_nodes >= 0 or args.num_gpus >= 0:
  246. if args.include != "" or args.exclude != "":
  247. raise ValueError("Cannot specify num_nodes/gpus with include/exclude")
  248. multi_node_exec = True
  249. if not resource_pool:
  250. resource_pool = {}
  251. device_count = torch.cuda.device_count()
  252. if device_count == 0:
  253. raise RuntimeError("Unable to proceed, no GPU resources available")
  254. resource_pool['localhost'] = device_count
  255. args.master_addr = "127.0.0.1"
  256. multi_node_exec = False
  257. if not multi_node_exec and args.num_nodes > 1:
  258. raise ValueError("Num nodes is >1 but no extra nodes available via hostfile")
  259. active_resources = parse_inclusion_exclusion(resource_pool,
  260. args.include,
  261. args.exclude)
  262. env = os.environ.copy()
  263. if not args.master_addr:
  264. first_host = list(active_resources.keys())[0]
  265. hostname_cmd = [f"ssh {first_host} hostname -I"]
  266. result = subprocess.check_output(hostname_cmd, shell=True)
  267. args.master_addr = result.decode('utf-8').split()[0]
  268. logger.info(f"Using IP address of {args.master_addr} for node {first_host}")
  269. if args.autotuning != "":
  270. run_autotuning(args, active_resources)
  271. return
  272. if args.num_nodes > 0:
  273. updated_active_resources = collections.OrderedDict()
  274. for count, hostname in enumerate(active_resources.keys()):
  275. if args.num_nodes == count:
  276. break
  277. updated_active_resources[hostname] = active_resources[hostname]
  278. active_resources = updated_active_resources
  279. if args.num_gpus > 0:
  280. updated_active_resources = collections.OrderedDict()
  281. for hostname in active_resources.keys():
  282. updated_active_resources[hostname] = list(range(args.num_gpus))
  283. active_resources = updated_active_resources
  284. # encode world info as base64 to make it easier to pass via command line
  285. world_info_base64 = encode_world_info(active_resources)
  286. multi_node_exec = args.force_multi or len(active_resources) > 1
  287. if not multi_node_exec:
  288. deepspeed_launch = [
  289. sys.executable,
  290. "-u",
  291. "-m",
  292. "deepspeed.launcher.launch",
  293. f"--world_info={world_info_base64}",
  294. f"--master_addr={args.master_addr}",
  295. f"--master_port={args.master_port}"
  296. ]
  297. cmd = deepspeed_launch + [args.user_script] + args.user_args
  298. else:
  299. args.launcher = args.launcher.lower()
  300. if args.launcher == PDSH_LAUNCHER:
  301. runner = PDSHRunner(args, world_info_base64)
  302. elif args.launcher == OPENMPI_LAUNCHER:
  303. runner = OpenMPIRunner(args, world_info_base64, resource_pool)
  304. elif args.launcher == MVAPICH_LAUNCHER:
  305. runner = MVAPICHRunner(args, world_info_base64, resource_pool)
  306. else:
  307. raise NotImplementedError(f"Unknown launcher {args.launcher}")
  308. if not runner.backend_exists():
  309. raise RuntimeError(f"launcher '{args.launcher}' not installed.")
  310. curr_path = os.path.abspath('.')
  311. if 'PYTHONPATH' in env:
  312. env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
  313. else:
  314. env['PYTHONPATH'] = curr_path
  315. exports = ""
  316. for var in env.keys():
  317. if any([var.startswith(name) for name in EXPORT_ENVS]):
  318. runner.add_export(var, env[var])
  319. for environ_path in DEEPSPEED_ENVIRONMENT_PATHS:
  320. environ_file = os.path.join(environ_path, DEEPSPEED_ENVIRONMENT_NAME)
  321. if os.path.isfile(environ_file):
  322. with open(environ_file, 'r') as fd:
  323. for var in fd.readlines():
  324. key, val = var.split('=')
  325. runner.add_export(key, val)
  326. cmd = runner.get_cmd(env, active_resources)
  327. logger.info(f"cmd = {' '.join(cmd)}")
  328. result = subprocess.Popen(cmd, env=env)
  329. result.wait()
  330. # In case of failure must propagate the error-condition back to the caller (usually shell). The
  331. # actual error and traceback should have been printed in the subprocess, so in order to avoid
  332. # unnecessary noise we just quietly exit here with the same code as the subprocess
  333. if result.returncode > 0:
  334. sys.exit(result.returncode)
  335. if __name__ == "__main__":
  336. main()