autotuner.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132
  1. import copy
  2. import json
  3. import os
  4. from random import sample
  5. import shutil
  6. import subprocess
  7. import hjson
  8. import torch
  9. import time
  10. import datetime
  11. import math
  12. from ..runtime.config_utils import dict_raise_error_on_duplicate_keys
  13. from ..runtime.constants import *
  14. from ..runtime.zero.constants import *
  15. from ..utils import logger
  16. from .config import DeepSpeedAutotuningConfig
  17. from .constants import *
  18. from .scheduler import ResourceManager, run_experiment
  19. from .tuner import GridSearchTuner, RandomTuner, ModelBasedTuner
  20. from .utils import *
  21. try:
  22. from tabulate import tabulate
  23. except ImportError:
  24. tabulate = None
  25. class Autotuner:
  26. """The DeepSpeed Autotuner automatically discovers the optimal DeepSpeed configuration that delivers good training speed. The Autotuner uses model information, system information, and heuristics to efficiently tune system knobs that affect compute and memory efficiencies, such as ZeRO optimization stages, micro-batch sizes, and many other ZeRO optimization configurations. It not only reduces the time and resources user spend on tuning, but also can discover configurations better than hand-tuned methods.
  27. Autotuning with DeepSpeed requires no code change from DeepSpeed users. Please refer to the README for usage details.
  28. """
  29. def __init__(self, args, active_resources):
  30. self.args = args
  31. self.selected_exp_dir = None
  32. assert tabulate is not None, "Missing required package `tabulate`, please install with `pip install deepspeed[autotuning]`."
  33. logger.debug(f"autotunning args={args}")
  34. self.user_config = self._get_user_config(args.user_args)
  35. assert self.user_config is not None, "DeepSpeed configuration is not provided"
  36. self.autotuning_config = DeepSpeedAutotuningConfig(self.user_config)
  37. self.exps_dir = DEFAULT_EXPRS_DIR
  38. if self.autotuning_config.exps_dir and self.autotuning_config.exps_dir != "":
  39. self.exps_dir = self.autotuning_config.exps_dir
  40. if self.autotuning_config.overwrite and os.path.exists(self.exps_dir):
  41. shutil.rmtree(self.exps_dir, ignore_errors=True)
  42. if not os.path.exists(self.exps_dir):
  43. os.makedirs(self.exps_dir, exist_ok=True)
  44. self.results_dir = DEFAULT_RESULTS_DIR
  45. if self.autotuning_config.results_dir and self.autotuning_config.results_dir != "":
  46. self.results_dir = self.autotuning_config.results_dir
  47. if self.autotuning_config.overwrite and os.path.exists(self.results_dir):
  48. shutil.rmtree(self.results_dir, ignore_errors=True)
  49. if not os.path.exists(self.results_dir):
  50. os.makedirs(self.results_dir, exist_ok=True)
  51. # set the active resource for the autotuner resource manager
  52. self.rm = self._get_resource_manager(active_resources)
  53. # get resource requirement for each autotuning experiment
  54. self.exp_num_nodes, self.exp_num_gpus = self._get_exp_resources(args)
  55. assert self.exp_num_gpus <= self.rm.num_gpus_per_node, "num_gpus in the autotuning configuration must not be less than the --num_gpus value in the train script if any"
  56. assert self.exp_num_nodes <= len(
  57. self.rm.nodes), "num_nodes in the autotuning configuration must not be less than the --num_nodes value in the train script if any"
  58. self.records = {}
  59. def print_tuning_results(self):
  60. """Print the autotuning results in tabular format.
  61. """
  62. best_space_records = self.get_best_space_records()
  63. tab = []
  64. if best_space_records:
  65. for key, val in best_space_records.items():
  66. if not val:
  67. continue
  68. row = []
  69. row.append(key)
  70. num_exps = 0
  71. if key == GLOBAL_TUNING_SPACE:
  72. cnt = 0
  73. for k, v in best_space_records.items():
  74. if k != GLOBAL_TUNING_SPACE:
  75. cnt += v[2]
  76. num_exps = cnt
  77. else:
  78. num_exps = val[2]
  79. row.append(num_exps)
  80. row.append(val[1])
  81. row.append(val[0]['name'])
  82. tab.append(row)
  83. summary = tabulate(tab,
  84. headers=[
  85. "tuning_space",
  86. "num_experiments",
  87. "best_metric_val",
  88. "best_exp_name"
  89. ],
  90. tablefmt="pipe")
  91. print(summary)
  92. with open(os.path.join(self.results_dir,
  93. 'summary.txt'),
  94. 'w',
  95. buffering=BUFSIZE) as fd:
  96. fd.write(summary)
  97. fd.flush()
  98. os.fsync(fd)
  99. if GLOBAL_TUNING_SPACE in best_space_records:
  100. best_exp, best_metric_val, total_num_exps = best_space_records[GLOBAL_TUNING_SPACE]
  101. if best_exp:
  102. logger.info(
  103. f"{best_exp['name']} is the optimal setup after tuning. The exp result is at {best_exp['result_dir']}."
  104. )
  105. else:
  106. logger.info(
  107. f"No optimal setup is found. Please check that experiments were run successfully."
  108. )
  109. tuning_duration = datetime.timedelta(seconds=(time.time() - self.start_time))
  110. logger.info(f"Tuning completed in {tuning_duration}")
  111. with open(os.path.join(self.results_dir, 'summary.txt'), 'a') as f:
  112. f.write(
  113. f"\n\nTuning completed in {tuning_duration}. Total number of experiments: {self.rm.experiment_count - 1}."
  114. )
  115. f.flush()
  116. def _get_user_config(self, user_args):
  117. """Get DeepSpeed configuration from the user arguments passed to the launcher.
  118. Args:
  119. user_args ([list]): user arguments passed to the DeepSpeed launcher
  120. Returns:
  121. [dict]: DeepSpeed configuration dictionary
  122. """
  123. user_config_file = None
  124. if "--deepspeed_config" in user_args:
  125. idx = user_args.index("--deepspeed_config")
  126. assert ".json" in user_args[idx +
  127. 1], "DeepSpeed --deepspeed_config requires a json file to specify the configuration"
  128. user_config_file = user_args[idx + 1]
  129. elif "--deepspeed" in user_args:
  130. idx = user_args.index("--deepspeed")
  131. if ".json" in user_args[idx + 1]:
  132. user_config_file = user_args[idx + 1]
  133. logger.debug(f"user_config_file = {user_config_file}")
  134. if user_config_file is not None:
  135. assert os.path.isfile(
  136. user_config_file
  137. ), "DeepSpeed configuration file: {} is not an existing file".format(
  138. user_config_file
  139. )
  140. if os.path.exists(user_config_file):
  141. return json.load(open(user_config_file,
  142. "r"),
  143. object_pairs_hook=dict_raise_error_on_duplicate_keys)
  144. return None
  145. def _get_resource_manager(self, active_resources):
  146. """Initialize and return a resource manager
  147. Args:
  148. active_resources ([dict]): A dictionary of hostname and its slots (GPUs), e.g. {"worker-0": "0,1,2,3,4,5,6,7,8"}
  149. Raises:
  150. RuntimeError: raises the error if no GPU is available
  151. Returns:
  152. [ResourceManager]: A resource manager that schedules and runs autotuning experiments.
  153. """
  154. logger.info(f"active_resources = {active_resources}")
  155. hosts = []
  156. ngpus_per_node = 100
  157. for hostname, slots in active_resources.items():
  158. hosts.append(hostname)
  159. ngpus_per_node = min(len(slots), ngpus_per_node)
  160. assert ngpus_per_node > 0, "no gpu is available"
  161. return ResourceManager(args=self.args,
  162. hosts=hosts,
  163. num_gpus_per_node=ngpus_per_node,
  164. results_dir=self.results_dir,
  165. exps_dir=self.exps_dir,
  166. arg_mappings=self.autotuning_config.arg_mappings)
  167. def _get_exp_resources(self, args):
  168. """Get resource requirement for each autotuning experiment
  169. Args:
  170. args (dict): user args
  171. Returns:
  172. num_nodes, num_gpus: the number of gpus and number of nodes used in the autotuning experiments
  173. """
  174. if args.num_nodes > 0:
  175. num_nodes = args.num_nodes
  176. else:
  177. num_nodes = len(self.rm.nodes)
  178. if args.num_gpus > 0:
  179. num_gpus = args.num_gpus
  180. else:
  181. num_gpus = self.rm.num_gpus_per_node
  182. return num_nodes, num_gpus
  183. def metric(self):
  184. return self.autotuning_config.metric
  185. def fast_enabled(self):
  186. return self.autotuning_config.fast
  187. def max_train_batch_size(self):
  188. return self.autotuning_config.max_train_batch_size
  189. def mp_size(self):
  190. return self.autotuning_config.mp_size
  191. def max_train_micro_batch_size_per_gpu(self):
  192. if self.max_train_batch_size() and self.max_train_batch_size(
  193. ) > 0: # if the user specifies a max_train_batch_size
  194. max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size(
  195. ) // (self.exp_num_gpus * self.exp_num_nodes
  196. ) # gradient accumulation steps >=1
  197. return min(self.autotuning_config.max_train_micro_batch_size_per_gpu,
  198. max_train_micro_batch_size)
  199. else:
  200. return self.autotuning_config.max_train_micro_batch_size_per_gpu
  201. def min_train_micro_batch_size_per_gpu(self):
  202. return self.autotuning_config.min_train_micro_batch_size_per_gpu
  203. def num_tuning_micro_batch_sizes(self):
  204. return self.autotuning_config.num_tuning_micro_batch_sizes
  205. def fp16_enabled(self):
  206. if FP16 in self.user_config.keys():
  207. return self.user_config[FP16].get(FP16_ENABLED, FP16_ENABLED_DEFAULT)
  208. else:
  209. return False
  210. def get_gpu_memory_info(self):
  211. return torch.cuda.get_device_properties(0).total_memory
  212. def get_activation_memory_per_gpu(self):
  213. if self.model_info and "activation_mem_per_gpu" in self.model_info:
  214. return self.model_info["activation_mem_per_gpu"]
  215. def get_instantiation_memory_required_per_gpu(self, zero_stage):
  216. num_params = self.get_model_num_params()
  217. total_gpus = self.exp_num_nodes * self.exp_num_gpus
  218. fp16_enabled = self.fp16_enabled()
  219. if not num_params:
  220. return 0
  221. # assume the model uses Adam optimizer
  222. # ZERO_OPTIMIZATION_DISABLED:
  223. params_mem = num_params * (2 if fp16_enabled else 4)
  224. gradients_mem = num_params * (2 if fp16_enabled else 4)
  225. optimizer_mem = num_params * (16 if fp16_enabled else 8)
  226. if zero_stage >= ZERO_OPTIMIZATION_OPTIMIZER_STATES:
  227. optimizer_mem = optimizer_mem / total_gpus
  228. if zero_stage >= ZERO_OPTIMIZATION_GRADIENTS:
  229. gradients_mem = gradients_mem / total_gpus
  230. if zero_stage >= ZERO_OPTIMIZATION_WEIGHTS:
  231. params_mem = params_mem / total_gpus
  232. mem_per_gpu = (params_mem + gradients_mem + optimizer_mem) / self.mp_size()
  233. return mem_per_gpu
  234. def _generate_experiments(self, tuning_space, max_train_batch_size_per_gpu):
  235. """Generates a list of autotuning experiments given a tuning_space.
  236. The corresponding parameter values are replaced by user-defined values in the DeepSpeed configuration file.
  237. Args:
  238. tuning_space ([dict]): A DeepSpeed configuration dictionary where a value can be a list (called a tuning parameter). For example,
  239. {
  240. "zero_optimization": {
  241. "stage": 1,
  242. "reduce_bucket_size": [5e7,
  243. 5e8,
  244. 1e9],
  245. "allgather_bucket_size": [5e7,
  246. 5e8,
  247. 1e9],
  248. }
  249. }
  250. reduce_bucket_size and allgather_bucket_size are the tuning parameters in this tuning space.
  251. Returns:
  252. [list]: a list of experiments generated by taking combinations of values of the tuning space. The above tuning space generates 3*3 = 9 experiments if the user DeepSpeed configuration file does not overwrite the two tuning parameters or define more tuning parameters.
  253. """
  254. exps = []
  255. # each zero stage uses a different template configuration file
  256. config_zero = tuning_space.get(ZERO_OPTIMIZATION, {})
  257. stage = config_zero.get(ZERO_OPTIMIZATION_STAGE, None)
  258. template_config = {}
  259. if stage == 0:
  260. template_path = DEFAULT_TEMPLATE_PATH_ZERO_0
  261. template_config = hjson.load(open(template_path, 'r'))
  262. prefix = "z0_"
  263. elif stage == 1:
  264. template_path = DEFAULT_TEMPLATE_PATH_ZERO_1
  265. template_config = hjson.load(open(template_path, 'r'))
  266. prefix = "z1_"
  267. elif stage == 2:
  268. template_path = DEFAULT_TEMPLATE_PATH_ZERO_2
  269. template_config = hjson.load(open(template_path, 'r'))
  270. prefix = "z2_"
  271. elif stage == 3:
  272. template_path = DEFAULT_TEMPLATE_PATH_ZERO_3
  273. template_config = hjson.load(open(template_path, 'r'))
  274. model_info = self.model_info
  275. if model_info and "hidden_size" in model_info:
  276. hs = model_info["hidden_size"]
  277. template_config[ZERO_OPTIMIZATION][
  278. ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE] = hs * hs
  279. template_config[ZERO_OPTIMIZATION][
  280. ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE] = 0.9 * hs * hs
  281. template_config[ZERO_OPTIMIZATION][
  282. ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD] = 10 * hs
  283. prefix = "z3_"
  284. else:
  285. return exps
  286. # replace the corresponding parameter values if the user specfies them in the DeepSpeed configuration file
  287. replace_dict(tuning_space,
  288. self.user_config,
  289. [ZERO_OPTIMIZATION,
  290. TRAIN_MICRO_BATCH_SIZE_PER_GPU])
  291. logger.debug(f"tuning_space = {json.dumps(tuning_space)}")
  292. all_configs = get_all_configs(tuning_space)
  293. tuning_keys = get_tuning_keys(tuning_space)
  294. logger.debug(f"tuning_keys = {tuning_keys}")
  295. logger.debug(f"before prunning total configs = {len(all_configs)}")
  296. pruned_list = prune_configs(all_configs)
  297. logger.debug(f"after prunning total configs = {len(pruned_list)}")
  298. for config in pruned_list:
  299. exp_config = copy.deepcopy(template_config)
  300. # fill the template with the expr config
  301. replace_dict(exp_config, config)
  302. # if the config does not use offloading, remove the offloading section
  303. config_zero = config.get(ZERO_OPTIMIZATION, None)
  304. if config_zero:
  305. if OFFLOAD_OPTIMIZER not in config_zero and OFFLOAD_OPTIMIZER in exp_config[
  306. ZERO_OPTIMIZATION]:
  307. del exp_config[ZERO_OPTIMIZATION][OFFLOAD_OPTIMIZER]
  308. if OFFLOAD_PARAM not in config_zero and OFFLOAD_PARAM in exp_config[
  309. ZERO_OPTIMIZATION]:
  310. del exp_config[ZERO_OPTIMIZATION][OFFLOAD_PARAM]
  311. # set gradient accumulation steps according to max_train_batch_size_per_gpu
  312. mbs = exp_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU]
  313. gas = max_train_batch_size_per_gpu // mbs
  314. exp_config[GRADIENT_ACCUMULATION_STEPS] = gas
  315. exp_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  316. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  317. exp = {}
  318. # generate the expr name
  319. exp_name = canonical_name(exp_config, tuning_keys, prefix)
  320. exp['name'] = exp_name
  321. exp[DS_CONFIG] = exp_config
  322. exp['num_gpus'] = self.exp_num_gpus
  323. exp['num_nodes'] = self.exp_num_nodes
  324. exps.append(exp)
  325. return exps
  326. def tune(self):
  327. """ Tunes Zero stages, micro batch size per GPU, and other Zero configurations. Performance metrics of different tuning spaces are recorded in self.records.
  328. """
  329. self.start_time = time.time()
  330. if self.fast_enabled():
  331. logger.info(f"Fast mode is enabled. Tuning micro batch size only.")
  332. # model info profile run with DEFAULT_MIN_MEM_CONFIG
  333. model_info = self.model_info_profile_run()
  334. if model_info:
  335. self.model_info = model_info
  336. else:
  337. return
  338. logger.info(
  339. f"The model has {number_to_string(self.get_model_num_params())} parameters.")
  340. self.gpu_mem = self.get_gpu_memory_info()
  341. logger.info(
  342. f"Memory per GPU in the system is {memory_to_string(self.gpu_mem, postfix='B')}."
  343. )
  344. self.activation_mem = self.get_activation_memory_per_gpu()
  345. logger.info(
  346. f"The model requires at least {memory_to_string(self.activation_mem, postfix='B')} activation memory for micro batch size 1."
  347. )
  348. stage = self.user_config.get(ZERO_OPTIMIZATION,
  349. {}).get(ZERO_OPTIMIZATION_STAGE,
  350. "all")
  351. user_zero_stages = [stage] if not isinstance(stage, list) else stage
  352. logger.info(f"User-defined zero stages are {stage}.")
  353. mbs = 0
  354. max_mbs = 0
  355. metric_val = 0
  356. required_gpu_mem = self.get_instantiation_memory_required_per_gpu(
  357. ZERO_OPTIMIZATION_DISABLED) + self.activation_mem
  358. if self.gpu_mem > required_gpu_mem:
  359. if "all" in user_zero_stages or ZERO_OPTIMIZATION_DISABLED in user_zero_stages:
  360. logger.info(
  361. f"The model might be runable with ZERO 0 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1), adding DEFAULT_TUNING_SPACE_ZERO_0 to the global tuning space"
  362. )
  363. next_max_mbs, next_mbs, next_metric_val = self.tune_space(
  364. DEFAULT_TUNING_SPACE_ZERO_0)
  365. if next_mbs > mbs:
  366. mbs = next_mbs
  367. max_mbs = next_max_mbs
  368. metric_val = next_metric_val
  369. else:
  370. logger.info(
  371. f"The model is not runable with ZERO stage {ZERO_OPTIMIZATION_DISABLED} (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1)"
  372. )
  373. required_gpu_mem = self.get_instantiation_memory_required_per_gpu(
  374. ZERO_OPTIMIZATION_OPTIMIZER_STATES) + self.activation_mem
  375. if self.gpu_mem > required_gpu_mem:
  376. if "all" in user_zero_stages or ZERO_OPTIMIZATION_OPTIMIZER_STATES in user_zero_stages:
  377. logger.info(
  378. f"The model might be runable with ZERO 1 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory), adding DEFAULT_TUNING_SPACE_ZERO_1 to the global tuning space"
  379. )
  380. next_max_mbs, next_mbs, next_metric_val = self.tune_space(
  381. DEFAULT_TUNING_SPACE_ZERO_1, prev_max_mbs = max_mbs, prev_best_mbs=mbs, prev_best_metric_val=metric_val)
  382. if next_mbs > mbs:
  383. mbs = next_mbs
  384. max_mbs = next_max_mbs
  385. metric_val = next_metric_val
  386. else:
  387. logger.info(
  388. f"The model is not runable with ZERO stage {ZERO_OPTIMIZATION_OPTIMIZER_STATES} (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1)"
  389. )
  390. required_gpu_mem = self.get_instantiation_memory_required_per_gpu(
  391. ZERO_OPTIMIZATION_GRADIENTS) + self.activation_mem
  392. if self.gpu_mem > required_gpu_mem:
  393. if "all" in user_zero_stages or ZERO_OPTIMIZATION_GRADIENTS in user_zero_stages:
  394. logger.info(
  395. f"The model might be runable with ZERO 2 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory), adding DEFAULT_TUNING_SPACE_ZERO_2 to the global tuning space"
  396. )
  397. next_max_mbs, next_mbs, next_metric_val = self.tune_space(
  398. DEFAULT_TUNING_SPACE_ZERO_2, prev_max_mbs = max_mbs, prev_best_mbs=mbs, prev_best_metric_val=metric_val)
  399. if next_mbs > mbs:
  400. mbs = next_mbs
  401. max_mbs = next_max_mbs
  402. metric_val = next_metric_val
  403. else:
  404. logger.info(
  405. f"The model is not runable with ZERO stage {ZERO_OPTIMIZATION_GRADIENTS} (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1)"
  406. )
  407. required_gpu_mem = self.get_instantiation_memory_required_per_gpu(
  408. ZERO_OPTIMIZATION_WEIGHTS) + self.activation_mem
  409. if self.gpu_mem > required_gpu_mem:
  410. if "all" in user_zero_stages or ZERO_OPTIMIZATION_WEIGHTS in user_zero_stages:
  411. logger.info(
  412. f"The model might be runable with ZERO 3 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory), adding DEFAULT_TUNING_SPACE_ZERO_3 to the global tuning space"
  413. )
  414. _, _, _ = self.tune_space(
  415. DEFAULT_TUNING_SPACE_ZERO_3, prev_max_mbs = max_mbs, prev_best_mbs=mbs, prev_best_metric_val=metric_val)
  416. else:
  417. logger.info(
  418. f"The model has {self.get_model_num_params()} parameters and requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory per GPU with DeepSpeed Zero stage {ZERO_OPTIMIZATION_WEIGHTS} optimization. Memory per GPU in system is {memory_to_string(self.gpu_mem)}. No tuning is performed."
  419. )
  420. return
  421. def tune_space(self,
  422. tuning_space,
  423. prev_max_mbs=0,
  424. prev_best_mbs=0,
  425. prev_best_metric_val=0):
  426. config_zero = tuning_space.get(ZERO_OPTIMIZATION, {})
  427. stage = config_zero.get(ZERO_OPTIMIZATION_STAGE, ZERO_OPTIMIZATION_STAGE_DEFAULT)
  428. tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage)
  429. tuning_micro_batch_sizes = []
  430. max_train_batch_size_per_gpu = 0
  431. tuning_micro_batch_sizes_overwritten = False
  432. # calcuate max micro batch size using gpu memory, model instatiation memory and activation memory
  433. # calculated_max_micro_batch_size = (memory_per_gpu - instantiation_memory) // activation_memory_micro_batch_size_1
  434. calculated_max_micro_batch_size = int(
  435. self.gpu_mem -
  436. self.get_instantiation_memory_required_per_gpu(stage)) // self.activation_mem
  437. logger.info(
  438. f"Start tuning for space {tuning_space_name}, calculated_max_micro_batch_size = {calculated_max_micro_batch_size}"
  439. )
  440. if calculated_max_micro_batch_size < prev_max_mbs:
  441. logger.info(
  442. f"No need to tune Zero stage {stage}. End tuning for space {tuning_space_name}"
  443. )
  444. return 0, 0, 0
  445. if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self.user_config and isinstance(
  446. self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU],
  447. list):
  448. # user-specified micro batch size per gpu is a list which overwrites the default tuning behavior
  449. tuning_micro_batch_sizes = [
  450. s for s in self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU]
  451. if isinstance(s,
  452. int)
  453. ]
  454. gas = self.get_gas_from_user_config()
  455. min_micro_batch_size = min(tuning_micro_batch_sizes)
  456. max_micro_batch_size = max(tuning_micro_batch_sizes)
  457. max_train_batch_size_per_gpu = max_micro_batch_size * gas
  458. tuning_micro_batch_sizes_overwritten = True
  459. else:
  460. # auto-detects the list of micro batch sizes to tune
  461. min_micro_batch_size, max_micro_batch_size = self.get_min_max_micro_batch_size(
  462. stage, prev_max_mbs, calculated_max_micro_batch_size)
  463. if max_micro_batch_size < prev_max_mbs:
  464. logger.info(
  465. f"No need to tune Zero stage {stage}. End tuning for space {tuning_space_name}"
  466. )
  467. return 0, 0, 0
  468. tuning_micro_batch_sizes, max_train_batch_size_per_gpu = self.get_tuning_micro_batch_size_list(
  469. min_micro_batch_size,
  470. max_micro_batch_size,
  471. num_tuning_micro_batch_sizes=self.num_tuning_micro_batch_sizes())
  472. logger.info(
  473. f"tuning_micro_batch_sizes = {tuning_micro_batch_sizes}, max_train_batch_size_per_gpu = {max_train_batch_size_per_gpu}"
  474. )
  475. # return if the tuning_micro_batch_sizes list is empty
  476. if not tuning_micro_batch_sizes:
  477. logger.info(f"End tuning for space {tuning_space_name}")
  478. return 0, 0, 0
  479. # tune micro batch sizes and gradient accumulation steps given max_train_batch_size_per_gpu
  480. tuning_micro_batch_sizes = self.run_tuning_micro_batch_sizes(
  481. tuning_micro_batch_sizes,
  482. max_train_batch_size_per_gpu,
  483. min_micro_batch_size,
  484. stage,
  485. tuning_micro_batch_sizes_overwritten)
  486. fast_best_record = self.get_best_space_record(tuning_space_name)
  487. fast_best_metric_val = fast_best_record[1] if fast_best_record else 0
  488. fast_best_mbs = fast_best_record[0][DS_CONFIG][
  489. TRAIN_MICRO_BATCH_SIZE_PER_GPU] if fast_best_record else 0
  490. logger.info(
  491. f"fast_best_mbs = {fast_best_mbs}, name = {fast_best_record[0]['name']}")
  492. if self.fast_enabled() or stage == 0:
  493. logger.info(f"End tuning for space: {tuning_space_name}")
  494. return max_micro_batch_size, fast_best_mbs, fast_best_metric_val
  495. # if the best metric or the micro batch size for that best metric in the current Zero stage after tuning micro batch size is less than the corrresponding value in the prevous Zero stage, return, do not tune other Zero configuration paramerts
  496. if stage > 0:
  497. if fast_best_mbs <= prev_best_mbs or fast_best_metric_val < prev_best_metric_val:
  498. logger.info(
  499. f"End tuning for space: {tuning_space_name}. No need to tune other Zero configuration paramerts."
  500. )
  501. return max_micro_batch_size, fast_best_mbs, fast_best_metric_val
  502. tuning_space[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = tuning_micro_batch_sizes
  503. tuning_space_name = canonical_name(tuning_space,
  504. tuning_keys=get_tuning_keys(tuning_space),
  505. prefix="z" + str(stage) + "_",
  506. omit_val=True)
  507. logger.info(f'Tuning space is {tuning_space}')
  508. logger.info(f'Tuning space name is {tuning_space_name}')
  509. exps = self._generate_experiments(tuning_space, max_train_batch_size_per_gpu)
  510. logger.info(f'Tuner type is {self.autotuning_config.tuner_type}')
  511. if self.autotuning_config.tuner_type == AUTOTUNING_TUNER_MODELBASED:
  512. t = ModelBasedTuner(exps, self.rm, self.metric(), tuning_space)
  513. elif self.autotuning_config.tuner_type == AUTOTUNING_TUNER_RANDOM:
  514. t = RandomTuner(exps, self.rm, self.metric())
  515. else:
  516. t = GridSearchTuner(exps, self.rm, self.metric())
  517. sample_size = len(self.rm.nodes) * self.rm.num_gpus_per_node // (
  518. self.exp_num_gpus * self.exp_num_nodes)
  519. num_exps = t.tune(sample_size=sample_size,
  520. n_trials=self.autotuning_config.tuner_num_trials,
  521. early_stopping=self.autotuning_config.tuner_early_stopping)
  522. exp = t.best_exp
  523. metric_val = t.best_metric_val
  524. if exp:
  525. self.update_records(tuning_space_name, exp, metric_val, num_exps)
  526. full_best_record = self.get_best_space_record(tuning_space_name)
  527. full_best_metric_val = full_best_record[1] if full_best_record else -1
  528. if full_best_metric_val > fast_best_metric_val:
  529. best_metric_val = full_best_metric_val
  530. best_mbs = full_best_record[0][DS_CONFIG][
  531. TRAIN_MICRO_BATCH_SIZE_PER_GPU] if full_best_record else -1
  532. else:
  533. best_metric_val = fast_best_metric_val
  534. best_mbs = fast_best_mbs
  535. logger.info(f"End tuning for space: {tuning_space_name}")
  536. return max_micro_batch_size, best_mbs, best_metric_val
  537. def get_plauteu_mbs(self, tuning_space_name):
  538. if tuning_space_name not in self.records:
  539. return 0
  540. space_records = self.records[tuning_space_name]
  541. sorted_space_records = sorted(
  542. space_records,
  543. key=lambda x: x[0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU])
  544. prev_metric_val = None
  545. prev_micro_batch_size = 0
  546. for (exp, metric_val, _) in sorted_space_records:
  547. if prev_metric_val:
  548. if metric_val < prev_metric_val:
  549. break
  550. if (metric_val >= prev_metric_val
  551. and (metric_val - prev_metric_val) / prev_metric_val <
  552. METRIC_PERCENT_DIFF_CONST):
  553. break
  554. prev_metric_val = metric_val
  555. prev_micro_batch_size = exp[DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU]
  556. plateau_mbs = prev_micro_batch_size
  557. return plateau_mbs
  558. def get_model_num_params(self):
  559. if self.model_info and "num_params" in self.model_info:
  560. return self.model_info["num_params"]
  561. def model_info_profile_run(self):
  562. """Does a model information profling experiment that collects the number of model parameters and activation memory.\
  563. The experiment produces a "profile_model_info" folder under self.results_dir.
  564. Returns:
  565. [dict]: a model inforation dictionary, e.g., {"num_params": 335144976, "trainable_num_params": 335144976, "activation_mem_per_gpu": 324358144, "rank": 0}
  566. """
  567. logger.info("Starting model info profile run.")
  568. model_info = self.autotuning_config.model_info
  569. if model_info and MODEL_INFO_NUM_PARAMS in model_info:
  570. return model_info
  571. ds_config = copy.deepcopy(self.user_config)
  572. replace_dict(ds_config, DEFAULT_MIN_MEM_CONFIG)
  573. model_info_path = os.path.join(self.results_dir,
  574. "profile_model_info",
  575. "model_info.json")
  576. ds_config[AUTOTUNING] = {
  577. "enabled": True,
  578. "model_info_path": model_info_path,
  579. "model_info": {
  580. "profile": True
  581. }
  582. }
  583. exp_config = {}
  584. exp_name = "profile_model_info"
  585. exp_config['name'] = exp_name
  586. exp_config[DS_CONFIG] = ds_config
  587. exp_config['num_gpus'] = self.exp_num_gpus
  588. exp_config['num_nodes'] = self.exp_num_nodes
  589. exp_path = os.path.join(self.exps_dir, f'{exp_name}.json')
  590. with open(exp_path, 'w', buffering=BUFSIZE) as fd:
  591. json.dump(exp_config, fd)
  592. fd.flush()
  593. os.fsync(fd)
  594. self.rm.schedule_experiments([exp_path])
  595. self.rm.run()
  596. for exp_id, (exp_json, err) in self.rm.finished_experiments.items():
  597. self.rm.clear()
  598. if err:
  599. logger.error(
  600. f"The model is not runnable with DeepSpeed with error = {err}")
  601. return None
  602. if os.path.exists(model_info_path):
  603. with open(model_info_path, 'r') as f:
  604. model_info = hjson.load(f)
  605. return model_info
  606. def update_records(self, space_name, exp, metric_val, num_exps):
  607. if space_name not in self.records:
  608. self.records[space_name] = [(exp, metric_val, num_exps)]
  609. else:
  610. self.records[space_name].append((exp, metric_val, num_exps))
  611. def get_best_space_record(self, space_name):
  612. if space_name not in self.records:
  613. return None
  614. space_records = self.records[space_name]
  615. best_space_record = None
  616. space_num_exps = 0
  617. for (exp, metric_val, num_exps) in space_records:
  618. space_num_exps += num_exps
  619. if best_space_record is None or metric_val > best_space_record[1]:
  620. best_space_record = (exp, metric_val)
  621. if best_space_record:
  622. best_space_record = best_space_record + (space_num_exps, )
  623. return best_space_record
  624. def get_best_space_records(self):
  625. best_space_records = {}
  626. global_best_record = None
  627. for space_name, space_records in self.records.items():
  628. best_space_record = self.get_best_space_record(space_name)
  629. if best_space_record:
  630. best_space_records[space_name] = best_space_record
  631. if not global_best_record or best_space_record[1] > global_best_record[1]:
  632. global_best_record = best_space_record
  633. if global_best_record:
  634. best_space_records[GLOBAL_TUNING_SPACE] = global_best_record
  635. return best_space_records
  636. def run_tuning_micro_batch_sizes(self,
  637. tuning_micro_batch_sizes,
  638. max_train_batch_size_per_gpu,
  639. min_micro_batch_size,
  640. stage,
  641. tuning_micro_batch_sizes_overwritten):
  642. assert tuning_micro_batch_sizes, "the tuning micro batch size list is empty"
  643. tuning_micro_batch_sizes.sort()
  644. max_micro_batch_size = tuning_micro_batch_sizes[-1]
  645. max_micro_batch_size_metric_val = 0
  646. ds_config = get_first_config(self.user_config)
  647. ds_config[ZERO_OPTIMIZATION] = {ZERO_OPTIMIZATION_STAGE: stage}
  648. tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage)
  649. exp_paths = []
  650. for mbs in tuning_micro_batch_sizes:
  651. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  652. gas = max_train_batch_size_per_gpu // mbs
  653. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  654. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  655. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  656. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  657. exp_config = {}
  658. exp_config['name'] = exp_name
  659. exp_config[DS_CONFIG] = ds_config
  660. exp_config['num_gpus'] = self.exp_num_gpus
  661. exp_config['num_nodes'] = self.exp_num_nodes
  662. exp_path = os.path.join(self.exps_dir, f'{exp_name}.json')
  663. with open(exp_path, 'w', buffering=BUFSIZE) as fd:
  664. json.dump(exp_config, fd)
  665. fd.flush()
  666. os.fsync(fd)
  667. exp_paths.append(exp_path)
  668. self.rm.schedule_experiments(exp_paths)
  669. self.rm.run()
  670. for exp_id, (exp, err) in self.rm.finished_experiments.items():
  671. if exp:
  672. metric_file = exp[DS_CONFIG][AUTOTUNING][AUTOTUNING_METRIC_PATH]
  673. if os.path.exists(metric_file):
  674. with open(metric_file, 'r') as f:
  675. results = hjson.load(f)
  676. metric_val = results[self.metric()]
  677. self.update_records(tuning_space_name, exp, metric_val, 1)
  678. if max_micro_batch_size == exp[DS_CONFIG][
  679. TRAIN_MICRO_BATCH_SIZE_PER_GPU]:
  680. max_micro_batch_size_metric_val = metric_val
  681. else:
  682. self.update_records(tuning_space_name, exp, 0, 1)
  683. else:
  684. mbs = exp[DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU]
  685. logger.info(f"micro batch size = {mbs} was not run successfully")
  686. self.rm.clear()
  687. if tuning_micro_batch_sizes_overwritten:
  688. return tuning_micro_batch_sizes
  689. # in a auto-detected tuning_micro_batch_sizs list, max_micro_batch_size might not be performant as the memory consumption is close to max
  690. # try smaller values while gas stays the same
  691. # if finding a more performant mbs value, use it to replace max_micro_batch_size in the list
  692. min_micro_batch_size_with_same_gas = (
  693. tuning_micro_batch_sizes[-2] +
  694. 1) if len(tuning_micro_batch_sizes) > 1 else min_micro_batch_size
  695. prev_best_metric_val = max_micro_batch_size_metric_val
  696. prev_best_mbs = max_micro_batch_size
  697. stride = (max_micro_batch_size - min_micro_batch_size_with_same_gas) // 3
  698. if stride == 0:
  699. stride = 1
  700. for mbs in reversed(
  701. range(min_micro_batch_size_with_same_gas,
  702. max_micro_batch_size,
  703. stride)):
  704. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  705. gas = max_train_batch_size_per_gpu // mbs
  706. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  707. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  708. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  709. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  710. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  711. if metric_val:
  712. self.update_records(tuning_space_name, exp, metric_val, 1)
  713. if metric_val > prev_best_metric_val * (1 + METRIC_PERCENT_DIFF_CONST):
  714. prev_best_metric_val = metric_val
  715. prev_best_mbs = mbs
  716. else:
  717. break
  718. else:
  719. self.update_records(tuning_space_name, exp, 0, 1)
  720. break
  721. if prev_best_mbs != max_micro_batch_size:
  722. tuning_micro_batch_sizes[-1] = prev_best_mbs
  723. return tuning_micro_batch_sizes
  724. def get_min_max_micro_batch_size(self,
  725. stage,
  726. min_micro_batch_size,
  727. calculated_max_micro_batch_size):
  728. # get min and max micro batch size with gradient accumulation steps = 1
  729. if min_micro_batch_size > calculated_max_micro_batch_size:
  730. return -1, -1
  731. used_micro_batch_sizes = []
  732. tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage)
  733. ds_config = get_first_config(self.user_config)
  734. ds_config[ZERO_OPTIMIZATION] = {ZERO_OPTIMIZATION_STAGE: stage}
  735. gas = self.get_gas_from_user_config()
  736. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  737. # search for the min micro batch size
  738. if min_micro_batch_size < 1:
  739. if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self.user_config and isinstance(
  740. self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU],
  741. int):
  742. # user specifies train_micro_batch_size_per_gpu as an int
  743. mbs = int(self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU])
  744. else:
  745. # user does not specify train_micro_batch_size_per_gpu or sets it to "auto" when using Hugging Face
  746. val = self.get_val_from_user_args(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
  747. if val:
  748. mbs = int(val)
  749. else:
  750. mbs = 1
  751. assert mbs > 0, "The micro batch size per GPU must be greater than 0."
  752. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  753. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  754. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  755. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  756. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  757. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  758. if metric_val:
  759. self.update_records(tuning_space_name, exp, metric_val, 1)
  760. used_micro_batch_sizes.append(mbs)
  761. min_micro_batch_size = mbs
  762. else:
  763. self.update_records(tuning_space_name, exp, 0, 1)
  764. logger.info(
  765. f"User-specified micro batch size per GPU {mbs} does not run")
  766. if self.min_train_micro_batch_size_per_gpu() == mbs:
  767. return -1, -1
  768. mbs = self.min_train_micro_batch_size_per_gpu()
  769. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  770. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  771. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  772. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  773. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  774. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  775. if not metric_val:
  776. self.update_records(tuning_space_name, exp, 0, 1)
  777. logger.info(
  778. f"min_train_micro_batch_size_per_gpu {mbs} is not runnable.")
  779. return -1, -1
  780. self.update_records(tuning_space_name, exp, metric_val, 1)
  781. min_micro_batch_size = mbs
  782. used_micro_batch_sizes.append(mbs)
  783. else:
  784. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = min_micro_batch_size
  785. ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
  786. ds_config[TRAIN_BATCH_SIZE] = min_micro_batch_size * gas * \
  787. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  788. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(
  789. min_micro_batch_size)
  790. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  791. if metric_val:
  792. self.update_records(tuning_space_name, exp, metric_val, 1)
  793. used_micro_batch_sizes.append(min_micro_batch_size)
  794. else:
  795. self.update_records(tuning_space_name, exp, 0, 1)
  796. return -1, -1
  797. # search for the max micro batch size
  798. max_micro_batch_size = min(calculated_max_micro_batch_size,
  799. self.max_train_micro_batch_size_per_gpu())
  800. for mbs in [
  801. math.ceil(1.05 * max_micro_batch_size),
  802. max_micro_batch_size,
  803. int(0.95 * max_micro_batch_size)
  804. ]:
  805. if mbs > self.max_train_micro_batch_size_per_gpu():
  806. continue
  807. if mbs in used_micro_batch_sizes:
  808. return min_micro_batch_size, mbs
  809. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
  810. ds_config[TRAIN_BATCH_SIZE] = mbs * gas * \
  811. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  812. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mbs)
  813. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  814. if metric_val:
  815. logger.info(f"mbs = {mbs} is found as max mbs")
  816. self.update_records(tuning_space_name, exp, metric_val, 1)
  817. used_micro_batch_sizes.append(mbs)
  818. return min_micro_batch_size, mbs
  819. else:
  820. self.update_records(tuning_space_name, exp, 0, 1)
  821. space_records = self.records[
  822. tuning_space_name] if tuning_space_name in self.records else []
  823. if space_records:
  824. prev_idx = min(range(len(space_records)),
  825. key=lambda i: abs(space_records[i][0][DS_CONFIG][
  826. TRAIN_MICRO_BATCH_SIZE_PER_GPU] - min_micro_batch_size))
  827. prev_metric_val = space_records[prev_idx][1]
  828. else:
  829. prev_metric_val = None
  830. low = min_micro_batch_size
  831. high = max_micro_batch_size
  832. while low < high:
  833. mid = int((low + high) // 2)
  834. logger.debug(f"trying mbs = {mid}, low = {low}, high = {high}")
  835. if mid == low:
  836. break
  837. if mid not in used_micro_batch_sizes:
  838. ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mid
  839. ds_config[TRAIN_BATCH_SIZE] = mid * gas * \
  840. self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
  841. exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mid)
  842. exp, metric_val = self.run_ds_config(ds_config, exp_name)
  843. if metric_val:
  844. low = mid
  845. self.update_records(tuning_space_name, exp, metric_val, 1)
  846. used_micro_batch_sizes.append(mid)
  847. if prev_metric_val and ((metric_val - prev_metric_val) /
  848. prev_metric_val) < METRIC_PERCENT_DIFF_CONST:
  849. logger.info(f"performance plateaus at mbs = {low}")
  850. break
  851. prev_metric_val = metric_val
  852. else:
  853. self.update_records(tuning_space_name, exp, 0, 1)
  854. high = mid - 1
  855. else:
  856. low = mid
  857. max_micro_batch_size = low
  858. logger.info(
  859. f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}."
  860. )
  861. return min_micro_batch_size, max_micro_batch_size
  862. def get_gas_from_user_config(self):
  863. gas = 1
  864. if GRADIENT_ACCUMULATION_STEPS in self.user_config:
  865. gas_in_config = self.user_config[GRADIENT_ACCUMULATION_STEPS]
  866. if isinstance(gas_in_config, int):
  867. gas = gas_in_config
  868. elif gas_in_config == "auto": # GRADIENT_ACCUMULATION_STEPS: "auto"
  869. val = self.get_val_from_config(GRADIENT_ACCUMULATION_STEPS)
  870. if val:
  871. gas = int(val)
  872. elif isinstance(gas_in_config, list):
  873. logger.info(
  874. f"Specifying a list of {GRADIENT_ACCUMULATION_STEPS} to tune is not supported. 1 would be used."
  875. )
  876. assert gas > 0, "Gradient accumulation steps must be positive."
  877. return gas
  878. def get_val_from_user_args(self, ds_name):
  879. arg_mappings = self.autotuning_config.arg_mappings
  880. user_args = self.args.user_args
  881. if arg_mappings and ds_name in arg_mappings:
  882. arg_name = arg_mappings[ds_name]
  883. if arg_name in user_args:
  884. idx = user_args.index(arg_name)
  885. if user_args[idx + 1].isnumeric():
  886. return (user_args[idx + 1])
  887. return None
  888. def get_tuning_micro_batch_size_list(self,
  889. min_micro_batch_size,
  890. max_micro_batch_size,
  891. num_tuning_micro_batch_sizes):
  892. """Get a list of micro batch sizes to tune based on min and max values, as well as the size of the list.
  893. Args:
  894. min_micro_batch_size ([int]): min micro batch size per GPU
  895. max_micro_batch_size ([int]): max micro batch size per GPU
  896. num_tuning_micro_batch_sizes (int): the number of items in the returned list
  897. Returns:
  898. [list]: a list of micro batch sizes to tune.
  899. """
  900. if min_micro_batch_size <= 0 or max_micro_batch_size <= 0:
  901. logger.info(
  902. f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}"
  903. )
  904. return [], 0
  905. # NUM_GPUS=$(( ${NUM_WORKERS} * ${NUM_GPUS_PER_WORKER} ))
  906. # DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) ))
  907. # GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${BATCH_SIZE} * ${DP_SIZE}) ))
  908. if self.max_train_batch_size() and self.max_train_batch_size(
  909. ) > 0: # if the user specifies a max_train_batch_size
  910. max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size(
  911. ) // (self.exp_num_gpus * self.exp_num_nodes)
  912. else:
  913. gas = self.get_gas_from_user_config()
  914. max_train_batch_size_per_gpu = max_micro_batch_size * gas // self.mp_size()
  915. logger.info(f"max_train_batch_size_per_gpu = {max_train_batch_size_per_gpu}")
  916. if min_micro_batch_size < max_micro_batch_size // 2:
  917. min_micro_batch_size = max_micro_batch_size // 2
  918. # constant stride
  919. stride = (max_micro_batch_size -
  920. min_micro_batch_size) // num_tuning_micro_batch_sizes
  921. if stride == 0:
  922. stride = 1
  923. ls = []
  924. min_gas = max_train_batch_size_per_gpu // max_micro_batch_size
  925. # if gas is the same as min_gas, do not add mbs to the tuning list
  926. for mbs in range(min_micro_batch_size, max_micro_batch_size, stride):
  927. if max_micro_batch_size // mbs != min_gas:
  928. ls.append(mbs)
  929. ls.append(max_micro_batch_size)
  930. return ls, max_train_batch_size_per_gpu
  931. def run_ds_config(self, ds_config, exp_name):
  932. exp_config = {}
  933. exp_config['name'] = exp_name
  934. exp_config[DS_CONFIG] = ds_config
  935. exp_config['num_gpus'] = self.exp_num_gpus
  936. exp_config['num_nodes'] = self.exp_num_nodes
  937. exp_path = os.path.join(self.exps_dir, f'{exp_name}.json')
  938. logger.debug(f'run_ds_config exp_name = {exp_name}')
  939. with open(exp_path, 'w', buffering=BUFSIZE) as fd:
  940. json.dump(exp_config, fd)
  941. fd.flush()
  942. os.fsync(fd)
  943. self.rm.schedule_experiments([exp_path])
  944. self.rm.run()
  945. exp, metric_val = self.rm.parse_results(self.metric())
  946. self.rm.clear()
  947. return exp, metric_val
  948. def run_after_tuning(self):
  949. """ Launches the training with the optmimal DeepSpeed configuration found through the autotuning process.
  950. "ds_config_optimal.json" describing the optmimal DeepSpeed configuration as well the command used to launch training "cmd_optimal.txt" are saved to self.results_dir.
  951. """
  952. best_space_records = self.get_best_space_records()
  953. if GLOBAL_TUNING_SPACE not in best_space_records:
  954. return
  955. best_exp, best_metric_val, _ = best_space_records[GLOBAL_TUNING_SPACE]
  956. if best_exp:
  957. logger.info(
  958. "Start training with the optmimal DeepSpeed configuration found through the tuning process"
  959. )
  960. exp_dir = best_exp["result_dir"]
  961. cmd = None
  962. with open(os.path.join(exp_dir, "cmd.txt"), "r") as f:
  963. cmd = [str(i) for i in f.read().split()]
  964. ds_config = hjson.load(open(os.path.join(exp_dir, "ds_config.json"), "r"))
  965. ds_config.pop(AUTOTUNING)
  966. ds_config_path = os.path.join(self.results_dir, "ds_config_optimal.json")
  967. json.dump(ds_config, open(ds_config_path, "w"))
  968. idx = cmd.index(os.path.join(exp_dir, "ds_config.json"))
  969. cmd[idx] = ds_config_path
  970. cmd_path = os.path.join(self.results_dir, "cmd_optimal.txt")
  971. with open(cmd_path, "w") as fd:
  972. fd.write(" ".join(cmd))
  973. fd.write("\n")
  974. fd.flush()
  975. result = subprocess.Popen(cmd)
  976. result.wait()
  977. logger.info(
  978. f"Done running with the optimal DeepSpeed configuration found by autotuning: {ds_config_path}"
  979. )