config.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117
  1. """
  2. Copyright (c) Microsoft Corporation
  3. Licensed under the MIT license.
  4. """
  5. import os
  6. from typing import Union
  7. import torch
  8. import json
  9. import copy
  10. from .constants import *
  11. from .fp16.loss_scaler import (
  12. INITIAL_LOSS_SCALE,
  13. SCALE_WINDOW,
  14. DELAYED_SHIFT,
  15. MIN_LOSS_SCALE,
  16. )
  17. from .config_utils import (
  18. get_scalar_param,
  19. dict_raise_error_on_duplicate_keys,
  20. ScientificNotationEncoder,
  21. )
  22. from .zero.config import DeepSpeedZeroConfig
  23. from .zero.constants import *
  24. from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
  25. from ..git_version_info import version as __version__
  26. from ..utils import logger
  27. from ..elasticity import (
  28. elasticity_enabled,
  29. compute_elastic_config,
  30. ensure_immutable_elastic_config,
  31. )
  32. from ..elasticity.config import ElasticityConfigError
  33. from ..elasticity.constants import (
  34. ELASTICITY,
  35. IGNORE_NON_ELASTIC_BATCH_INFO,
  36. IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT,
  37. )
  38. from ..profiling.config import DeepSpeedFlopsProfilerConfig
  39. from ..autotuning.config import DeepSpeedAutotuningConfig
  40. from .swap_tensor.aio_config import get_aio_config
  41. TENSOR_CORE_ALIGN_SIZE = 8
  42. ADAGRAD_OPTIMIZER = 'adagrad'
  43. ADAM_OPTIMIZER = 'adam'
  44. ADAMW_OPTIMIZER = 'adamw'
  45. LAMB_OPTIMIZER = 'lamb'
  46. ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
  47. ONEBIT_LAMB_OPTIMIZER = 'onebitlamb'
  48. DEEPSPEED_OPTIMIZERS = [
  49. ADAGRAD_OPTIMIZER,
  50. ADAM_OPTIMIZER,
  51. ADAMW_OPTIMIZER,
  52. LAMB_OPTIMIZER,
  53. ONEBIT_ADAM_OPTIMIZER,
  54. ONEBIT_LAMB_OPTIMIZER,
  55. ]
  56. # extra optimizer parameters for adam/adamw
  57. TORCH_ADAM_PARAM = "torch_adam"
  58. # default to adamw logic for adam/adamw optimizers unless user explicitly opts out
  59. ADAM_W_MODE = "adam_w_mode"
  60. ADAM_W_MODE_DEFAULT = True
  61. class DeepSpeedConfigError(Exception):
  62. pass
  63. def get_curriculum_enabled(param_dict):
  64. if CURRICULUM_LEARNING in param_dict.keys():
  65. return get_scalar_param(param_dict[CURRICULUM_LEARNING],
  66. CURRICULUM_ENABLED,
  67. CURRICULUM_ENABLED_DEFAULT)
  68. else:
  69. return False
  70. def get_curriculum_params(param_dict):
  71. if CURRICULUM_LEARNING in param_dict.keys():
  72. curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING])
  73. curriculum_params.pop(CURRICULUM_ENABLED)
  74. return curriculum_params
  75. else:
  76. return False
  77. def get_pld_enabled(param_dict):
  78. if PROGRESSIVE_LAYER_DROP in param_dict.keys():
  79. return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP],
  80. PLD_ENABLED,
  81. PLD_ENABLED_DEFAULT)
  82. else:
  83. return False
  84. def get_pld_params(param_dict):
  85. if PROGRESSIVE_LAYER_DROP in param_dict.keys():
  86. pld_params = copy.copy(param_dict[PROGRESSIVE_LAYER_DROP])
  87. pld_params.pop(PLD_ENABLED)
  88. return pld_params
  89. else:
  90. return False
  91. def get_amp_enabled(param_dict):
  92. if AMP in param_dict.keys():
  93. return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT)
  94. else:
  95. return False
  96. def get_amp_params(param_dict):
  97. if AMP in param_dict.keys():
  98. amp_params = copy.copy(param_dict[AMP])
  99. amp_params.pop(AMP_ENABLED)
  100. return amp_params
  101. else:
  102. return False
  103. def get_fp16_enabled(param_dict):
  104. if FP16 in param_dict.keys():
  105. return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
  106. else:
  107. return False
  108. def get_bfloat16_enabled(param_dict):
  109. if BFLOAT16 in param_dict.keys():
  110. return get_scalar_param(param_dict[BFLOAT16],
  111. BFLOAT16_ENABLED,
  112. BFLOAT16_ENABLED_DEFAULT)
  113. else:
  114. return False
  115. def get_fp16_master_weights_and_grads_enabled(param_dict):
  116. if get_fp16_enabled(param_dict):
  117. return get_scalar_param(param_dict[FP16],
  118. FP16_MASTER_WEIGHTS_AND_GRADS,
  119. FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT)
  120. else:
  121. return False
  122. def get_loss_scale(param_dict):
  123. if get_fp16_enabled(param_dict):
  124. return get_scalar_param(param_dict[FP16],
  125. FP16_LOSS_SCALE,
  126. FP16_LOSS_SCALE_DEFAULT)
  127. elif get_bfloat16_enabled(param_dict):
  128. return 1.0
  129. else:
  130. return FP16_LOSS_SCALE_DEFAULT
  131. def get_initial_dynamic_scale(param_dict):
  132. if get_fp16_enabled(param_dict):
  133. initial_scale_power = get_scalar_param(param_dict[FP16],
  134. FP16_INITIAL_SCALE_POWER,
  135. FP16_INITIAL_SCALE_POWER_DEFAULT)
  136. elif get_bfloat16_enabled(param_dict):
  137. initial_scale_power = 0
  138. else:
  139. initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT
  140. return 2**initial_scale_power
  141. def get_dynamic_loss_scale_args(param_dict):
  142. loss_scale_args = None
  143. if get_fp16_enabled(param_dict):
  144. fp16_dict = param_dict[FP16]
  145. dynamic_loss_args = [
  146. FP16_INITIAL_SCALE_POWER,
  147. FP16_LOSS_SCALE_WINDOW,
  148. FP16_MIN_LOSS_SCALE,
  149. FP16_HYSTERESIS,
  150. ]
  151. if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args):
  152. init_scale = get_scalar_param(fp16_dict,
  153. FP16_INITIAL_SCALE_POWER,
  154. FP16_INITIAL_SCALE_POWER_DEFAULT)
  155. scale_window = get_scalar_param(fp16_dict,
  156. FP16_LOSS_SCALE_WINDOW,
  157. FP16_LOSS_SCALE_WINDOW_DEFAULT)
  158. delayed_shift = get_scalar_param(fp16_dict,
  159. FP16_HYSTERESIS,
  160. FP16_HYSTERESIS_DEFAULT)
  161. min_loss_scale = get_scalar_param(fp16_dict,
  162. FP16_MIN_LOSS_SCALE,
  163. FP16_MIN_LOSS_SCALE_DEFAULT)
  164. loss_scale_args = {
  165. INITIAL_LOSS_SCALE: 2**init_scale,
  166. SCALE_WINDOW: scale_window,
  167. DELAYED_SHIFT: delayed_shift,
  168. MIN_LOSS_SCALE: min_loss_scale,
  169. }
  170. return loss_scale_args
  171. def get_gradient_accumulation_steps(param_dict):
  172. return get_scalar_param(param_dict,
  173. GRADIENT_ACCUMULATION_STEPS,
  174. GRADIENT_ACCUMULATION_STEPS_DEFAULT)
  175. def get_sparse_gradients_enabled(param_dict):
  176. return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT)
  177. def get_zero_optimization(param_dict):
  178. return get_scalar_param(param_dict, ZERO_OPTIMIZATION, ZERO_OPTIMIZATION_DEFAULT)
  179. def get_zero_reduce_scatter(param_dict):
  180. return get_scalar_param(
  181. param_dict,
  182. ZERO_OPTIMIZATION_REDUCE_SCATTER,
  183. ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
  184. )
  185. def get_communication_data_type(param_dict):
  186. val = get_scalar_param(param_dict,
  187. COMMUNICATION_DATA_TYPE,
  188. COMMUNICATION_DATA_TYPE_DEFAULT)
  189. val = val.lower() if val is not None else val
  190. if val is None:
  191. return val # we must determine it by other parameters
  192. elif val == "fp32":
  193. return torch.float32
  194. elif val == "fp16":
  195. return torch.float16
  196. elif val == "bfp16":
  197. return torch.bfloat16
  198. raise ValueError(
  199. f"Invalid communication_data_type. Supported data types: ['fp16', 'bfp16', 'fp32']. Got: {val}"
  200. )
  201. def get_prescale_gradients(param_dict):
  202. return get_scalar_param(param_dict, PRESCALE_GRADIENTS, PRESCALE_GRADIENTS_DEFAULT)
  203. def get_gradient_predivide_factor(param_dict):
  204. return get_scalar_param(param_dict,
  205. GRADIENT_PREDIVIDE_FACTOR,
  206. GRADIENT_PREDIVIDE_FACTOR_DEFAULT)
  207. def get_quantize_enabled(param_dict):
  208. if QUANTIZE_TRAINING in param_dict.keys():
  209. return get_scalar_param(
  210. param_dict[QUANTIZE_TRAINING],
  211. QUANTIZE_TRAINING_ENABLED,
  212. QUANTIZE_TRAINING_ENABLED_DEFAULT,
  213. )
  214. else:
  215. return False
  216. def get_quantize_training(param_dict):
  217. if QUANTIZE_TRAINING in param_dict.keys():
  218. return (
  219. (param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS][TARGET_BITS]),
  220. (param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS][START_BITS]
  221. if START_BITS in param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS].keys() else
  222. QUANTIZE_START_BITS_DEFAULT),
  223. (param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE][QUANTIZE_PERIOD]
  224. if QUANTIZE_SCHEDULE in param_dict[QUANTIZE_TRAINING].keys() else
  225. QUANTIZE_PERIOD_DEFAULT),
  226. (param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE][SCHEDULE_OFFSET]
  227. if QUANTIZE_SCHEDULE in param_dict[QUANTIZE_TRAINING].keys() and
  228. SCHEDULE_OFFSET in param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE].keys()
  229. else QUANTIZE_OFFSET_DEFAULT),
  230. (param_dict[QUANTIZE_TRAINING][QUANTIZE_GROUPS] if QUANTIZE_GROUPS
  231. in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_GROUPS_DEFAULT),
  232. (param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE]
  233. [FP16_MIXED_QUANTIZE_ENABLED]
  234. if FP16_MIXED_QUANTIZE in param_dict[QUANTIZE_TRAINING].keys()
  235. and FP16_MIXED_QUANTIZE_ENABLED
  236. in param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE].keys() else
  237. FP16_MIXED_QUANTIZE_ENABLED_DEFAULT),
  238. (param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE][QUANTIZE_CHANGE_RATIO]
  239. if FP16_MIXED_QUANTIZE in param_dict[QUANTIZE_TRAINING].keys()
  240. and QUANTIZE_CHANGE_RATIO
  241. in param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE].keys() else
  242. QUANTIZE_CHANGE_RATIO_DEFAULT),
  243. (1 if QUANTIZE_ALGO in param_dict[QUANTIZE_TRAINING]
  244. and QUANTIZE_TYPE in param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO].keys()
  245. and param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO][QUANTIZE_TYPE]
  246. == QUANTIZE_ASYMMETRIC else QUANTIZE_TYPE_DEFAULT),
  247. (1 if QUANTIZE_ALGO in param_dict[QUANTIZE_TRAINING] and QUANTIZE_ROUNDING
  248. in param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO].keys()
  249. and param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO][QUANTIZE_ROUNDING]
  250. == STOCHASTIC_ROUNDING else QUANTIZE_ROUNDING_DEFAULT),
  251. (param_dict[QUANTIZE_TRAINING][QUANTIZE_VERBOSE] if QUANTIZE_VERBOSE
  252. in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_VERBOSE_DEFAULT),
  253. (param_dict[QUANTIZE_TRAINING][QUANTIZER_KERNEL] if QUANTIZER_KERNEL
  254. in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZER_KERNEL_DEFAULT),
  255. )
  256. else:
  257. return (
  258. QUANTIZE_TARGET_BITS_DEFAULT,
  259. QUANTIZE_START_BITS_DEFAULT,
  260. QUANTIZE_PERIOD_DEFAULT,
  261. QUANTIZE_OFFSET_DEFAULT,
  262. QUANTIZE_GROUPS_DEFAULT,
  263. FP16_MIXED_QUANTIZE_ENABLED_DEFAULT,
  264. QUANTIZE_CHANGE_RATIO_DEFAULT,
  265. QUANTIZE_TYPE_DEFAULT,
  266. QUANTIZE_ROUNDING_DEFAULT,
  267. QUANTIZE_VERBOSE_DEFAULT,
  268. QUANTIZER_KERNEL_DEFAULT,
  269. )
  270. def get_steps_per_print(param_dict):
  271. return get_scalar_param(param_dict, STEPS_PER_PRINT, STEPS_PER_PRINT_DEFAULT)
  272. def get_disable_allgather(param_dict):
  273. return get_scalar_param(param_dict, DISABLE_ALLGATHER, DISABLE_ALLGATHER_DEFAULT)
  274. def get_dump_state(param_dict):
  275. return get_scalar_param(param_dict, DUMP_STATE, DUMP_STATE_DEFAULT)
  276. def get_gradient_clipping(param_dict):
  277. return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT)
  278. def get_sparse_attention(param_dict):
  279. if SPARSE_ATTENTION in param_dict.keys():
  280. sparsity = param_dict[SPARSE_ATTENTION]
  281. mode = get_sparse_attention_mode(sparsity)
  282. if mode == SPARSE_DENSE_MODE:
  283. return get_sparse_dense_config(sparsity)
  284. elif mode == SPARSE_FIXED_MODE:
  285. return get_sparse_fixed_config(sparsity)
  286. elif mode == SPARSE_VARIABLE_MODE:
  287. return get_sparse_variable_config(sparsity)
  288. elif mode == SPARSE_BIGBIRD_MODE:
  289. return get_sparse_bigbird_config(sparsity)
  290. elif mode == SPARSE_BSLONGFORMER_MODE:
  291. return get_sparse_bslongformer_config(sparsity)
  292. else:
  293. raise NotImplementedError(
  294. f"Given sparsity mode, {mode}, has not been implemented yet!")
  295. else:
  296. return None
  297. def get_sparse_dense_config(sparsity):
  298. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  299. return {SPARSE_MODE: SPARSE_DENSE_MODE, SPARSE_BLOCK: block}
  300. def get_sparse_fixed_config(sparsity):
  301. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  302. different_layout_per_head = get_scalar_param(
  303. sparsity,
  304. SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
  305. SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
  306. )
  307. num_local_blocks = get_scalar_param(sparsity,
  308. SPARSE_NUM_LOCAL_BLOCKS,
  309. SPARSE_NUM_LOCAL_BLOCKS_DEFAULT)
  310. num_global_blocks = get_scalar_param(sparsity,
  311. SPARSE_NUM_GLOBAL_BLOCKS,
  312. SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
  313. attention = get_scalar_param(sparsity,
  314. SPARSE_ATTENTION_TYPE,
  315. SPARSE_ATTENTION_TYPE_DEFAULT)
  316. horizontal_global_attention = get_scalar_param(
  317. sparsity,
  318. SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
  319. SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT,
  320. )
  321. num_different_global_patterns = get_scalar_param(
  322. sparsity,
  323. SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS,
  324. SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT,
  325. )
  326. return {
  327. SPARSE_MODE: SPARSE_FIXED_MODE,
  328. SPARSE_BLOCK: block,
  329. SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
  330. SPARSE_NUM_LOCAL_BLOCKS: num_local_blocks,
  331. SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks,
  332. SPARSE_ATTENTION_TYPE: attention,
  333. SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention,
  334. SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS: num_different_global_patterns,
  335. }
  336. def get_sparse_variable_config(sparsity):
  337. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  338. different_layout_per_head = get_scalar_param(
  339. sparsity,
  340. SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
  341. SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
  342. )
  343. num_random_blocks = get_scalar_param(sparsity,
  344. SPARSE_NUM_RANDOM_BLOCKS,
  345. SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
  346. local_window_blocks = get_scalar_param(sparsity,
  347. SPARSE_LOCAL_WINDOW_BLOCKS,
  348. SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT)
  349. global_block_indices = get_scalar_param(sparsity,
  350. SPARSE_GLOBAL_BLOCK_INDICES,
  351. SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
  352. global_block_end_indices = get_scalar_param(
  353. sparsity,
  354. SPARSE_GLOBAL_BLOCK_END_INDICES,
  355. SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT,
  356. )
  357. attention = get_scalar_param(sparsity,
  358. SPARSE_ATTENTION_TYPE,
  359. SPARSE_ATTENTION_TYPE_DEFAULT)
  360. horizontal_global_attention = get_scalar_param(
  361. sparsity,
  362. SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
  363. SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT,
  364. )
  365. return {
  366. SPARSE_MODE: SPARSE_VARIABLE_MODE,
  367. SPARSE_BLOCK: block,
  368. SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
  369. SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
  370. SPARSE_LOCAL_WINDOW_BLOCKS: local_window_blocks,
  371. SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
  372. SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices,
  373. SPARSE_ATTENTION_TYPE: attention,
  374. SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention,
  375. }
  376. def get_sparse_bigbird_config(sparsity):
  377. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  378. different_layout_per_head = get_scalar_param(
  379. sparsity,
  380. SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
  381. SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
  382. )
  383. num_random_blocks = get_scalar_param(sparsity,
  384. SPARSE_NUM_RANDOM_BLOCKS,
  385. SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
  386. num_sliding_window_blocks = get_scalar_param(
  387. sparsity,
  388. SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
  389. SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT,
  390. )
  391. num_global_blocks = get_scalar_param(sparsity,
  392. SPARSE_NUM_GLOBAL_BLOCKS,
  393. SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
  394. return {
  395. SPARSE_MODE: SPARSE_BIGBIRD_MODE,
  396. SPARSE_BLOCK: block,
  397. SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
  398. SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
  399. SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
  400. SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks,
  401. }
  402. def get_sparse_bslongformer_config(sparsity):
  403. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  404. different_layout_per_head = get_scalar_param(
  405. sparsity,
  406. SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
  407. SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
  408. )
  409. num_sliding_window_blocks = get_scalar_param(
  410. sparsity,
  411. SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
  412. SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT,
  413. )
  414. global_block_indices = get_scalar_param(sparsity,
  415. SPARSE_GLOBAL_BLOCK_INDICES,
  416. SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
  417. global_block_end_indices = get_scalar_param(
  418. sparsity,
  419. SPARSE_GLOBAL_BLOCK_END_INDICES,
  420. SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT,
  421. )
  422. return {
  423. SPARSE_MODE: SPARSE_BSLONGFORMER_MODE,
  424. SPARSE_BLOCK: block,
  425. SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
  426. SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
  427. SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
  428. SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices,
  429. }
  430. def get_sparse_attention_mode(param_dict):
  431. if SPARSE_MODE in param_dict.keys():
  432. return param_dict[SPARSE_MODE]
  433. else:
  434. return SPARSE_MODE_DEFAULT
  435. def get_sparse_attention_type(param_dict):
  436. if SPARSE_ATTENTION_TYPE in param_dict.keys():
  437. return param_dict[SPARSE_ATTENTION_TYPE]
  438. else:
  439. return SPARSE_ATTENTION_TYPE_DEFAULT
  440. def get_pipeline_config(param_dict):
  441. """Parses pipeline engine configuration. """
  442. default_pipeline = {
  443. "stages": "auto",
  444. "partition": "best",
  445. "seed_layers": False,
  446. "activation_checkpoint_interval": 0,
  447. }
  448. config = default_pipeline
  449. for key, val in param_dict.get("pipeline", {}).items():
  450. config[key] = val
  451. return config
  452. def get_optimizer_name(param_dict):
  453. if OPTIMIZER in param_dict.keys() and TYPE in param_dict[OPTIMIZER].keys():
  454. return param_dict[OPTIMIZER][TYPE]
  455. else:
  456. return OPTIMIZER_TYPE_DEFAULT
  457. def get_optimizer_params(param_dict):
  458. if (get_optimizer_name(param_dict) is not None
  459. and OPTIMIZER_PARAMS in param_dict[OPTIMIZER].keys()):
  460. return param_dict[OPTIMIZER][OPTIMIZER_PARAMS]
  461. else:
  462. return None
  463. def get_optimizer_gradient_clipping(param_dict):
  464. optimizer_params = get_optimizer_params(param_dict)
  465. if optimizer_params is not None and MAX_GRAD_NORM in optimizer_params.keys():
  466. return optimizer_params[MAX_GRAD_NORM]
  467. else:
  468. return None
  469. def get_optimizer_legacy_fusion(param_dict):
  470. if OPTIMIZER in param_dict.keys() and LEGACY_FUSION in param_dict[OPTIMIZER].keys():
  471. return param_dict[OPTIMIZER][LEGACY_FUSION]
  472. else:
  473. return LEGACY_FUSION_DEFAULT
  474. def get_zero_allow_untested_optimizer(param_dict):
  475. return get_scalar_param(param_dict,
  476. ZERO_ALLOW_UNTESTED_OPTIMIZER,
  477. ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT)
  478. def get_scheduler_name(param_dict):
  479. if SCHEDULER in param_dict.keys() and TYPE in param_dict[SCHEDULER].keys():
  480. return param_dict[SCHEDULER][TYPE]
  481. else:
  482. return SCHEDULER_TYPE_DEFAULT
  483. def get_scheduler_params(param_dict):
  484. if (get_scheduler_name(param_dict) is not None
  485. and SCHEDULER_PARAMS in param_dict[SCHEDULER].keys()):
  486. return param_dict[SCHEDULER][SCHEDULER_PARAMS]
  487. else:
  488. return None
  489. def get_train_batch_size(param_dict):
  490. return get_scalar_param(param_dict, TRAIN_BATCH_SIZE, TRAIN_BATCH_SIZE_DEFAULT)
  491. def get_train_micro_batch_size_per_gpu(param_dict):
  492. return get_scalar_param(
  493. param_dict,
  494. TRAIN_MICRO_BATCH_SIZE_PER_GPU,
  495. TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT,
  496. )
  497. def get_wall_clock_breakdown(param_dict):
  498. return get_scalar_param(param_dict,
  499. WALL_CLOCK_BREAKDOWN,
  500. WALL_CLOCK_BREAKDOWN_DEFAULT)
  501. def get_memory_breakdown(param_dict):
  502. return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
  503. def get_tensorboard_enabled(param_dict):
  504. if TENSORBOARD in param_dict.keys():
  505. return get_scalar_param(param_dict[TENSORBOARD],
  506. TENSORBOARD_ENABLED,
  507. TENSORBOARD_ENABLED_DEFAULT)
  508. else:
  509. return False
  510. def get_eigenvalue_config(param_dict):
  511. if get_quantize_enabled(param_dict):
  512. param_dict = param_dict[QUANTIZE_TRAINING]
  513. return (
  514. get_eigenvalue_enabled(param_dict),
  515. get_eigenvalue_verbose(param_dict),
  516. get_eigenvalue_max_iter(param_dict),
  517. get_eigenvalue_tol(param_dict),
  518. get_eigenvalue_stability(param_dict),
  519. get_eigenvalue_gas_boundary_resolution(param_dict),
  520. get_eigenvalue_layer_name(param_dict),
  521. get_eigenvalue_layer_num(param_dict),
  522. )
  523. else:
  524. return (
  525. EIGENVALUE_ENABLED_DEFAULT,
  526. EIGENVALUE_VERBOSE_DEFAULT,
  527. EIGENVALUE_MAX_ITER_DEFAULT,
  528. EIGENVALUE_TOL_DEFAULT,
  529. EIGENVALUE_STABILITY_DEFAULT,
  530. EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT,
  531. EIGENVALUE_LAYER_NAME_DEFAULT,
  532. EIGENVALUE_LAYER_NUM_DEFAULT,
  533. )
  534. def get_eigenvalue_enabled(param_dict):
  535. if EIGENVALUE in param_dict.keys():
  536. return get_scalar_param(param_dict[EIGENVALUE],
  537. EIGENVALUE_ENABLED,
  538. EIGENVALUE_ENABLED_DEFAULT)
  539. else:
  540. return EIGENVALUE_ENABLED_DEFAULT
  541. def get_eigenvalue_verbose(param_dict):
  542. if EIGENVALUE in param_dict.keys():
  543. return get_scalar_param(param_dict[EIGENVALUE],
  544. EIGENVALUE_VERBOSE,
  545. EIGENVALUE_VERBOSE_DEFAULT)
  546. else:
  547. return EIGENVALUE_VERBOSE_DEFAULT
  548. def get_eigenvalue_max_iter(param_dict):
  549. if EIGENVALUE in param_dict.keys():
  550. return get_scalar_param(param_dict[EIGENVALUE],
  551. EIGENVALUE_MAX_ITER,
  552. EIGENVALUE_MAX_ITER_DEFAULT)
  553. else:
  554. return EIGENVALUE_MAX_ITER_DEFAULT
  555. def get_eigenvalue_tol(param_dict):
  556. if EIGENVALUE in param_dict.keys():
  557. return get_scalar_param(param_dict[EIGENVALUE],
  558. EIGENVALUE_TOL,
  559. EIGENVALUE_TOL_DEFAULT)
  560. else:
  561. return EIGENVALUE_TOL_DEFAULT
  562. def get_eigenvalue_stability(param_dict):
  563. if EIGENVALUE in param_dict.keys():
  564. return get_scalar_param(param_dict[EIGENVALUE],
  565. EIGENVALUE_STABILITY,
  566. EIGENVALUE_STABILITY_DEFAULT)
  567. else:
  568. return EIGENVALUE_STABILITY_DEFAULT
  569. def get_eigenvalue_gas_boundary_resolution(param_dict):
  570. if EIGENVALUE in param_dict.keys():
  571. return get_scalar_param(
  572. param_dict[EIGENVALUE],
  573. EIGENVALUE_GAS_BOUNDARY_RESOLUTION,
  574. EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT,
  575. )
  576. else:
  577. return EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT
  578. def get_eigenvalue_layer_name(param_dict):
  579. if EIGENVALUE in param_dict.keys():
  580. return get_scalar_param(param_dict[EIGENVALUE],
  581. EIGENVALUE_LAYER_NAME,
  582. EIGENVALUE_LAYER_NAME_DEFAULT)
  583. else:
  584. return EIGENVALUE_LAYER_NAME_DEFAULT
  585. def get_eigenvalue_layer_num(param_dict):
  586. if EIGENVALUE in param_dict.keys():
  587. return get_scalar_param(param_dict[EIGENVALUE],
  588. EIGENVALUE_LAYER_NUM,
  589. EIGENVALUE_LAYER_NUM_DEFAULT)
  590. else:
  591. return EIGENVALUE_LAYER_NUM_DEFAULT
  592. def get_tensorboard_output_path(param_dict):
  593. if get_tensorboard_enabled(param_dict):
  594. return get_scalar_param(
  595. param_dict[TENSORBOARD],
  596. TENSORBOARD_OUTPUT_PATH,
  597. TENSORBOARD_OUTPUT_PATH_DEFAULT,
  598. )
  599. else:
  600. return TENSORBOARD_OUTPUT_PATH_DEFAULT
  601. def get_tensorboard_job_name(param_dict):
  602. if get_tensorboard_enabled(param_dict):
  603. return get_scalar_param(param_dict[TENSORBOARD],
  604. TENSORBOARD_JOB_NAME,
  605. TENSORBOARD_JOB_NAME_DEFAULT)
  606. else:
  607. return TENSORBOARD_JOB_NAME_DEFAULT
  608. def get_checkpoint_params(param_dict):
  609. return param_dict.get(CHECKPOINT, {})
  610. def get_checkpoint_tag_validation_mode(checkpoint_params):
  611. tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION,
  612. CHECKPOINT_TAG_VALIDATION_DEFAULT)
  613. tag_validation_mode = tag_validation_mode.upper()
  614. if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES:
  615. return tag_validation_mode
  616. else:
  617. raise DeepSpeedConfigError(
  618. "Checkpoint config contains invalid tag_validation "
  619. f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}"
  620. )
  621. def get_dataloader_drop_last(param_dict):
  622. return get_scalar_param(param_dict,
  623. DATALOADER_DROP_LAST,
  624. DATALOADER_DROP_LAST_DEFAULT)
  625. '''Write deepspeed config files by modifying basic templates.
  626. Can be used for quickly changing parameters via command line parameters.'''
  627. class DeepSpeedConfigWriter:
  628. def __init__(self, data=None):
  629. self.data = data if data is not None else {}
  630. def add_config(self, key, value):
  631. self.data[key] = value
  632. def load_config(self, filename):
  633. self.data = json.load(open(filename,
  634. "r"),
  635. object_pairs_hook=dict_raise_error_on_duplicate_keys)
  636. def write_config(self, filename):
  637. with open(filename, "w") as outfile:
  638. json.dump(self.data, outfile)
  639. class DeepSpeedConfig(object):
  640. def __init__(self, config: Union[str, dict], mpu=None):
  641. super(DeepSpeedConfig, self).__init__()
  642. if isinstance(config, dict):
  643. self._param_dict = config
  644. elif os.path.exists(config):
  645. self._param_dict = json.load(
  646. open(config,
  647. "r"),
  648. object_pairs_hook=dict_raise_error_on_duplicate_keys)
  649. else:
  650. raise ValueError(
  651. f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}"
  652. )
  653. try:
  654. self.global_rank = torch.distributed.get_rank()
  655. if mpu is None:
  656. self.world_size = torch.distributed.get_world_size()
  657. else:
  658. self.world_size = mpu.get_data_parallel_world_size()
  659. except:
  660. self.global_rank = 0
  661. self.world_size = 1
  662. # If elastic-mode enabled, update compute + update _param_dict
  663. self.elasticity_enabled = elasticity_enabled(self._param_dict)
  664. if self.elasticity_enabled:
  665. logger.info("DeepSpeed elasticity support enabled")
  666. final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(
  667. ds_config=self._param_dict,
  668. target_deepspeed_version=__version__,
  669. world_size=self.world_size,
  670. )
  671. elastic_dict = self._param_dict[ELASTICITY]
  672. # Ensure the resource scheduler saw the same elastic config we are using at runtime
  673. ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict)
  674. ignore_non_elastic_batch_info = elastic_dict.get(
  675. IGNORE_NON_ELASTIC_BATCH_INFO,
  676. IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
  677. if not ignore_non_elastic_batch_info:
  678. batch_params = [
  679. TRAIN_BATCH_SIZE,
  680. TRAIN_MICRO_BATCH_SIZE_PER_GPU,
  681. GRADIENT_ACCUMULATION_STEPS,
  682. ]
  683. if any(map(lambda t: t in self._param_dict, batch_params)):
  684. raise ElasticityConfigError("One or more batch related parameters were found in your " \
  685. f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \
  686. f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \
  687. "elastic training is enabled, which takes control of these parameters. " \
  688. "If you want to suppress this error (the parameters will be silently ignored) " \
  689. f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.")
  690. # micro_bsz * world_size * gas = total_batch_size
  691. # gas = total_batch_size // (micro_bsz * world_size)
  692. gradient_accu_steps = final_batch_size // (micro_batch_size *
  693. self.world_size)
  694. if TRAIN_BATCH_SIZE in self._param_dict:
  695. logger.warning(
  696. "[Elasticity] overriding training_batch_size: "
  697. f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}")
  698. if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict:
  699. logger.warning(
  700. "[Elasticity] overriding train_micro_batch_size_per_gpu: "
  701. f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}"
  702. )
  703. if GRADIENT_ACCUMULATION_STEPS in self._param_dict:
  704. logger.warning(
  705. "[Elasticity] overriding gradient_accumulation_steps: "
  706. f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}"
  707. )
  708. logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}")
  709. self._param_dict[TRAIN_BATCH_SIZE] = final_batch_size
  710. self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = micro_batch_size
  711. self._param_dict[GRADIENT_ACCUMULATION_STEPS] = gradient_accu_steps
  712. self._initialize_params(self._param_dict)
  713. self._configure_train_batch_size()
  714. self._do_sanity_check()
  715. def _initialize_params(self, param_dict):
  716. self.train_batch_size = get_train_batch_size(param_dict)
  717. #print(f"beginning get_train_batch_size = {get_train_batch_size}")
  718. self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu(
  719. param_dict)
  720. self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict)
  721. self.steps_per_print = get_steps_per_print(param_dict)
  722. self.dump_state = get_dump_state(param_dict)
  723. self.disable_allgather = get_disable_allgather(param_dict)
  724. self.communication_data_type = get_communication_data_type(param_dict)
  725. self.prescale_gradients = get_prescale_gradients(param_dict)
  726. self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict)
  727. self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
  728. self.zero_config = DeepSpeedZeroConfig(param_dict)
  729. self.zero_optimization_stage = self.zero_config.stage
  730. self.zero_enabled = self.zero_optimization_stage > 0
  731. self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(
  732. param_dict)
  733. self.gradient_clipping = get_gradient_clipping(param_dict)
  734. self.fp16_enabled = get_fp16_enabled(param_dict)
  735. self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
  736. assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
  737. assert not (self.bfloat16_enabled and (self.zero_optimization_stage != 2)), 'bfloat16 mode is only enabled for Zero2 currently'
  738. self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(
  739. param_dict)
  740. self.amp_enabled = get_amp_enabled(param_dict)
  741. self.amp_params = get_amp_params(param_dict)
  742. self.loss_scale = get_loss_scale(param_dict)
  743. self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict)
  744. self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)
  745. self.quantize_training_enabled = get_quantize_enabled(param_dict)
  746. (
  747. self.quantize_target_bits,
  748. self.quantize_start_bits,
  749. self.quantize_period,
  750. self.quantize_offset,
  751. self.quantize_groups,
  752. self.fp16_mixed_quantize,
  753. self.quantize_change_rate,
  754. self.quantize_type,
  755. self.quantize_rounding,
  756. self.quantize_verbose,
  757. self.use_quantizer_kernel,
  758. ) = get_quantize_training(param_dict)
  759. self.optimizer_name = get_optimizer_name(param_dict)
  760. if (self.optimizer_name is not None
  761. and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS):
  762. self.optimizer_name = self.optimizer_name.lower()
  763. self.optimizer_params = get_optimizer_params(param_dict)
  764. self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict)
  765. self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer(
  766. param_dict)
  767. self.scheduler_name = get_scheduler_name(param_dict)
  768. self.scheduler_params = get_scheduler_params(param_dict)
  769. self.flops_profiler_config = DeepSpeedFlopsProfilerConfig(param_dict)
  770. self.wall_clock_breakdown = (get_wall_clock_breakdown(param_dict)
  771. | self.flops_profiler_config.enabled)
  772. self.memory_breakdown = get_memory_breakdown(param_dict)
  773. self.autotuning_config = DeepSpeedAutotuningConfig(param_dict)
  774. self.tensorboard_enabled = get_tensorboard_enabled(param_dict)
  775. self.tensorboard_output_path = get_tensorboard_output_path(param_dict)
  776. self.tensorboard_job_name = get_tensorboard_job_name(param_dict)
  777. (
  778. self.eigenvalue_enabled,
  779. self.eigenvalue_verbose,
  780. self.eigenvalue_max_iter,
  781. self.eigenvalue_tol,
  782. self.eigenvalue_stability,
  783. self.eigenvalue_gas_boundary_resolution,
  784. self.eigenvalue_layer_name,
  785. self.eigenvalue_layer_num,
  786. ) = get_eigenvalue_config(param_dict)
  787. self.sparse_attention = get_sparse_attention(param_dict)
  788. self.pipeline = get_pipeline_config(param_dict)
  789. self.pld_enabled = get_pld_enabled(param_dict)
  790. self.pld_params = get_pld_params(param_dict)
  791. self.curriculum_enabled = get_curriculum_enabled(param_dict)
  792. self.curriculum_params = get_curriculum_params(param_dict)
  793. checkpoint_params = get_checkpoint_params(param_dict)
  794. validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params)
  795. self.checkpoint_tag_validation_enabled = (validation_mode !=
  796. ValidationMode.IGNORE)
  797. self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL
  798. self.aio_config = get_aio_config(param_dict)
  799. self.dataloader_drop_last = get_dataloader_drop_last(param_dict)
  800. def _batch_assertion(self):
  801. train_batch = self.train_batch_size
  802. micro_batch = self.train_micro_batch_size_per_gpu
  803. grad_acc = self.gradient_accumulation_steps
  804. assert (
  805. train_batch > 0
  806. ), f"Train batch size: {train_batch} has to be greater than 0"
  807. assert (
  808. micro_batch > 0
  809. ), f"Micro batch size per gpu: {micro_batch} has to be greater than 0"
  810. assert (
  811. grad_acc > 0
  812. ), f"Gradient accumulation steps: {grad_acc} has to be greater than 0"
  813. assert train_batch == micro_batch * grad_acc * self.world_size, (
  814. f"Check batch related parameters. train_batch_size is not equal"
  815. " to micro_batch_per_gpu * gradient_acc_step * world_size"
  816. f"{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}"
  817. )
  818. def _set_batch_related_parameters(self):
  819. train_batch = self.train_batch_size
  820. micro_batch = self.train_micro_batch_size_per_gpu
  821. grad_acc = self.gradient_accumulation_steps
  822. #print(f"train_batch = {train_batch}, micro_batch={micro_batch}")
  823. # all values are provided nothing needs to be set
  824. if train_batch is not None and micro_batch is not None and grad_acc is not None:
  825. return
  826. # global_accumulation_steps needs to be set
  827. elif train_batch is not None and micro_batch is not None:
  828. grad_acc = train_batch // micro_batch
  829. grad_acc //= self.world_size
  830. self.gradient_accumulation_steps = grad_acc
  831. # micro_batch_per_gpu needs to be set
  832. elif train_batch is not None and grad_acc is not None:
  833. micro_batch = train_batch // self.world_size
  834. micro_batch //= grad_acc
  835. self.train_micro_batch_size_per_gpu = micro_batch
  836. # train_batch_size needs to be set
  837. elif micro_batch is not None and grad_acc is not None:
  838. train_batch_size = micro_batch * grad_acc
  839. train_batch_size *= self.world_size
  840. self.train_batch_size = train_batch_size
  841. # gradient_accumulation_steps and micro_batch_per_gpus is set
  842. elif train_batch is not None:
  843. self.gradient_accumulation_steps = 1
  844. self.train_micro_batch_size_per_gpu = train_batch // self.world_size
  845. # train_batch_size and gradient_accumulation_step is set
  846. elif micro_batch is not None:
  847. self.train_batch_size = micro_batch * self.world_size
  848. self.gradient_accumulation_steps = 1
  849. # either none of the three parameters are provided or just gradient_accumulation_step is provided
  850. else:
  851. assert False, \
  852. 'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided'
  853. def _configure_train_batch_size(self):
  854. self._set_batch_related_parameters()
  855. self._batch_assertion()
  856. def _do_sanity_check(self):
  857. self._do_error_check()
  858. self._do_warning_check()
  859. def print(self, name):
  860. logger.info("{}:".format(name))
  861. for arg in sorted(vars(self)):
  862. if arg != "_param_dict":
  863. dots = "." * (29 - len(arg))
  864. logger.info(" {} {} {}".format(arg, dots, getattr(self, arg)))
  865. logger.info(" json = {}".format(
  866. json.dumps(
  867. self._param_dict,
  868. sort_keys=True,
  869. indent=4,
  870. cls=ScientificNotationEncoder,
  871. separators=(",",
  872. ":"),
  873. )))
  874. def _do_error_check(self):
  875. assert (
  876. self.train_micro_batch_size_per_gpu
  877. ), "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
  878. assert (
  879. self.gradient_accumulation_steps
  880. ), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS)
  881. if self.zero_enabled:
  882. assert (
  883. self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION
  884. ), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(
  885. MAX_STAGE_ZERO_OPTIMIZATION
  886. )
  887. if self.fp16_master_weights_and_gradients:
  888. assert self.zero_enabled and self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now."
  889. def _do_warning_check(self):
  890. fp16_enabled = self.fp16_enabled
  891. vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT)
  892. if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0:
  893. logger.warning(
  894. "DeepSpeedConfig: vocabulary size {} is not aligned to {}, may import tensor core utilization."
  895. .format(vocabulary_size,
  896. TENSOR_CORE_ALIGN_SIZE))
  897. if (self.optimizer_params is not None
  898. and MAX_GRAD_NORM in self.optimizer_params.keys()
  899. and self.optimizer_params[MAX_GRAD_NORM] > 0):
  900. if fp16_enabled:
  901. if self.global_rank == 0:
  902. logger.warning(
  903. "DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper"
  904. .format(MAX_GRAD_NORM,
  905. self.optimizer_params[MAX_GRAD_NORM]))
  906. else:
  907. if self.global_rank == 0:
  908. logger.warning(
  909. "DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero"
  910. .format(self.optimizer_params[MAX_GRAD_NORM]))
  911. self.optimizer_params[MAX_GRAD_NORM] = 0.0