config.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import sys
  5. from typing import Optional
  6. from enum import Enum
  7. from deepspeed.pydantic_v1 import Field, validator, root_validator
  8. from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
  9. from deepspeed.utils import logger
  10. from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
  11. # ZeRO optimization. By default, this optimization is not enabled.
  12. # Users have to configure the desired optimization (0 means disabled) in params.json as below example:
  13. ZERO_FORMAT = """
  14. ZeRO optimization should be enabled as:
  15. "session_params": {
  16. "zero_optimization": {
  17. "stage": [0|1|2],
  18. "stage3_max_live_parameters" : 1000000000,
  19. "stage3_max_reuse_distance" : 1000000000,
  20. "allgather_partitions": [true|false],
  21. "use_multi_rank_bucket_allreduce": [true|false],
  22. "allgather_bucket_size": 500000000,
  23. "reduce_scatter": [true|false],
  24. "contiguous_gradients" : [true|false]
  25. "overlap_comm": [true|false],
  26. "reduce_bucket_size": 500000000,
  27. "load_from_fp32_weights": [true|false],
  28. "cpu_offload": [true|false] (deprecated),
  29. "cpu_offload_params" : [true|false] (deprecated),
  30. "cpu_offload_use_pin_memory": [true|false] (deprecated),
  31. "sub_group_size" : 1000000000000,
  32. "offload_param": {...},
  33. "offload_optimizer": {...},
  34. "ignore_unused_parameters": [true|false],
  35. "round_robin_gradients": [true|false],
  36. "zero_hpz_partition_size": 1,
  37. "zero_quantized_weights": [true|false],
  38. "zero_quantized_nontrainable_weights": [true|false],
  39. "zero_quantized_gradients": [true|false],
  40. "memory_efficient_linear": [true|false],
  41. "override_module_apply": [true|false],
  42. }
  43. }
  44. """
  45. ZERO_OPTIMIZATION = "zero_optimization"
  46. def read_zero_config_deprecated(param_dict):
  47. zero_config_dict = {}
  48. zero_config_dict["stage"] = 1 if param_dict[ZERO_OPTIMIZATION] else 0
  49. if zero_config_dict["stage"] > 0:
  50. zero_config_dict["allgather_bucket_size"] = get_scalar_param(param_dict, "allgather_size", 5e8)
  51. logger.warning(
  52. "DeepSpeedConfig: this format of ZeRO optimization setup is deprecated. Please use the following format: {}".
  53. format(ZERO_FORMAT))
  54. return zero_config_dict
  55. def get_zero_config(param_dict):
  56. if ZERO_OPTIMIZATION in param_dict:
  57. zero_config_dict = param_dict[ZERO_OPTIMIZATION]
  58. if isinstance(zero_config_dict, bool):
  59. zero_config_dict = read_zero_config_deprecated(param_dict)
  60. else:
  61. zero_config_dict = {}
  62. return DeepSpeedZeroConfig(**zero_config_dict)
  63. class ZeroStageEnum(int, Enum):
  64. """ Enum class for possible zero stages """
  65. disabled = 0
  66. optimizer_states = 1
  67. gradients = 2
  68. weights = 3
  69. max_stage = 3
  70. class DeepSpeedZeroConfig(DeepSpeedConfigModel):
  71. """
  72. Sets parameters for ZeRO optimizations.
  73. """
  74. stage: ZeroStageEnum = 0
  75. """
  76. Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer
  77. to disabled, optimizer state partitioning, and optimizer+gradient state
  78. partitioning, and optimizer+gradient+parameter partitioning, respectively.
  79. """
  80. contiguous_gradients: bool = True
  81. """
  82. Copies the gradients to a contiguous buffer as they are produced. Avoids
  83. memory fragmentation during backward pass.
  84. """
  85. reduce_scatter: bool = True
  86. """
  87. Uses reduce or reduce scatter instead of allreduce to average gradients
  88. """
  89. reduce_bucket_size: int = Field(pp_int(5e8), ge=0)
  90. """
  91. Number of elements reduced/allreduced at a time. Limits the memory required
  92. for the allgather for large model sizes
  93. """
  94. use_multi_rank_bucket_allreduce: bool = True
  95. """
  96. Combine the reduce buckets of the different ranks and do an All-Reduce instead of multiple Reduce ops.
  97. This feature is useful when the model is small and we want to scale it on too many GPUs which therefore
  98. reduces the message sizes of each packet.
  99. """
  100. allgather_partitions: bool = True
  101. """
  102. Chooses between allgather collective or a series of broadcast collectives
  103. to gather updated parameters from all the GPUs at the end of each step
  104. """
  105. allgather_bucket_size: int = Field(pp_int(5e8), ge=0)
  106. """
  107. Number of elements allgathered at a time. Limits the memory required for
  108. the allgather for large model sizes
  109. """
  110. overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below)
  111. """
  112. Attempts to overlap the reduction of the gradients with backward computation
  113. """
  114. load_from_fp32_weights: bool = True
  115. """
  116. Boolean indicating whether to initialize fp32 master weights from fp32
  117. copies in checkpoint (no precision loss) or from model's fp16 copies (with
  118. precision loss). This can be used to initialize optimizer state even when
  119. checkpoint is missing optimizer state.
  120. """
  121. elastic_checkpoint: bool = False
  122. """
  123. Enable loading checkpoint that was saved by job with different GPU count.
  124. No longer supported.
  125. """
  126. offload_param: Optional[DeepSpeedZeroOffloadParamConfig] = None
  127. """
  128. Enable offloading of model parameters to CPU or NVMe. This frees up GPU
  129. memory for larger models or batch sizes. Valid only with stage 3. Expects a
  130. dictionary containing values for :any:`DeepSpeedZeroOffloadParamConfig`.
  131. """
  132. offload_optimizer: Optional[DeepSpeedZeroOffloadOptimizerConfig] = None
  133. """
  134. Enable offloading of optimizer state to CPU or NVMe, and optimizer
  135. computation to CPU. This frees up GPU memory for larger models or batch
  136. sizes. Valid for ZeRO stage 1, 2, 3. Expects a dictionary containing values
  137. for :any:`DeepSpeedZeroOffloadOptimizerConfig`.
  138. """
  139. sub_group_size: int = Field(pp_int(1e9), ge=0)
  140. """
  141. Tile size for parameter processing to fit massive models (with trillions of
  142. parameters). Used by ZeRO3-Offload and ZeRO-Infinity
  143. """
  144. cpu_offload_param: bool = Field(
  145. None,
  146. deprecated=True,
  147. new_param="offload_param",
  148. new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None),
  149. )
  150. """ Deprecated, please use ``offload_param`` """
  151. cpu_offload_use_pin_memory: bool = Field(
  152. None,
  153. deprecated=True,
  154. new_param="offload_param or offload_optimizer",
  155. set_new_param=False,
  156. )
  157. """ Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
  158. cpu_offload: bool = Field(
  159. None,
  160. deprecated=True,
  161. new_param="offload_optimizer",
  162. new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None),
  163. )
  164. """ Deprecated, please use ``offload_optimizer`` """
  165. prefetch_bucket_size: int = Field(pp_int(5e7), ge=0, alias="stage3_prefetch_bucket_size")
  166. """
  167. Maximum number of parameter elements to fetch ahead of use. Used by ZeRO3,
  168. ZeRO3-Offload, ZeRO-Infinity, and ZeRO-Inference.
  169. """
  170. param_persistence_threshold: int = Field(pp_int(1e5), ge=0, alias="stage3_param_persistence_threshold")
  171. """
  172. Do not partition parameters smaller than this threshold. Smaller values use
  173. less memory, but can greatly increase communication (especially
  174. latency-bound messages).
  175. """
  176. model_persistence_threshold: int = Field(pp_int(sys.maxsize, "sys.maxsize"),
  177. ge=0,
  178. alias="stage3_model_persistence_threshold")
  179. """
  180. Maximum number of parameter elements that can be persisted in GPU and not
  181. partitioned. This imposes an upper bound on the number of unpartitioned
  182. parameters resulting from param_persistence_threshold setting. Used by
  183. ZeRO3-Offload, ZeRO-Infinity and ZeRO-Inference.
  184. """
  185. max_live_parameters: int = Field(pp_int(1e9), ge=0, alias="stage3_max_live_parameters")
  186. """
  187. The maximum number of parameters resident per GPU before releasing. Smaller
  188. values use less memory, but perform more communication.
  189. """
  190. max_reuse_distance: int = Field(pp_int(1e9), ge=0, alias="stage3_max_reuse_distance")
  191. """
  192. Do not release a parameter if it will be reused within this threshold of
  193. parameters. Smaller values use less memory, but perform more communication.
  194. """
  195. gather_16bit_weights_on_model_save: bool = Field(False, alias="stage3_gather_16bit_weights_on_model_save")
  196. """
  197. Consolidate the weights before saving the model by ``save_16bit_model()``.
  198. Since the weights are partitioned across GPUs, they aren’t part of
  199. ``state_dict``, so this function automatically gathers the weights when
  200. this option is enabled and then saves the fp16 model weights.
  201. """
  202. stage3_gather_fp16_weights_on_model_save: bool = Field(False,
  203. deprecated=True,
  204. new_param="gather_16bit_weights_on_model_save")
  205. """ Deprecated, please use ``gather_16bit_weights_on_model_save`` """
  206. ignore_unused_parameters: bool = True
  207. """
  208. Unused parameters in modules may be unexpected in static networks, but
  209. could be normal in dynamic networks. This controls whether or not training
  210. should terminate with an error message when unused parameters are detected.
  211. This is set to ``True`` by default, which means unused parameters are
  212. ignored and training continues. Now is just used in stage 2.
  213. """
  214. legacy_stage1: bool = False
  215. """
  216. For backward-compatibility enable old ZeRO stage 1 implementation. Use at
  217. your own risk, will be deprecated soon.
  218. """
  219. round_robin_gradients: bool = False
  220. """
  221. Stage 1 and 2 optimization for CPU offloading that parallelizes gradient
  222. copying to CPU memory among ranks by fine-grained gradient partitioning.
  223. Performance benefit grows with gradient accumulation steps (more copying
  224. between optimizer steps) or GPU count (increased parallelism).
  225. """
  226. zero_hpz_partition_size: int = Field(1, ge=0)
  227. """
  228. Number of ranks in zero parameters partitioning secondary group
  229. """
  230. zero_quantized_weights: bool = False
  231. """
  232. Boolean indicating whether to quantize zero parameters (weights)
  233. for efficient all_gather comm
  234. """
  235. zero_quantized_nontrainable_weights: bool = False
  236. """
  237. Boolean indicating whether to quantize non-trainable zero parameters (weights)
  238. for efficient memory usage and communication. Different from zero_quantized_weights
  239. that stores the weights in original precision and only perform quantization during communication,
  240. this flag will store the weights in quantized precision. This is useful for LoRA training.
  241. """
  242. zero_quantized_gradients: bool = False
  243. """
  244. Boolean indicating whether to use quantized zero gradients
  245. for efficient all_2_all_reduce comm
  246. """
  247. mics_shard_size: int = Field(-1, new_param="mics_shard_size")
  248. mics_hierarchical_params_gather: bool = False
  249. memory_efficient_linear: bool = True
  250. """
  251. Use memory efficient linear implementation, for Stage 3.
  252. """
  253. """
  254. Whether force load checkpoint in pipeline mode, current only for Stage 3.
  255. """
  256. pipeline_loading_checkpoint: bool = False
  257. override_module_apply: bool = True
  258. """
  259. Override nn.Module apply function, for Stage 3.
  260. """
  261. # Validators
  262. @validator("overlap_comm")
  263. def overlap_comm_valid(cls, field_value, values):
  264. if field_value is None:
  265. assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'"
  266. field_value = values["stage"] == ZeroStageEnum.weights
  267. return field_value
  268. @root_validator
  269. def offload_ratio_check(cls, values):
  270. offload_config = getattr(values, "offload_optimizer", {})
  271. if offload_config and offload_config.ratio < 1.0:
  272. assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
  273. return values