config.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. from typing import Union
  6. from enum import Enum
  7. import torch
  8. import json
  9. import hjson
  10. import copy
  11. import base64
  12. from .constants import *
  13. from .fp16.loss_scaler import (
  14. INITIAL_LOSS_SCALE,
  15. SCALE_WINDOW,
  16. DELAYED_SHIFT,
  17. CONSECUTIVE_HYSTERESIS,
  18. MIN_LOSS_SCALE,
  19. )
  20. from .config_utils import (
  21. get_scalar_param,
  22. dict_raise_error_on_duplicate_keys,
  23. ScientificNotationEncoder,
  24. )
  25. from .zero.config import get_zero_config, ZeroStageEnum
  26. from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
  27. from ..comm.config import DeepSpeedCommsConfig
  28. from ..monitor.config import get_monitor_config
  29. from deepspeed import comm as dist
  30. from deepspeed.runtime.config_utils import DeepSpeedConfigModel
  31. from ..git_version_info import version as __version__
  32. from ..utils import logger
  33. from ..elasticity import (
  34. elasticity_enabled,
  35. compute_elastic_config,
  36. ensure_immutable_elastic_config,
  37. )
  38. from ..elasticity.config import ElasticityConfigError
  39. from ..elasticity.constants import (
  40. ELASTICITY,
  41. IGNORE_NON_ELASTIC_BATCH_INFO,
  42. IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT,
  43. MODEL_PARALLEL_SIZE,
  44. MODEL_PARALLEL_SIZE_DEFAULT,
  45. NUM_GPUS_PER_NODE,
  46. NUM_GPUS_PER_NODE_DEFAULT,
  47. )
  48. from ..profiling.config import DeepSpeedFlopsProfilerConfig
  49. from ..autotuning.config import DeepSpeedAutotuningConfig
  50. from ..nebula.config import DeepSpeedNebulaConfig
  51. from ..compression.config import get_compression_config, get_quantize_enabled
  52. from ..compression.constants import *
  53. from .swap_tensor.aio_config import get_aio_config
  54. from .data_pipeline.config import get_data_efficiency_enabled, get_data_efficiency_config, get_curriculum_enabled_legacy, get_curriculum_params_legacy
  55. from .data_pipeline.constants import *
  56. TENSOR_CORE_ALIGN_SIZE = 8
  57. ADAGRAD_OPTIMIZER = 'adagrad'
  58. ADAM_OPTIMIZER = 'adam'
  59. ADAMW_OPTIMIZER = 'adamw'
  60. LAMB_OPTIMIZER = 'lamb'
  61. ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
  62. ZERO_ONE_ADAM_OPTIMIZER = 'zerooneadam'
  63. ONEBIT_LAMB_OPTIMIZER = 'onebitlamb'
  64. DEEPSPEED_OPTIMIZERS = [
  65. ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER,
  66. ZERO_ONE_ADAM_OPTIMIZER
  67. ]
  68. # extra optimizer parameters for adam/adamw
  69. TORCH_ADAM_PARAM = "torch_adam"
  70. # default to adamw logic for adam/adamw optimizers unless user explicitly opts out
  71. ADAM_W_MODE = "adam_w_mode"
  72. ADAM_W_MODE_DEFAULT = True
  73. class DeepSpeedConfigError(Exception):
  74. pass
  75. class DtypeEnum(Enum):
  76. # The torch dtype must always be the first value (so we return torch.dtype)
  77. fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
  78. fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
  79. int8 = torch.int8, "torch.int8", "int8"
  80. bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
  81. # Copied from https://stackoverflow.com/a/43210118
  82. # Allows us to use multiple values for each Enum index and returns first
  83. # listed value when Enum is called
  84. def __new__(cls, *values):
  85. obj = object.__new__(cls)
  86. # first value is canonical value
  87. obj._value_ = values[0]
  88. for other_value in values[1:]:
  89. cls._value2member_map_[other_value] = obj
  90. obj._all_values = values
  91. return obj
  92. def __repr__(self):
  93. return "<%s.%s: %s>" % (
  94. self.__class__.__name__,
  95. self._name_,
  96. ", ".join([repr(v) for v in self._all_values]),
  97. )
  98. def get_pld_enabled(param_dict):
  99. if PROGRESSIVE_LAYER_DROP in param_dict.keys():
  100. return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], PLD_ENABLED, PLD_ENABLED_DEFAULT)
  101. else:
  102. return False
  103. def get_pld_params(param_dict):
  104. if PROGRESSIVE_LAYER_DROP in param_dict.keys():
  105. pld_params = copy.copy(param_dict[PROGRESSIVE_LAYER_DROP])
  106. pld_params.pop(PLD_ENABLED)
  107. return pld_params
  108. else:
  109. return False
  110. def get_amp_enabled(param_dict):
  111. if AMP in param_dict.keys():
  112. return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT)
  113. else:
  114. return False
  115. def get_amp_params(param_dict):
  116. if AMP in param_dict.keys():
  117. amp_params = copy.copy(param_dict[AMP])
  118. amp_params.pop(AMP_ENABLED)
  119. return amp_params
  120. else:
  121. return False
  122. def get_fp16_enabled(param_dict):
  123. if FP16 in param_dict.keys():
  124. return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
  125. else:
  126. return False
  127. def get_bfloat16_enabled(param_dict):
  128. for key in [BFLOAT16, BFLOAT16_OLD]:
  129. if key in param_dict.keys():
  130. return get_scalar_param(param_dict[key], BFLOAT16_ENABLED, BFLOAT16_ENABLED_DEFAULT)
  131. return False
  132. def get_fp16_master_weights_and_grads_enabled(param_dict):
  133. if get_fp16_enabled(param_dict):
  134. return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT)
  135. else:
  136. return False
  137. def get_fp16_auto_cast(param_dict):
  138. if get_fp16_enabled(param_dict):
  139. return get_scalar_param(param_dict[FP16], FP16_AUTO_CAST, FP16_AUTO_CAST_DEFAULT)
  140. def get_loss_scale(param_dict):
  141. if get_fp16_enabled(param_dict):
  142. return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT)
  143. elif get_bfloat16_enabled(param_dict):
  144. return 1.0
  145. else:
  146. return FP16_LOSS_SCALE_DEFAULT
  147. def get_initial_dynamic_scale(param_dict):
  148. if get_fp16_enabled(param_dict):
  149. initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER,
  150. FP16_INITIAL_SCALE_POWER_DEFAULT)
  151. elif get_bfloat16_enabled(param_dict):
  152. initial_scale_power = 0
  153. else:
  154. initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT
  155. return 2**initial_scale_power
  156. def get_dynamic_loss_scale_args(param_dict):
  157. loss_scale_args = None
  158. if get_fp16_enabled(param_dict):
  159. fp16_dict = param_dict[FP16]
  160. dynamic_loss_args = [
  161. FP16_INITIAL_SCALE_POWER,
  162. FP16_LOSS_SCALE_WINDOW,
  163. FP16_MIN_LOSS_SCALE,
  164. FP16_HYSTERESIS,
  165. FP16_CONSECUTIVE_HYSTERESIS,
  166. ]
  167. if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args):
  168. init_scale = get_scalar_param(fp16_dict, FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT)
  169. scale_window = get_scalar_param(fp16_dict, FP16_LOSS_SCALE_WINDOW, FP16_LOSS_SCALE_WINDOW_DEFAULT)
  170. delayed_shift = get_scalar_param(fp16_dict, FP16_HYSTERESIS, FP16_HYSTERESIS_DEFAULT)
  171. consecutive_hysteresis = get_scalar_param(fp16_dict, FP16_CONSECUTIVE_HYSTERESIS,
  172. FP16_CONSECUTIVE_HYSTERESIS_DEFAULT)
  173. min_loss_scale = get_scalar_param(fp16_dict, FP16_MIN_LOSS_SCALE, FP16_MIN_LOSS_SCALE_DEFAULT)
  174. loss_scale_args = {
  175. INITIAL_LOSS_SCALE: 2**init_scale,
  176. SCALE_WINDOW: scale_window,
  177. DELAYED_SHIFT: delayed_shift,
  178. CONSECUTIVE_HYSTERESIS: consecutive_hysteresis,
  179. MIN_LOSS_SCALE: min_loss_scale,
  180. }
  181. return loss_scale_args
  182. def get_gradient_accumulation_steps(param_dict):
  183. return get_scalar_param(param_dict, GRADIENT_ACCUMULATION_STEPS, GRADIENT_ACCUMULATION_STEPS_DEFAULT)
  184. def get_sparse_gradients_enabled(param_dict):
  185. return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT)
  186. def get_communication_data_type(param_dict):
  187. val = get_scalar_param(param_dict, COMMUNICATION_DATA_TYPE, COMMUNICATION_DATA_TYPE_DEFAULT)
  188. val = val.lower() if val is not None else val
  189. if val is None:
  190. return val # we must determine it by other parameters
  191. elif val == "fp32":
  192. return torch.float32
  193. elif val == "fp16":
  194. return torch.float16
  195. elif val == "bfp16":
  196. return torch.bfloat16
  197. raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bfp16', 'fp32']. Got: {val}")
  198. def get_prescale_gradients(param_dict):
  199. return get_scalar_param(param_dict, PRESCALE_GRADIENTS, PRESCALE_GRADIENTS_DEFAULT)
  200. def get_gradient_predivide_factor(param_dict):
  201. return get_scalar_param(param_dict, GRADIENT_PREDIVIDE_FACTOR, GRADIENT_PREDIVIDE_FACTOR_DEFAULT)
  202. def get_steps_per_print(param_dict):
  203. return get_scalar_param(param_dict, STEPS_PER_PRINT, STEPS_PER_PRINT_DEFAULT)
  204. def get_disable_allgather(param_dict):
  205. return get_scalar_param(param_dict, DISABLE_ALLGATHER, DISABLE_ALLGATHER_DEFAULT)
  206. def get_dump_state(param_dict):
  207. return get_scalar_param(param_dict, DUMP_STATE, DUMP_STATE_DEFAULT)
  208. def get_gradient_clipping(param_dict):
  209. return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT)
  210. def get_sparse_attention(param_dict):
  211. if SPARSE_ATTENTION in param_dict.keys():
  212. sparsity = param_dict[SPARSE_ATTENTION]
  213. mode = get_sparse_attention_mode(sparsity)
  214. if mode == SPARSE_DENSE_MODE:
  215. return get_sparse_dense_config(sparsity)
  216. elif mode == SPARSE_FIXED_MODE:
  217. return get_sparse_fixed_config(sparsity)
  218. elif mode == SPARSE_VARIABLE_MODE:
  219. return get_sparse_variable_config(sparsity)
  220. elif mode == SPARSE_BIGBIRD_MODE:
  221. return get_sparse_bigbird_config(sparsity)
  222. elif mode == SPARSE_BSLONGFORMER_MODE:
  223. return get_sparse_bslongformer_config(sparsity)
  224. else:
  225. raise NotImplementedError(f"Given sparsity mode, {mode}, has not been implemented yet!")
  226. else:
  227. return None
  228. def get_sparse_dense_config(sparsity):
  229. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  230. return {SPARSE_MODE: SPARSE_DENSE_MODE, SPARSE_BLOCK: block}
  231. def get_sparse_fixed_config(sparsity):
  232. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  233. different_layout_per_head = get_scalar_param(
  234. sparsity,
  235. SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
  236. SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
  237. )
  238. num_local_blocks = get_scalar_param(sparsity, SPARSE_NUM_LOCAL_BLOCKS, SPARSE_NUM_LOCAL_BLOCKS_DEFAULT)
  239. num_global_blocks = get_scalar_param(sparsity, SPARSE_NUM_GLOBAL_BLOCKS, SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
  240. attention = get_scalar_param(sparsity, SPARSE_ATTENTION_TYPE, SPARSE_ATTENTION_TYPE_DEFAULT)
  241. horizontal_global_attention = get_scalar_param(
  242. sparsity,
  243. SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
  244. SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT,
  245. )
  246. num_different_global_patterns = get_scalar_param(
  247. sparsity,
  248. SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS,
  249. SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT,
  250. )
  251. return {
  252. SPARSE_MODE: SPARSE_FIXED_MODE,
  253. SPARSE_BLOCK: block,
  254. SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
  255. SPARSE_NUM_LOCAL_BLOCKS: num_local_blocks,
  256. SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks,
  257. SPARSE_ATTENTION_TYPE: attention,
  258. SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention,
  259. SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS: num_different_global_patterns,
  260. }
  261. def get_sparse_variable_config(sparsity):
  262. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  263. different_layout_per_head = get_scalar_param(
  264. sparsity,
  265. SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
  266. SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
  267. )
  268. num_random_blocks = get_scalar_param(sparsity, SPARSE_NUM_RANDOM_BLOCKS, SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
  269. local_window_blocks = get_scalar_param(sparsity, SPARSE_LOCAL_WINDOW_BLOCKS, SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT)
  270. global_block_indices = get_scalar_param(sparsity, SPARSE_GLOBAL_BLOCK_INDICES, SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
  271. global_block_end_indices = get_scalar_param(
  272. sparsity,
  273. SPARSE_GLOBAL_BLOCK_END_INDICES,
  274. SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT,
  275. )
  276. attention = get_scalar_param(sparsity, SPARSE_ATTENTION_TYPE, SPARSE_ATTENTION_TYPE_DEFAULT)
  277. horizontal_global_attention = get_scalar_param(
  278. sparsity,
  279. SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
  280. SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT,
  281. )
  282. return {
  283. SPARSE_MODE: SPARSE_VARIABLE_MODE,
  284. SPARSE_BLOCK: block,
  285. SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
  286. SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
  287. SPARSE_LOCAL_WINDOW_BLOCKS: local_window_blocks,
  288. SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
  289. SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices,
  290. SPARSE_ATTENTION_TYPE: attention,
  291. SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention,
  292. }
  293. def get_sparse_bigbird_config(sparsity):
  294. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  295. different_layout_per_head = get_scalar_param(
  296. sparsity,
  297. SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
  298. SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
  299. )
  300. num_random_blocks = get_scalar_param(sparsity, SPARSE_NUM_RANDOM_BLOCKS, SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
  301. num_sliding_window_blocks = get_scalar_param(
  302. sparsity,
  303. SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
  304. SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT,
  305. )
  306. num_global_blocks = get_scalar_param(sparsity, SPARSE_NUM_GLOBAL_BLOCKS, SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
  307. return {
  308. SPARSE_MODE: SPARSE_BIGBIRD_MODE,
  309. SPARSE_BLOCK: block,
  310. SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
  311. SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
  312. SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
  313. SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks,
  314. }
  315. def get_sparse_bslongformer_config(sparsity):
  316. block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
  317. different_layout_per_head = get_scalar_param(
  318. sparsity,
  319. SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
  320. SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
  321. )
  322. num_sliding_window_blocks = get_scalar_param(
  323. sparsity,
  324. SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
  325. SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT,
  326. )
  327. global_block_indices = get_scalar_param(sparsity, SPARSE_GLOBAL_BLOCK_INDICES, SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
  328. global_block_end_indices = get_scalar_param(
  329. sparsity,
  330. SPARSE_GLOBAL_BLOCK_END_INDICES,
  331. SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT,
  332. )
  333. return {
  334. SPARSE_MODE: SPARSE_BSLONGFORMER_MODE,
  335. SPARSE_BLOCK: block,
  336. SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
  337. SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
  338. SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
  339. SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices,
  340. }
  341. def get_sparse_attention_mode(param_dict):
  342. if SPARSE_MODE in param_dict.keys():
  343. return param_dict[SPARSE_MODE]
  344. else:
  345. return SPARSE_MODE_DEFAULT
  346. def get_sparse_attention_type(param_dict):
  347. if SPARSE_ATTENTION_TYPE in param_dict.keys():
  348. return param_dict[SPARSE_ATTENTION_TYPE]
  349. else:
  350. return SPARSE_ATTENTION_TYPE_DEFAULT
  351. def get_pipeline_config(param_dict):
  352. """Parses pipeline engine configuration. """
  353. default_pipeline = {
  354. "stages": "auto",
  355. "partition": "best",
  356. "seed_layers": False,
  357. "activation_checkpoint_interval": 0,
  358. }
  359. config = default_pipeline
  360. for key, val in param_dict.get("pipeline", {}).items():
  361. config[key] = val
  362. return config
  363. def get_optimizer_name(param_dict):
  364. if OPTIMIZER in param_dict.keys() and TYPE in param_dict[OPTIMIZER].keys():
  365. return param_dict[OPTIMIZER][TYPE]
  366. else:
  367. return OPTIMIZER_TYPE_DEFAULT
  368. def get_optimizer_params(param_dict):
  369. if (get_optimizer_name(param_dict) is not None and OPTIMIZER_PARAMS in param_dict[OPTIMIZER].keys()):
  370. return param_dict[OPTIMIZER][OPTIMIZER_PARAMS]
  371. else:
  372. return None
  373. def get_optimizer_gradient_clipping(param_dict):
  374. optimizer_params = get_optimizer_params(param_dict)
  375. if optimizer_params is not None and MAX_GRAD_NORM in optimizer_params.keys():
  376. return optimizer_params[MAX_GRAD_NORM]
  377. else:
  378. return None
  379. def get_optimizer_legacy_fusion(param_dict):
  380. if OPTIMIZER in param_dict.keys() and LEGACY_FUSION in param_dict[OPTIMIZER].keys():
  381. return param_dict[OPTIMIZER][LEGACY_FUSION]
  382. else:
  383. return LEGACY_FUSION_DEFAULT
  384. def get_zero_allow_untested_optimizer(param_dict):
  385. return get_scalar_param(param_dict, ZERO_ALLOW_UNTESTED_OPTIMIZER, ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT)
  386. def get_zero_force_ds_cpu_optimizer(param_dict):
  387. return get_scalar_param(param_dict, ZERO_FORCE_DS_CPU_OPTIMIZER, ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT)
  388. def get_scheduler_name(param_dict):
  389. if SCHEDULER in param_dict.keys() and TYPE in param_dict[SCHEDULER].keys():
  390. return param_dict[SCHEDULER][TYPE]
  391. else:
  392. return SCHEDULER_TYPE_DEFAULT
  393. def get_scheduler_params(param_dict):
  394. if (get_scheduler_name(param_dict) is not None and SCHEDULER_PARAMS in param_dict[SCHEDULER].keys()):
  395. return param_dict[SCHEDULER][SCHEDULER_PARAMS]
  396. else:
  397. return None
  398. def get_train_batch_size(param_dict):
  399. return get_scalar_param(param_dict, TRAIN_BATCH_SIZE, TRAIN_BATCH_SIZE_DEFAULT)
  400. def get_train_micro_batch_size_per_gpu(param_dict):
  401. return get_scalar_param(
  402. param_dict,
  403. TRAIN_MICRO_BATCH_SIZE_PER_GPU,
  404. TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT,
  405. )
  406. def get_wall_clock_breakdown(param_dict):
  407. return get_scalar_param(param_dict, WALL_CLOCK_BREAKDOWN, WALL_CLOCK_BREAKDOWN_DEFAULT)
  408. def get_memory_breakdown(param_dict):
  409. return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
  410. class HybridEngineConfig(DeepSpeedConfigModel):
  411. enabled: bool = False
  412. max_out_tokens: int = 512
  413. inference_tp_size: int = 1
  414. release_inference_cache: bool = False
  415. pin_parameters: bool = True
  416. tp_gather_partition_size: int = 8
  417. def get_hybrid_engine_config(param_dict):
  418. hybrid_engine_config_dict = param_dict.get("hybrid_engine", {})
  419. hybrid_engine_config = HybridEngineConfig(**hybrid_engine_config_dict)
  420. return hybrid_engine_config
  421. def get_eigenvalue_config(param_dict):
  422. if get_quantize_enabled(param_dict):
  423. param_dict = param_dict[QUANTIZE_TRAINING]
  424. assert not get_eigenvalue_enabled(param_dict), "Eigenvalue based MoQ is temporarily disabled"
  425. return (
  426. get_eigenvalue_enabled(param_dict),
  427. get_eigenvalue_verbose(param_dict),
  428. get_eigenvalue_max_iter(param_dict),
  429. get_eigenvalue_tol(param_dict),
  430. get_eigenvalue_stability(param_dict),
  431. get_eigenvalue_gas_boundary_resolution(param_dict),
  432. get_eigenvalue_layer_name(param_dict),
  433. get_eigenvalue_layer_num(param_dict),
  434. )
  435. else:
  436. return (
  437. EIGENVALUE_ENABLED_DEFAULT,
  438. EIGENVALUE_VERBOSE_DEFAULT,
  439. EIGENVALUE_MAX_ITER_DEFAULT,
  440. EIGENVALUE_TOL_DEFAULT,
  441. EIGENVALUE_STABILITY_DEFAULT,
  442. EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT,
  443. EIGENVALUE_LAYER_NAME_DEFAULT,
  444. EIGENVALUE_LAYER_NUM_DEFAULT,
  445. )
  446. def get_eigenvalue_enabled(param_dict):
  447. if EIGENVALUE in param_dict.keys():
  448. return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_ENABLED, EIGENVALUE_ENABLED_DEFAULT)
  449. else:
  450. return EIGENVALUE_ENABLED_DEFAULT
  451. def get_eigenvalue_verbose(param_dict):
  452. if EIGENVALUE in param_dict.keys():
  453. return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_VERBOSE, EIGENVALUE_VERBOSE_DEFAULT)
  454. else:
  455. return EIGENVALUE_VERBOSE_DEFAULT
  456. def get_eigenvalue_max_iter(param_dict):
  457. if EIGENVALUE in param_dict.keys():
  458. return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_MAX_ITER, EIGENVALUE_MAX_ITER_DEFAULT)
  459. else:
  460. return EIGENVALUE_MAX_ITER_DEFAULT
  461. def get_eigenvalue_tol(param_dict):
  462. if EIGENVALUE in param_dict.keys():
  463. return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_TOL, EIGENVALUE_TOL_DEFAULT)
  464. else:
  465. return EIGENVALUE_TOL_DEFAULT
  466. def get_eigenvalue_stability(param_dict):
  467. if EIGENVALUE in param_dict.keys():
  468. return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_STABILITY, EIGENVALUE_STABILITY_DEFAULT)
  469. else:
  470. return EIGENVALUE_STABILITY_DEFAULT
  471. def get_eigenvalue_gas_boundary_resolution(param_dict):
  472. if EIGENVALUE in param_dict.keys():
  473. return get_scalar_param(
  474. param_dict[EIGENVALUE],
  475. EIGENVALUE_GAS_BOUNDARY_RESOLUTION,
  476. EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT,
  477. )
  478. else:
  479. return EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT
  480. def get_eigenvalue_layer_name(param_dict):
  481. if EIGENVALUE in param_dict.keys():
  482. return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_LAYER_NAME, EIGENVALUE_LAYER_NAME_DEFAULT)
  483. else:
  484. return EIGENVALUE_LAYER_NAME_DEFAULT
  485. def get_eigenvalue_layer_num(param_dict):
  486. if EIGENVALUE in param_dict.keys():
  487. return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_LAYER_NUM, EIGENVALUE_LAYER_NUM_DEFAULT)
  488. else:
  489. return EIGENVALUE_LAYER_NUM_DEFAULT
  490. def get_checkpoint_params(param_dict):
  491. return param_dict.get(CHECKPOINT, {})
  492. def get_data_types_params(param_dict):
  493. return param_dict.get(DATA_TYPES, {})
  494. def get_checkpoint_tag_validation_mode(checkpoint_params):
  495. tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION, CHECKPOINT_TAG_VALIDATION_DEFAULT)
  496. tag_validation_mode = tag_validation_mode.upper()
  497. if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES:
  498. return tag_validation_mode
  499. else:
  500. raise DeepSpeedConfigError(
  501. "Checkpoint config contains invalid tag_validation "
  502. f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}")
  503. def get_checkpoint_parallel_write_pipeline(checkpoint_params):
  504. par_write_params = checkpoint_params.get(CHECKPOINT_PARALLEL_WRITE, {})
  505. par_write_pipeline = par_write_params.get(CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE,
  506. CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT)
  507. if par_write_pipeline in [True, False]:
  508. return par_write_pipeline
  509. else:
  510. raise DeepSpeedConfigError("checkpoint::parallel_write::pipeline_stage "
  511. f"value of '{par_write_pipeline}' is invalid, expecting: true or false")
  512. def get_dataloader_drop_last(param_dict):
  513. return get_scalar_param(param_dict, DATALOADER_DROP_LAST, DATALOADER_DROP_LAST_DEFAULT)
  514. '''Write deepspeed config files by modifying basic templates.
  515. Can be used for quickly changing parameters via command line parameters.'''
  516. class DeepSpeedConfigWriter:
  517. def __init__(self, data=None):
  518. self.data = data if data is not None else {}
  519. def add_config(self, key, value):
  520. self.data[key] = value
  521. def load_config(self, filename):
  522. self.data = json.load(open(filename, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
  523. def write_config(self, filename):
  524. with open(filename, "w") as outfile:
  525. json.dump(self.data, outfile)
  526. class DeepSpeedConfig(object):
  527. def __init__(self, config: Union[str, dict], mpu=None):
  528. super(DeepSpeedConfig, self).__init__()
  529. if isinstance(config, dict):
  530. self._param_dict = config
  531. elif os.path.exists(config):
  532. self._param_dict = hjson.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
  533. else:
  534. try:
  535. config_decoded = base64.urlsafe_b64decode(config).decode('utf-8')
  536. self._param_dict = hjson.loads(config_decoded)
  537. except (UnicodeDecodeError, AttributeError):
  538. raise ValueError(
  539. f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}"
  540. )
  541. try:
  542. self.global_rank = dist.get_rank()
  543. if mpu is None:
  544. self.world_size = dist.get_world_size()
  545. else:
  546. self.world_size = mpu.get_data_parallel_world_size()
  547. except:
  548. self.global_rank = 0
  549. self.world_size = 1
  550. # If elastic-mode enabled, update compute + update _param_dict
  551. self.elasticity_enabled = elasticity_enabled(self._param_dict)
  552. if self.elasticity_enabled:
  553. logger.info("DeepSpeed elasticity support enabled")
  554. final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(
  555. ds_config=self._param_dict,
  556. target_deepspeed_version=__version__,
  557. world_size=self.world_size,
  558. )
  559. elastic_dict = self._param_dict[ELASTICITY]
  560. # Ensure the resource scheduler saw the same elastic config we are using at runtime
  561. ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict)
  562. self.elastic_model_parallel_size = elastic_dict.get(MODEL_PARALLEL_SIZE, MODEL_PARALLEL_SIZE_DEFAULT)
  563. if self.elastic_model_parallel_size < 1:
  564. raise ElasticityConfigError("Model-Parallel size cannot be less than 1, "
  565. f"given model-parallel size: {self.elastic_model_parallel_size}")
  566. self.num_gpus_per_node = elastic_dict.get(NUM_GPUS_PER_NODE, NUM_GPUS_PER_NODE_DEFAULT)
  567. if self.num_gpus_per_node < 1:
  568. raise ElasticityConfigError("NUmber of GPUs per node cannot be less than 1, "
  569. f"given number of GPUs per node: {self.num_gpus_per_node}")
  570. ignore_non_elastic_batch_info = elastic_dict.get(IGNORE_NON_ELASTIC_BATCH_INFO,
  571. IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
  572. if not ignore_non_elastic_batch_info:
  573. batch_params = [
  574. TRAIN_BATCH_SIZE,
  575. TRAIN_MICRO_BATCH_SIZE_PER_GPU,
  576. GRADIENT_ACCUMULATION_STEPS,
  577. ]
  578. if any(map(lambda t: t in self._param_dict, batch_params)):
  579. raise ElasticityConfigError("One or more batch related parameters were found in your " \
  580. f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \
  581. f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \
  582. "elastic training is enabled, which takes control of these parameters. " \
  583. "If you want to suppress this error (the parameters will be silently ignored) " \
  584. f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.")
  585. # micro_bsz * world_size * gas = total_batch_size
  586. # gas = total_batch_size // (micro_bsz * world_size)
  587. gradient_accu_steps = final_batch_size // (micro_batch_size * self.world_size)
  588. if TRAIN_BATCH_SIZE in self._param_dict:
  589. logger.warning("[Elasticity] overriding training_batch_size: "
  590. f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}")
  591. if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict:
  592. logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: "
  593. f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}")
  594. if GRADIENT_ACCUMULATION_STEPS in self._param_dict:
  595. logger.warning("[Elasticity] overriding gradient_accumulation_steps: "
  596. f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}")
  597. logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}")
  598. self._param_dict[TRAIN_BATCH_SIZE] = final_batch_size
  599. self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = micro_batch_size
  600. self._param_dict[GRADIENT_ACCUMULATION_STEPS] = gradient_accu_steps
  601. # Pass a copy so that user json is unmodified, e.g. for logging
  602. self._initialize_params(copy.copy(self._param_dict))
  603. self._configure_train_batch_size()
  604. self._do_sanity_check()
  605. def _initialize_params(self, param_dict):
  606. self.train_batch_size = get_train_batch_size(param_dict)
  607. #print(f"beginning get_train_batch_size = {get_train_batch_size}")
  608. self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu(param_dict)
  609. self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict)
  610. self.steps_per_print = get_steps_per_print(param_dict)
  611. self.dump_state = get_dump_state(param_dict)
  612. self.disable_allgather = get_disable_allgather(param_dict)
  613. self.communication_data_type = get_communication_data_type(param_dict)
  614. self.prescale_gradients = get_prescale_gradients(param_dict)
  615. self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict)
  616. self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
  617. self.zero_config = get_zero_config(param_dict)
  618. self.mics_shard_size = self.zero_config.mics_shard_size
  619. self.mics_hierarchial_params_gather = self.zero_config.mics_hierarchical_params_gather
  620. self.zero_optimization_stage = self.zero_config.stage
  621. self.zero_enabled = self.zero_optimization_stage > 0
  622. self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(param_dict)
  623. self.comms_config = DeepSpeedCommsConfig(param_dict)
  624. self.monitor_config = get_monitor_config(param_dict)
  625. self.gradient_clipping = get_gradient_clipping(param_dict)
  626. self.fp16_enabled = get_fp16_enabled(param_dict)
  627. self.fp16_auto_cast = get_fp16_auto_cast(param_dict)
  628. self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
  629. assert not (self.fp16_enabled
  630. and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
  631. self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict)
  632. self.amp_enabled = get_amp_enabled(param_dict)
  633. self.amp_params = get_amp_params(param_dict)
  634. self.loss_scale = get_loss_scale(param_dict)
  635. self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict)
  636. self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)
  637. self.compression_config = get_compression_config(param_dict)
  638. self.optimizer_name = get_optimizer_name(param_dict)
  639. if (self.optimizer_name is not None and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS):
  640. self.optimizer_name = self.optimizer_name.lower()
  641. self.optimizer_params = get_optimizer_params(param_dict)
  642. self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict)
  643. self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer(param_dict)
  644. self.zero_force_ds_cpu_optimizer = get_zero_force_ds_cpu_optimizer(param_dict)
  645. self.scheduler_name = get_scheduler_name(param_dict)
  646. self.scheduler_params = get_scheduler_params(param_dict)
  647. self.flops_profiler_config = DeepSpeedFlopsProfilerConfig(param_dict)
  648. self.wall_clock_breakdown = (get_wall_clock_breakdown(param_dict) | self.flops_profiler_config.enabled)
  649. self.memory_breakdown = get_memory_breakdown(param_dict)
  650. self.autotuning_config = DeepSpeedAutotuningConfig(param_dict)
  651. (
  652. self.eigenvalue_enabled,
  653. self.eigenvalue_verbose,
  654. self.eigenvalue_max_iter,
  655. self.eigenvalue_tol,
  656. self.eigenvalue_stability,
  657. self.eigenvalue_gas_boundary_resolution,
  658. self.eigenvalue_layer_name,
  659. self.eigenvalue_layer_num,
  660. ) = get_eigenvalue_config(param_dict)
  661. self.hybrid_engine = get_hybrid_engine_config(param_dict)
  662. self.sparse_attention = get_sparse_attention(param_dict)
  663. self.pipeline = get_pipeline_config(param_dict)
  664. self.pld_enabled = get_pld_enabled(param_dict)
  665. self.pld_params = get_pld_params(param_dict)
  666. self.curriculum_enabled_legacy = get_curriculum_enabled_legacy(param_dict)
  667. self.curriculum_params_legacy = get_curriculum_params_legacy(param_dict)
  668. self.data_efficiency_enabled = get_data_efficiency_enabled(param_dict)
  669. self.data_efficiency_config = get_data_efficiency_config(param_dict)
  670. checkpoint_params = get_checkpoint_params(param_dict)
  671. validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params)
  672. self.checkpoint_tag_validation_enabled = (validation_mode != ValidationMode.IGNORE)
  673. self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL
  674. self.load_universal_checkpoint = checkpoint_params.get(LOAD_UNIVERSAL_CHECKPOINT,
  675. LOAD_UNIVERSAL_CHECKPOINT_DEFAULT)
  676. self.use_node_local_storage = checkpoint_params.get(USE_NODE_LOCAL_STORAGE_CHECKPOINT,
  677. USE_NODE_LOCAL_STORAGE_CHECKPOINT_DEFAULT)
  678. data_types_params = get_data_types_params(param_dict)
  679. self.grad_accum_dtype = data_types_params.get(GRAD_ACCUM_DTYPE, GRAD_ACCUM_DTYPE_DEFAULT)
  680. par_write_pipe = get_checkpoint_parallel_write_pipeline(checkpoint_params)
  681. self.checkpoint_parallel_write_pipeline = par_write_pipe
  682. self.aio_config = get_aio_config(param_dict)
  683. self.dataloader_drop_last = get_dataloader_drop_last(param_dict)
  684. self.nebula_config = DeepSpeedNebulaConfig(param_dict)
  685. def _batch_assertion(self):
  686. train_batch = self.train_batch_size
  687. micro_batch = self.train_micro_batch_size_per_gpu
  688. grad_acc = self.gradient_accumulation_steps
  689. assert (train_batch > 0), f"Train batch size: {train_batch} has to be greater than 0"
  690. assert (micro_batch > 0), f"Micro batch size per gpu: {micro_batch} has to be greater than 0"
  691. assert (grad_acc > 0), f"Gradient accumulation steps: {grad_acc} has to be greater than 0"
  692. assert train_batch == micro_batch * grad_acc * self.world_size, (
  693. f"Check batch related parameters. train_batch_size is not equal "
  694. "to micro_batch_per_gpu * gradient_acc_step * world_size "
  695. f"{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}")
  696. def _set_batch_related_parameters(self):
  697. train_batch = self.train_batch_size
  698. micro_batch = self.train_micro_batch_size_per_gpu
  699. grad_acc = self.gradient_accumulation_steps
  700. #print(f"train_batch = {train_batch}, micro_batch={micro_batch}")
  701. # all values are provided nothing needs to be set
  702. if train_batch is not None and micro_batch is not None and grad_acc is not None:
  703. return
  704. # global_accumulation_steps needs to be set
  705. elif train_batch is not None and micro_batch is not None:
  706. grad_acc = train_batch // micro_batch
  707. grad_acc //= self.world_size
  708. self.gradient_accumulation_steps = grad_acc
  709. # micro_batch_per_gpu needs to be set
  710. elif train_batch is not None and grad_acc is not None:
  711. micro_batch = train_batch // self.world_size
  712. micro_batch //= grad_acc
  713. self.train_micro_batch_size_per_gpu = micro_batch
  714. # train_batch_size needs to be set
  715. elif micro_batch is not None and grad_acc is not None:
  716. train_batch_size = micro_batch * grad_acc
  717. train_batch_size *= self.world_size
  718. self.train_batch_size = train_batch_size
  719. # gradient_accumulation_steps and micro_batch_per_gpus is set
  720. elif train_batch is not None:
  721. self.gradient_accumulation_steps = 1
  722. self.train_micro_batch_size_per_gpu = train_batch // self.world_size
  723. # train_batch_size and gradient_accumulation_step is set
  724. elif micro_batch is not None:
  725. self.train_batch_size = micro_batch * self.world_size
  726. self.gradient_accumulation_steps = 1
  727. # either none of the three parameters are provided or just gradient_accumulation_step is provided
  728. else:
  729. assert False, \
  730. 'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided'
  731. def _configure_train_batch_size(self):
  732. self._set_batch_related_parameters()
  733. self._batch_assertion()
  734. def _do_sanity_check(self):
  735. self._do_error_check()
  736. self._do_warning_check()
  737. def print_user_config(self):
  738. logger.info(" json = {}".format(
  739. json.dumps(
  740. self._param_dict,
  741. sort_keys=True,
  742. indent=4,
  743. cls=ScientificNotationEncoder,
  744. separators=(",", ":"),
  745. )))
  746. def print(self, name):
  747. logger.info("{}:".format(name))
  748. for arg in sorted(vars(self)):
  749. if arg != "_param_dict":
  750. dots = "." * (29 - len(arg))
  751. logger.info(" {} {} {}".format(arg, dots, getattr(self, arg)))
  752. self.print_user_config()
  753. def _do_error_check(self):
  754. assert (self.train_micro_batch_size_per_gpu
  755. ), "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
  756. assert (
  757. self.gradient_accumulation_steps), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS)
  758. if self.zero_enabled:
  759. assert (self.zero_optimization_stage <=
  760. ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(
  761. ZeroStageEnum.max_stage)
  762. if self.fp16_master_weights_and_gradients:
  763. assert self.zero_enabled and self.zero_optimization_stage == ZeroStageEnum.gradients, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now."
  764. def _do_warning_check(self):
  765. fp16_enabled = self.fp16_enabled
  766. vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT)
  767. if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0:
  768. logger.warning(
  769. "DeepSpeedConfig: vocabulary size {} is not aligned to {}, may import tensor core utilization.".format(
  770. vocabulary_size, TENSOR_CORE_ALIGN_SIZE))
  771. if (self.optimizer_params is not None and MAX_GRAD_NORM in self.optimizer_params.keys()
  772. and self.optimizer_params[MAX_GRAD_NORM] > 0):
  773. if fp16_enabled:
  774. if self.global_rank == 0:
  775. logger.warning("DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper".format(
  776. MAX_GRAD_NORM, self.optimizer_params[MAX_GRAD_NORM]))
  777. else:
  778. if self.global_rank == 0:
  779. logger.warning(
  780. "DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero"
  781. .format(self.optimizer_params[MAX_GRAD_NORM]))
  782. self.optimizer_params[MAX_GRAD_NORM] = 0.0