constants.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Various symbolic constants used for model checkpointing
  6. """
  7. #########################################
  8. # Optimizer checkpoint keys
  9. #########################################
  10. OPTIMIZER_STATE_DICT = "optimizer_state_dict"
  11. FP32_GROUPS = "fp32_groups"
  12. FP32_FLAT_GROUPS = 'fp32_flat_groups'
  13. BASE_OPTIMIZER_STATE = 'base_optimizer_state'
  14. BASE_OPTIMIZER_STATE_STEP = 'base_optimizer_state_step'
  15. SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
  16. GROUP_PADDINGS = 'group_paddings'
  17. PARTITION_COUNT = 'partition_count'
  18. ZERO_STAGE = 'zero_stage'
  19. CLIP_GRAD = 'clip_grad'
  20. FP32_WEIGHT_KEY = "fp32"
  21. LOSS_SCALER = 'loss_scaler'
  22. #########################################
  23. # Module checkpoint keys
  24. #########################################
  25. PARAM = 'param'
  26. PARAM_SHAPES = 'param_shapes'
  27. BUFFER_NAMES = 'buffer_names'
  28. FROZEN_PARAM_SHAPES = 'frozen_param_shapes'
  29. FROZEN_PARAM_FRAGMENTS = 'frozen_param_fragments'
  30. #########################################
  31. # Checkpoint naming constants
  32. #########################################
  33. MODEL_FILE_PREFIX = 'mp_rank_'
  34. ZERO_FILE_PREFIX = 'zero_pp_rank_'
  35. OPTIM_FILE_SUFFIX = '_optim_states.pt'
  36. MODEL_FILE_SUFFIX = '_model_states.pt'
  37. LAYER_FILE_PREFIX = 'layer_'
  38. BF16_ZERO_FILE_PREFIX = 'bf16_' + ZERO_FILE_PREFIX
  39. FP16_ZERO_FILE_PREFIX = 'fp16_' + ZERO_FILE_PREFIX
  40. #########################################
  41. # Checkpoint utility keys
  42. #########################################
  43. DS_VERSION = 'ds_version'
  44. #########################################
  45. # Universal Checkpoint keys
  46. #########################################
  47. UNIVERSAL_CHECKPOINT_INFO = 'universal_checkpoint_info'
  48. UNIVERSAL_CHECKPOINT_VERSION_KEY = 'universal_checkpoint_version'
  49. # Reserve version 0.1 for the hardcoded logic used in BLOOM-176B training
  50. UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2
  51. # Vocabulary padding
  52. VOCAB_TENSOR = 'vocab_tensor'
  53. PADDED_VOCAB_SIZE = 'padded_vocab_size'
  54. ORIGINAL_VOCAB_SIZE = 'original_vocab_size'
  55. # Parameter splitting/merging
  56. PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
  57. CAT_DIM = "cat_dim"
  58. # Following is a special case where a parameter effectively contains sub parameters.
  59. # As an example, consider Megatron-DeepSpeed GPT SWIGLU implementation (mlp.h_to_4h).
  60. # In this case, a single parameter ia allocated contiguously, but used as separate parameters.
  61. # When using universal checkpoint, we have to normalize the representation of the full parameter.
  62. # We normalize it by concatenating all slices of the sub params and then concatenating the sub params.
  63. # All concat operations are done on CAT_DIM (currently, no support for different concat dims sub params and TP slicing).
  64. # Similarly, load_hp_checkpoint_state has to take the needed actions when loading from universal.
  65. PARAM_N_SUB_PARAMS = "param_n_sub_params"
  66. # Regex list of parameters that require special handling
  67. VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
  68. PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
  69. PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
  70. PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
  71. TP_REPLICATED_PARAMETER_PATTERNS = 'tp_replicated_parameter_patterns'
  72. PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0 = 'parameter_with_2_sub_params_cat_dim_0'