scheduler.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import copy
  5. import json
  6. import subprocess
  7. import sys
  8. import threading
  9. import time
  10. import base64
  11. import os
  12. import hjson
  13. from tqdm import tqdm
  14. from ..utils import logger
  15. from .constants import AUTOTUNING, AUTOTUNING_METRIC_PATH, BUFSIZE
  16. from .utils import get_val_by_key, search_error, was_interruptted
  17. """
  18. thread-0: loop over experiment queue dispatching experiments if they become available
  19. thread-N: start each experiment in its own thread
  20. """
  21. from deepspeed import comm as dist
  22. TIMEOUT = 5
  23. class ResourceManager:
  24. def __init__(self, args, hosts, num_gpus_per_node, results_dir, exps_dir, arg_mappings):
  25. self.results_dir = results_dir
  26. self.exps_dir = exps_dir
  27. self.nodes = []
  28. self.num_gpus_per_node = num_gpus_per_node
  29. for host in hosts:
  30. self.nodes.append(Node(host, num_gpus_per_node))
  31. self.experiment_queue = []
  32. self.running_experiments = {}
  33. self.finished_experiments = {}
  34. self.experiment_count = 0
  35. self.exp_paths = set()
  36. self.args = args
  37. self.arg_mappings = {}
  38. if arg_mappings is not None:
  39. for k, v in arg_mappings.items():
  40. k = k.strip()
  41. v = v.strip()
  42. if k not in self.arg_mappings:
  43. self.arg_mappings[k] = v
  44. def schedule_experiments(self, exp_paths):
  45. for exp_path in exp_paths:
  46. if exp_path in self.exp_paths:
  47. continue
  48. else:
  49. self.exp_paths.add(exp_path)
  50. with open(exp_path, "r") as fd:
  51. exp = hjson.load(fd)
  52. exp["exp_id"] = self.experiment_count
  53. self.experiment_count += 1
  54. result_dir = exp["result_dir"] = os.path.join(self.results_dir, exp['name'])
  55. if AUTOTUNING in exp["ds_config"]:
  56. metric_file = os.path.join(result_dir, "metrics.json")
  57. exp["ds_config"][AUTOTUNING][AUTOTUNING_METRIC_PATH] = metric_file
  58. stderr_file = os.path.join(result_dir, "stderr.log")
  59. model_info_file = os.path.join(result_dir, "model_info.json")
  60. metric_file = os.path.join(result_dir, "metrics.json")
  61. # skip existing experiments (except for the ones that were interrupted)
  62. if os.path.exists(result_dir) and os.path.exists(stderr_file):
  63. if not was_interruptted(stderr_file):
  64. err = search_error(stderr_file)
  65. exp_id = exp["exp_id"]
  66. self.finished_experiments[exp_id] = (exp, err)
  67. if err or os.path.exists(metric_file) or os.path.exists(model_info_file):
  68. logger.info(f"Skipping exp {exp['name']} whose result already exists")
  69. continue
  70. self.experiment_queue.append(exp)
  71. def run_job(self, exp: dict, reservations):
  72. exp_id = exp["exp_id"]
  73. exp["master_port"] = self.args.master_port + exp_id
  74. exp["result_dir"] = os.path.join(self.results_dir, exp['name'])
  75. user_script = self.args.user_script
  76. user_args = self.args.user_args
  77. # overwrite the user arg in the arg_mappings
  78. for key, val in self.arg_mappings.items():
  79. nval = get_val_by_key(exp, key)
  80. if nval and str(nval) != "auto":
  81. if val in user_args:
  82. idx = user_args.index(val)
  83. user_args[idx + 1] = str(nval)
  84. else:
  85. user_args.append(val)
  86. user_args.append(str(nval))
  87. t = threading.Thread(target=run_experiment, args=(exp, reservations, user_script, user_args))
  88. t.start()
  89. self.running_experiments[exp_id] = (t, exp, reservations, time.time())
  90. def experiment_check(self, pbar):
  91. finished_exps = []
  92. for exp_id, exp_data in self.running_experiments.items():
  93. thread, exp_json, reservations, start_time = exp_data
  94. logger.debug(f"Checking exp_id = {exp_id}, alive = {thread.is_alive()}")
  95. thread.join(timeout=TIMEOUT)
  96. if not thread.is_alive():
  97. exp_dir = exp_json["result_dir"]
  98. stderr_file = os.path.join(exp_dir, "stderr.log")
  99. err = search_error(stderr_file)
  100. finished_exps.append((exp_id, reservations))
  101. self.finished_experiments[exp_id] = (exp_json, err)
  102. duration = time.time() - start_time
  103. logger.debug(f"Finished exp_id = {exp_id}, duration={duration:.2f} sec")
  104. pbar.update(len(finished_exps))
  105. for exp_id, reservations in finished_exps:
  106. for reservation in reservations:
  107. reservation.restore_slots()
  108. self.running_experiments.pop(exp_id)
  109. time.sleep(TIMEOUT)
  110. def resource_request(self, exp):
  111. num_gpus, num_nodes = exp['num_gpus'], exp['num_nodes']
  112. slot_request = num_gpus
  113. reservations = []
  114. for node in self.nodes:
  115. if num_nodes == 0:
  116. break
  117. slots = node.reserve_slots(slot_request=slot_request)
  118. if slots:
  119. reservations.append(Reservation(node=node, slots=slots))
  120. num_nodes -= 1
  121. if num_nodes == 0:
  122. # request satisfied
  123. return reservations
  124. else:
  125. # request not satisfied
  126. for reservation in reservations:
  127. reservation.restore_slots()
  128. def status(self):
  129. status = ""
  130. for node in self.nodes:
  131. status += f"{node.host} ({len(node.idle_slots)} idle gpus), "
  132. return status[:-1]
  133. def run(self):
  134. pbar = tqdm(total=len(self.experiment_queue))
  135. while len(self.experiment_queue) > 0:
  136. exp = self.experiment_queue.pop(0)
  137. logger.debug(f'Popped exp_id = {exp["exp_id"]} from the queue')
  138. logger.debug(f'Resource status: {self.status()}')
  139. reservations = self.resource_request(exp)
  140. if not reservations:
  141. logger.debug(f'Unable to schedule exp_id = {exp["exp_id"]}')
  142. self.experiment_queue.insert(0, exp)
  143. logger.debug(f'Put exp_id = {exp["exp_id"]} back into the queue')
  144. self.experiment_check(pbar)
  145. else:
  146. desc = ""
  147. for reservation in reservations:
  148. reservation.slots.sort()
  149. slots = ",".join(map(str, reservation.slots))
  150. desc += f"{reservation.node.host}:{slots}@"
  151. desc = desc[:-1]
  152. logger.debug(f'Running exp_id = {exp["exp_id"]} on {desc}')
  153. self.run_job(exp, reservations)
  154. # All pending experiments are scheduled, waiting for them to complete
  155. while len(self.running_experiments) > 0:
  156. self.experiment_check(pbar)
  157. def save_exp_results_to_database(self, message, ranks=None, path=None):
  158. """Print message when one of following condition meets
  159. + not dist.is_initialized()
  160. + dist.get_rank() in ranks if ranks is not None or ranks = [-1]
  161. Args:
  162. message (str)
  163. ranks (list)
  164. path (str)
  165. """
  166. should_log = not dist.is_initialized()
  167. ranks = ranks or []
  168. my_rank = dist.get_rank() if dist.is_initialized() else -1
  169. if ranks and not should_log:
  170. should_log = ranks[0] == -1
  171. should_log = should_log or (my_rank in set(ranks))
  172. logger.debug(f"*** Should log: {should_log}")
  173. if should_log:
  174. message['rank'] = my_rank
  175. with open(path, 'a') as outfile:
  176. json.dump(message, outfile)
  177. outfile.write('\n')
  178. def parse_results(self, metric):
  179. """ Parses the metric file of the finished experiments to select the optimal DeepSpeed configuration.
  180. Args:
  181. finished_experiments (dcit): a dictionary of experiment id and experiment description.
  182. Returns:
  183. The path to the result folder of the experiment with the optimal configuration.
  184. """
  185. max_throughput = sys.float_info.min
  186. best_exp_id = -1
  187. for exp_id, (exp, err) in self.finished_experiments.items():
  188. if err:
  189. logger.info(
  190. f"The experiment exp_id = {exp_id}, exp_name = {exp['name']}, did not run successfully with error = {err}, thus a metrics.txt does not exist for it. Check the stderr.log in {exp['result_dir']}"
  191. )
  192. continue
  193. metric_file = exp["ds_config"][AUTOTUNING][AUTOTUNING_METRIC_PATH]
  194. if os.path.exists(metric_file):
  195. with open(metric_file, 'r') as f:
  196. results = hjson.load(f)
  197. curr_throughput = results[metric]
  198. if curr_throughput > max_throughput:
  199. max_throughput = curr_throughput
  200. best_exp_id = exp_id
  201. exp['results'] = results
  202. if best_exp_id != -1:
  203. best_exp, _ = self.finished_experiments[best_exp_id]
  204. return best_exp, max_throughput
  205. return exp, None
  206. def clear(self):
  207. """Clear experiment queues, does not reset self.experiment_count
  208. """
  209. self.experiment_queue = []
  210. # clean up the running experiments
  211. for exp_id, exp_data in self.running_experiments.items():
  212. thread, exp_json, reservations, start_time = exp_data
  213. clean_up(exp_json, reservations)
  214. self.running_experiments = {}
  215. self.finished_experiments = {}
  216. self.exp_paths = set()
  217. class Node:
  218. def __init__(self, host, max_slots):
  219. self.host = host
  220. self.max_slots = max_slots
  221. self.idle_slots = list(range(max_slots))
  222. def reserve_slots(self, slot_request: int) -> list:
  223. if len(self.idle_slots) >= slot_request:
  224. return [self.idle_slots.pop(0) for _ in range(slot_request)]
  225. def restore_slots(self, slots: list):
  226. self.idle_slots += slots
  227. class Reservation:
  228. def __init__(self, node, slots):
  229. self.node = node
  230. self.slots = slots
  231. def restore_slots(self):
  232. self.node.restore_slots(self.slots)
  233. def desc(self):
  234. slots = ",".join(map(str, self.slots))
  235. return f"{self.node.host}:{slots}@"
  236. def get_job_id():
  237. # Infrastructure-specific job-id
  238. infra_job_id = None
  239. if "DLWS_JOB_ID" in os.environ:
  240. infra_job_id = os.environ["DLWS_JOB_ID"]
  241. elif "DLTS_JOB_ID" in os.environ:
  242. infra_job_id = os.environ["DLTS_JOB_ID"]
  243. else:
  244. infra_job_id = "unknown-job-id"
  245. return infra_job_id
  246. def get_user():
  247. user = None
  248. if "USER" in os.environ:
  249. user = os.environ["USER"]
  250. else:
  251. user = "unknown-user"
  252. return user
  253. def run_experiment(exp: dict, reservations, user_script, user_args):
  254. include_str = ""
  255. for reservation in reservations:
  256. reservation.slots.sort()
  257. slots = ",".join(map(str, reservation.slots))
  258. include_str += f"{reservation.node.host}:{slots}@"
  259. include_str = include_str[:-1]
  260. master_port = exp["master_port"]
  261. hostfile = exp["hostfile"]
  262. exp["launcher_args"] = [
  263. "--hostfile",
  264. f"{hostfile}",
  265. "--include",
  266. f"{include_str}",
  267. "--master_port",
  268. str(master_port),
  269. ]
  270. logger.debug(f'launcher args={exp["launcher_args"]}')
  271. exp["user"] = get_user()
  272. exp["job_id"] = get_job_id()
  273. exp_dir = exp["result_dir"]
  274. os.makedirs(exp_dir, exist_ok=True)
  275. ds_config_path = os.path.join(exp_dir, "ds_config.json")
  276. exp["ds_config_path"] = ds_config_path
  277. ds_config = copy.deepcopy(exp["ds_config"])
  278. ds_config_json = json.dumps(ds_config).encode('utf-8')
  279. exp["ds_config_base64"] = base64.urlsafe_b64encode(ds_config_json).decode('utf-8')
  280. with open(exp["ds_config_path"], "w", buffering=BUFSIZE) as fd:
  281. json.dump(ds_config, fd)
  282. fd.flush()
  283. os.fsync(fd)
  284. path = exp["ds_config_path"]
  285. logger.info(f"Scheduler wrote ds_config to {path}, {os.path.abspath(path)}")
  286. with open(os.path.join(exp_dir, "exp.json"), "w", buffering=BUFSIZE) as fd:
  287. json.dump(exp, fd)
  288. fd.flush()
  289. os.fsync(fd)
  290. path = os.path.join(exp_dir, "exp.json")
  291. logger.info(f"Scheduler wrote exp to {path}, {os.path.abspath(path)}")
  292. # remove "--deepspeed_config ds_config.json" from user_args
  293. if user_args:
  294. if "--deepspeed_config" in user_args:
  295. idx = user_args.index("--deepspeed_config")
  296. # "--deepspeed_config" is omitted in HF
  297. elif "--deepspeed" in user_args:
  298. idx = user_args.index("--deepspeed")
  299. assert idx < len(user_args), "there is no ds_config file specified after --deepspeed_config or --deepspeed"
  300. # user_args[idx + 1] = exp["ds_config_path"]
  301. # pass base64 serialized ds_config to launcher
  302. user_args[idx + 1] = exp["ds_config_base64"]
  303. exp["user_script"] = user_script
  304. exp["user_args"] = user_args
  305. cmd = ["deepspeed"] + exp["launcher_args"] + [user_script] + user_args
  306. assert len(exp["launcher_args"]) > 0, "must provide launcher args"
  307. with open(os.path.join(exp_dir, "cmd.txt"), "w", buffering=BUFSIZE) as fd:
  308. fd.write(" ".join(cmd))
  309. fd.write("\n")
  310. fd.flush()
  311. os.fsync(fd)
  312. logger.info(
  313. f"Launching exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}, and ds_config = {os.path.abspath(ds_config_path)}"
  314. )
  315. with open(os.path.join(exp_dir, "stdout.log"), "wb") as out, open(os.path.join(exp_dir, "stderr.log"),
  316. "wb") as err:
  317. result = subprocess.Popen(cmd, stdout=out, stderr=err)
  318. result.wait()
  319. out.flush()
  320. err.flush()
  321. os.fsync(out)
  322. os.fsync(err)
  323. clean_up(exp, reservations)
  324. logger.info(f"Done running exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}")
  325. PDSH_MAX_FAN_OUT = 1024
  326. def clean_up(exp: dict, reservations):
  327. env = os.environ.copy()
  328. env['PDSH_RCMD_TYPE'] = 'ssh'
  329. nodes_str = ""
  330. for reservation in reservations:
  331. nodes_str += f"{reservation.node.host},"
  332. nodes_str = nodes_str[:-1]
  333. logger.debug(f"Cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}")
  334. # PDSH flags for max node fan out and specific hosts to launch on
  335. # See https://linux.die.net/man/1/pdsh for flag details
  336. pdsh_cmd = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', nodes_str]
  337. kill_cmd = [
  338. 'pkill',
  339. '-f',
  340. exp['name'],
  341. ]
  342. cmd = pdsh_cmd + kill_cmd
  343. logger.debug("cmd = {}".format(' '.join(cmd)))
  344. result = subprocess.Popen(cmd, env=env)
  345. result.wait()
  346. # In case of failure must propagate the error-condition back to the caller (usually shell). The
  347. # actual error and traceback should have been printed in the subprocess, so in order to avoid
  348. # unnecessary noise we just quietly exit here with the same code as the subprocess
  349. if result.returncode > 0:
  350. sys.exit(result.returncode)
  351. logger.info(f"Done cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}")