elasticity.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import json
  6. import numpy as np
  7. import math
  8. from packaging import version as pkg_version
  9. from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \
  10. ElasticityIncompatibleWorldSize
  11. from .constants import ELASTICITY, ENABLED, ENABLED_DEFAULT, LATEST_ELASTICITY_VERSION, \
  12. MINIMUM_DEEPSPEED_VERSION, DEEPSPEED_ELASTICITY_CONFIG
  13. from ..git_version_info import version as __version__
  14. from ..utils import logger
  15. # Thirty eight smallest highly composite numbers. The list should
  16. # be enough to support up to 720K batch size.
  17. HCN_LIST = [
  18. 1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680, 2520, 5040, 7560, 10080, 15120, 20160,
  19. 25200, 27720, 45360, 50400, 55440, 83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280, 720720
  20. ]
  21. def get_candidate_batch_sizes(base_list, max_acceptable_batch_size):
  22. candidate_batch_size = []
  23. for base in base_list:
  24. if base >= max_acceptable_batch_size:
  25. candidate_batch_size.append(base)
  26. else:
  27. value = max_acceptable_batch_size // base
  28. index = np.argmax(np.asarray(HCN_LIST) > value)
  29. candidate_batch_size.append(HCN_LIST[index - 1] * base)
  30. candidate_batch_size = list(set(candidate_batch_size))
  31. logger.info(f"Candidate batch size: {candidate_batch_size}")
  32. return candidate_batch_size
  33. def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus):
  34. valid_gpus = []
  35. for micro_batch in micro_batches:
  36. if batch_size % micro_batch == 0:
  37. max_gpus = batch_size // micro_batch
  38. if min_valid_gpus <= max_gpus <= max_valid_gpus:
  39. valid_gpus.append(max_gpus)
  40. # find all factors less than max_gpus / 2
  41. for i in range(1, max_gpus // 2 + 1):
  42. if i > max_valid_gpus:
  43. break
  44. if i < min_valid_gpus:
  45. continue
  46. if max_gpus % i == 0:
  47. valid_gpus.append(i)
  48. valid_gpus = set(valid_gpus)
  49. valid_gpus = sorted(list(valid_gpus))
  50. return valid_gpus
  51. def get_best_candidates(candidate_batch_sizes, micro_batches, min_gpus, max_gpus, prefer_larger):
  52. max_valid_gpus = 0
  53. valid_gpus = None
  54. final_batch_size = int(min(micro_batches))
  55. for batch_size in candidate_batch_sizes:
  56. current_valid_gpus = get_valid_gpus(batch_size, micro_batches, min_gpus, max_gpus)
  57. if (len(current_valid_gpus) > max_valid_gpus or (len(current_valid_gpus) == max_valid_gpus and
  58. ((prefer_larger and batch_size > final_batch_size) or
  59. (not prefer_larger and batch_size < final_batch_size)))):
  60. max_valid_gpus = len(current_valid_gpus)
  61. valid_gpus = current_valid_gpus
  62. final_batch_size = batch_size
  63. return final_batch_size, valid_gpus
  64. def _get_compatible_gpus_v01(micro_batches,
  65. max_acceptable_batch_size,
  66. min_gpus=None,
  67. max_gpus=None,
  68. prefer_larger=True):
  69. '''We use two heuristics to compute the batch size
  70. 1. We use the Lowest Common Multiple of the micro-batches
  71. as the base batch size and scale it by a HCN such that the result is
  72. the largest batch size less than the max_acceptable batch size
  73. 2. We use each of the micro batches as a base and scale it
  74. by a HCN such that the result is the largest batch size less than the
  75. max_acceptable batch size.
  76. We then use brute force to count the number of compatible GPU count for
  77. each of the aforementioned cases, and return the batch size with the most number of
  78. compatible GPU counts in the min-max GPU range if provided, other wise
  79. we return the batch size with the most number of total compatible GPU counts.
  80. Returns:
  81. final_batch_size
  82. valid_gpus
  83. '''
  84. min_gpus = min_gpus or 1
  85. max_gpus = max_gpus or max_acceptable_batch_size // min(micro_batches)
  86. if not all(mb <= max_acceptable_batch_size for mb in micro_batches):
  87. raise ValueError(f"All micro batches must be less than \
  88. or equal to max_acceptable_batch_size: {max_acceptable_batch_size}")
  89. lcm = np.lcm.reduce(micro_batches)
  90. base_list = []
  91. base_list.extend(micro_batches)
  92. base_list.append(lcm)
  93. candidate_batch_sizes = get_candidate_batch_sizes(base_list, max_acceptable_batch_size)
  94. final_batch_size, valid_gpus = get_best_candidates(candidate_batch_sizes, micro_batches, min_gpus, max_gpus,
  95. prefer_larger)
  96. return final_batch_size, valid_gpus
  97. def _get_compatible_gpus_v02(micro_batches,
  98. max_acceptable_batch_size,
  99. current_num_gpus,
  100. min_gpus=None,
  101. max_gpus=None,
  102. prefer_larger=True,
  103. num_gpus_per_node=1,
  104. model_parallel_size=1):
  105. '''
  106. Returns:
  107. final_batch_size
  108. valid_gpus
  109. micro-batch size
  110. '''
  111. if num_gpus_per_node % model_parallel_size != 0:
  112. raise ElasticityError(
  113. f"In Elasticity v0.2, number of GPUs per node:" \
  114. f"{num_gpus_per_node} should be divisible by " \
  115. f"model parallel size {model_parallel_size}")
  116. def get_microbatch(final_batch_size):
  117. candidate_microbatch = None
  118. for micro_batch in micro_batches:
  119. if final_batch_size // current_num_gpus % micro_batch == 0:
  120. if candidate_microbatch is None:
  121. candidate_microbatch = micro_batch
  122. if prefer_larger and candidate_microbatch < micro_batch:
  123. candidate_microbatch = micro_batch
  124. return candidate_microbatch
  125. dp_size_per_node = num_gpus_per_node // model_parallel_size
  126. final_batch_size, valid_world_size = _get_compatible_gpus_v01(
  127. micro_batches,
  128. int(max_acceptable_batch_size / dp_size_per_node),
  129. int(min_gpus / num_gpus_per_node),
  130. int(max_gpus / num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level
  131. prefer_larger=prefer_larger)
  132. final_batch_size = int(final_batch_size) * dp_size_per_node
  133. valid_dp_world_size = [i * dp_size_per_node for i in valid_world_size]
  134. if current_num_gpus // model_parallel_size in valid_dp_world_size:
  135. candidate_microbatch = get_microbatch(final_batch_size)
  136. return final_batch_size, valid_dp_world_size, candidate_microbatch
  137. current_dp_size = (current_num_gpus / num_gpus_per_node) * dp_size_per_node
  138. candidate_batch_sizes = []
  139. for micro_batch in micro_batches:
  140. min_batch_size = micro_batch * current_dp_size
  141. factor = math.floor(max_acceptable_batch_size / float(min_batch_size))
  142. candidate_batch_sizes.append(factor * min_batch_size)
  143. used_microbatch = None
  144. if prefer_larger:
  145. candidate_batch_size = max(candidate_batch_sizes)
  146. else:
  147. candidate_batch_size = min(candidate_batch_sizes)
  148. candidate_microbatch = get_microbatch(candidate_batch_size)
  149. return candidate_batch_size, [int(current_dp_size)], candidate_microbatch
  150. def _compatible_ds_version_check(target_deepspeed_version: str):
  151. min_version = pkg_version.parse(MINIMUM_DEEPSPEED_VERSION)
  152. target_version = pkg_version.parse(target_deepspeed_version)
  153. err_str = f"Target deepspeed version of {target_deepspeed_version} is not compatible " \
  154. f"with minimum version {MINIMUM_DEEPSPEED_VERSION} supporting elasticity."
  155. if target_version < min_version:
  156. raise ElasticityError(err_str)
  157. return True
  158. def elasticity_enabled(ds_config: dict):
  159. if ELASTICITY not in ds_config:
  160. return False
  161. return ds_config[ELASTICITY].get(ENABLED, ENABLED_DEFAULT)
  162. def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict):
  163. """
  164. Ensure the resource scheduler saw the same elastic config we are using at runtime
  165. """
  166. if DEEPSPEED_ELASTICITY_CONFIG in os.environ:
  167. scheduler_elastic_config_dict = json.loads(os.environ[DEEPSPEED_ELASTICITY_CONFIG])
  168. scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict)
  169. runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict)
  170. err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}"
  171. if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size:
  172. raise ElasticityConfigError(
  173. err_str.format('max_acceptable_batch_size', scheduler_elastic_config.max_acceptable_batch_size,
  174. 'max_acceptable_batch_size', runtime_elastic_config.max_acceptable_batch_size))
  175. if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches:
  176. raise ElasticityConfigError(
  177. err_str.format('micro_batches', scheduler_elastic_config.micro_batches, 'micro_batches',
  178. runtime_elastic_config.micro_batches))
  179. if runtime_elastic_config.version != scheduler_elastic_config.version:
  180. raise ElasticityConfigError(
  181. err_str.format('version', scheduler_elastic_config.version, 'version', runtime_elastic_config.version))
  182. else:
  183. logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \
  184. "guarantee resource scheduler will scale this job using compatible GPU counts.")
  185. def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0, return_microbatch=False):
  186. """Core deepspeed elasticity API. Given an elastic config (similar to the example below)
  187. DeepSpeed will compute a total train batch size corresponding valid GPU count list that
  188. provides a high level of elasticity. Elasticity in this case means we are safe to scale
  189. the training job up/down across the GPU count list *without* any negative impacts on
  190. training convergence. This is achievable primarily due to DeepSpeed's gradient accumulation
  191. feature which allows us to decompose a global training batch size into:
  192. micro-batch-size * gradient-accumulation-steps * world-size.
  193. "elasticity": {
  194. "enabled": true,
  195. "max_train_batch_size": 2000,
  196. "micro_batch_sizes": [2,4,6],
  197. "min_gpus": 1,
  198. "max_gpus" : 10000
  199. "min_time": 20
  200. "version": 0.1
  201. }
  202. Intended to be called both by scheduling infrastructure and deepspeed runtime.
  203. For the same `ds_config` we should return deterministic results.
  204. Args:
  205. ds_config (dict): DeepSpeed config dictionary/json
  206. target_deepspeed_version (str): When called from scheduling
  207. infrastructure we want to ensure that the target deepspeed version is
  208. compatible with the elasticity version used in the backend.
  209. world_size (int, optional): Intended/current DP world size, will do some sanity
  210. checks to ensure world size is actually valid with the config.
  211. return_microbatch (bool, optional): whether to return micro batch size or not.
  212. Raises:
  213. ElasticityConfigError: Missing required elasticity config or elasticity disabled
  214. ElasticityError: If target deepspeed version is not compatible with current version
  215. Returns:
  216. final_batch_size (int): total batch size used for training
  217. valid_gpus (list(int)): list of valid GPU counts with this config
  218. micro_batch_size (int, optional): if world_size is provided will return
  219. specific micro batch size
  220. """
  221. if not isinstance(ds_config, dict):
  222. raise ValueError("Expected ds_config to be a dictionary but received " \
  223. f"a {type(ds_config)}, containing: {ds_config}")
  224. if ELASTICITY not in ds_config:
  225. raise ElasticityConfigError(f"'{ELASTICITY}' is missing from config json," \
  226. " please add it if running an elastic training job.")
  227. elastic_config_dict = ds_config[ELASTICITY]
  228. if not elastic_config_dict.get(ENABLED, ENABLED_DEFAULT):
  229. raise ElasticityConfigError("Elasticity is disabled, please enable it " \
  230. "('enabled':true) if running an elastic training job.")
  231. elastic_config = ElasticityConfig(elastic_config_dict)
  232. model_parallel_size = elastic_config.model_parallel_size
  233. num_gpus_per_node = elastic_config.num_gpus_per_node
  234. if model_parallel_size > 1 and float(elastic_config.version) != 0.2:
  235. raise ElasticityConfigError(f"Elasticity V{elastic_config.version} " \
  236. f"does not support model-parallel training. Given model-parallel size: " \
  237. f"{model_parallel_size}")
  238. if float(elastic_config.version) > LATEST_ELASTICITY_VERSION:
  239. raise ElasticityConfigError("Attempting to run elasticity version " \
  240. f"{elastic_config.version} but runtime only supports up " \
  241. f"to {LATEST_ELASTICITY_VERSION}")
  242. # Ensure target deepspeed version works with intended elasticity version
  243. if not _compatible_ds_version_check(target_deepspeed_version):
  244. raise ElasticityError("Unable to run elasticity on target deepspeed version of" \
  245. f" {target_deepspeed_version}, currently {__version__}")
  246. if float(elastic_config.version) == 0.1:
  247. final_batch_size, valid_gpus = _get_compatible_gpus_v01(
  248. micro_batches=elastic_config.micro_batches,
  249. max_acceptable_batch_size=elastic_config.max_acceptable_batch_size,
  250. min_gpus=elastic_config.min_gpus,
  251. max_gpus=elastic_config.max_gpus,
  252. prefer_larger=elastic_config.prefer_larger_batch_size)
  253. # ensure batch size is int dtype
  254. final_batch_size = int(final_batch_size)
  255. elif float(elastic_config.version) == 0.2:
  256. if world_size != 0:
  257. current_num_gpus = world_size
  258. else:
  259. if "WORLD_SIZE" in os.environ and \
  260. os.getenv('WORLD_SIZE').isnumeric():
  261. current_num_gpus = int(os.getenv('WORLD_SIZE'))
  262. else:
  263. WORLD_SIZE = os.getenv('WORLD_SIZE')
  264. raise ElasticityConfigError(
  265. 'Elasticity V 0.2 needs WORLD_SIZE '\
  266. 'to compute valid batch size. '\
  267. 'Either give it as argument to function compute_elastic_config '\
  268. 'or set it as an environment variable. '\
  269. f'Value of WORLD_SIZE as environment variable is {WORLD_SIZE}')
  270. final_batch_size, valid_gpus, candidate_microbatch_size = _get_compatible_gpus_v02(
  271. micro_batches=elastic_config.micro_batches,
  272. max_acceptable_batch_size=elastic_config.max_acceptable_batch_size,
  273. current_num_gpus=current_num_gpus,
  274. min_gpus=elastic_config.min_gpus,
  275. max_gpus=elastic_config.max_gpus,
  276. prefer_larger=elastic_config.prefer_larger_batch_size,
  277. num_gpus_per_node=num_gpus_per_node,
  278. model_parallel_size=model_parallel_size)
  279. # ensure batch size is int dtype
  280. final_batch_size = int(final_batch_size)
  281. else:
  282. raise NotImplementedError(f"Unable to find elastic logic for version: {elastic_config.version}")
  283. logger.info(f"Valid World Size (GPUs / Model Parallel Size): {valid_gpus}")
  284. if world_size > 0:
  285. if world_size not in valid_gpus:
  286. raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \
  287. f"with the current list of valid GPU counts: {valid_gpus}")
  288. # Pick largest valid micro batch size
  289. micro_batch_size = None
  290. for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True):
  291. if final_batch_size // world_size % mbsz == 0:
  292. micro_batch_size = mbsz
  293. break
  294. assert micro_batch_size is not None, "Unable to find divisible micro batch size" \
  295. f" world_size={world_size}, final_batch_size={final_batch_size}, and " \
  296. f" micro_batches={elastic_config.micro_batches}."
  297. return final_batch_size, valid_gpus, micro_batch_size
  298. if return_microbatch:
  299. # Pick a valid micro batch size
  300. if float(elastic_config.version) == 0.2:
  301. return final_batch_size, valid_gpus, candidate_microbatch_size
  302. else:
  303. micro_batch_size = None
  304. for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True):
  305. if final_batch_size // world_size % mbsz == 0:
  306. micro_batch_size = mbsz
  307. break
  308. assert micro_batch_size is not None, "Unable to find divisible micro batch size" \
  309. f" world_size={world_size}, final_batch_size={final_batch_size}, and " \
  310. f" micro_batches={elastic_config.micro_batches}."
  311. return final_batch_size, valid_gpus, micro_batch_size
  312. return final_batch_size, valid_gpus