utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import re
  5. import collections.abc
  6. import os
  7. import json
  8. from deepspeed.runtime.constants import GRADIENT_ACCUMULATION_STEPS, TRAIN_MICRO_BATCH_SIZE_PER_GPU
  9. import itertools
  10. import copy
  11. from ..utils import logger
  12. def search_error(filename):
  13. if not os.path.exists(filename):
  14. return "stderr.log does not exist"
  15. with open(filename) as f:
  16. for line in f:
  17. for s in ["Error", "error", "ERROR"]:
  18. idx = line.find(s)
  19. if idx != -1:
  20. return line[idx + len(s):].lstrip(": ")
  21. return None
  22. def was_interruptted(filename):
  23. if not os.path.exists(filename):
  24. return "stderr.log does not exist"
  25. with open(filename) as f:
  26. for line in f:
  27. s = "KeyboardInterrupt"
  28. idx = line.find(s)
  29. if idx != -1:
  30. return True
  31. return False
  32. def find_replace_str(value, replace_dict):
  33. if not isinstance(value, str):
  34. return str(value)
  35. matches = re.findall(r"\$[A-Za-z0-9_]+", value)
  36. for var in matches:
  37. var_key = var.replace("$", "").lower()
  38. if var_key == "nvme_path":
  39. continue
  40. assert var_key in replace_dict, f"unknown var key: {var_key}, in {replace_dict}"
  41. if isinstance(replace_dict[var_key], str):
  42. value = value.replace(var, replace_dict[var_key])
  43. else:
  44. assert len(matches) == 1, "unable to replace multiple non-string matches"
  45. value = replace_dict[var_key]
  46. return value
  47. def find_replace(target, replace_dict):
  48. if isinstance(target, dict):
  49. for key, value in target.items():
  50. if isinstance(value, str):
  51. target[key] = find_replace_str(value, replace_dict)
  52. if isinstance(value, list):
  53. for i in range(len(value)):
  54. value[i] = find_replace_str(value[i], replace_dict)
  55. if isinstance(value, dict):
  56. find_replace(value, replace_dict)
  57. elif isinstance(target, list):
  58. for i in range(len(target)):
  59. target[i] = str(find_replace_str(target[i], replace_dict))
  60. def get_list(val):
  61. if not isinstance(val, list):
  62. return [val]
  63. else:
  64. return val
  65. def combine_dict(d, u):
  66. for k, v in u.items():
  67. if isinstance(v, collections.abc.Mapping):
  68. d[k] = combine_dict(d.get(k, {}), v)
  69. else:
  70. if k not in d:
  71. d[k] = v
  72. else:
  73. if not isinstance(d[k], list):
  74. d[k] = [d[k]]
  75. d[k].extend(i for i in get_list(v) if i not in d[k])
  76. return d
  77. def del_if_exists(t, d):
  78. """Deletes a key from a dictionary if it exists.
  79. Args:
  80. t (string): target key to delete
  81. d (dict): dictionary to delete from
  82. """
  83. if t in d:
  84. del d[t]
  85. return
  86. for k, v in d.items():
  87. if isinstance(v, collections.abc.Mapping):
  88. del_if_exists(t, v)
  89. def replace_dict(d, u, ignored_keys=[]):
  90. """Replaces values in dict d with values in dict u.
  91. Args:
  92. d (dict): the target dict to overwrite
  93. u (dict): the dict containing the values to overwrite the target dict
  94. Returns:
  95. dict d with values overwritten by the corresponding ones in dict u.
  96. """
  97. if u is not None:
  98. for k, v in u.items():
  99. if k not in ignored_keys:
  100. if v is None:
  101. del_if_exists(k, d)
  102. continue
  103. if isinstance(v, collections.abc.Mapping):
  104. d[k] = replace_dict(d.get(k, {}), v, ignored_keys)
  105. else:
  106. d[k] = v
  107. return d
  108. def get_val_by_key(d: dict, k):
  109. if k in d:
  110. return d[k]
  111. for v in d.values():
  112. if isinstance(v, dict):
  113. return get_val_by_key(v, k)
  114. return None
  115. def set_val_by_key(d: dict, k, vv):
  116. if k in d:
  117. d[k] = vv
  118. for v in d.values():
  119. if isinstance(v, dict):
  120. set_val_by_key(v, k, vv)
  121. def fetch_hostfile(hostfile_path):
  122. if not os.path.isfile(hostfile_path):
  123. logger.warning("Unable to find hostfile, will proceed with training "
  124. "with local resources only.")
  125. return None
  126. # e.g., worker-0 slots=16
  127. with open(hostfile_path, 'r') as fd:
  128. resource_pool = collections.OrderedDict()
  129. for line in fd.readlines():
  130. line = line.strip()
  131. if line == '':
  132. # skip empty lines
  133. continue
  134. try:
  135. hostname, slots = line.split()
  136. _, slot_count = slots.split("=")
  137. slot_count = int(slot_count)
  138. except ValueError as err:
  139. logger.error("Hostfile is not formatted correctly, unable to "
  140. "proceed with training.")
  141. raise err
  142. if hostname in resource_pool:
  143. logger.error("Hostfile contains duplicate hosts, unable to "
  144. "proceed with training.")
  145. raise ValueError("host {} is already defined".format(hostname))
  146. resource_pool[hostname] = slot_count
  147. return resource_pool
  148. def validate_ds_config(config: dict):
  149. def is_False(config: dict, key):
  150. if config is None:
  151. return False
  152. return bool(config.get(key))
  153. config_zero = config.get("zero_optimization", {})
  154. if not config_zero:
  155. return True
  156. stage = config_zero.get("stage")
  157. offload = False
  158. if stage == 1:
  159. return True
  160. elif stage == 2:
  161. if is_False(config_zero, "cpu_offload") and is_False(config_zero, "cpu_offload_params"):
  162. return False
  163. elif stage == 3:
  164. offload_devices = ["cpu", "nvme"]
  165. if config_zero.get("offload_optimizer", {}).get("device") in offload_devices:
  166. offload = True
  167. if config_zero.get("offload_param", {}).get("device") in offload_devices:
  168. offload = True
  169. else:
  170. return True
  171. # HF requires that "ZeRO Offload can only work with DeepSpeed optimizers"
  172. if offload and not config.get("optimizer"):
  173. return False
  174. return True
  175. def remove_dupe_dicts(l):
  176. """ Removes duplicate dictionaries from a list. Uses list comprehension and the json library to sort and stringify each dictionary and the set data type to ensure unique values. Works with nested data structures.
  177. Args:
  178. l (list): a list of (nested) data structures.
  179. Returns:
  180. A list of unique values.
  181. """
  182. list_of_strings = [json.dumps(d, sort_keys=True) for d in l]
  183. list_of_strings = set(list_of_strings)
  184. return [json.loads(s) for s in list_of_strings]
  185. def prune_config(config, ignored_keys=[]):
  186. """ Prunes the input configurations
  187. Args:
  188. configs (dict): A configuration dictionary.
  189. ignored_keys (list, optional): the keys of the sections to delete. Defaults to [].
  190. Returns:
  191. A configuration dictionary.
  192. """
  193. if ignored_keys:
  194. for k in ignored_keys:
  195. def find_del_key(d: dict, k: str):
  196. if k in d:
  197. del d[k]
  198. else:
  199. for dd in d.values():
  200. if isinstance(dd, dict):
  201. find_del_key(dd, k)
  202. find_del_key(config, k)
  203. def prune_configs(configs, ignored_keys=[]):
  204. """ Prunes the input list of configurations
  205. Args:
  206. configs (list): A list of configuration dictionaries.
  207. ignored_keys (list, optional): the keys of the sections to delete. Defaults to [].
  208. Returns:
  209. A list of valid and unique configuration dictionaries.
  210. """
  211. pruned_list = []
  212. for config in configs:
  213. prune_config(config, ignored_keys)
  214. pruned_list.append(config)
  215. return remove_dupe_dicts(pruned_list)
  216. def get_tuning_keys(tuning_space: dict):
  217. """Outputs the list of tunable parameters in the tuning space dict.
  218. Args:
  219. tuning_space (dict): a configuration dictionary containing tunable parameters as lists of values.
  220. Returns:
  221. A list of strings
  222. """
  223. tuning_keys = []
  224. for key, val in tuning_space.items():
  225. if isinstance(val, dict):
  226. tuning_keys.extend(get_tuning_keys(val))
  227. if isinstance(val, list) and len(val) > 1:
  228. tuning_keys.append(key)
  229. return tuning_keys
  230. def get_all_configs(tuning_space: dict, ignore_keys=None):
  231. """ Splits the tuning space dictionary to result in all combinations of values.
  232. Args:
  233. tuning_space (dict): the tuning space where tunable parameters are lists of values.
  234. """
  235. def gen_combinations(d: dict):
  236. keys, values = d.keys(), d.values()
  237. for v in values:
  238. if not isinstance(v, list):
  239. v = [v]
  240. values_choices = (gen_combinations(v) if isinstance(v, dict) else get_list(v) for v in values)
  241. for comb in itertools.product(*values_choices):
  242. yield dict(zip(keys, comb))
  243. all_configs = []
  244. ignored_key_vals = {}
  245. for ik in ignore_keys:
  246. ignored_key_vals[ik] = tuning_space.get(ik, {})
  247. del_if_exists(ik, tuning_space)
  248. for c in gen_combinations(tuning_space):
  249. replace_dict(c, ignored_key_vals)
  250. all_configs.append(c)
  251. return all_configs
  252. def canonical_name(config: dict, tuning_keys=None, prefix="", omit_val=False):
  253. """ Generates a name from the acronyms of the tuning keys in the config dict. TRAIN_MICRO_BATCH_SIZE_PER_GPU is always included in the tuning keys.
  254. Args:
  255. config (dict): the config dict used to generate the name
  256. tuning_keys (list, optional): the tuning keys used to generate the name. Defaults to None.
  257. prefix (str, optional): a string added to the beginning of the name. Defaults to None.
  258. """
  259. if TRAIN_MICRO_BATCH_SIZE_PER_GPU not in tuning_keys:
  260. tuning_keys.append(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
  261. if GRADIENT_ACCUMULATION_STEPS not in tuning_keys:
  262. tuning_keys.append(GRADIENT_ACCUMULATION_STEPS)
  263. tuning_keys.sort()
  264. def get_offload_name(offload_config):
  265. cname = ""
  266. if offload_config is None:
  267. return "None_"
  268. for key, val in offload_config.items():
  269. key = "".join(map(lambda c: c[0], key.split('_')))
  270. if (isinstance(val, int) or isinstance(val, float)) and val > 9000:
  271. cname += key + '{:.1e}'.format(val) + "_"
  272. else:
  273. if isinstance(val, bool):
  274. val = "T" if val else "F"
  275. cname += f"{key}{val}_"
  276. return cname
  277. def get_name_by_keys(config: dict, tuning_keys=None, omit_val=False):
  278. cname = ""
  279. if not tuning_keys or config is None:
  280. return cname
  281. for key, val in config.items():
  282. # skip the arg_mappings section when naming the exp file
  283. if key == "arg_mappings":
  284. continue
  285. if key == "offload_param":
  286. cname += "op_"
  287. if not omit_val:
  288. cname += get_offload_name(val)
  289. continue
  290. if key == "offload_optimizer":
  291. cname += "oo_"
  292. if not omit_val:
  293. cname += get_offload_name(val)
  294. continue
  295. # recursively call the func to get name for the child dicts
  296. if isinstance(val, dict):
  297. n = get_name_by_keys(val, tuning_keys, omit_val=omit_val)
  298. if n != "":
  299. cname += n + "_"
  300. if tuning_keys and key not in tuning_keys:
  301. continue
  302. key_str = "".join(map(lambda c: c[0], key.split('_')))
  303. if not omit_val:
  304. if (isinstance(val, int) or isinstance(val, float)) and val > 9000:
  305. cname += key_str + '{:.1e}'.format(val) + "_"
  306. else:
  307. if isinstance(val, bool):
  308. val = "T" if val else "F"
  309. cname += f"{key_str}{val}_"
  310. else:
  311. cname += key_str + "_"
  312. return cname[:-1]
  313. name = get_name_by_keys(config, tuning_keys, omit_val=omit_val)
  314. return prefix + (name if name != "" else "exp")
  315. def get_first_config(config: dict):
  316. if not config:
  317. return None
  318. cfg = copy.deepcopy(config)
  319. for key, val in cfg.items():
  320. if isinstance(val, dict):
  321. if key == "optimizer": # use user defined optimizer which might have lists of values as params
  322. cfg[key] = val
  323. else:
  324. cfg[key] = get_first_config(val)
  325. if isinstance(val, list) and len(val) > 0:
  326. cfg[key] = val[0]
  327. return cfg
  328. def write_experiments(exps: list, exps_dir: str):
  329. exp_paths = []
  330. for exp in exps:
  331. exp_name = exp['name']
  332. # write the expr config to a json file
  333. exp_path = os.path.join(exps_dir, f'{exp_name}.json')
  334. with open(exp_path, 'w') as fd:
  335. json.dump(exp, fd)
  336. exp_paths.append(exp_path)
  337. return exp_paths
  338. def memory_to_string(n, postfix="", units=None, precision=2):
  339. if units is None:
  340. if n // 10**12 > 0:
  341. return str(round(n / 1024**4, precision)) + " T" + postfix
  342. if n // 10**9 > 0:
  343. return str(round(n / 1024**3, precision)) + " G" + postfix
  344. elif n // 10**6 > 0:
  345. return str(round(n / 1024**2, precision)) + " M" + postfix
  346. elif n // 10**3 > 0:
  347. return str(round(n / 1014, precision)) + " K" + postfix
  348. else:
  349. return str(n) + " "
  350. else:
  351. if units == "T":
  352. return str(round(n / 1024**4, precision)) + " " + units
  353. if units == "G" + postfix:
  354. return str(round(n / 1024**3, precision)) + " " + units
  355. elif units == "M" + postfix:
  356. return str(round(n / 1024**2, precision)) + " " + units
  357. elif units == "K" + postfix:
  358. return str(round(n / 1024, precision)) + " " + units
  359. else:
  360. return str(n) + " "
  361. def number_to_string(n, postfix="", units=None, precision=2):
  362. if units is None:
  363. if n // 10**9 > 0:
  364. return str(round(n / 1000**3, precision)) + " B" + postfix
  365. if n // 10**6 > 0:
  366. return str(round(n / 1000**2, precision)) + " M" + postfix
  367. elif n // 10**3 > 0:
  368. return str(round(n / 1000**1, precision)) + " K" + postfix
  369. else:
  370. return str(n) + " "
  371. else:
  372. if units == "B" + postfix:
  373. return str(round(n / 1000**3, precision)) + " " + units
  374. elif units == "M" + postfix:
  375. return str(round(n / 1000**2, precision)) + " " + units
  376. elif units == "K" + postfix:
  377. return str(round(n / 1000**1, precision)) + " " + units
  378. else:
  379. return str(n) + " "