elasticity.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. import os
  5. import re
  6. import json
  7. import numpy as np
  8. from packaging import version as pkg_version
  9. from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \
  10. ElasticityIncompatibleWorldSize
  11. from .constants import ELASTICITY, ENABLED, ENABLED_DEFAULT, LATEST_ELASTICITY_VERSION, \
  12. MINIMUM_DEEPSPEED_VERSION, IGNORE_NON_ELASTIC_BATCH_INFO, \
  13. IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT, DEEPSPEED_ELASTICITY_CONFIG
  14. from ..git_version_info import version as __version__
  15. from ..utils import logger
  16. # Thirty eight smallest highly composite numbers. The list should
  17. # be enough to support up to 720K batch size.
  18. HCN_LIST = [
  19. 1,
  20. 2,
  21. 4,
  22. 6,
  23. 12,
  24. 24,
  25. 36,
  26. 48,
  27. 60,
  28. 120,
  29. 180,
  30. 240,
  31. 360,
  32. 720,
  33. 840,
  34. 1260,
  35. 1680,
  36. 2520,
  37. 5040,
  38. 7560,
  39. 10080,
  40. 15120,
  41. 20160,
  42. 25200,
  43. 27720,
  44. 45360,
  45. 50400,
  46. 55440,
  47. 83160,
  48. 110880,
  49. 166320,
  50. 221760,
  51. 277200,
  52. 332640,
  53. 498960,
  54. 554400,
  55. 665280,
  56. 720720
  57. ]
  58. def get_candidate_batch_sizes(base_list, max_acceptable_batch_size):
  59. candidate_batch_size = []
  60. for base in base_list:
  61. if base >= max_acceptable_batch_size:
  62. candidate_batch_size.append(base)
  63. else:
  64. value = max_acceptable_batch_size // base
  65. index = np.argmax(np.asarray(HCN_LIST) > value)
  66. candidate_batch_size.append(HCN_LIST[index - 1] * base)
  67. candidate_batch_size = list(set(candidate_batch_size))
  68. logger.info(f"Candidate batch size: {candidate_batch_size}")
  69. return candidate_batch_size
  70. def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus):
  71. valid_gpus = []
  72. for micro_batch in micro_batches:
  73. if batch_size % micro_batch == 0:
  74. max_gpus = batch_size // micro_batch
  75. if max_gpus >= min_valid_gpus and max_gpus <= max_valid_gpus:
  76. valid_gpus.append(max_gpus)
  77. # find all factors less than max_gpus / 2
  78. for i in range(1, max_gpus // 2 + 1):
  79. if i > max_valid_gpus:
  80. break
  81. if i < min_valid_gpus:
  82. continue
  83. if max_gpus % i == 0:
  84. valid_gpus.append(i)
  85. valid_gpus = set(valid_gpus)
  86. valid_gpus = sorted(list(valid_gpus))
  87. logger.info(f"Valid GPUs: {valid_gpus}")
  88. return valid_gpus
  89. def get_best_candidates(candidate_batch_sizes,
  90. micro_batches,
  91. min_gpus,
  92. max_gpus,
  93. prefer_larger):
  94. max_valid_gpus = 0
  95. valid_gpus = None
  96. final_batch_size = int(min(micro_batches))
  97. for batch_size in candidate_batch_sizes:
  98. current_valid_gpus = get_valid_gpus(batch_size,
  99. micro_batches,
  100. min_gpus,
  101. max_gpus)
  102. if (len(current_valid_gpus) > max_valid_gpus
  103. or (len(current_valid_gpus) == max_valid_gpus and
  104. ((prefer_larger and batch_size > final_batch_size) or
  105. (not prefer_larger and batch_size < final_batch_size)))):
  106. max_valid_gpus = len(current_valid_gpus)
  107. valid_gpus = current_valid_gpus
  108. final_batch_size = batch_size
  109. return final_batch_size, valid_gpus
  110. def _get_compatible_gpus_v01(micro_batches,
  111. max_acceptable_batch_size,
  112. min_gpus=None,
  113. max_gpus=None,
  114. prefer_larger=True):
  115. '''We use two heuristics to compute the batch size
  116. 1. We use the Lowest Common Multiple of the micro-batches
  117. as the base batch size and scale it by a HCN such that the result is
  118. the largest batch size less than the max_acceptable batch size
  119. 2. We use each of the micro batches as a base and scale it
  120. by a HCN such that the result is the largest batch size less than the
  121. max_acceptable batch size.
  122. We then use brute force to count the number of compatible GPU count for
  123. each of the aforementioned cases, and return the batch size with the most number of
  124. compatible GPU counts in the min-max GPU range if provided, other wise
  125. we return the batch size with the most number of total compatible GPU counts.
  126. Returns:
  127. final_batch_size
  128. valid_gpus
  129. '''
  130. min_gpus = min_gpus or 1
  131. max_gpus = max_gpus or max_acceptable_batch_size // min(micro_batches)
  132. if not all(mb <= max_acceptable_batch_size for mb in micro_batches):
  133. raise ValueError(f"All micro batches must be less than \
  134. or equal to max_acceptable_batch_size: {max_acceptable_batch_size}")
  135. lcm = np.lcm.reduce(micro_batches)
  136. base_list = []
  137. base_list.extend(micro_batches)
  138. base_list.append(lcm)
  139. candidate_batch_sizes = get_candidate_batch_sizes(base_list,
  140. max_acceptable_batch_size)
  141. final_batch_size, valid_gpus = get_best_candidates(
  142. candidate_batch_sizes,
  143. micro_batches,
  144. min_gpus,
  145. max_gpus,
  146. prefer_larger)
  147. return final_batch_size, valid_gpus
  148. def _compatible_ds_version_check(target_deepspeed_version: str):
  149. min_version = pkg_version.parse(MINIMUM_DEEPSPEED_VERSION)
  150. target_version = pkg_version.parse(target_deepspeed_version)
  151. err_str = f"Target deepspeed version of {target_deepspeed_version} is not compatible " \
  152. f"with minimum version {MINIMUM_DEEPSPEED_VERSION} supporting elasticity."
  153. if target_version < min_version:
  154. raise ElasticityError(err_str)
  155. return True
  156. def elasticity_enabled(ds_config: dict):
  157. if ELASTICITY not in ds_config:
  158. return False
  159. return ds_config[ELASTICITY].get(ENABLED, ENABLED_DEFAULT)
  160. def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict):
  161. """
  162. Ensure the resource scheduler saw the same elastic config we are using at runtime
  163. """
  164. if DEEPSPEED_ELASTICITY_CONFIG in os.environ:
  165. scheduler_elastic_config_dict = json.loads(
  166. os.environ[DEEPSPEED_ELASTICITY_CONFIG])
  167. scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict)
  168. runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict)
  169. err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}"
  170. if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size:
  171. raise ElasticityConfigError(
  172. err_str.format('max_acceptable_batch_size',
  173. scheduler_elastic_config.max_acceptable_batch_size,
  174. 'max_acceptable_batch_size',
  175. runtime_elastic_config.max_acceptable_batch_size))
  176. if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches:
  177. raise ElasticityConfigError(
  178. err_str.format('micro_batches',
  179. scheduler_elastic_config.micro_batches,
  180. 'micro_batches',
  181. runtime_elastic_config.micro_batches))
  182. if runtime_elastic_config.version != scheduler_elastic_config.version:
  183. raise ElasticityConfigError(
  184. err_str.format('version',
  185. scheduler_elastic_config.version,
  186. 'version',
  187. runtime_elastic_config.version))
  188. else:
  189. logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \
  190. "guarantee resource scheduler will scale this job using compatible GPU counts.")
  191. def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0):
  192. """Core deepspeed elasticity API. Given an elastic config (similar to the example below)
  193. DeepSpeed will compute a total train batch size corresponding valid GPU count list that
  194. provides a high level of elasticity. Elasticity in this case means we are safe to scale
  195. the training job up/down across the GPU count list *without* any negative impacts on
  196. training convergence. This is achievable primarily due to DeepSpeed's gradient accumulation
  197. feature which allows us to decompose a global training batch size into:
  198. micro-batch-size * gradient-accumulation-steps * world-size.
  199. "elasticity": {
  200. "enabled": true,
  201. "max_train_batch_size": 2000,
  202. "micro_batch_sizes": [2,4,6],
  203. "min_gpus": 1,
  204. "max_gpus" : 10000
  205. "min_time": 20
  206. "version": 0.1
  207. }
  208. Intended to be called both by scheduling infrastructure and deepspeed runtime.
  209. For the same `ds_config` we should return deterministic results.
  210. Args:
  211. ds_config (dict): DeepSpeed config dictionary/json
  212. target_deepspeed_version (str): When called from scheduling
  213. infrastructure we want to ensure that the target deepspeed version is
  214. compatible with the elasticity version used in the backend.
  215. world_size (int, optional): Intended/current world size, will do some sanity
  216. checks to ensure world size is actually valid with the config.
  217. Raises:
  218. ElasticityConfigError: Missing required elasticity config or elasticity disabled
  219. ElasticityError: If target deepspeed version is not compatible with current version
  220. Returns:
  221. final_batch_size (int): total batch size used for training
  222. valid_gpus (list(int)): list of valid GPU counts with this config
  223. micro_batch_size (int, optional): if world_size is provided will return
  224. specific micro batch size
  225. """
  226. if not isinstance(ds_config, dict):
  227. raise ValueError("Expected ds_config to be a dictionary but received " \
  228. f"a {type(ds_config)}, containing: {ds_config}")
  229. if ELASTICITY not in ds_config:
  230. raise ElasticityConfigError(f"'{ELASTICITY}' is missing from config json," \
  231. " please add it if running an elastic training job.")
  232. elastic_config_dict = ds_config[ELASTICITY]
  233. if not elastic_config_dict.get(ENABLED, ENABLED_DEFAULT):
  234. raise ElasticityConfigError("Elasticity is disabled, please enable it " \
  235. "('enabled':true) if running an elastic training job.")
  236. elastic_config = ElasticityConfig(elastic_config_dict)
  237. if float(elastic_config.version) > LATEST_ELASTICITY_VERSION:
  238. raise ElasticityConfigError("Attempting to run elasticity version " \
  239. f"{elastic_config.version} but runtime only supports up " \
  240. f"to {LATEST_ELASTICITY_VERSION}")
  241. # Ensure target deepspeed version works with intended elasticity version
  242. if not _compatible_ds_version_check(target_deepspeed_version):
  243. raise ElasticityError("Unable to run elasticity on target deepspeed version of" \
  244. f" {target_deepspeed_version}, currently {__version__}")
  245. if float(elastic_config.version) == 0.1:
  246. final_batch_size, valid_gpus = _get_compatible_gpus_v01(
  247. micro_batches=elastic_config.micro_batches,
  248. max_acceptable_batch_size=elastic_config.max_acceptable_batch_size,
  249. min_gpus=elastic_config.min_gpus,
  250. max_gpus=elastic_config.max_gpus,
  251. prefer_larger=elastic_config.prefer_larger_batch_size)
  252. # ensure batch size is int dtype
  253. final_batch_size = int(final_batch_size)
  254. else:
  255. raise NotImplementedError(
  256. f"Unable to find elastic logic for version: {elastic_config.version}")
  257. if world_size > 0:
  258. if world_size not in valid_gpus:
  259. raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \
  260. f"with the current list of valid GPU counts: {valid_gpus}")
  261. # Pick largest valid micro batch size
  262. micro_batch_size = None
  263. for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True):
  264. if final_batch_size // world_size % mbsz == 0:
  265. micro_batch_size = mbsz
  266. break
  267. assert micro_batch_size is not None, "Unable to find divisible micro batch size" \
  268. f" world_size={world_size}, final_batch_size={final_batch_size}, and " \
  269. f" micro_batches={elastic_config.micro_batches}."
  270. return final_batch_size, valid_gpus, micro_batch_size
  271. return final_batch_size, valid_gpus