autotuner.py 53 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import shutil
  5. import subprocess
  6. import time
  7. import datetime
  8. import math
  9. import hjson
  10. from ..runtime.config_utils import dict_raise_error_on_duplicate_keys
  11. from ..runtime.constants import *
  12. from ..runtime.zero.config import ZERO_OPTIMIZATION, ZeroStageEnum
  13. from ..utils import logger
  14. from .config import DeepSpeedAutotuningConfig
  15. from .constants import *
  16. from .scheduler import ResourceManager
  17. from .tuner import GridSearchTuner, RandomTuner, ModelBasedTuner
  18. from .utils import *
  19. from deepspeed.accelerator import get_accelerator
  20. try:
  21. from tabulate import tabulate
  22. except ImportError:
  23. tabulate = None
  24. try:
  25. import mlflow
  26. has_mlflow = True
  27. except Exception as e:
  28. has_mlflow = False
  29. ZERO_OPTIMIZATION_STAGE = "stage"
  30. OFFLOAD_OPTIMIZER = "offload_optimizer"
  31. OFFLOAD_PARAM = "offload_param"
  32. ZERO_OPTIMIZATION_STAGE_DEFAULT = ZeroStageEnum.disabled
  33. class Autotuner:
  34. """The DeepSpeed Autotuner automatically discovers the optimal DeepSpeed configuration that delivers good training speed. The Autotuner uses model information, system information, and heuristics to efficiently tune system knobs that affect compute and memory efficiencies, such as ZeRO optimization stages, micro-batch sizes, and many other ZeRO optimization configurations. It not only reduces the time and resources user spend on tuning, but also can discover configurations better than hand-tuned methods.
  35. Autotuning with DeepSpeed requires no code change from DeepSpeed users. Please refer to the README for usage details.
  36. """
  37. def __init__(self, args, active_resources):
  38. self.args = args
  39. self.selected_exp_dir = None
  40. assert tabulate is not None, "Missing required package `tabulate`, please install with `pip install deepspeed[autotuning]`."
  41. logger.debug(f"autotuning args={args}")
  42. self.user_config = self._get_user_config(args.user_args)
  43. assert self.user_config is not None, "DeepSpeed configuration is not provided"
  44. self.autotuning_config = DeepSpeedAutotuningConfig(self.user_config)
  45. if self.user_config[AUTOTUNING]:
  46. if AUTOTUNING_EXPS_DIR in self.user_config[AUTOTUNING].keys():
  47. del self.user_config[AUTOTUNING][AUTOTUNING_EXPS_DIR]
  48. if AUTOTUNING_RESULTS_DIR in self.user_config[AUTOTUNING].keys():
  49. del self.user_config[AUTOTUNING][AUTOTUNING_RESULTS_DIR]
  50. self.exps_dir = self.autotuning_config.exps_dir
  51. if self.autotuning_config.overwrite and os.path.exists(self.exps_dir):
  52. shutil.rmtree(self.exps_dir, ignore_errors=True)
  53. if not os.path.exists(self.exps_dir):
  54. try:
  55. os.makedirs(self.exps_dir, exist_ok=True)
  56. logger.info(f"Created autotuning experiments directory: {self.exps_dir}")
  57. except:
  58. logger.error(
  59. f"Failed to create {self.exps_dir}, please check `exps_dir` in the autotuning config file is accessible by all the nodes in the job."
  60. )
  61. exit(-1)
  62. self.results_dir = self.autotuning_config.results_dir
  63. if self.autotuning_config.overwrite and os.path.exists(self.results_dir):
  64. shutil.rmtree(self.results_dir, ignore_errors=True)
  65. if not os.path.exists(self.results_dir):
  66. try:
  67. os.makedirs(self.results_dir, exist_ok=True)
  68. logger.info(f"Created autotuning results directory: {self.exps_dir}")
  69. except:
  70. logger.error(
  71. f"Failed to create {self.results_dir}, please check `results_dir` in the autotuning config file is accessible by all the nodes in the job."
  72. )
  73. exit(-1)
  74. # set the active resource for the autotuner resource manager
  75. self.rm = self._get_resource_manager(active_resources)
  76. # get resource requirement for each autotuning experiment
  77. self.exp_num_nodes, self.exp_num_gpus = self._get_exp_resources(args)
  78. assert self.exp_num_gpus <= self.rm.num_gpus_per_node, "num_gpus in the autotuning configuration must not be less than the --num_gpus value in the train script if any"
  79. assert self.exp_num_nodes <= len(
  80. self.rm.nodes
  81. ), "num_nodes in the autotuning configuration must not be less than the --num_nodes value in the train script if any"
  82. self.records = {}
  83. self.optimal_cmd = None
  84. self.optimal_ds_config = None
  85. self.mlflow_parent_id = None
  86. def print_tuning_results(self):
  87. """Print the autotuning results in tabular format.
  88. """
  89. best_space_records = self.get_best_space_records()
  90. tab = []
  91. if best_space_records:
  92. for key, val in best_space_records.items():
  93. if not val:
  94. continue
  95. row = []
  96. row.append(key)
  97. num_exps = 0
  98. if key == GLOBAL_TUNING_SPACE:
  99. cnt = 0
  100. for k, v in best_space_records.items():
  101. if k != GLOBAL_TUNING_SPACE:
  102. cnt += v[2]
  103. num_exps = cnt
  104. else:
  105. num_exps = val[2]
  106. row.append(num_exps)
  107. row.append(val[1])
  108. row.append(val[0]['name'])
  109. tab.append(row)
  110. summary = tabulate(tab,
  111. headers=["tuning_space", "num_experiments", "best_metric_val", "best_exp_name"],
  112. tablefmt="pipe")
  113. print(summary)
  114. with open(os.path.join(self.results_dir, 'summary.txt'), 'w', buffering=BUFSIZE) as fd:
  115. fd.write(summary)
  116. fd.flush()
  117. os.fsync(fd)
  118. if GLOBAL_TUNING_SPACE in best_space_records:
  119. best_exp, best_metric_val, total_num_exps = best_space_records[GLOBAL_TUNING_SPACE]
  120. if best_exp:
  121. logger.info(
  122. f"{best_exp['name']} is the optimal setup after tuning. The exp result is at {best_exp['result_dir']}."
  123. )
  124. else:
  125. logger.info(f"No optimal setup is found. Please check that experiments were run successfully.")
  126. tuning_duration = datetime.timedelta(seconds=(time.time() - self.start_time))
  127. logger.info(f"Tuning completed in {tuning_duration}")
  128. with open(os.path.join(self.results_dir, 'summary.txt'), 'a') as f:
  129. f.write(
  130. f"\n\nTuning completed in {tuning_duration}. Total number of experiments: {self.rm.experiment_count - 1}."
  131. )
  132. f.flush()
  133. def _get_user_config(self, user_args):
  134. """Get DeepSpeed configuration from the user arguments passed to the launcher.
  135. Args:
  136. user_args ([list]): user arguments passed to the DeepSpeed launcher
  137. Returns:
  138. [dict]: DeepSpeed configuration dictionary
  139. """
  140. user_config_file = None
  141. if "--deepspeed_config" in user_args:
  142. idx = user_args.index("--deepspeed_config")
  143. assert ".json" in user_args[
  144. idx + 1], "DeepSpeed --deepspeed_config requires a json file to specify the configuration"
  145. user_config_file = user_args[idx + 1]
  146. elif "--deepspeed" in user_args:
  147. idx = user_args.index("--deepspeed")
  148. if ".json" in user_args[idx + 1]:
  149. user_config_file = user_args[idx + 1]
  150. logger.debug(f"user_config_file = {user_config_file}")
  151. if user_config_file is not None:
  152. assert os.path.isfile(user_config_file), "DeepSpeed configuration file: {} is not an existing file".format(
  153. user_config_file)
  154. if os.path.exists(user_config_file):
  155. return json.load(open(user_config_file, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
  156. return None
  157. def _get_resource_manager(self, active_resources):
  158. """Initialize and return a resource manager
  159. Args:
  160. active_resources ([dict]): A dictionary of hostname and its slots (GPUs), e.g. {"worker-0": "0,1,2,3,4,5,6,7,8"}
  161. Raises:
  162. RuntimeError: raises the error if no GPU is available
  163. Returns:
  164. [ResourceManager]: A resource manager that schedules and runs autotuning experiments.
  165. """
  166. logger.info(f"active_resources = {active_resources}")
  167. hosts = []
  168. ngpus_per_node = 100
  169. for hostname, slots in active_resources.items():
  170. hosts.append(hostname)
  171. ngpus_per_node = min(len(slots), ngpus_per_node)
  172. assert ngpus_per_node > 0, "no gpu is available"
  173. return ResourceManager(args=self.args,
  174. hosts=hosts,
  175. num_gpus_per_node=ngpus_per_node,
  176. results_dir=self.results_dir,
  177. exps_dir=self.exps_dir,
  178. arg_mappings=self.autotuning_config.arg_mappings)
  179. def _get_exp_resources(self, args):
  180. """Get resource requirement for each autotuning experiment
  181. Args:
  182. args (dict): user args
  183. Returns:
  184. num_nodes, num_gpus: the number of gpus and number of nodes used in the autotuning experiments
  185. """
  186. if args.num_nodes > 0:
  187. num_nodes = args.num_nodes
  188. else:
  189. num_nodes = len(self.rm.nodes)
  190. if args.num_gpus > 0:
  191. num_gpus = args.num_gpus
  192. else:
  193. num_gpus = self.rm.num_gpus_per_node
  194. return num_nodes, num_gpus
  195. def metric(self):
  196. return self.autotuning_config.metric
  197. def fast_enabled(self):
  198. return self.autotuning_config.fast
  199. def max_train_batch_size(self):
  200. return self.autotuning_config.max_train_batch_size
  201. def mp_size(self):
  202. return self.autotuning_config.mp_size
  203. def max_train_micro_batch_size_per_gpu(self):
  204. if self.max_train_batch_size(
  205. ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size
  206. max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size() // (
  207. self.exp_num_gpus * self.exp_num_nodes) # gradient accumulation steps >=1
  208. return min(self.autotuning_config.max_train_micro_batch_size_per_gpu, max_train_micro_batch_size)
  209. else:
  210. return self.autotuning_config.max_train_micro_batch_size_per_gpu
  211. def min_train_micro_batch_size_per_gpu(self):
  212. return self.autotuning_config.min_train_micro_batch_size_per_gpu
  213. def num_tuning_micro_batch_sizes(self):
  214. return self.autotuning_config.num_tuning_micro_batch_sizes
  215. def fp16_enabled(self):
  216. if FP16 in self.user_config.keys():
  217. return self.user_config[FP16].get(FP16_ENABLED, FP16_ENABLED_DEFAULT)
  218. else:
  219. return False
  220. def get_gpu_memory_info(self):
  221. return get_accelerator().total_memory()
  222. def get_activation_memory_per_gpu(self):
  223. if self.model_info and "activation_mem_per_gpu" in self.model_info:
  224. return self.model_info["activation_mem_per_gpu"]
  225. def get_instantiation_memory_required_per_gpu(self, zero_stage):
  226. num_params = self.get_model_num_params()
  227. total_gpus = self.exp_num_nodes * self.exp_num_gpus
  228. fp16_enabled = self.fp16_enabled()
  229. if not num_params:
  230. return 0
  231. # assume the model uses Adam optimizer
  232. # ZeroStageEnum.disabled:
  233. params_mem = num_params * (2 if fp16_enabled else 4)
  234. gradients_mem = num_params * (2 if fp16_enabled else 4)
  235. optimizer_mem = num_params * (16 if fp16_enabled else 8)
  236. if zero_stage >= ZeroStageEnum.optimizer_states:
  237. optimizer_mem = optimizer_mem / total_gpus
  238. if zero_stage >= ZeroStageEnum.gradients:
  239. gradients_mem = gradients_mem / total_gpus
  240. if zero_stage >= ZeroStageEnum.weights:
  241. params_mem = params_mem / total_gpus
  242. mem_per_gpu = (params_mem + gradients_mem + optimizer_mem) / self.mp_size()
  243. return mem_per_gpu
  244. def _generate_experiments(self, tuning_space, max_train_batch_size_per_gpu):
  245. """Generates a list of autotuning experiments given a tuning_space.
  246. The corresponding parameter values are replaced by user-defined values in the DeepSpeed configuration file.
  247. Args:
  248. tuning_space ([dict]): A DeepSpeed configuration dictionary where a value can be a list (called a tuning parameter). For example,
  249. {
  250. "zero_optimization": {
  251. "stage": 1,
  252. "reduce_bucket_size": [5e7,
  253. 5e8,
  254. 1e9],
  255. "allgather_bucket_size": [5e7,
  256. 5e8,
  257. 1e9],
  258. }
  259. }
  260. reduce_bucket_size and allgather_bucket_size are the tuning parameters in this tuning space.
  261. Returns:
  262. [list]: a list of experiments generated by taking combinations of values of the tuning space. The above tuning space generates 3*3 = 9 experiments if the user DeepSpeed configuration file does not overwrite the two tuning parameters or define more tuning parameters.
  263. """
  264. exps = []
  265. # each zero stage uses a different template configuration file
  266. config_zero = tuning_space.get(ZERO_OPTIMIZATION, {})
  267. stage = config_zero.get(ZERO_OPTIMIZATION_STAGE, ZERO_OPTIMIZATION_STAGE_DEFAULT)
  268. template_config = {}
  269. if stage == 0:
  270. template_path = DEFAULT_TEMPLATE_PATH_ZERO_0
  271. template_config = hjson.load(open(template_path, 'r'))
  272. prefix = "z0_"
  273. elif stage == 1:
  274. template_path = DEFAULT_TEMPLATE_PATH_ZERO_1
  275. template_config = hjson.load(open(template_path, 'r'))
  276. prefix = "z1_"
  277. elif stage == 2:
  278. template_path = DEFAULT_TEMPLATE_PATH_ZERO_2
  279. template_config = hjson.load(open(template_path, 'r'))
  280. prefix = "z2_"
  281. elif stage == 3:
  282. template_path = DEFAULT_TEMPLATE_PATH_ZERO_3
  283. template_config = hjson.load(open(template_path, 'r'))
  284. model_info = self.model_info
  285. if model_info and "hidden_size" in model_info:
  286. hs = model_info["hidden_size"]
  287. template_config[ZERO_OPTIMIZATION]['reduce_bucket_size'] = hs * hs
  288. template_config[ZERO_OPTIMIZATION]['stage3_prefetch_bucket_size'] = 0.9 * hs * hs
  289. template_config[ZERO_OPTIMIZATION]['stage3_param_persistence_threshold'] = 10 * hs
  290. prefix = "z3_"
  291. else:
  292. return exps
  293. # replace the corresponding parameter values if the user specifies them in the DeepSpeed configuration file
  294. replace_dict(tuning_space, self.user_config, [ZERO_OPTIMIZATION, TRAIN_MICRO_BATCH_SIZE_PER_GPU])
  295. logger.debug(f"tuning_space = {json.dumps(tuning_space)}")
  296. all_configs = get_all_configs(tuning_space, ignore_keys=["optimizer"])
  297. tuning_keys = get_tuning_keys(tuning_space)
  298. logger.debug(f"tuning_keys = {tuning_keys}")
  299. logger.debug(f"before pruning total configs = {len(all_configs)}")
  300. pruned_list = prune_configs(all_configs)
  301. logger.debug(f"after pruning total configs = {len(pruned_list)}")
  302. for config in pruned_list:
  303. exp_config = copy.deepcopy(template_config)
  304. # fill the template with the expr config
  305. replace_dict(exp_config, config)
  306. # if the config does not use offloading, remove the offloading section
  307. config_zero = config.get(ZERO_OPTIMIZATION, None)
  308. if config_zero:
  309. if OFFLOAD_OPTIMIZER not in config_zero and OFFLOAD_OPTIMIZER in exp_config[ZERO_OPTIMIZATION]:
  310. del exp_config[ZERO_OPTIMIZATION][OFFLOAD_OPTIMIZER]
  311. if OFFLOAD_PARAM not in config_zero and OFFLOAD_PARAM in exp_config[ZERO_OPTIMIZATION]:
  312. del exp_config[ZERO_OPTIMIZATION][OFFLOAD_PARAM]
  313. # set gradient accumulation steps according to max_train_batch_size_per_gpu
  314. mbs = exp_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU]
  315. gas = max_train_batch_size_per_gpu // mbs
  316. exp_config[GRADIENT_ACCUMULATION_STEPS] = gas
  317. exp_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  318. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  319. exp = {}
  320. # generate the expr name
  321. exp_name = canonical_name(exp_config, tuning_keys, prefix)
  322. exp['name'] = exp_name
  323. exp[DS_CONFIG] = exp_config
  324. exp['num_gpus'] = self.exp_num_gpus
  325. exp['num_nodes'] = self.exp_num_nodes
  326. exps.append(exp)
  327. return exps
  328. def tune(self):
  329. """ Tunes Zero stages, micro batch size per GPU, and other Zero configurations. Performance metrics of different tuning spaces are recorded in self.records.
  330. """
  331. if has_mlflow:
  332. self.mlflow_parent_id = os.environ['MLFLOW_RUN_ID']
  333. mlflow.start_run(run_id=self.mlflow_parent_id)
  334. self.start_time = time.time()
  335. if self.fast_enabled():
  336. logger.info(f"Fast mode is enabled. Tuning micro batch size only.")
  337. # model info profile run with DEFAULT_MIN_MEM_CONFIG
  338. model_info = self.model_info_profile_run()
  339. if model_info:
  340. self.model_info = model_info
  341. else:
  342. return
  343. logger.info(f"The model has {number_to_string(self.get_model_num_params())} parameters.")
  344. self.gpu_mem = self.get_gpu_memory_info()
  345. logger.info(f"Memory per GPU in the system is {memory_to_string(self.gpu_mem, postfix='B')}.")
  346. self.activation_mem = self.get_activation_memory_per_gpu()
  347. logger.info(
  348. f"The model requires at least {memory_to_string(self.activation_mem, postfix='B')} activation memory for micro batch size 1."
  349. )
  350. stage = self.user_config.get(ZERO_OPTIMIZATION, {}).get(ZERO_OPTIMIZATION_STAGE, 0)
  351. user_zero_stages = [stage] if not isinstance(stage, list) else stage
  352. logger.info(f"User-defined zero stages are {stage}.")
  353. mbs = 0
  354. max_mbs = 0
  355. metric_val = 0
  356. required_gpu_mem = self.get_instantiation_memory_required_per_gpu(ZeroStageEnum.disabled) + self.activation_mem
  357. if self.gpu_mem > required_gpu_mem:
  358. if "all" in user_zero_stages or ZeroStageEnum.disabled in user_zero_stages:
  359. logger.info(
  360. f"The model might be runable with ZERO 0 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1), adding DEFAULT_TUNING_SPACE_ZERO_0 to the global tuning space"
  361. )
  362. next_max_mbs, next_mbs, next_metric_val = self.tune_space(DEFAULT_TUNING_SPACE_ZERO_0)
  363. if next_mbs > mbs:
  364. mbs = next_mbs
  365. max_mbs = next_max_mbs
  366. metric_val = next_metric_val
  367. if has_mlflow:
  368. mlflow.log_metric(f"z0{self.metric()}", next_metric_val)
  369. else:
  370. logger.info(
  371. f"The model is not runable with ZERO stage {ZeroStageEnum.disabled} (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1)"
  372. )
  373. required_gpu_mem = self.get_instantiation_memory_required_per_gpu(
  374. ZeroStageEnum.optimizer_states) + self.activation_mem
  375. if self.gpu_mem > required_gpu_mem:
  376. if "all" in user_zero_stages or ZeroStageEnum.optimizer_states in user_zero_stages:
  377. logger.info(
  378. f"The model might be runable with ZERO 1 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory), adding DEFAULT_TUNING_SPACE_ZERO_1 to the global tuning space"
  379. )
  380. next_max_mbs, next_mbs, next_metric_val = self.tune_space(DEFAULT_TUNING_SPACE_ZERO_1,
  381. prev_max_mbs=max_mbs,
  382. prev_best_mbs=mbs,
  383. prev_best_metric_val=metric_val)
  384. if next_mbs > mbs:
  385. mbs = next_mbs
  386. max_mbs = next_max_mbs
  387. metric_val = next_metric_val
  388. if has_mlflow:
  389. mlflow.log_metric(f"z1{self.metric()}", next_metric_val)
  390. else:
  391. logger.info(
  392. f"The model is not runable with ZERO stage {ZeroStageEnum.optimizer_states} (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1)"
  393. )
  394. required_gpu_mem = self.get_instantiation_memory_required_per_gpu(
  395. ZeroStageEnum.gradients) + self.activation_mem
  396. if self.gpu_mem > required_gpu_mem:
  397. if "all" in user_zero_stages or ZeroStageEnum.gradients in user_zero_stages:
  398. logger.info(
  399. f"The model might be runable with ZERO 2 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory), adding DEFAULT_TUNING_SPACE_ZERO_2 to the global tuning space"
  400. )
  401. next_max_mbs, next_mbs, next_metric_val = self.tune_space(DEFAULT_TUNING_SPACE_ZERO_2,
  402. prev_max_mbs=max_mbs,
  403. prev_best_mbs=mbs,
  404. prev_best_metric_val=metric_val)
  405. if next_mbs > mbs:
  406. mbs = next_mbs
  407. max_mbs = next_max_mbs
  408. metric_val = next_metric_val
  409. if has_mlflow:
  410. mlflow.log_metric(f"z2{self.metric()}", next_metric_val)
  411. else:
  412. logger.info(
  413. f"The model is not runable with ZERO stage {ZeroStageEnum.gradients} (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1)"
  414. )
  415. required_gpu_mem = self.get_instantiation_memory_required_per_gpu(ZeroStageEnum.weights) + self.activation_mem
  416. if self.gpu_mem > required_gpu_mem:
  417. if "all" in user_zero_stages or ZeroStageEnum.weights in user_zero_stages:
  418. logger.info(
  419. f"The model might be runable with ZERO 3 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory), adding DEFAULT_TUNING_SPACE_ZERO_3 to the global tuning space"
  420. )
  421. _, _, next_metric_val = self.tune_space(DEFAULT_TUNING_SPACE_ZERO_3,
  422. prev_max_mbs=max_mbs,
  423. prev_best_mbs=mbs,
  424. prev_best_metric_val=metric_val)
  425. if has_mlflow:
  426. mlflow.log_metric(f"z3{self.metric()}", next_metric_val)
  427. else:
  428. logger.info(
  429. f"The model has {self.get_model_num_params()} parameters and requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory per GPU with DeepSpeed Zero stage {ZeroStageEnum.weights} optimization. Memory per GPU in system is {memory_to_string(self.gpu_mem)}. No tuning is performed."
  430. )
  431. return
  432. if has_mlflow:
  433. mlflow.end_run()
  434. def tune_space(self, tuning_space, prev_max_mbs=0, prev_best_mbs=0, prev_best_metric_val=0):
  435. config_zero = tuning_space.get(ZERO_OPTIMIZATION, {})
  436. stage = config_zero.get(ZERO_OPTIMIZATION_STAGE, None)
  437. tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage)
  438. tuning_micro_batch_sizes = []
  439. max_train_batch_size_per_gpu = 0
  440. tuning_micro_batch_sizes_overwritten = False
  441. # calculate max micro batch size using gpu memory, model instantiation memory and activation memory
  442. # calculated_max_micro_batch_size = (memory_per_gpu - instantiation_memory) // activation_memory_micro_batch_size_1
  443. calculated_max_micro_batch_size = int(
  444. self.gpu_mem - self.get_instantiation_memory_required_per_gpu(stage)) // self.activation_mem
  445. logger.info(
  446. f"Start tuning for space {tuning_space_name}, calculated_max_micro_batch_size = {calculated_max_micro_batch_size}"
  447. )
  448. if calculated_max_micro_batch_size < prev_max_mbs:
  449. logger.info(f"No need to tune Zero stage {stage}. End tuning for space {tuning_space_name}")
  450. return 0, 0, 0
  451. if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self.user_config and isinstance(
  452. self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU], list):
  453. # user-specified micro batch size per gpu is a list which overwrites the default tuning behavior
  454. tuning_micro_batch_sizes = [
  455. s for s in self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] if isinstance(s, int)
  456. ]
  457. gas = self.get_gas_from_user_config()
  458. min_micro_batch_size = min(tuning_micro_batch_sizes)
  459. max_micro_batch_size = max(tuning_micro_batch_sizes)
  460. max_train_batch_size_per_gpu = max_micro_batch_size * gas
  461. tuning_micro_batch_sizes_overwritten = True
  462. else:
  463. # auto-detects the list of micro batch sizes to tune
  464. min_micro_batch_size, max_micro_batch_size = self.get_min_max_micro_batch_size(
  465. stage, prev_max_mbs, calculated_max_micro_batch_size)
  466. if max_micro_batch_size < prev_max_mbs:
  467. logger.info(f"No need to tune Zero stage {stage}. End tuning for space {tuning_space_name}")
  468. return 0, 0, 0
  469. tuning_micro_batch_sizes, max_train_batch_size_per_gpu = self.get_tuning_micro_batch_size_list(
  470. min_micro_batch_size,
  471. max_micro_batch_size,
  472. num_tuning_micro_batch_sizes=self.num_tuning_micro_batch_sizes())
  473. logger.info(
  474. f"tuning_micro_batch_sizes = {tuning_micro_batch_sizes}, max_train_batch_size_per_gpu = {max_train_batch_size_per_gpu}"
  475. )
  476. # return if the tuning_micro_batch_sizes list is empty
  477. if not tuning_micro_batch_sizes:
  478. logger.info(f"End tuning for space {tuning_space_name}")
  479. return 0, 0, 0
  480. # tune micro batch sizes and gradient accumulation steps given max_train_batch_size_per_gpu
  481. tuning_micro_batch_sizes = self.run_tuning_micro_batch_sizes(tuning_micro_batch_sizes,
  482. max_train_batch_size_per_gpu,
  483. min_micro_batch_size, stage,
  484. tuning_micro_batch_sizes_overwritten)
  485. fast_best_record = self.get_best_space_record(tuning_space_name)
  486. fast_best_metric_val = fast_best_record[1] if fast_best_record else 0
  487. fast_best_mbs = fast_best_record[0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU] if fast_best_record else 0
  488. logger.info(f"fast_best_mbs = {fast_best_mbs}, name = {fast_best_record[0]['name']}")
  489. if self.fast_enabled() or stage == 0:
  490. logger.info(f"End tuning for space: {tuning_space_name}")
  491. return max_micro_batch_size, fast_best_mbs, fast_best_metric_val
  492. # if the best metric or the micro batch size for that best metric in the current Zero stage after tuning micro batch size is less than the corresponding value in the previous Zero stage, return, do not tune other Zero configuration parameters
  493. if stage > 0:
  494. if fast_best_mbs <= prev_best_mbs or fast_best_metric_val < prev_best_metric_val:
  495. logger.info(
  496. f"End tuning for space: {tuning_space_name}. No need to tune other Zero configuration parameters.")
  497. return max_micro_batch_size, fast_best_mbs, fast_best_metric_val
  498. tuning_space[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = tuning_micro_batch_sizes
  499. tuning_space_name = canonical_name(tuning_space,
  500. tuning_keys=get_tuning_keys(tuning_space),
  501. prefix="z" + str(stage) + "_",
  502. omit_val=True)
  503. logger.info(f'Tuning space is {tuning_space}')
  504. logger.info(f'Tuning space name is {tuning_space_name}')
  505. exps = self._generate_experiments(tuning_space, max_train_batch_size_per_gpu)
  506. logger.info(f'Tuner type is {self.autotuning_config.tuner_type}')
  507. if self.autotuning_config.tuner_type == AUTOTUNING_TUNER_MODELBASED:
  508. t = ModelBasedTuner(exps, self.rm, self.metric(), tuning_space)
  509. elif self.autotuning_config.tuner_type == AUTOTUNING_TUNER_RANDOM:
  510. t = RandomTuner(exps, self.rm, self.metric())
  511. else:
  512. t = GridSearchTuner(exps, self.rm, self.metric())
  513. sample_size = len(self.rm.nodes) * self.rm.num_gpus_per_node // (self.exp_num_gpus * self.exp_num_nodes)
  514. num_exps = t.tune(sample_size=sample_size,
  515. n_trials=self.autotuning_config.tuner_num_trials,
  516. early_stopping=self.autotuning_config.tuner_early_stopping)
  517. exp = t.best_exp
  518. metric_val = t.best_metric_val
  519. if exp:
  520. self.update_records(tuning_space_name, exp, metric_val, num_exps)
  521. full_best_record = self.get_best_space_record(tuning_space_name)
  522. full_best_metric_val = full_best_record[1] if full_best_record else -1
  523. if full_best_metric_val > fast_best_metric_val:
  524. best_metric_val = full_best_metric_val
  525. best_mbs = full_best_record[0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU] if full_best_record else -1
  526. else:
  527. best_metric_val = fast_best_metric_val
  528. best_mbs = fast_best_mbs
  529. logger.info(f"End tuning for space: {tuning_space_name}")
  530. return max_micro_batch_size, best_mbs, best_metric_val
  531. def get_plateau_mbs(self, tuning_space_name):
  532. if tuning_space_name not in self.records:
  533. return 0
  534. space_records = self.records[tuning_space_name]
  535. sorted_space_records = sorted(space_records, key=lambda x: x[0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU])
  536. prev_metric_val = None
  537. prev_micro_batch_size = 0
  538. for (exp, metric_val, _) in sorted_space_records:
  539. if prev_metric_val:
  540. if metric_val < prev_metric_val:
  541. break
  542. if (metric_val >= prev_metric_val
  543. and (metric_val - prev_metric_val) / prev_metric_val < METRIC_PERCENT_DIFF_CONST):
  544. break
  545. prev_metric_val = metric_val
  546. prev_micro_batch_size = exp[DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU]
  547. plateau_mbs = prev_micro_batch_size
  548. return plateau_mbs
  549. def get_model_num_params(self):
  550. if self.model_info and "num_params" in self.model_info:
  551. return self.model_info["num_params"]
  552. def model_info_profile_run(self):
  553. """Does a model information profiling experiment that collects the number of model parameters and activation memory.\
  554. The experiment produces a "profile_model_info" folder under self.results_dir.
  555. Returns:
  556. [dict]: a model information dictionary, e.g., {"num_params": 335144976, "trainable_num_params": 335144976, "activation_mem_per_gpu": 324358144, "rank": 0}
  557. """
  558. logger.info("Starting model info profile run.")
  559. model_info = self.autotuning_config.model_info
  560. if model_info and MODEL_INFO_NUM_PARAMS in model_info:
  561. return model_info
  562. ds_config = copy.deepcopy(self.user_config)
  563. replace_dict(ds_config, DEFAULT_MIN_MEM_CONFIG)
  564. model_info_path = os.path.join(self.results_dir, "profile_model_info", "model_info.json")
  565. ds_config[AUTOTUNING] = {"enabled": True, "model_info_path": model_info_path, "model_info": {"profile": True}}
  566. exp_config = {}
  567. exp_name = "profile_model_info"
  568. exp_config['name'] = exp_name
  569. exp_config[DS_CONFIG] = ds_config
  570. exp_config['num_gpus'] = self.exp_num_gpus
  571. exp_config['num_nodes'] = self.exp_num_nodes
  572. exp_config['hostfile'] = self.args.hostfile
  573. exp_path = os.path.join(self.exps_dir, f'{exp_name}.json')
  574. with open(exp_path, 'w', buffering=BUFSIZE) as fd:
  575. json.dump(exp_config, fd)
  576. fd.flush()
  577. os.fsync(fd)
  578. self.rm.schedule_experiments([exp_path])
  579. self.rm.run()
  580. for exp_id, (exp_json, err) in self.rm.finished_experiments.items():
  581. self.rm.clear()
  582. if err:
  583. logger.error(f"The model is not runnable with DeepSpeed with error = {err}")
  584. return None
  585. if os.path.exists(model_info_path):
  586. with open(model_info_path, 'r') as f:
  587. model_info = hjson.load(f)
  588. return model_info
  589. def update_records(self, space_name, exp, metric_val, num_exps):
  590. if space_name not in self.records:
  591. self.records[space_name] = [(exp, metric_val, num_exps)]
  592. else:
  593. self.records[space_name].append((exp, metric_val, num_exps))
  594. def get_best_space_record(self, space_name):
  595. if space_name not in self.records:
  596. return None
  597. space_records = self.records[space_name]
  598. best_space_record = None
  599. space_num_exps = 0
  600. for (exp, metric_val, num_exps) in space_records:
  601. space_num_exps += num_exps
  602. if best_space_record is None or metric_val > best_space_record[1]:
  603. best_space_record = (exp, metric_val)
  604. if best_space_record:
  605. best_space_record = best_space_record + (space_num_exps, )
  606. return best_space_record
  607. def get_best_space_records(self):
  608. best_space_records = {}
  609. global_best_record = None
  610. for space_name, space_records in self.records.items():
  611. best_space_record = self.get_best_space_record(space_name)
  612. if best_space_record:
  613. best_space_records[space_name] = best_space_record
  614. if not global_best_record or best_space_record[1] > global_best_record[1]:
  615. global_best_record = best_space_record
  616. if global_best_record:
  617. best_space_records[GLOBAL_TUNING_SPACE] = global_best_record
  618. return best_space_records
  619. def run_tuning_micro_batch_sizes(self, tuning_micro_batch_sizes, max_train_batch_size_per_gpu,
  620. min_micro_batch_size, stage, tuning_micro_batch_sizes_overwritten):
  621. assert tuning_micro_batch_sizes, "the tuning micro batch size list is empty"
  622. tuning_micro_batch_sizes.sort()
  623. max_micro_batch_size = tuning_micro_batch_sizes[-1]
  624. max_micro_batch_size_metric_val = 0
  625. ds_config = get_first_config(self.user_config)
  626. ds_config[ZERO_OPTIMIZATION] = {ZERO_OPTIMIZATION_STAGE: stage}
  627. tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage)
  628. exp_paths = []
  629. for mbs in tuning_micro_batch_sizes:
  630. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  631. gas = max_train_batch_size_per_gpu // mbs
  632. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  633. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  634. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  635. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  636. exp_config = {}
  637. exp_config['name'] = exp_name
  638. exp_config[DS_CONFIG] = ds_config
  639. exp_config['num_gpus'] = self.exp_num_gpus
  640. exp_config['num_nodes'] = self.exp_num_nodes
  641. exp_config['hostfile'] = self.args.hostfile
  642. exp_path = os.path.join(self.exps_dir, f'{exp_name}.json')
  643. with open(exp_path, 'w', buffering=BUFSIZE) as fd:
  644. json.dump(exp_config, fd)
  645. fd.flush()
  646. os.fsync(fd)
  647. exp_paths.append(exp_path)
  648. self.rm.schedule_experiments(exp_paths)
  649. self.rm.run()
  650. for exp_id, (exp, err) in self.rm.finished_experiments.items():
  651. if exp:
  652. metric_file = exp[DS_CONFIG][AUTOTUNING][AUTOTUNING_METRIC_PATH]
  653. if os.path.exists(metric_file):
  654. with open(metric_file, 'r') as f:
  655. results = hjson.load(f)
  656. metric_val = results[self.metric()]
  657. self.update_records(tuning_space_name, exp, metric_val, 1)
  658. if max_micro_batch_size == exp[DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU]:
  659. max_micro_batch_size_metric_val = metric_val
  660. if has_mlflow:
  661. os.environ.pop('MLFLOW_RUN_ID')
  662. mlflow.start_run(nested=True, run_name=exp['name'])
  663. for metric in results:
  664. mlflow.log_metric(metric, results[metric])
  665. mlflow.end_run()
  666. os.environ['MLFLOW_RUN_ID'] = self.mlflow_parent_id
  667. else:
  668. self.update_records(tuning_space_name, exp, 0, 1)
  669. else:
  670. mbs = exp[DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU]
  671. logger.info(f"micro batch size = {mbs} was not run successfully")
  672. self.rm.clear()
  673. if tuning_micro_batch_sizes_overwritten:
  674. return tuning_micro_batch_sizes
  675. # in a auto-detected tuning_micro_batch_sizes list, max_micro_batch_size might not be performant as the memory consumption is close to max
  676. # try smaller values while gas stays the same
  677. # if finding a more performant mbs value, use it to replace max_micro_batch_size in the list
  678. min_micro_batch_size_with_same_gas = (tuning_micro_batch_sizes[-2] +
  679. 1) if len(tuning_micro_batch_sizes) > 1 else min_micro_batch_size
  680. prev_best_metric_val = max_micro_batch_size_metric_val
  681. prev_best_mbs = max_micro_batch_size
  682. stride = (max_micro_batch_size - min_micro_batch_size_with_same_gas) // 3
  683. if stride == 0:
  684. stride = 1
  685. for mbs in reversed(range(min_micro_batch_size_with_same_gas, max_micro_batch_size, stride)):
  686. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  687. gas = max_train_batch_size_per_gpu // mbs
  688. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  689. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  690. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  691. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  692. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  693. if metric_val:
  694. with open(metric_file, 'r') as f:
  695. results = hjson.load(f)
  696. metric_val = results[self.metric()]
  697. if has_mlflow:
  698. os.environ.pop('MLFLOW_RUN_ID')
  699. mlflow.start_run(nested=True, run_name=exp_name)
  700. for metric in results:
  701. mlflow.log_metric(metric, results[metric])
  702. mlflow.end_run()
  703. os.environ['MLFLOW_RUN_ID'] = self.mlflow_parent_id
  704. self.update_records(tuning_space_name, exp, metric_val, 1)
  705. if metric_val > prev_best_metric_val * (1 + METRIC_PERCENT_DIFF_CONST):
  706. prev_best_metric_val = metric_val
  707. prev_best_mbs = mbs
  708. else:
  709. break
  710. else:
  711. self.update_records(tuning_space_name, exp, 0, 1)
  712. break
  713. if prev_best_mbs != max_micro_batch_size:
  714. tuning_micro_batch_sizes[-1] = prev_best_mbs
  715. return tuning_micro_batch_sizes
  716. def get_min_max_micro_batch_size(self, stage, min_micro_batch_size, calculated_max_micro_batch_size):
  717. # get min and max micro batch size with gradient accumulation steps = 1
  718. if min_micro_batch_size > calculated_max_micro_batch_size:
  719. return -1, -1
  720. used_micro_batch_sizes = []
  721. tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage)
  722. ds_config = get_first_config(self.user_config)
  723. ds_config[ZERO_OPTIMIZATION] = {ZERO_OPTIMIZATION_STAGE: stage}
  724. gas = self.get_gas_from_user_config()
  725. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  726. # search for the min micro batch size
  727. if min_micro_batch_size < 1:
  728. if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self.user_config and isinstance(
  729. self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU], int):
  730. # user specifies train_micro_batch_size_per_gpu as an int
  731. mbs = int(self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU])
  732. else:
  733. # user does not specify train_micro_batch_size_per_gpu or sets it to "auto" when using Hugging Face
  734. val = self.get_val_from_user_args(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
  735. if val:
  736. mbs = int(val)
  737. else:
  738. mbs = 1
  739. assert mbs > 0, "The micro batch size per GPU must be greater than 0."
  740. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  741. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  742. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  743. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  744. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  745. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  746. if metric_val:
  747. self.update_records(tuning_space_name, exp, metric_val, 1)
  748. used_micro_batch_sizes.append(mbs)
  749. min_micro_batch_size = mbs
  750. else:
  751. self.update_records(tuning_space_name, exp, 0, 1)
  752. logger.info(f"User-specified micro batch size per GPU {mbs} does not run")
  753. if self.min_train_micro_batch_size_per_gpu() == mbs:
  754. return -1, -1
  755. mbs = self.min_train_micro_batch_size_per_gpu()
  756. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  757. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  758. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  759. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  760. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  761. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  762. if not metric_val:
  763. self.update_records(tuning_space_name, exp, 0, 1)
  764. logger.info(f"min_train_micro_batch_size_per_gpu {mbs} is not runnable.")
  765. return -1, -1
  766. self.update_records(tuning_space_name, exp, metric_val, 1)
  767. min_micro_batch_size = mbs
  768. used_micro_batch_sizes.append(mbs)
  769. else:
  770. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = min_micro_batch_size
  771. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  772. ds_config[TRAIN_BATCH_SIZE] = min_micro_batch_size * gas * \
  773. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  774. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(min_micro_batch_size)
  775. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  776. if metric_val:
  777. self.update_records(tuning_space_name, exp, metric_val, 1)
  778. used_micro_batch_sizes.append(min_micro_batch_size)
  779. else:
  780. self.update_records(tuning_space_name, exp, 0, 1)
  781. return -1, -1
  782. # search for the max micro batch size
  783. max_micro_batch_size = min(calculated_max_micro_batch_size, self.max_train_micro_batch_size_per_gpu())
  784. for mbs in [math.ceil(1.05 * max_micro_batch_size), max_micro_batch_size, int(0.95 * max_micro_batch_size)]:
  785. if mbs > self.max_train_micro_batch_size_per_gpu():
  786. continue
  787. if mbs in used_micro_batch_sizes:
  788. return min_micro_batch_size, mbs
  789. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  790. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  791. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  792. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  793. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  794. if metric_val:
  795. logger.info(f"mbs = {mbs} is found as max mbs")
  796. self.update_records(tuning_space_name, exp, metric_val, 1)
  797. used_micro_batch_sizes.append(mbs)
  798. return min_micro_batch_size, mbs
  799. else:
  800. self.update_records(tuning_space_name, exp, 0, 1)
  801. space_records = self.records[tuning_space_name] if tuning_space_name in self.records else []
  802. if space_records:
  803. prev_idx = min(range(len(space_records)),
  804. key=lambda i: abs(space_records[i][0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU] -
  805. min_micro_batch_size))
  806. prev_metric_val = space_records[prev_idx][1]
  807. else:
  808. prev_metric_val = None
  809. low = min_micro_batch_size
  810. high = max_micro_batch_size
  811. # binary search until low is the smallest micro batch size that OOMs.
  812. while low <= high:
  813. mid = int((low + high) // 2)
  814. logger.debug(f"trying mbs = {mid}, low = {low}, high = {high}")
  815. if mid not in used_micro_batch_sizes:
  816. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mid
  817. ds_config[TRAIN_BATCH_SIZE] = mid * gas * \
  818. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  819. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mid)
  820. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  821. if metric_val:
  822. low = mid + 1
  823. self.update_records(tuning_space_name, exp, metric_val, 1)
  824. used_micro_batch_sizes.append(mid)
  825. if prev_metric_val and (
  826. (metric_val - prev_metric_val) / prev_metric_val) < METRIC_PERCENT_DIFF_CONST:
  827. logger.info(f"performance plateaus at mbs = {low}")
  828. break
  829. prev_metric_val = metric_val
  830. else:
  831. self.update_records(tuning_space_name, exp, 0, 1)
  832. high = mid - 1
  833. else:
  834. low = mid + 1
  835. max_micro_batch_size = low - 1
  836. logger.info(f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}.")
  837. return min_micro_batch_size, max_micro_batch_size
  838. def get_gas_from_user_config(self):
  839. gas = 1
  840. if GRADIENT_ACCUMULATION_STEPS in self.user_config:
  841. gas_in_config = self.user_config[GRADIENT_ACCUMULATION_STEPS]
  842. if isinstance(gas_in_config, int):
  843. gas = gas_in_config
  844. elif gas_in_config == "auto": # GRADIENT_ACCUMULATION_STEPS: "auto"
  845. val = self.get_val_from_user_args(GRADIENT_ACCUMULATION_STEPS)
  846. if val:
  847. gas = int(val)
  848. elif isinstance(gas_in_config, list):
  849. logger.info(
  850. f"Specifying a list of {GRADIENT_ACCUMULATION_STEPS} to tune is not supported. 1 would be used.")
  851. assert gas > 0, "Gradient accumulation steps must be positive."
  852. return gas
  853. def get_val_from_user_args(self, ds_name):
  854. arg_mappings = self.autotuning_config.arg_mappings
  855. user_args = self.args.user_args
  856. if arg_mappings and ds_name in arg_mappings:
  857. arg_name = arg_mappings[ds_name]
  858. if arg_name in user_args:
  859. idx = user_args.index(arg_name)
  860. if user_args[idx + 1].isnumeric():
  861. return (user_args[idx + 1])
  862. return None
  863. def get_tuning_micro_batch_size_list(self, min_micro_batch_size, max_micro_batch_size,
  864. num_tuning_micro_batch_sizes):
  865. """Get a list of micro batch sizes to tune based on min and max values, as well as the size of the list.
  866. Args:
  867. min_micro_batch_size ([int]): min micro batch size per GPU
  868. max_micro_batch_size ([int]): max micro batch size per GPU
  869. num_tuning_micro_batch_sizes (int): the number of items in the returned list
  870. Returns:
  871. [list]: a list of micro batch sizes to tune.
  872. """
  873. if min_micro_batch_size <= 0 or max_micro_batch_size <= 0:
  874. logger.info(
  875. f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}")
  876. return [], 0
  877. # NUM_GPUS=$(( ${NUM_WORKERS} * ${NUM_GPUS_PER_WORKER} ))
  878. # DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) ))
  879. # GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${BATCH_SIZE} * ${DP_SIZE}) ))
  880. if self.max_train_batch_size(
  881. ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size
  882. max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size() // (self.exp_num_gpus *
  883. self.exp_num_nodes)
  884. else:
  885. gas = self.get_gas_from_user_config()
  886. max_train_batch_size_per_gpu = max_micro_batch_size * gas // self.mp_size()
  887. logger.info(f"max_train_batch_size_per_gpu = {max_train_batch_size_per_gpu}")
  888. if min_micro_batch_size < max_micro_batch_size // 2:
  889. min_micro_batch_size = max_micro_batch_size // 2
  890. # constant stride
  891. stride = (max_micro_batch_size - min_micro_batch_size) // num_tuning_micro_batch_sizes
  892. if stride == 0:
  893. stride = 1
  894. ls = []
  895. min_gas = max_train_batch_size_per_gpu // max_micro_batch_size
  896. # if gas is the same as min_gas, do not add mbs to the tuning list
  897. for mbs in range(min_micro_batch_size, max_micro_batch_size, stride):
  898. if max_train_batch_size_per_gpu // mbs != min_gas:
  899. ls.append(mbs)
  900. ls.append(max_micro_batch_size)
  901. return ls, max_train_batch_size_per_gpu
  902. def run_ds_config(self, ds_config, exp_name):
  903. exp_config = {}
  904. exp_config['name'] = exp_name
  905. exp_config[DS_CONFIG] = ds_config
  906. exp_config['num_gpus'] = self.exp_num_gpus
  907. exp_config['num_nodes'] = self.exp_num_nodes
  908. exp_config['hostfile'] = self.args.hostfile
  909. exp_path = os.path.join(self.exps_dir, f'{exp_name}.json')
  910. logger.debug(f'run_ds_config exp_name = {exp_name}')
  911. with open(exp_path, 'w', buffering=BUFSIZE) as fd:
  912. json.dump(exp_config, fd)
  913. fd.flush()
  914. os.fsync(fd)
  915. self.rm.schedule_experiments([exp_path])
  916. self.rm.run()
  917. exp, metric_val = self.rm.parse_results(self.metric())
  918. self.rm.clear()
  919. return exp, metric_val
  920. def write_optimal_config(self):
  921. best_space_records = self.get_best_space_records()
  922. if GLOBAL_TUNING_SPACE not in best_space_records:
  923. return
  924. best_exp, best_metric_val, _ = best_space_records[GLOBAL_TUNING_SPACE]
  925. if best_exp:
  926. exp_dir = best_exp["result_dir"]
  927. cmd = None
  928. with open(os.path.join(exp_dir, "cmd.txt"), "r") as f:
  929. cmd = [str(i) for i in f.read().split()]
  930. ds_config = hjson.load(open(os.path.join(exp_dir, "ds_config.json"), "r"))
  931. ds_config.pop(AUTOTUNING)
  932. ds_config_path = os.path.join(self.results_dir, "ds_config_optimal.json")
  933. json.dump(ds_config, open(ds_config_path, "w"))
  934. cmd_path = os.path.join(self.results_dir, "cmd_optimal.txt")
  935. with open(cmd_path, "w") as fd:
  936. fd.write(" ".join(cmd))
  937. fd.write("\n")
  938. fd.flush()
  939. self.optimal_cmd = cmd
  940. self.optimal_ds_config = ds_config
  941. logger.info(
  942. f"Wrote the optimal DeepSpeed configuration found by autotuning to {ds_config_path}, and the corresponding DeepSpeed command to {cmd_path}"
  943. )
  944. def run_after_tuning(self):
  945. """ Launches the training with the optimal DeepSpeed configuration found through the autotuning process.
  946. "ds_config_optimal.json" describing the optimal DeepSpeed configuration as well the command used to launch training "cmd_optimal.txt" are saved to self.results_dir.
  947. """
  948. if self.optimal_cmd:
  949. result = subprocess.Popen(self.optimal_cmd)
  950. result.wait()
  951. logger.info(f"Done running with the optimal DeepSpeed configuration using {self.optimal_cmd}")
  952. else:
  953. logger.info(f"No optimal DeepSpeed configuration found by autotuning.")