config.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. """
  3. Copyright (c) Microsoft Corporation
  4. Licensed under the MIT license.
  5. """
  6. from pydantic import Field, validator
  7. import sys
  8. from typing import Optional
  9. from enum import Enum
  10. from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
  11. from deepspeed.utils import logger
  12. from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
  13. # ZeRO optimization. By default, this optimization is not enabled.
  14. # Users have to configure the desired optimization (0 means disabled) in params.json as below example:
  15. ZERO_FORMAT = """
  16. ZeRO optimization should be enabled as:
  17. "session_params": {
  18. "zero_optimization": {
  19. "stage": [0|1|2],
  20. "stage3_max_live_parameters" : 1000000000,
  21. "stage3_max_reuse_distance" : 1000000000,
  22. "allgather_partitions": [true|false],
  23. "allgather_bucket_size": 500000000,
  24. "reduce_scatter": [true|false],
  25. "contiguous_gradients" : [true|false]
  26. "overlap_comm": [true|false],
  27. "reduce_bucket_size": 500000000,
  28. "load_from_fp32_weights": [true|false],
  29. "cpu_offload": [true|false] (deprecated),
  30. "cpu_offload_params" : [true|false] (deprecated),
  31. "cpu_offload_use_pin_memory": [true|false] (deprecated),
  32. "sub_group_size" : 1000000000000,
  33. "offload_param": {...},
  34. "offload_optimizer": {...},
  35. "ignore_unused_parameters": [true|false],
  36. "round_robin_gradients": [true|false]
  37. }
  38. }
  39. """
  40. ZERO_OPTIMIZATION = "zero_optimization"
  41. def read_zero_config_deprecated(param_dict):
  42. zero_config_dict = {}
  43. zero_config_dict["stage"] = 1 if param_dict[ZERO_OPTIMIZATION] else 0
  44. if zero_config_dict["stage"] > 0:
  45. zero_config_dict["allgather_bucket_size"] = get_scalar_param(
  46. param_dict,
  47. "allgather_size",
  48. 5e8)
  49. logger.warning(
  50. "DeepSpeedConfig: this format of ZeRO optimization setup is deprecated. Please use the following format: {}"
  51. .format(ZERO_FORMAT))
  52. return zero_config_dict
  53. def get_zero_config(param_dict):
  54. if ZERO_OPTIMIZATION in param_dict:
  55. zero_config_dict = param_dict[ZERO_OPTIMIZATION]
  56. if isinstance(zero_config_dict, bool):
  57. zero_config_dict = read_zero_config_deprecated(param_dict)
  58. else:
  59. zero_config_dict = {}
  60. return DeepSpeedZeroConfig(**zero_config_dict)
  61. class ZeroStageEnum(int, Enum):
  62. """ Enum class for possible zero stages """
  63. disabled = 0
  64. optimizer_states = 1
  65. gradients = 2
  66. weights = 3
  67. max_stage = 3
  68. class DeepSpeedZeroConfig(DeepSpeedConfigModel):
  69. """
  70. Sets parameters for ZeRO optimizations.
  71. """
  72. stage: ZeroStageEnum = 0
  73. """
  74. Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer
  75. to disabled, optimizer state partitioning, and optimizer+gradient state
  76. partitioning, and optimizer+gradient+parameter partitioning, respectively.
  77. """
  78. contiguous_gradients: bool = True
  79. """
  80. Copies the gradients to a contiguous buffer as they are produced. Avoids
  81. memory fragmentation during backward pass.
  82. """
  83. reduce_scatter: bool = True
  84. """
  85. Uses reduce or reduce scatter instead of allreduce to average gradients
  86. """
  87. reduce_bucket_size: int = Field(pp_int(5e8), ge=0)
  88. """
  89. Number of elements reduced/allreduced at a time. Limits the memory required
  90. for the allgather for large model sizes
  91. """
  92. allgather_partitions: bool = True
  93. """
  94. Chooses between allgather collective or a series of broadcast collectives
  95. to gather updated parameters from all the GPUs at the end of each step
  96. """
  97. allgather_bucket_size: int = Field(pp_int(5e8), ge=0)
  98. """
  99. Number of elements allgathered at a time. Limits the memory required for
  100. the allgather for large model sizes
  101. """
  102. overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below)
  103. """
  104. Attempts to overlap the reduction of the gradients with backward computation
  105. """
  106. load_from_fp32_weights: bool = True
  107. """
  108. Boolean indicating whether to initialize fp32 master weights from fp32
  109. copies in checkpoint (no precision loss) or from model's fp16 copies (with
  110. precision loss). This can be used to initialize optimizer state even when
  111. checkpoint is missing optimizer state.
  112. """
  113. elastic_checkpoint: bool = False
  114. """
  115. Enable loading checkpoint that was saved by job with different GPU count.
  116. No longer supported.
  117. """
  118. offload_param: Optional[DeepSpeedZeroOffloadParamConfig] = None
  119. """
  120. Enable offloading of model parameters to CPU or NVMe. This frees up GPU
  121. memory for larger models or batch sizes. Valid only with stage 3. Expects a
  122. dictionary containing values for :any:`DeepSpeedZeroOffloadParamConfig`.
  123. """
  124. offload_optimizer: Optional[DeepSpeedZeroOffloadOptimizerConfig] = None
  125. """
  126. Enable offloading of optimizer state to CPU or NVMe, and optimizer
  127. computation to CPU. This frees up GPU memory for larger models or batch
  128. sizes. Valid for ZeRO stage 1, 2, 3. Expects a dictionary containing values
  129. for :any:`DeepSpeedZeroOffloadOptimizerConfig`.
  130. """
  131. sub_group_size: int = Field(pp_int(1e9), ge=0)
  132. """
  133. Tile size for parameter processing to fit massive models (with trillions of
  134. parameters). Used by ZeRO3-Offload and ZeRO-Infinity
  135. """
  136. cpu_offload_param: bool = Field(
  137. None,
  138. deprecated=True,
  139. new_param="offload_param",
  140. new_param_fn=(
  141. lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu)
  142. if val else None),
  143. )
  144. """ Deprecated, please use ``offload_param`` """
  145. cpu_offload_use_pin_memory: bool = Field(
  146. None,
  147. deprecated=True,
  148. new_param="offload_param or offload_optimizer",
  149. set_new_param=False,
  150. )
  151. """ Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
  152. cpu_offload: bool = Field(
  153. None,
  154. deprecated=True,
  155. new_param="offload_optimizer",
  156. new_param_fn=(
  157. lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu)
  158. if val else None),
  159. )
  160. """ Deprecated, please use ``offload_optimizer`` """
  161. prefetch_bucket_size: int = Field(pp_int(5e7),
  162. ge=0,
  163. alias="stage3_prefetch_bucket_size")
  164. """
  165. Maximum number of parameter elements to fetch ahead of use. Used by ZeRO3,
  166. ZeRO3-Offload, ZeRO-Infinity, and ZeRO-Inference.
  167. """
  168. param_persistence_threshold: int = Field(pp_int(1e5),
  169. ge=0,
  170. alias="stage3_param_persistence_threshold")
  171. """
  172. Do not partition parameters smaller than this threshold. Smaller values use
  173. less memory, but can greatly increase communication (especially
  174. latency-bound messages).
  175. """
  176. model_persistence_threshold: int = Field(pp_int(sys.maxsize,
  177. "sys.maxsize"),
  178. ge=0,
  179. alias="stage3_model_persistence_threshold")
  180. """
  181. Maximum number of parameter elements that can be persisted in GPU and not
  182. partitioned. This imposes an upper bound on the number of unpartitioned
  183. parameters resulting from param_persistence_threshold setting. Used by
  184. ZeRO3-Offload, ZeRO-Infinity and ZeRO-Inference.
  185. """
  186. max_live_parameters: int = Field(pp_int(1e9),
  187. ge=0,
  188. alias="stage3_max_live_parameters")
  189. """
  190. The maximum number of parameters resident per GPU before releasing. Smaller
  191. values use less memory, but perform more communication.
  192. """
  193. max_reuse_distance: int = Field(pp_int(1e9), ge=0, alias="stage3_max_reuse_distance")
  194. """
  195. Do not release a parameter if it will be reused within this threshold of
  196. parameters. Smaller values use less memory, but perform more communication.
  197. """
  198. gather_16bit_weights_on_model_save: bool = Field(
  199. False,
  200. alias="stage3_gather_16bit_weights_on_model_save")
  201. """
  202. Consolidate the weights before saving the model by ``save_16bit_model()``.
  203. Since the weights are partitioned across GPUs, they aren’t part of
  204. ``state_dict``, so this function automatically gathers the weights when
  205. this option is enabled and then saves the fp16 model weights.
  206. """
  207. stage3_gather_fp16_weights_on_model_save: bool = Field(
  208. False,
  209. deprecated=True,
  210. new_param="gather_16bit_weights_on_model_save")
  211. """ Deprecated, please use ``gather_16bit_weights_on_model_save`` """
  212. ignore_unused_parameters: bool = True
  213. """
  214. Unused parameters in modules may be unexpected in static networks, but
  215. could be normal in dynamic networks. This controls whether or not training
  216. should terminate with an error message when unused parameters are detected.
  217. This is set to ``False`` by default, which means unused parameters are
  218. ignored and training continues. Now is just used in stage 2.
  219. """
  220. legacy_stage1: bool = False
  221. """
  222. For backward-compatibility enable old ZeRO stage 1 implementation. Use at
  223. your own risk, will be deprecated soon.
  224. """
  225. round_robin_gradients: bool = False
  226. """
  227. Stage 1 and 2 optimization for CPU offloading that parallelizes gradient
  228. copying to CPU memory among ranks by fine-grained gradient partitioning.
  229. Performance benefit grows with gradient accumulation steps (more copying
  230. between optimizer steps) or GPU count (increased parallelism).
  231. """
  232. # Validators
  233. @validator("overlap_comm")
  234. def overlap_comm_valid(cls, field_value, values):
  235. if field_value is None:
  236. assert (
  237. "stage" in values
  238. ), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'"
  239. field_value = values["stage"] == ZeroStageEnum.weights
  240. return field_value