constants.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Various symbolic constants used for model checkpointing
  6. """
  7. #########################################
  8. # Optimizer checkpoint keys
  9. #########################################
  10. OPTIMIZER_STATE_DICT = "optimizer_state_dict"
  11. FP32_GROUPS = "fp32_groups"
  12. FP32_FLAT_GROUPS = 'fp32_flat_groups'
  13. BASE_OPTIMIZER_STATE = 'base_optimizer_state'
  14. SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
  15. GROUP_PADDINGS = 'group_paddings'
  16. PARTITION_COUNT = 'partition_count'
  17. ZERO_STAGE = 'zero_stage'
  18. CLIP_GRAD = 'clip_grad'
  19. FP32_WEIGHT_KEY = "fp32"
  20. #########################################
  21. # Module checkpoint keys
  22. #########################################
  23. PARAM = 'param'
  24. PARAM_SHAPES = 'param_shapes'
  25. BUFFER_NAMES = 'buffer_names'
  26. FROZEN_PARAM_SHAPES = 'frozen_param_shapes'
  27. FROZEN_PARAM_FRAGMENTS = 'frozen_param_fragments'
  28. #########################################
  29. # Checkpoint naming constants
  30. #########################################
  31. MODEL_FILE_PREFIX = 'mp_rank_'
  32. ZERO_FILE_PREFIX = 'zero_pp_rank_'
  33. OPTIM_FILE_SUFFIX = '_optim_states.pt'
  34. MODEL_FILE_SUFFIX = '_model_states.pt'
  35. LAYER_FILE_PREFIX = 'layer_'
  36. BF16_ZERO_FILE_PREFIX = 'bf16_' + ZERO_FILE_PREFIX
  37. FP16_ZERO_FILE_PREFIX = 'fp16_' + ZERO_FILE_PREFIX
  38. #########################################
  39. # Checkpoint utility keys
  40. #########################################
  41. DS_VERSION = 'ds_version'
  42. #########################################
  43. # Universal Checkpoint keys
  44. #########################################
  45. UNIVERSAL_CHECKPOINT_INFO = 'universal_checkpoint_info'
  46. UNIVERSAL_CHECKPOINT_VERSION_KEY = 'universal_checkpoint_version'
  47. # Reserve version 0.1 for the hardcoded logic used in BLOOM-176B training
  48. UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2
  49. # Vocabulary padding
  50. VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
  51. PADDED_VOCAB_SIZE = 'padded_vocab_size'
  52. ORIGINAL_VOCAB_SIZE = 'original_vocab_size'
  53. # Parameter splitting/merging
  54. PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
  55. CAT_DIM = "cat_dim"
  56. # Regex list of parameters that require special handling
  57. VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
  58. PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
  59. PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
  60. PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'