elasticity.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. import os
  5. import json
  6. import numpy as np
  7. import math
  8. from packaging import version as pkg_version
  9. from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \
  10. ElasticityIncompatibleWorldSize
  11. from .constants import ELASTICITY, ENABLED, ENABLED_DEFAULT, LATEST_ELASTICITY_VERSION, \
  12. MINIMUM_DEEPSPEED_VERSION, DEEPSPEED_ELASTICITY_CONFIG
  13. from ..git_version_info import version as __version__
  14. from ..utils import logger
  15. # Thirty eight smallest highly composite numbers. The list should
  16. # be enough to support up to 720K batch size.
  17. HCN_LIST = [
  18. 1,
  19. 2,
  20. 4,
  21. 6,
  22. 12,
  23. 24,
  24. 36,
  25. 48,
  26. 60,
  27. 120,
  28. 180,
  29. 240,
  30. 360,
  31. 720,
  32. 840,
  33. 1260,
  34. 1680,
  35. 2520,
  36. 5040,
  37. 7560,
  38. 10080,
  39. 15120,
  40. 20160,
  41. 25200,
  42. 27720,
  43. 45360,
  44. 50400,
  45. 55440,
  46. 83160,
  47. 110880,
  48. 166320,
  49. 221760,
  50. 277200,
  51. 332640,
  52. 498960,
  53. 554400,
  54. 665280,
  55. 720720
  56. ]
  57. def get_candidate_batch_sizes(base_list, max_acceptable_batch_size):
  58. candidate_batch_size = []
  59. for base in base_list:
  60. if base >= max_acceptable_batch_size:
  61. candidate_batch_size.append(base)
  62. else:
  63. value = max_acceptable_batch_size // base
  64. index = np.argmax(np.asarray(HCN_LIST) > value)
  65. candidate_batch_size.append(HCN_LIST[index - 1] * base)
  66. candidate_batch_size = list(set(candidate_batch_size))
  67. logger.info(f"Candidate batch size: {candidate_batch_size}")
  68. return candidate_batch_size
  69. def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus):
  70. valid_gpus = []
  71. for micro_batch in micro_batches:
  72. if batch_size % micro_batch == 0:
  73. max_gpus = batch_size // micro_batch
  74. if max_gpus >= min_valid_gpus and max_gpus <= max_valid_gpus:
  75. valid_gpus.append(max_gpus)
  76. # find all factors less than max_gpus / 2
  77. for i in range(1, max_gpus // 2 + 1):
  78. if i > max_valid_gpus:
  79. break
  80. if i < min_valid_gpus:
  81. continue
  82. if max_gpus % i == 0:
  83. valid_gpus.append(i)
  84. valid_gpus = set(valid_gpus)
  85. valid_gpus = sorted(list(valid_gpus))
  86. return valid_gpus
  87. def get_best_candidates(candidate_batch_sizes,
  88. micro_batches,
  89. min_gpus,
  90. max_gpus,
  91. prefer_larger):
  92. max_valid_gpus = 0
  93. valid_gpus = None
  94. final_batch_size = int(min(micro_batches))
  95. for batch_size in candidate_batch_sizes:
  96. current_valid_gpus = get_valid_gpus(batch_size,
  97. micro_batches,
  98. min_gpus,
  99. max_gpus)
  100. if (len(current_valid_gpus) > max_valid_gpus
  101. or (len(current_valid_gpus) == max_valid_gpus and
  102. ((prefer_larger and batch_size > final_batch_size) or
  103. (not prefer_larger and batch_size < final_batch_size)))):
  104. max_valid_gpus = len(current_valid_gpus)
  105. valid_gpus = current_valid_gpus
  106. final_batch_size = batch_size
  107. return final_batch_size, valid_gpus
  108. def _get_compatible_gpus_v01(micro_batches,
  109. max_acceptable_batch_size,
  110. min_gpus=None,
  111. max_gpus=None,
  112. prefer_larger=True):
  113. '''We use two heuristics to compute the batch size
  114. 1. We use the Lowest Common Multiple of the micro-batches
  115. as the base batch size and scale it by a HCN such that the result is
  116. the largest batch size less than the max_acceptable batch size
  117. 2. We use each of the micro batches as a base and scale it
  118. by a HCN such that the result is the largest batch size less than the
  119. max_acceptable batch size.
  120. We then use brute force to count the number of compatible GPU count for
  121. each of the aforementioned cases, and return the batch size with the most number of
  122. compatible GPU counts in the min-max GPU range if provided, other wise
  123. we return the batch size with the most number of total compatible GPU counts.
  124. Returns:
  125. final_batch_size
  126. valid_gpus
  127. '''
  128. min_gpus = min_gpus or 1
  129. max_gpus = max_gpus or max_acceptable_batch_size // min(micro_batches)
  130. if not all(mb <= max_acceptable_batch_size for mb in micro_batches):
  131. raise ValueError(f"All micro batches must be less than \
  132. or equal to max_acceptable_batch_size: {max_acceptable_batch_size}")
  133. lcm = np.lcm.reduce(micro_batches)
  134. base_list = []
  135. base_list.extend(micro_batches)
  136. base_list.append(lcm)
  137. candidate_batch_sizes = get_candidate_batch_sizes(base_list,
  138. max_acceptable_batch_size)
  139. final_batch_size, valid_gpus = get_best_candidates(
  140. candidate_batch_sizes,
  141. micro_batches,
  142. min_gpus,
  143. max_gpus,
  144. prefer_larger)
  145. return final_batch_size, valid_gpus
  146. def _get_compatible_gpus_v02(micro_batches,
  147. max_acceptable_batch_size,
  148. current_num_gpus,
  149. min_gpus=None,
  150. max_gpus=None,
  151. prefer_larger=True,
  152. num_gpus_per_node=1,
  153. model_parallel_size=1):
  154. '''
  155. Returns:
  156. final_batch_size
  157. valid_gpus
  158. micro-batch size
  159. '''
  160. if num_gpus_per_node % model_parallel_size != 0:
  161. raise ElasticityError(
  162. f"In Elasticity v0.2, number of GPUs per node:" \
  163. f"{num_gpus_per_node} should be divisible by " \
  164. f"model parallel size {model_parallel_size}")
  165. def get_microbatch(final_batch_size):
  166. candidate_microbatch = None
  167. for micro_batch in micro_batches:
  168. if final_batch_size // current_num_gpus % micro_batch == 0:
  169. if candidate_microbatch == None:
  170. candidate_microbatch = micro_batch
  171. if prefer_larger and candidate_microbatch < micro_batch:
  172. candidate_microbatch = micro_batch
  173. return candidate_microbatch
  174. dp_size_per_node = num_gpus_per_node // model_parallel_size
  175. final_batch_size, valid_world_size = _get_compatible_gpus_v01(micro_batches,
  176. int(max_acceptable_batch_size/dp_size_per_node),
  177. int(min_gpus/num_gpus_per_node),
  178. int(max_gpus/num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level
  179. prefer_larger=prefer_larger)
  180. final_batch_size = int(final_batch_size) * dp_size_per_node
  181. valid_dp_world_size = [i * dp_size_per_node for i in valid_world_size]
  182. if current_num_gpus // model_parallel_size in valid_dp_world_size:
  183. candidate_microbatch = get_microbatch(final_batch_size)
  184. return final_batch_size, valid_dp_world_size, candidate_microbatch
  185. current_dp_size = (current_num_gpus / num_gpus_per_node) * dp_size_per_node
  186. candidate_batch_sizes = []
  187. for micro_batch in micro_batches:
  188. min_batch_size = micro_batch * current_dp_size
  189. factor = math.floor(max_acceptable_batch_size / float(min_batch_size))
  190. candidate_batch_sizes.append(factor * min_batch_size)
  191. used_microbatch = None
  192. if prefer_larger:
  193. candidate_batch_size = max(candidate_batch_sizes)
  194. else:
  195. candidate_batch_size = min(candidate_batch_sizes)
  196. candidate_microbatch = get_microbatch(candidate_batch_size)
  197. return candidate_batch_size, [int(current_dp_size)], candidate_microbatch
  198. def _compatible_ds_version_check(target_deepspeed_version: str):
  199. min_version = pkg_version.parse(MINIMUM_DEEPSPEED_VERSION)
  200. target_version = pkg_version.parse(target_deepspeed_version)
  201. err_str = f"Target deepspeed version of {target_deepspeed_version} is not compatible " \
  202. f"with minimum version {MINIMUM_DEEPSPEED_VERSION} supporting elasticity."
  203. if target_version < min_version:
  204. raise ElasticityError(err_str)
  205. return True
  206. def elasticity_enabled(ds_config: dict):
  207. if ELASTICITY not in ds_config:
  208. return False
  209. return ds_config[ELASTICITY].get(ENABLED, ENABLED_DEFAULT)
  210. def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict):
  211. """
  212. Ensure the resource scheduler saw the same elastic config we are using at runtime
  213. """
  214. if DEEPSPEED_ELASTICITY_CONFIG in os.environ:
  215. scheduler_elastic_config_dict = json.loads(
  216. os.environ[DEEPSPEED_ELASTICITY_CONFIG])
  217. scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict)
  218. runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict)
  219. err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}"
  220. if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size:
  221. raise ElasticityConfigError(
  222. err_str.format('max_acceptable_batch_size',
  223. scheduler_elastic_config.max_acceptable_batch_size,
  224. 'max_acceptable_batch_size',
  225. runtime_elastic_config.max_acceptable_batch_size))
  226. if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches:
  227. raise ElasticityConfigError(
  228. err_str.format('micro_batches',
  229. scheduler_elastic_config.micro_batches,
  230. 'micro_batches',
  231. runtime_elastic_config.micro_batches))
  232. if runtime_elastic_config.version != scheduler_elastic_config.version:
  233. raise ElasticityConfigError(
  234. err_str.format('version',
  235. scheduler_elastic_config.version,
  236. 'version',
  237. runtime_elastic_config.version))
  238. else:
  239. logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \
  240. "guarantee resource scheduler will scale this job using compatible GPU counts.")
  241. def compute_elastic_config(ds_config: dict,
  242. target_deepspeed_version: str,
  243. world_size=0,
  244. return_microbatch=False):
  245. """Core deepspeed elasticity API. Given an elastic config (similar to the example below)
  246. DeepSpeed will compute a total train batch size corresponding valid GPU count list that
  247. provides a high level of elasticity. Elasticity in this case means we are safe to scale
  248. the training job up/down across the GPU count list *without* any negative impacts on
  249. training convergence. This is achievable primarily due to DeepSpeed's gradient accumulation
  250. feature which allows us to decompose a global training batch size into:
  251. micro-batch-size * gradient-accumulation-steps * world-size.
  252. "elasticity": {
  253. "enabled": true,
  254. "max_train_batch_size": 2000,
  255. "micro_batch_sizes": [2,4,6],
  256. "min_gpus": 1,
  257. "max_gpus" : 10000
  258. "min_time": 20
  259. "version": 0.1
  260. }
  261. Intended to be called both by scheduling infrastructure and deepspeed runtime.
  262. For the same `ds_config` we should return deterministic results.
  263. Args:
  264. ds_config (dict): DeepSpeed config dictionary/json
  265. target_deepspeed_version (str): When called from scheduling
  266. infrastructure we want to ensure that the target deepspeed version is
  267. compatible with the elasticity version used in the backend.
  268. world_size (int, optional): Intended/current DP world size, will do some sanity
  269. checks to ensure world size is actually valid with the config.
  270. return_microbatch (bool, optional): whether to return micro batch size or not.
  271. Raises:
  272. ElasticityConfigError: Missing required elasticity config or elasticity disabled
  273. ElasticityError: If target deepspeed version is not compatible with current version
  274. Returns:
  275. final_batch_size (int): total batch size used for training
  276. valid_gpus (list(int)): list of valid GPU counts with this config
  277. micro_batch_size (int, optional): if world_size is provided will return
  278. specific micro batch size
  279. """
  280. if not isinstance(ds_config, dict):
  281. raise ValueError("Expected ds_config to be a dictionary but received " \
  282. f"a {type(ds_config)}, containing: {ds_config}")
  283. if ELASTICITY not in ds_config:
  284. raise ElasticityConfigError(f"'{ELASTICITY}' is missing from config json," \
  285. " please add it if running an elastic training job.")
  286. elastic_config_dict = ds_config[ELASTICITY]
  287. if not elastic_config_dict.get(ENABLED, ENABLED_DEFAULT):
  288. raise ElasticityConfigError("Elasticity is disabled, please enable it " \
  289. "('enabled':true) if running an elastic training job.")
  290. elastic_config = ElasticityConfig(elastic_config_dict)
  291. model_parallel_size = elastic_config.model_parallel_size
  292. num_gpus_per_node = elastic_config.num_gpus_per_node
  293. if model_parallel_size > 1 and float(elastic_config.version) != 0.2:
  294. raise ElasticityConfigError(f"Elasticity V{elastic_config.version} " \
  295. f"does not support model-parallel training. Given model-parallel size: " \
  296. f"{model_parallel_size}")
  297. if float(elastic_config.version) > LATEST_ELASTICITY_VERSION:
  298. raise ElasticityConfigError("Attempting to run elasticity version " \
  299. f"{elastic_config.version} but runtime only supports up " \
  300. f"to {LATEST_ELASTICITY_VERSION}")
  301. # Ensure target deepspeed version works with intended elasticity version
  302. if not _compatible_ds_version_check(target_deepspeed_version):
  303. raise ElasticityError("Unable to run elasticity on target deepspeed version of" \
  304. f" {target_deepspeed_version}, currently {__version__}")
  305. if float(elastic_config.version) == 0.1:
  306. final_batch_size, valid_gpus = _get_compatible_gpus_v01(
  307. micro_batches=elastic_config.micro_batches,
  308. max_acceptable_batch_size=elastic_config.max_acceptable_batch_size,
  309. min_gpus=elastic_config.min_gpus,
  310. max_gpus=elastic_config.max_gpus,
  311. prefer_larger=elastic_config.prefer_larger_batch_size)
  312. # ensure batch size is int dtype
  313. final_batch_size = int(final_batch_size)
  314. elif float(elastic_config.version) == 0.2:
  315. if world_size != 0:
  316. current_num_gpus = world_size
  317. else:
  318. if "WORLD_SIZE" in os.environ and \
  319. os.getenv('WORLD_SIZE').isnumeric():
  320. current_num_gpus = int(os.getenv('WORLD_SIZE'))
  321. else:
  322. WORLD_SIZE = os.getenv('WORLD_SIZE')
  323. raise ElasticityConfigError(
  324. 'Elasticity V 0.2 needs WORLD_SIZE '\
  325. 'to compute valid batch size. '\
  326. 'Either give it as argument to function compute_elastic_config '\
  327. 'or set it as an environment variable. '\
  328. f'Value of WORLD_SIZE as environment variable is {WORLD_SIZE}')
  329. final_batch_size, valid_gpus, candidate_microbatch_size = _get_compatible_gpus_v02(
  330. micro_batches=elastic_config.micro_batches,
  331. max_acceptable_batch_size=elastic_config.max_acceptable_batch_size,
  332. current_num_gpus=current_num_gpus,
  333. min_gpus=elastic_config.min_gpus,
  334. max_gpus=elastic_config.max_gpus,
  335. prefer_larger=elastic_config.prefer_larger_batch_size,
  336. num_gpus_per_node=num_gpus_per_node,
  337. model_parallel_size=model_parallel_size)
  338. # ensure batch size is int dtype
  339. final_batch_size = int(final_batch_size)
  340. else:
  341. raise NotImplementedError(
  342. f"Unable to find elastic logic for version: {elastic_config.version}")
  343. logger.info(f"Valid World Size (GPUs / Model Parallel Size): {valid_gpus}")
  344. if world_size > 0:
  345. if world_size not in valid_gpus:
  346. raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \
  347. f"with the current list of valid GPU counts: {valid_gpus}")
  348. # Pick largest valid micro batch size
  349. micro_batch_size = None
  350. for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True):
  351. if final_batch_size // world_size % mbsz == 0:
  352. micro_batch_size = mbsz
  353. break
  354. assert micro_batch_size is not None, "Unable to find divisible micro batch size" \
  355. f" world_size={world_size}, final_batch_size={final_batch_size}, and " \
  356. f" micro_batches={elastic_config.micro_batches}."
  357. return final_batch_size, valid_gpus, micro_batch_size
  358. if return_microbatch:
  359. # Pick a valid micro batch size
  360. if float(elastic_config.version) == 0.2:
  361. return final_batch_size, valid_gpus, candidate_microbatch_size
  362. else:
  363. micro_batch_size = None
  364. for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True):
  365. if final_batch_size // world_size % mbsz == 0:
  366. micro_batch_size = mbsz
  367. break
  368. assert micro_batch_size is not None, "Unable to find divisible micro batch size" \
  369. f" world_size={world_size}, final_batch_size={final_batch_size}, and " \
  370. f" micro_batches={elastic_config.micro_batches}."
  371. return final_batch_size, valid_gpus, micro_batch_size
  372. return final_batch_size, valid_gpus