multinode_runner.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import sys
  6. import shutil
  7. import subprocess
  8. import warnings
  9. from shlex import split
  10. from abc import ABC, abstractmethod
  11. from deepspeed.accelerator import get_accelerator
  12. from ..utils import logger, get_numactl_cmd
  13. from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
  14. class MultiNodeRunner(ABC):
  15. def __init__(self, args, world_info_base64):
  16. self.args = args
  17. self.validate_args()
  18. self.user_arguments = self.parse_user_args()
  19. self.user_script = args.user_script
  20. self.world_info_base64 = world_info_base64
  21. self.exports = {}
  22. @abstractmethod
  23. def backend_exists(self):
  24. """Return whether the corresponding backend exists"""
  25. @abstractmethod
  26. def get_cmd(self, environment, active_resources):
  27. """Return the command to execute on node"""
  28. def add_export(self, key, var):
  29. self.exports[key.strip()] = var.strip()
  30. def parse_user_args(self):
  31. return self.args.user_args
  32. @property
  33. def name(self):
  34. """Return the name of the backend"""
  35. return self.__class__.__name__
  36. def validate_args(self):
  37. """Validate self.args"""
  38. class PDSHRunner(MultiNodeRunner):
  39. def __init__(self, args, world_info_base64):
  40. super().__init__(args, world_info_base64)
  41. def backend_exists(self):
  42. return shutil.which('pdsh')
  43. def parse_user_args(self):
  44. processed_args = []
  45. for arg in self.args.user_args:
  46. # With pdsh, if we are passing a string as an argument, it will get
  47. # split on whitespace. To avoid this and support strings that
  48. # contain '"', we do this extra processing step:
  49. if " " in arg:
  50. arg = '"{}"'.format(arg.replace('"', '\\"'))
  51. processed_args.append(arg)
  52. return processed_args
  53. @property
  54. def name(self):
  55. return "pdsh"
  56. def get_cmd(self, environment, active_resources):
  57. environment['PDSH_RCMD_TYPE'] = 'ssh'
  58. if self.args.ssh_port is not None: # only specify ssh port if it is specified
  59. environment["PDSH_SSH_ARGS_APPEND"] = f"{environment.get('PDSH_SSH_ARGS_APPEND', '')} \
  60. -p {self.args.ssh_port}"
  61. active_workers = ",".join(active_resources.keys())
  62. logger.info("Running on the following workers: %s" % active_workers)
  63. # PDSH flags for max node fan out and specific hosts to launch on
  64. # See https://linux.die.net/man/1/pdsh for flag details
  65. pdsh_cmd_args = ['pdsh', '-S', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers] + split(
  66. self.args.launcher_args)
  67. exports = ""
  68. for key, val in self.exports.items():
  69. exports += "export {}={}; ".format(key, val)
  70. # https://linux.die.net/man/1/pdsh
  71. # %n will be replaced by pdsh command
  72. deepspeed_launch = [
  73. exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "deepspeed.launcher.launch",
  74. f'--world_info={self.world_info_base64}', "--node_rank=%n", f"--master_addr={self.args.master_addr}",
  75. f"--master_port={self.args.master_port}"
  76. ]
  77. if self.args.no_python:
  78. deepspeed_launch.append("--no_python")
  79. if self.args.module:
  80. deepspeed_launch.append("--module")
  81. if self.args.no_local_rank:
  82. deepspeed_launch.append("--no_local_rank")
  83. if self.args.save_pid:
  84. deepspeed_launch += ["--save_pid", f"{os.getpid()}"]
  85. if self.args.elastic_training:
  86. deepspeed_launch.append("--enable_elastic_training")
  87. deepspeed_launch.append(f"--max_elastic_nodes={self.args.max_elastic_nodes}")
  88. deepspeed_launch.append(f"--min_elastic_nodes={self.args.min_elastic_nodes}")
  89. cmd_to_search = [i + "\\" for i in deepspeed_launch[2:6]]
  90. kill_command = pdsh_cmd_args + ["pkill -f ", " ".join(cmd_to_search)[:-2]]
  91. return pdsh_cmd_args + deepspeed_launch + [self.user_script] + self.user_arguments, kill_command, environment
  92. class OpenMPIRunner(MultiNodeRunner):
  93. def __init__(self, args, world_info_base64, resource_pool):
  94. super().__init__(args, world_info_base64)
  95. self.resource_pool = resource_pool
  96. self.add_export('UCX_TLS', 'tcp')
  97. def backend_exists(self):
  98. #TODO: if IB is available we should suggestion mvapich
  99. return shutil.which('ompi_info')
  100. @property
  101. def name(self):
  102. return "openmpi"
  103. def validate_args(self):
  104. super().validate_args()
  105. #TODO: Allow for include/exclude at node-level but not gpu-level
  106. if self.args.include != "" or self.args.exclude != "":
  107. raise ValueError(f"{self.name} backend does not support worker include/exclusion")
  108. if self.args.num_nodes != -1 or self.args.num_gpus != -1:
  109. raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
  110. def get_cmd(self, environment, active_resources):
  111. total_process_count = sum(self.resource_pool.values())
  112. mpirun_cmd = [
  113. 'mpirun',
  114. '-n',
  115. f'{total_process_count}',
  116. '-hostfile',
  117. f'{self.args.hostfile}',
  118. '--mca',
  119. 'btl',
  120. '^openib',
  121. '--mca',
  122. 'btl_tcp_if_include',
  123. 'eth0',
  124. ] + split(self.args.launcher_args)
  125. export_cmd = []
  126. for k, v in self.exports.items():
  127. export_cmd += ['-x', "{}={}".format(k, v)]
  128. python_exec = []
  129. if not self.args.no_python:
  130. python_exec = [sys.executable, "-u"]
  131. if self.args.module:
  132. python_exec.append("-m")
  133. return mpirun_cmd + export_cmd + python_exec + [self.user_script] + self.user_arguments
  134. class MPICHRunner(MultiNodeRunner):
  135. def __init__(self, args, world_info_base64, resource_pool):
  136. super().__init__(args, world_info_base64)
  137. self.resource_pool = resource_pool
  138. def backend_exists(self):
  139. #TODO: if IB is available we should suggestion mpich
  140. return shutil.which('mpirun') #mpich_info
  141. @property
  142. def name(self):
  143. return "mpich"
  144. def validate_args(self):
  145. super().validate_args()
  146. #TODO: Allow for include/exclude at node-level but not gpu-level
  147. if self.args.include != "" or self.args.exclude != "":
  148. raise ValueError(f"{self.name} backend does not support worker include/exclusion")
  149. if self.args.num_nodes != -1 or self.args.num_gpus != -1:
  150. raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
  151. def get_cmd(self, environment, active_resources):
  152. devices_per_node = self.resource_pool.values()
  153. total_process_count = sum(devices_per_node)
  154. process_per_node = list(devices_per_node)[0]
  155. if not all([n == process_per_node for n in devices_per_node]):
  156. raise ValueError("MPICH requires same number of devices per node")
  157. mpirun_cmd = [
  158. 'mpirun',
  159. '-n',
  160. f'{total_process_count}',
  161. '-ppn',
  162. f'{process_per_node}',
  163. ] + split(self.args.launcher_args)
  164. export_cmd = []
  165. for k, v in self.exports.items():
  166. export_cmd += ['-genv', "{}={}".format(k, v)]
  167. export_cmd += ['-genv', 'MASTER_ADDR', str(self.args.master_addr)]
  168. export_cmd += ['-genv', 'MASTER_PORT', str(self.args.master_port)]
  169. export_cmd += ['-genv', 'WORLD_SIZE', str(total_process_count)]
  170. export_cmd += ['-genv', 'LOCAL_SIZE', str(process_per_node)]
  171. export_cmd += ['-hosts']
  172. hosts = ""
  173. for i, host in enumerate(self.resource_pool.keys()):
  174. if i == 0:
  175. hosts = f"{host}"
  176. else:
  177. hosts += f",{host}"
  178. export_cmd += [hosts]
  179. helper_args = ["--launcher"] + [self.args.launcher]
  180. python_exec = []
  181. if not self.args.no_python:
  182. python_exec += [sys.executable, "-u"]
  183. if self.args.module:
  184. python_exec.append("-m")
  185. helper_args.append("--module")
  186. else:
  187. helper_args.append("--no_python")
  188. helper_cmd = str(os.path.dirname(os.path.realpath(__file__))) + '/launcher_helper.py'
  189. helper_cmd = [helper_cmd] + helper_args + [self.user_script] + self.user_arguments
  190. return mpirun_cmd + export_cmd + python_exec + helper_cmd
  191. class IMPIRunner(MultiNodeRunner):
  192. def __init__(self, args, world_info_base64, resource_pool):
  193. super().__init__(args, world_info_base64)
  194. self.resource_pool = resource_pool
  195. def backend_exists(self):
  196. #TODO: if IB is available we should suggestion mpich
  197. return shutil.which('mpirun') #mpich_info
  198. @property
  199. def name(self):
  200. return "impi"
  201. def validate_args(self):
  202. super().validate_args()
  203. #TODO: Allow for include/exclude at node-level but not gpu-level
  204. if self.args.include != "" or self.args.exclude != "":
  205. raise ValueError(f"{self.name} backend does not support worker include/exclusion")
  206. if self.args.num_nodes != -1 or self.args.num_gpus != -1:
  207. raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
  208. def get_cmd(self, environment, active_resources):
  209. devices_per_node = self.resource_pool.values()
  210. total_process_count = sum(devices_per_node)
  211. process_per_node = list(devices_per_node)[0]
  212. if not all([n == process_per_node for n in devices_per_node]):
  213. raise ValueError("Intel MPI requires same number of devices per node")
  214. mpirun_cmd = [
  215. 'mpirun',
  216. '-ppn',
  217. f'{process_per_node}',
  218. ] + split(self.args.launcher_args)
  219. export_cmd = []
  220. for k, v in self.exports.items():
  221. export_cmd += ['-genv', f'{k}', f'{v}']
  222. if self.args.bind_cores_to_rank:
  223. cores_per_rank, _ = get_numactl_cmd(self.args.bind_core_list, process_per_node, 0)
  224. export_cmd += ['-genv', 'OMP_NUM_THREADS', str(cores_per_rank)]
  225. export_cmd += ['-genv', 'MASTER_ADDR', str(self.args.master_addr)]
  226. export_cmd += ['-genv', 'MASTER_PORT', str(self.args.master_port)]
  227. export_cmd += ['-genv', 'WORLD_SIZE', str(total_process_count)]
  228. export_cmd += ['-genv', 'LOCAL_SIZE', str(process_per_node)]
  229. # turn off IMPI core binding, use deepspeed's own core binding
  230. export_cmd += ['-genv', 'I_MPI_PIN', '0']
  231. export_cmd += ['-hosts']
  232. hosts = ""
  233. for i, host in enumerate(self.resource_pool.keys()):
  234. if i == 0:
  235. hosts = f"{host}"
  236. else:
  237. hosts += f",{host}"
  238. export_cmd += [hosts]
  239. per_host_cmd = []
  240. for i in range(total_process_count):
  241. local_rank = i % process_per_node
  242. python_exec = []
  243. if self.args.bind_cores_to_rank:
  244. _, numactl_cmd = get_numactl_cmd(self.args.bind_core_list, process_per_node, local_rank)
  245. python_exec += numactl_cmd
  246. if not self.args.no_python:
  247. python_exec += [sys.executable, "-u"]
  248. if self.args.module:
  249. python_exec.append("-m")
  250. env_mapping = ['-env', 'RANK', str(i)]
  251. env_mapping += ['-env', 'LOCAL_RANK', str(local_rank)]
  252. if i == 0:
  253. per_host_cmd = ['-n', '1'] + env_mapping + python_exec + [self.user_script] + self.user_arguments
  254. else:
  255. per_host_cmd = per_host_cmd + [':', '-n', '1'] + env_mapping + python_exec + [self.user_script
  256. ] + self.user_arguments
  257. print(mpirun_cmd + export_cmd + per_host_cmd)
  258. return mpirun_cmd + export_cmd + per_host_cmd
  259. class SlurmRunner(MultiNodeRunner):
  260. def __init__(self, args, world_info_base64, resource_pool):
  261. super().__init__(args, world_info_base64)
  262. self.resource_pool = resource_pool
  263. def backend_exists(self):
  264. return shutil.which('sinfo')
  265. @property
  266. def name(self):
  267. return 'slurm'
  268. def get_cmd(self, environment, active_resources):
  269. assert not getattr(self.args, 'detect_nvlink_pairs',
  270. False), "slurm backend does not support remapping visible devices"
  271. total_process_count = sum(self.resource_pool.values())
  272. srun_cmd = [
  273. 'srun',
  274. '-n',
  275. f'{total_process_count}',
  276. ] + split(self.args.launcher_args)
  277. if getattr(self.args, 'slurm_comment', ''):
  278. srun_cmd += ['--comment', self.args.slurm_comment]
  279. if self.args.include != "":
  280. srun_cmd.append('--include')
  281. srun_cmd.append(f'{self.args.include}')
  282. if self.args.exclude != "":
  283. srun_cmd.append('--exclude')
  284. srun_cmd.append(f'{self.args.exclude}')
  285. if self.args.num_nodes > 0:
  286. srun_cmd.append('--nodes')
  287. srun_cmd.append(f'{self.args.num_nodes}')
  288. if self.args.num_gpus > 0:
  289. srun_cmd.append('--gpus')
  290. srun_cmd.append(f'{self.args.num_gpus}')
  291. exports = '--export=ALL'
  292. for key, val in self.exports.items():
  293. exports += f",{key}={val}"
  294. python_exec = [sys.executable, "-u"]
  295. command = srun_cmd + [exports] + python_exec + [self.user_script] + self.user_arguments
  296. return command
  297. class MVAPICHRunner(MultiNodeRunner):
  298. def __init__(self, args, world_info_base64, resource_pool):
  299. super().__init__(args, world_info_base64)
  300. self.resource_pool = resource_pool
  301. # Disable the CMA kernel module, not available on Ubuntu systems
  302. self.add_export('MV2_SMP_USE_CMA', '0')
  303. # If we fail this will output more verbose logging
  304. self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1')
  305. # Enabled cuda-aware communication
  306. if get_accelerator().device_name() == 'cuda':
  307. self.add_export('MV2_USE_CUDA', '1')
  308. # Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/
  309. self.add_export('MV2_SUPPORT_DL', '1')
  310. # Support MPI_THREAD_MULTIPLE
  311. self.add_export('MV2_ENABLE_AFFINITY', '0')
  312. # Performance tuning flags for allgather
  313. self.add_export('MV2_INTER_ALLGATHER_TUNING', '5')
  314. self.add_export('MV2_CUDA_USE_NAIVE', '0')
  315. def backend_exists(self):
  316. #TODO: if IB is available we should suggestion mvapich
  317. mpiname_exists = shutil.which('mpiname')
  318. exists = False
  319. if not mpiname_exists:
  320. warnings.warn("mpiname does not exist, mvapich is not installed properly")
  321. else:
  322. results = subprocess.check_output('mpiname', shell=True)
  323. mpiname_results = results.decode('utf-8').strip()
  324. if "MVAPICH2-GDR" in mpiname_results:
  325. exists = True
  326. else:
  327. warnings.warn(f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}")
  328. return exists
  329. @property
  330. def name(self):
  331. return "mvapich"
  332. def validate_args(self):
  333. super().validate_args()
  334. #TODO: Allow for include/exclude at node-level but not gpu-level
  335. if self.args.include != "" or self.args.exclude != "":
  336. raise ValueError(f"{self.name} backend does not support worker include/exclusion")
  337. if self.args.num_nodes != -1 or self.args.num_gpus != -1:
  338. raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
  339. def get_cmd(self, environment, active_resources):
  340. devices_per_node = self.resource_pool.values()
  341. total_process_count = sum(devices_per_node)
  342. process_per_node = list(devices_per_node)[0]
  343. if not all([n == process_per_node for n in devices_per_node]):
  344. raise ValueError("mvapich requires same number of devices per node")
  345. with open(MVAPICH_TMP_HOSTFILE, 'w') as fd:
  346. for host in self.resource_pool.keys():
  347. fd.write(f'{host}\n')
  348. mpirun_cmd = [
  349. 'mpirun',
  350. '-np',
  351. f'{total_process_count}',
  352. '-ppn',
  353. f'{process_per_node}',
  354. '--hostfile',
  355. f'{MVAPICH_TMP_HOSTFILE}',
  356. ] + split(self.args.launcher_args)
  357. export_cmd = []
  358. for k, v in self.exports.items():
  359. export_cmd += ['-env', "{}={}".format(k, v)]
  360. python_exec = []
  361. if not self.args.no_python:
  362. python_exec = [sys.executable, "-u"]
  363. if self.args.module:
  364. python_exec.append("-m")
  365. return mpirun_cmd + export_cmd + python_exec + [self.user_script] + self.user_arguments