config.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. """
  2. Copyright (c) Microsoft Corporation
  3. Licensed under the MIT license.
  4. """
  5. from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
  6. from deepspeed.utils import logger
  7. from .constants import *
  8. from .offload_constants import *
  9. from .offload_config import get_offload_param_config, get_default_offload_param_config, \
  10. get_offload_optimizer_config, get_default_offload_optimizer_config
  11. class DeepSpeedZeroConfig(DeepSpeedConfigObject):
  12. def __init__(self, param_dict):
  13. super(DeepSpeedZeroConfig, self).__init__()
  14. self.stage = None
  15. self.contiguous_gradients = None
  16. self.reduce_scatter = None
  17. self.reduce_bucket_size = None
  18. self.allgather_partitions = None
  19. self.allgather_bucket_size = None
  20. self.overlap_comm = None
  21. self.load_from_fp32_weights = None
  22. self.elastic_checkpoint = None
  23. #Offload Specific Parameters
  24. self.offload_param = None
  25. self.offload_optimizer = None
  26. self.sub_group_size = None
  27. #Stage3 Specific Parameters
  28. self.prefetch_bucket_size = None
  29. self.param_persistence_threshold = None
  30. self.max_live_parameters = None
  31. self.max_reuse_distance = None
  32. self.gather_fp16_weights_on_model_save = None
  33. self.ignore_unused_parameters = None
  34. self.round_robin_gradients = None
  35. if ZERO_OPTIMIZATION in param_dict.keys():
  36. zero_config_dict = param_dict[ZERO_OPTIMIZATION]
  37. if type(zero_config_dict) is bool:
  38. zero_config_dict = self.read_zero_config_deprecated(param_dict)
  39. else:
  40. zero_config_dict = ZERO_OPTIMIZATION_DEFAULT
  41. self._initialize(zero_config_dict)
  42. def read_zero_config_deprecated(self, param_dict):
  43. zero_config_dict = {}
  44. zero_config_dict[
  45. ZERO_OPTIMIZATION_STAGE] = 1 if param_dict[ZERO_OPTIMIZATION] else 0
  46. if zero_config_dict[ZERO_OPTIMIZATION_STAGE] > 0:
  47. zero_config_dict[ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE] = get_scalar_param(
  48. param_dict,
  49. ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED,
  50. ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
  51. logger.warning(
  52. 'DeepSpeedConfig: this format of ZeRO optimization setup is deprecated. Please use the following format: {}'
  53. .format(ZERO_FORMAT))
  54. return zero_config_dict
  55. def _sanity_check(self, zero_config_dict):
  56. deprecated_dict = dict(
  57. ZERO_OPTIMIZATION_CPU_OFFLOAD=ZERO_OPTIMIZATION_OFFLOAD_OPTIMIZER,
  58. ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS=ZERO_OPTIMIZATION_OFFLOAD_PARAM,
  59. ZERO_OPTIMIZATION_CPU_OFFLOAD_USE_PIN_MEMORY=
  60. f'{ZERO_OPTIMIZATION_OFFLOAD_PARAM} or {ZERO_OPTIMIZATION_OFFLOAD_OPTIMIZER}'
  61. )
  62. for old_key, new_key in deprecated_dict.items():
  63. if old_key in zero_config_dict:
  64. logger.warning(
  65. f'DeepSpeedConfig: {old_key} is deprecated. Please use {new_key}.')
  66. def _initialize(self, zero_config_dict):
  67. self._sanity_check(zero_config_dict)
  68. self.stage = get_scalar_param(zero_config_dict,
  69. ZERO_OPTIMIZATION_STAGE,
  70. ZERO_OPTIMIZATION_STAGE_DEFAULT)
  71. self.contiguous_gradients = get_scalar_param(
  72. zero_config_dict,
  73. ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS,
  74. ZERO3_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT
  75. if self.stage == ZERO_OPTIMIZATION_WEIGHTS else
  76. ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT)
  77. self.reduce_bucket_size = get_scalar_param(
  78. zero_config_dict,
  79. ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE,
  80. ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT)
  81. self.reduce_scatter = get_scalar_param(zero_config_dict,
  82. ZERO_OPTIMIZATION_REDUCE_SCATTER,
  83. ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT)
  84. self.overlap_comm = get_scalar_param(
  85. zero_config_dict,
  86. ZERO_OPTIMIZATION_OVERLAP_COMM,
  87. ZERO3_OPTIMIZATION_OVERLAP_COMM_DEFAULT if self.stage
  88. == ZERO_OPTIMIZATION_WEIGHTS else ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT)
  89. self.allgather_partitions = get_scalar_param(
  90. zero_config_dict,
  91. ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS,
  92. ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT)
  93. self.allgather_bucket_size = get_scalar_param(
  94. zero_config_dict,
  95. ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
  96. ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
  97. self.load_from_fp32_weights = get_scalar_param(
  98. zero_config_dict,
  99. ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
  100. ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)
  101. self.elastic_checkpoint = get_scalar_param(
  102. zero_config_dict,
  103. ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT,
  104. ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT)
  105. if ZERO_OPTIMIZATION_CPU_OFFLOAD in zero_config_dict:
  106. cpu_offload_optimizer = get_scalar_param(
  107. zero_config_dict,
  108. ZERO_OPTIMIZATION_CPU_OFFLOAD,
  109. ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT)
  110. if cpu_offload_optimizer:
  111. self.offload_optimizer = get_default_offload_optimizer_config()
  112. else:
  113. self.offload_optimizer = get_offload_optimizer_config(zero_config_dict)
  114. if ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS in zero_config_dict:
  115. cpu_offload_params = get_scalar_param(
  116. zero_config_dict,
  117. ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS,
  118. ZERO_OPTIMIZATION_CPU_OFFLOAD_PARAMS_DEFAULT)
  119. if cpu_offload_params:
  120. self.offload_param = get_default_offload_param_config()
  121. else:
  122. self.offload_param = get_offload_param_config(zero_config_dict)
  123. self.sub_group_size = get_scalar_param(zero_config_dict,
  124. ZERO_OPTIMIZATION_SUB_GROUP_SIZE,
  125. ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT)
  126. self.max_live_parameters = get_scalar_param(
  127. zero_config_dict,
  128. ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS,
  129. ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS_DEFAULT)
  130. self.max_reuse_distance = get_scalar_param(
  131. zero_config_dict,
  132. ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE,
  133. ZERO_OPTIMIZATION_MAX_REUSE_DISTANCE_DEFAULT)
  134. self.prefetch_bucket_size = get_scalar_param(
  135. zero_config_dict,
  136. ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE,
  137. ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT)
  138. self.param_persistence_threshold = get_scalar_param(
  139. zero_config_dict,
  140. ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD,
  141. ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT)
  142. self.gather_fp16_weights_on_model_save = get_scalar_param(
  143. zero_config_dict,
  144. ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE,
  145. ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT)
  146. self.ignore_unused_parameters = get_scalar_param(
  147. zero_config_dict,
  148. ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS,
  149. ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT)
  150. self.legacy_stage1 = get_scalar_param(zero_config_dict,
  151. ZERO_OPTIMIZATION_LEGACY_STAGE1,
  152. ZERO_OPTIMIZATION_LEGACY_STAGE1_DEFAULT)
  153. self.round_robin_gradients = get_scalar_param(
  154. zero_config_dict,
  155. ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS,
  156. ZERO_OPTIMIZATION_ROUND_ROBIN_GRADIENTS_DEFAULT)