scheduler.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import copy
  3. from numpy import BUFSIZE
  4. import json
  5. import subprocess
  6. import sys
  7. import threading
  8. import time
  9. import base64
  10. import os
  11. import hjson
  12. from tqdm import tqdm
  13. from ..utils import logger
  14. from .constants import AUTOTUNING, AUTOTUNING_METRIC_PATH
  15. from .utils import get_val_by_key, search_error, was_interruptted
  16. """
  17. thread-0: loop over experiment queue dispatching experiments if they become available
  18. thread-N: start each experiment in its own thread
  19. """
  20. from deepspeed import comm as dist
  21. TIMEOUT = 5
  22. class ResourceManager:
  23. def __init__(self,
  24. args,
  25. hosts,
  26. num_gpus_per_node,
  27. results_dir,
  28. exps_dir,
  29. arg_mappings):
  30. self.results_dir = results_dir
  31. self.exps_dir = exps_dir
  32. self.nodes = []
  33. self.num_gpus_per_node = num_gpus_per_node
  34. for host in hosts:
  35. self.nodes.append(Node(host, num_gpus_per_node))
  36. self.experiment_queue = []
  37. self.running_experiments = {}
  38. self.finished_experiments = {}
  39. self.experiment_count = 0
  40. self.exp_paths = set()
  41. self.args = args
  42. self.arg_mappings = {}
  43. if arg_mappings is not None:
  44. for k, v in arg_mappings.items():
  45. k = k.strip()
  46. v = v.strip()
  47. if k not in self.arg_mappings:
  48. self.arg_mappings[k] = v
  49. def schedule_experiments(self, exp_paths):
  50. for exp_path in exp_paths:
  51. if exp_path in self.exp_paths:
  52. continue
  53. else:
  54. self.exp_paths.add(exp_path)
  55. with open(exp_path, "r") as fd:
  56. exp = hjson.load(fd)
  57. exp["exp_id"] = self.experiment_count
  58. self.experiment_count += 1
  59. result_dir = exp["result_dir"] = os.path.join(
  60. self.results_dir,
  61. exp['name'])
  62. if AUTOTUNING in exp["ds_config"]:
  63. metric_file = os.path.join(result_dir, "metrics.json")
  64. exp["ds_config"][AUTOTUNING][
  65. AUTOTUNING_METRIC_PATH] = metric_file
  66. stderr_file = os.path.join(result_dir, "stderr.log")
  67. model_info_file = os.path.join(result_dir, "model_info.json")
  68. metric_file = os.path.join(result_dir, "metrics.json")
  69. # skip existing experiments (except for the ones that were interrupted)
  70. if os.path.exists(result_dir) and os.path.exists(stderr_file):
  71. if not was_interruptted(stderr_file):
  72. err = search_error(stderr_file)
  73. exp_id = exp["exp_id"]
  74. self.finished_experiments[exp_id] = (exp, err)
  75. if err or os.path.exists(metric_file) or os.path.exists(
  76. model_info_file):
  77. logger.info(
  78. f"Skipping exp {exp['name']} whose result already exists"
  79. )
  80. continue
  81. self.experiment_queue.append(exp)
  82. def run_job(self, exp: dict, reservations):
  83. exp_id = exp["exp_id"]
  84. exp["master_port"] = self.args.master_port + exp_id
  85. exp["result_dir"] = os.path.join(self.results_dir, exp['name'])
  86. user_script = self.args.user_script
  87. user_args = self.args.user_args
  88. # overwrite the user arg in the arg_mappings
  89. for key, val in self.arg_mappings.items():
  90. nval = get_val_by_key(exp, key)
  91. if nval and str(nval) != "auto":
  92. if val in user_args:
  93. idx = user_args.index(val)
  94. user_args[idx + 1] = str(nval)
  95. else:
  96. user_args.append(val)
  97. user_args.append(str(nval))
  98. t = threading.Thread(target=run_experiment,
  99. args=(exp,
  100. reservations,
  101. user_script,
  102. user_args))
  103. t.start()
  104. self.running_experiments[exp_id] = (t, exp, reservations, time.time())
  105. def experiment_check(self, pbar):
  106. finished_exps = []
  107. for exp_id, exp_data in self.running_experiments.items():
  108. thread, exp_json, reservations, start_time = exp_data
  109. logger.debug(f"Checking exp_id = {exp_id}, alive = {thread.is_alive()}")
  110. thread.join(timeout=TIMEOUT)
  111. if not thread.is_alive():
  112. exp_dir = exp_json["result_dir"]
  113. stderr_file = os.path.join(exp_dir, "stderr.log")
  114. err = search_error(stderr_file)
  115. finished_exps.append((exp_id, reservations))
  116. self.finished_experiments[exp_id] = (exp_json, err)
  117. duration = time.time() - start_time
  118. logger.debug(f"Finished exp_id = {exp_id}, duration={duration:.2f} sec")
  119. pbar.update(len(finished_exps))
  120. for exp_id, reservations in finished_exps:
  121. for reservation in reservations:
  122. reservation.restore_slots()
  123. self.running_experiments.pop(exp_id)
  124. time.sleep(TIMEOUT)
  125. def resource_request(self, exp):
  126. num_gpus, num_nodes = exp['num_gpus'], exp['num_nodes']
  127. slot_request = num_gpus
  128. reservations = []
  129. for node in self.nodes:
  130. if num_nodes == 0:
  131. break
  132. slots = node.reserve_slots(slot_request=slot_request)
  133. if slots:
  134. reservations.append(Reservation(node=node, slots=slots))
  135. num_nodes -= 1
  136. if num_nodes == 0:
  137. # request satisfied
  138. return reservations
  139. else:
  140. # request not satisfied
  141. for reservation in reservations:
  142. reservation.restore_slots()
  143. def status(self):
  144. status = ""
  145. for node in self.nodes:
  146. status += f"{node.host} ({len(node.idle_slots)} idle gpus), "
  147. return status[:-1]
  148. def run(self):
  149. pbar = tqdm(total=len(self.experiment_queue))
  150. while len(self.experiment_queue) > 0:
  151. exp = self.experiment_queue.pop(0)
  152. logger.debug(f'Popped exp_id = {exp["exp_id"]} from the queue')
  153. logger.debug(f'Resource status: {self.status()}')
  154. reservations = self.resource_request(exp)
  155. if not reservations:
  156. logger.debug(f'Unable to schedule exp_id = {exp["exp_id"]}')
  157. self.experiment_queue.insert(0, exp)
  158. logger.debug(f'Put exp_id = {exp["exp_id"]} back into the queue')
  159. self.experiment_check(pbar)
  160. else:
  161. desc = ""
  162. for reservation in reservations:
  163. reservation.slots.sort()
  164. slots = ",".join(map(str, reservation.slots))
  165. desc += f"{reservation.node.host}:{slots}@"
  166. desc = desc[:-1]
  167. logger.debug(f'Running exp_id = {exp["exp_id"]} on {desc}')
  168. self.run_job(exp, reservations)
  169. # All pending experiments are scheduled, waiting for them to complete
  170. while len(self.running_experiments) > 0:
  171. self.experiment_check(pbar)
  172. def save_exp_results_to_database(self, message, ranks=None, path=None):
  173. """Print message when one of following condition meets
  174. + not dist.is_initialized()
  175. + dist.get_rank() in ranks if ranks is not None or ranks = [-1]
  176. Args:
  177. message (str)
  178. ranks (list)
  179. path (str)
  180. """
  181. should_log = not dist.is_initialized()
  182. ranks = ranks or []
  183. my_rank = dist.get_rank() if dist.is_initialized() else -1
  184. if ranks and not should_log:
  185. should_log = ranks[0] == -1
  186. should_log = should_log or (my_rank in set(ranks))
  187. logger.debug(f"*** Should log: {should_log}")
  188. if should_log:
  189. message['rank'] = my_rank
  190. with open(path, 'a') as outfile:
  191. json.dump(message, outfile)
  192. outfile.write('\n')
  193. def parse_results(self, metric):
  194. """ Parses the metric file of the finished experiments to select the optimal DeepSpeed configuration.
  195. Args:
  196. finished_experiments (dcit): a dictionary of experiment id and experiment description.
  197. Returns:
  198. The path to the result folder of the experiment with the optimal configuration.
  199. """
  200. max_throughput = sys.float_info.min
  201. best_exp_id = -1
  202. for exp_id, (exp, err) in self.finished_experiments.items():
  203. if err:
  204. logger.info(
  205. f"The experiment exp_id = {exp_id}, exp_name = {exp['name']}, did not run successfully with error = {err}, thus a metrics.txt does not exist for it. Check the stderr.log in {exp['result_dir']}"
  206. )
  207. continue
  208. metric_file = exp["ds_config"][AUTOTUNING][AUTOTUNING_METRIC_PATH]
  209. if os.path.exists(metric_file):
  210. with open(metric_file, 'r') as f:
  211. results = hjson.load(f)
  212. curr_throughput = results[metric]
  213. if curr_throughput > max_throughput:
  214. max_throughput = curr_throughput
  215. best_exp_id = exp_id
  216. exp['results'] = results
  217. if best_exp_id != -1:
  218. best_exp, _ = self.finished_experiments[best_exp_id]
  219. return best_exp, max_throughput
  220. return exp, None
  221. def clear(self):
  222. """Clear experiment queues, does not reset self.experiment_count
  223. """
  224. self.experiment_queue = []
  225. # clean up the running experiments
  226. for exp_id, exp_data in self.running_experiments.items():
  227. thread, exp_json, reservations, start_time = exp_data
  228. clean_up(exp_json, reservations)
  229. self.running_experiments = {}
  230. self.finished_experiments = {}
  231. self.exp_paths = set()
  232. class Node:
  233. def __init__(self, host, max_slots):
  234. self.host = host
  235. self.max_slots = max_slots
  236. self.idle_slots = list(range(max_slots))
  237. def reserve_slots(self, slot_request: int) -> list:
  238. if len(self.idle_slots) >= slot_request:
  239. return [self.idle_slots.pop(0) for _ in range(slot_request)]
  240. def restore_slots(self, slots: list):
  241. self.idle_slots += slots
  242. class Reservation:
  243. def __init__(self, node, slots):
  244. self.node = node
  245. self.slots = slots
  246. def restore_slots(self):
  247. self.node.restore_slots(self.slots)
  248. def desc(self):
  249. slots = ",".join(map(str, self.slots))
  250. return f"{self.node.host}:{slots}@"
  251. def get_job_id():
  252. # Infrastructure-specific job-id
  253. infra_job_id = None
  254. if "DLWS_JOB_ID" in os.environ:
  255. infra_job_id = os.environ["DLWS_JOB_ID"]
  256. elif "DLTS_JOB_ID" in os.environ:
  257. infra_job_id = os.environ["DLTS_JOB_ID"]
  258. else:
  259. infra_job_id = "unknown-job-id"
  260. return infra_job_id
  261. def get_user():
  262. user = None
  263. if "USER" in os.environ:
  264. user = os.environ["USER"]
  265. else:
  266. user = "unknown-user"
  267. return user
  268. def run_experiment(exp: dict, reservations, user_script, user_args):
  269. include_str = ""
  270. for reservation in reservations:
  271. reservation.slots.sort()
  272. slots = ",".join(map(str, reservation.slots))
  273. include_str += f"{reservation.node.host}:{slots}@"
  274. include_str = include_str[:-1]
  275. master_port = exp["master_port"]
  276. exp["launcher_args"] = [
  277. "--include",
  278. f"{include_str}",
  279. "--master_port",
  280. str(master_port),
  281. ]
  282. logger.debug(f'launcher args={exp["launcher_args"]}')
  283. exp["user"] = get_user()
  284. exp["job_id"] = get_job_id()
  285. exp_dir = exp["result_dir"]
  286. os.makedirs(exp_dir, exist_ok=True)
  287. ds_config_path = os.path.join(exp_dir, "ds_config.json")
  288. exp["ds_config_path"] = ds_config_path
  289. ds_config = copy.deepcopy(exp["ds_config"])
  290. ds_config_json = json.dumps(ds_config).encode('utf-8')
  291. exp["ds_config_base64"] = base64.urlsafe_b64encode(ds_config_json).decode('utf-8')
  292. with open(exp["ds_config_path"], "w", buffering=BUFSIZE) as fd:
  293. json.dump(ds_config, fd)
  294. fd.flush()
  295. os.fsync(fd)
  296. path = exp["ds_config_path"]
  297. logger.info(f"Scheduler wrote ds_config to {path}, {os.path.abspath(path)}")
  298. with open(os.path.join(exp_dir, "exp.json"), "w", buffering=BUFSIZE) as fd:
  299. json.dump(exp, fd)
  300. fd.flush()
  301. os.fsync(fd)
  302. path = os.path.join(exp_dir, "exp.json")
  303. logger.info(f"Scheduler wrote exp to {path}, {os.path.abspath(path)}")
  304. # remove "--deepspeed_config ds_config.json" from user_args
  305. if user_args:
  306. if "--deepspeed_config" in user_args:
  307. idx = user_args.index("--deepspeed_config")
  308. # "--deepspeed_config" is omitted in HF
  309. elif "--deepspeed" in user_args:
  310. idx = user_args.index("--deepspeed")
  311. assert idx < len(user_args), "there is no ds_config file specified after --deepspeed_config or --deepspeed"
  312. # user_args[idx + 1] = exp["ds_config_path"]
  313. # pass base64 serialized ds_config to launcher
  314. user_args[idx + 1] = exp["ds_config_base64"]
  315. exp["user_script"] = user_script
  316. exp["user_args"] = user_args
  317. cmd = ["deepspeed"] + exp["launcher_args"] + [user_script] + user_args
  318. assert len(exp["launcher_args"]) > 0, "must provide launcher args"
  319. with open(os.path.join(exp_dir, "cmd.txt"), "w", buffering=BUFSIZE) as fd:
  320. fd.write(" ".join(cmd))
  321. fd.write("\n")
  322. fd.flush()
  323. os.fsync(fd)
  324. logger.info(
  325. f"Launching exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}, and ds_config = {os.path.abspath(ds_config_path)}"
  326. )
  327. with open(os.path.join(exp_dir, "stdout.log"), "wb") as out, open(
  328. os.path.join(exp_dir, "stderr.log"), "wb"
  329. ) as err:
  330. result = subprocess.Popen(cmd, stdout=out, stderr=err)
  331. result.wait()
  332. out.flush()
  333. err.flush()
  334. os.fsync(out)
  335. os.fsync(err)
  336. clean_up(exp, reservations)
  337. logger.info(
  338. f"Done running exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}"
  339. )
  340. PDSH_MAX_FAN_OUT = 1024
  341. def clean_up(exp: dict, reservations):
  342. env = os.environ.copy()
  343. env['PDSH_RCMD_TYPE'] = 'ssh'
  344. nodes_str = ""
  345. for reservation in reservations:
  346. nodes_str += f"{reservation.node.host},"
  347. nodes_str = nodes_str[:-1]
  348. logger.debug(
  349. f"Cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}")
  350. # PDSH flags for max node fan out and specific hosts to launch on
  351. # See https://linux.die.net/man/1/pdsh for flag details
  352. pdsh_cmd = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', nodes_str]
  353. kill_cmd = [
  354. 'pkill',
  355. '-f',
  356. exp['name'],
  357. ]
  358. cmd = pdsh_cmd + kill_cmd
  359. logger.debug("cmd = {}".format(' '.join(cmd)))
  360. result = subprocess.Popen(cmd, env=env)
  361. result.wait()
  362. # In case of failure must propagate the error-condition back to the caller (usually shell). The
  363. # actual error and traceback should have been printed in the subprocess, so in order to avoid
  364. # unnecessary noise we just quietly exit here with the same code as the subprocess
  365. if result.returncode > 0:
  366. sys.exit(result.returncode)
  367. logger.info(
  368. f"Done cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}"
  369. )