constants.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Various symbolic constants used for model checkpointing
  6. """
  7. #########################################
  8. # Optimizer checkpoint keys
  9. #########################################
  10. OPTIMIZER_STATE_DICT = "optimizer_state_dict"
  11. FP32_GROUPS = "fp32_groups"
  12. FP32_FLAT_GROUPS = 'fp32_flat_groups'
  13. BASE_OPTIMIZER_STATE = 'base_optimizer_state'
  14. BASE_OPTIMIZER_STATE_STEP = 'base_optimizer_state_step'
  15. SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
  16. PARAM_GROUPS = 'param_groups'
  17. GROUP_PADDINGS = 'group_paddings'
  18. PARTITION_COUNT = 'partition_count'
  19. ZERO_STAGE = 'zero_stage'
  20. CLIP_GRAD = 'clip_grad'
  21. FP32_WEIGHT_KEY = "fp32"
  22. LOSS_SCALER = 'loss_scaler'
  23. #########################################
  24. # Module checkpoint keys
  25. #########################################
  26. PARAM = 'param'
  27. PARAM_SHAPES = 'param_shapes'
  28. BUFFER_NAMES = 'buffer_names'
  29. FROZEN_PARAM_SHAPES = 'frozen_param_shapes'
  30. FROZEN_PARAM_FRAGMENTS = 'frozen_param_fragments'
  31. #########################################
  32. # Checkpoint naming constants
  33. #########################################
  34. MODEL_FILE_PREFIX = 'mp_rank_'
  35. ZERO_FILE_PREFIX = 'zero_pp_rank_'
  36. OPTIM_FILE_SUFFIX = '_optim_states.pt'
  37. MODEL_FILE_SUFFIX = '_model_states.pt'
  38. LAYER_FILE_PREFIX = 'layer_'
  39. BF16_ZERO_FILE_PREFIX = 'bf16_' + ZERO_FILE_PREFIX
  40. FP16_ZERO_FILE_PREFIX = 'fp16_' + ZERO_FILE_PREFIX
  41. #########################################
  42. # Checkpoint utility keys
  43. #########################################
  44. DS_VERSION = 'ds_version'
  45. #########################################
  46. # Universal Checkpoint keys
  47. #########################################
  48. UNIVERSAL_CHECKPOINT_INFO = 'universal_checkpoint_info'
  49. UNIVERSAL_CHECKPOINT_VERSION_KEY = 'universal_checkpoint_version'
  50. # Reserve version 0.1 for the hardcoded logic used in BLOOM-176B training
  51. UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2
  52. # Vocabulary padding
  53. VOCAB_TENSOR = 'vocab_tensor'
  54. PADDED_VOCAB_SIZE = 'padded_vocab_size'
  55. ORIGINAL_VOCAB_SIZE = 'original_vocab_size'
  56. # Parameter splitting/merging
  57. PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
  58. CAT_DIM = "cat_dim"
  59. # Following is a special case where a parameter effectively contains sub parameters.
  60. # As an example, consider Megatron-DeepSpeed GPT SWIGLU implementation (mlp.h_to_4h).
  61. # In this case, a single parameter ia allocated contiguously, but used as separate parameters.
  62. # When using universal checkpoint, we have to normalize the representation of the full parameter.
  63. # We normalize it by concatenating all slices of the sub params and then concatenating the sub params.
  64. # All concat operations are done on CAT_DIM (currently, no support for different concat dims sub params and TP slicing).
  65. # Similarly, load_hp_checkpoint_state has to take the needed actions when loading from universal.
  66. PARAM_N_SUB_PARAMS = "param_n_sub_params"
  67. SUB_PARAM_SHAPE = "sub_param_shape"
  68. # Regex list of parameters that require special handling
  69. VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
  70. PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
  71. PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
  72. PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
  73. TP_REPLICATED_PARAMETER_PATTERNS = 'tp_replicated_parameter_patterns'
  74. PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0 = 'parameter_with_2_sub_params_cat_dim_0'
  75. PARAMETER_WITH_SUB_PARAMS = 'parameter_with_sub_params'
  76. SUB_PARAMS_SHAPE = 'sub_params_shape'