scheduler.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import copy
  5. from numpy import BUFSIZE
  6. import json
  7. import subprocess
  8. import sys
  9. import threading
  10. import time
  11. import base64
  12. import os
  13. import hjson
  14. from tqdm import tqdm
  15. from ..utils import logger
  16. from .constants import AUTOTUNING, AUTOTUNING_METRIC_PATH
  17. from .utils import get_val_by_key, search_error, was_interruptted
  18. """
  19. thread-0: loop over experiment queue dispatching experiments if they become available
  20. thread-N: start each experiment in its own thread
  21. """
  22. from deepspeed import comm as dist
  23. TIMEOUT = 5
  24. class ResourceManager:
  25. def __init__(self, args, hosts, num_gpus_per_node, results_dir, exps_dir, arg_mappings):
  26. self.results_dir = results_dir
  27. self.exps_dir = exps_dir
  28. self.nodes = []
  29. self.num_gpus_per_node = num_gpus_per_node
  30. for host in hosts:
  31. self.nodes.append(Node(host, num_gpus_per_node))
  32. self.experiment_queue = []
  33. self.running_experiments = {}
  34. self.finished_experiments = {}
  35. self.experiment_count = 0
  36. self.exp_paths = set()
  37. self.args = args
  38. self.arg_mappings = {}
  39. if arg_mappings is not None:
  40. for k, v in arg_mappings.items():
  41. k = k.strip()
  42. v = v.strip()
  43. if k not in self.arg_mappings:
  44. self.arg_mappings[k] = v
  45. def schedule_experiments(self, exp_paths):
  46. for exp_path in exp_paths:
  47. if exp_path in self.exp_paths:
  48. continue
  49. else:
  50. self.exp_paths.add(exp_path)
  51. with open(exp_path, "r") as fd:
  52. exp = hjson.load(fd)
  53. exp["exp_id"] = self.experiment_count
  54. self.experiment_count += 1
  55. result_dir = exp["result_dir"] = os.path.join(self.results_dir, exp['name'])
  56. if AUTOTUNING in exp["ds_config"]:
  57. metric_file = os.path.join(result_dir, "metrics.json")
  58. exp["ds_config"][AUTOTUNING][AUTOTUNING_METRIC_PATH] = metric_file
  59. stderr_file = os.path.join(result_dir, "stderr.log")
  60. model_info_file = os.path.join(result_dir, "model_info.json")
  61. metric_file = os.path.join(result_dir, "metrics.json")
  62. # skip existing experiments (except for the ones that were interrupted)
  63. if os.path.exists(result_dir) and os.path.exists(stderr_file):
  64. if not was_interruptted(stderr_file):
  65. err = search_error(stderr_file)
  66. exp_id = exp["exp_id"]
  67. self.finished_experiments[exp_id] = (exp, err)
  68. if err or os.path.exists(metric_file) or os.path.exists(model_info_file):
  69. logger.info(f"Skipping exp {exp['name']} whose result already exists")
  70. continue
  71. self.experiment_queue.append(exp)
  72. def run_job(self, exp: dict, reservations):
  73. exp_id = exp["exp_id"]
  74. exp["master_port"] = self.args.master_port + exp_id
  75. exp["result_dir"] = os.path.join(self.results_dir, exp['name'])
  76. user_script = self.args.user_script
  77. user_args = self.args.user_args
  78. # overwrite the user arg in the arg_mappings
  79. for key, val in self.arg_mappings.items():
  80. nval = get_val_by_key(exp, key)
  81. if nval and str(nval) != "auto":
  82. if val in user_args:
  83. idx = user_args.index(val)
  84. user_args[idx + 1] = str(nval)
  85. else:
  86. user_args.append(val)
  87. user_args.append(str(nval))
  88. t = threading.Thread(target=run_experiment, args=(exp, reservations, user_script, user_args))
  89. t.start()
  90. self.running_experiments[exp_id] = (t, exp, reservations, time.time())
  91. def experiment_check(self, pbar):
  92. finished_exps = []
  93. for exp_id, exp_data in self.running_experiments.items():
  94. thread, exp_json, reservations, start_time = exp_data
  95. logger.debug(f"Checking exp_id = {exp_id}, alive = {thread.is_alive()}")
  96. thread.join(timeout=TIMEOUT)
  97. if not thread.is_alive():
  98. exp_dir = exp_json["result_dir"]
  99. stderr_file = os.path.join(exp_dir, "stderr.log")
  100. err = search_error(stderr_file)
  101. finished_exps.append((exp_id, reservations))
  102. self.finished_experiments[exp_id] = (exp_json, err)
  103. duration = time.time() - start_time
  104. logger.debug(f"Finished exp_id = {exp_id}, duration={duration:.2f} sec")
  105. pbar.update(len(finished_exps))
  106. for exp_id, reservations in finished_exps:
  107. for reservation in reservations:
  108. reservation.restore_slots()
  109. self.running_experiments.pop(exp_id)
  110. time.sleep(TIMEOUT)
  111. def resource_request(self, exp):
  112. num_gpus, num_nodes = exp['num_gpus'], exp['num_nodes']
  113. slot_request = num_gpus
  114. reservations = []
  115. for node in self.nodes:
  116. if num_nodes == 0:
  117. break
  118. slots = node.reserve_slots(slot_request=slot_request)
  119. if slots:
  120. reservations.append(Reservation(node=node, slots=slots))
  121. num_nodes -= 1
  122. if num_nodes == 0:
  123. # request satisfied
  124. return reservations
  125. else:
  126. # request not satisfied
  127. for reservation in reservations:
  128. reservation.restore_slots()
  129. def status(self):
  130. status = ""
  131. for node in self.nodes:
  132. status += f"{node.host} ({len(node.idle_slots)} idle gpus), "
  133. return status[:-1]
  134. def run(self):
  135. pbar = tqdm(total=len(self.experiment_queue))
  136. while len(self.experiment_queue) > 0:
  137. exp = self.experiment_queue.pop(0)
  138. logger.debug(f'Popped exp_id = {exp["exp_id"]} from the queue')
  139. logger.debug(f'Resource status: {self.status()}')
  140. reservations = self.resource_request(exp)
  141. if not reservations:
  142. logger.debug(f'Unable to schedule exp_id = {exp["exp_id"]}')
  143. self.experiment_queue.insert(0, exp)
  144. logger.debug(f'Put exp_id = {exp["exp_id"]} back into the queue')
  145. self.experiment_check(pbar)
  146. else:
  147. desc = ""
  148. for reservation in reservations:
  149. reservation.slots.sort()
  150. slots = ",".join(map(str, reservation.slots))
  151. desc += f"{reservation.node.host}:{slots}@"
  152. desc = desc[:-1]
  153. logger.debug(f'Running exp_id = {exp["exp_id"]} on {desc}')
  154. self.run_job(exp, reservations)
  155. # All pending experiments are scheduled, waiting for them to complete
  156. while len(self.running_experiments) > 0:
  157. self.experiment_check(pbar)
  158. def save_exp_results_to_database(self, message, ranks=None, path=None):
  159. """Print message when one of following condition meets
  160. + not dist.is_initialized()
  161. + dist.get_rank() in ranks if ranks is not None or ranks = [-1]
  162. Args:
  163. message (str)
  164. ranks (list)
  165. path (str)
  166. """
  167. should_log = not dist.is_initialized()
  168. ranks = ranks or []
  169. my_rank = dist.get_rank() if dist.is_initialized() else -1
  170. if ranks and not should_log:
  171. should_log = ranks[0] == -1
  172. should_log = should_log or (my_rank in set(ranks))
  173. logger.debug(f"*** Should log: {should_log}")
  174. if should_log:
  175. message['rank'] = my_rank
  176. with open(path, 'a') as outfile:
  177. json.dump(message, outfile)
  178. outfile.write('\n')
  179. def parse_results(self, metric):
  180. """ Parses the metric file of the finished experiments to select the optimal DeepSpeed configuration.
  181. Args:
  182. finished_experiments (dcit): a dictionary of experiment id and experiment description.
  183. Returns:
  184. The path to the result folder of the experiment with the optimal configuration.
  185. """
  186. max_throughput = sys.float_info.min
  187. best_exp_id = -1
  188. for exp_id, (exp, err) in self.finished_experiments.items():
  189. if err:
  190. logger.info(
  191. f"The experiment exp_id = {exp_id}, exp_name = {exp['name']}, did not run successfully with error = {err}, thus a metrics.txt does not exist for it. Check the stderr.log in {exp['result_dir']}"
  192. )
  193. continue
  194. metric_file = exp["ds_config"][AUTOTUNING][AUTOTUNING_METRIC_PATH]
  195. if os.path.exists(metric_file):
  196. with open(metric_file, 'r') as f:
  197. results = hjson.load(f)
  198. curr_throughput = results[metric]
  199. if curr_throughput > max_throughput:
  200. max_throughput = curr_throughput
  201. best_exp_id = exp_id
  202. exp['results'] = results
  203. if best_exp_id != -1:
  204. best_exp, _ = self.finished_experiments[best_exp_id]
  205. return best_exp, max_throughput
  206. return exp, None
  207. def clear(self):
  208. """Clear experiment queues, does not reset self.experiment_count
  209. """
  210. self.experiment_queue = []
  211. # clean up the running experiments
  212. for exp_id, exp_data in self.running_experiments.items():
  213. thread, exp_json, reservations, start_time = exp_data
  214. clean_up(exp_json, reservations)
  215. self.running_experiments = {}
  216. self.finished_experiments = {}
  217. self.exp_paths = set()
  218. class Node:
  219. def __init__(self, host, max_slots):
  220. self.host = host
  221. self.max_slots = max_slots
  222. self.idle_slots = list(range(max_slots))
  223. def reserve_slots(self, slot_request: int) -> list:
  224. if len(self.idle_slots) >= slot_request:
  225. return [self.idle_slots.pop(0) for _ in range(slot_request)]
  226. def restore_slots(self, slots: list):
  227. self.idle_slots += slots
  228. class Reservation:
  229. def __init__(self, node, slots):
  230. self.node = node
  231. self.slots = slots
  232. def restore_slots(self):
  233. self.node.restore_slots(self.slots)
  234. def desc(self):
  235. slots = ",".join(map(str, self.slots))
  236. return f"{self.node.host}:{slots}@"
  237. def get_job_id():
  238. # Infrastructure-specific job-id
  239. infra_job_id = None
  240. if "DLWS_JOB_ID" in os.environ:
  241. infra_job_id = os.environ["DLWS_JOB_ID"]
  242. elif "DLTS_JOB_ID" in os.environ:
  243. infra_job_id = os.environ["DLTS_JOB_ID"]
  244. else:
  245. infra_job_id = "unknown-job-id"
  246. return infra_job_id
  247. def get_user():
  248. user = None
  249. if "USER" in os.environ:
  250. user = os.environ["USER"]
  251. else:
  252. user = "unknown-user"
  253. return user
  254. def run_experiment(exp: dict, reservations, user_script, user_args):
  255. include_str = ""
  256. for reservation in reservations:
  257. reservation.slots.sort()
  258. slots = ",".join(map(str, reservation.slots))
  259. include_str += f"{reservation.node.host}:{slots}@"
  260. include_str = include_str[:-1]
  261. master_port = exp["master_port"]
  262. hostfile = exp["hostfile"]
  263. exp["launcher_args"] = [
  264. "--hostfile",
  265. f"{hostfile}",
  266. "--include",
  267. f"{include_str}",
  268. "--master_port",
  269. str(master_port),
  270. ]
  271. logger.debug(f'launcher args={exp["launcher_args"]}')
  272. exp["user"] = get_user()
  273. exp["job_id"] = get_job_id()
  274. exp_dir = exp["result_dir"]
  275. os.makedirs(exp_dir, exist_ok=True)
  276. ds_config_path = os.path.join(exp_dir, "ds_config.json")
  277. exp["ds_config_path"] = ds_config_path
  278. ds_config = copy.deepcopy(exp["ds_config"])
  279. ds_config_json = json.dumps(ds_config).encode('utf-8')
  280. exp["ds_config_base64"] = base64.urlsafe_b64encode(ds_config_json).decode('utf-8')
  281. with open(exp["ds_config_path"], "w", buffering=BUFSIZE) as fd:
  282. json.dump(ds_config, fd)
  283. fd.flush()
  284. os.fsync(fd)
  285. path = exp["ds_config_path"]
  286. logger.info(f"Scheduler wrote ds_config to {path}, {os.path.abspath(path)}")
  287. with open(os.path.join(exp_dir, "exp.json"), "w", buffering=BUFSIZE) as fd:
  288. json.dump(exp, fd)
  289. fd.flush()
  290. os.fsync(fd)
  291. path = os.path.join(exp_dir, "exp.json")
  292. logger.info(f"Scheduler wrote exp to {path}, {os.path.abspath(path)}")
  293. # remove "--deepspeed_config ds_config.json" from user_args
  294. if user_args:
  295. if "--deepspeed_config" in user_args:
  296. idx = user_args.index("--deepspeed_config")
  297. # "--deepspeed_config" is omitted in HF
  298. elif "--deepspeed" in user_args:
  299. idx = user_args.index("--deepspeed")
  300. assert idx < len(user_args), "there is no ds_config file specified after --deepspeed_config or --deepspeed"
  301. # user_args[idx + 1] = exp["ds_config_path"]
  302. # pass base64 serialized ds_config to launcher
  303. user_args[idx + 1] = exp["ds_config_base64"]
  304. exp["user_script"] = user_script
  305. exp["user_args"] = user_args
  306. cmd = ["deepspeed"] + exp["launcher_args"] + [user_script] + user_args
  307. assert len(exp["launcher_args"]) > 0, "must provide launcher args"
  308. with open(os.path.join(exp_dir, "cmd.txt"), "w", buffering=BUFSIZE) as fd:
  309. fd.write(" ".join(cmd))
  310. fd.write("\n")
  311. fd.flush()
  312. os.fsync(fd)
  313. logger.info(
  314. f"Launching exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}, and ds_config = {os.path.abspath(ds_config_path)}"
  315. )
  316. with open(os.path.join(exp_dir, "stdout.log"), "wb") as out, open(os.path.join(exp_dir, "stderr.log"),
  317. "wb") as err:
  318. result = subprocess.Popen(cmd, stdout=out, stderr=err)
  319. result.wait()
  320. out.flush()
  321. err.flush()
  322. os.fsync(out)
  323. os.fsync(err)
  324. clean_up(exp, reservations)
  325. logger.info(f"Done running exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}")
  326. PDSH_MAX_FAN_OUT = 1024
  327. def clean_up(exp: dict, reservations):
  328. env = os.environ.copy()
  329. env['PDSH_RCMD_TYPE'] = 'ssh'
  330. nodes_str = ""
  331. for reservation in reservations:
  332. nodes_str += f"{reservation.node.host},"
  333. nodes_str = nodes_str[:-1]
  334. logger.debug(f"Cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}")
  335. # PDSH flags for max node fan out and specific hosts to launch on
  336. # See https://linux.die.net/man/1/pdsh for flag details
  337. pdsh_cmd = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', nodes_str]
  338. kill_cmd = [
  339. 'pkill',
  340. '-f',
  341. exp['name'],
  342. ]
  343. cmd = pdsh_cmd + kill_cmd
  344. logger.debug("cmd = {}".format(' '.join(cmd)))
  345. result = subprocess.Popen(cmd, env=env)
  346. result.wait()
  347. # In case of failure must propagate the error-condition back to the caller (usually shell). The
  348. # actual error and traceback should have been printed in the subprocess, so in order to avoid
  349. # unnecessary noise we just quietly exit here with the same code as the subprocess
  350. if result.returncode > 0:
  351. sys.exit(result.returncode)
  352. logger.info(f"Done cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}")