zero_checkpoint.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import torch
  3. from .constants import (BASE_OPTIMIZER_STATE,
  4. GROUP_PADDINGS,
  5. OPTIMIZER_STATE_DICT,
  6. PARTITION_COUNT)
  7. from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state)
  8. from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)
  9. GROUP_STATE_KEY = 'state'
  10. class ZeROCheckpoint(object):
  11. def __init__(self, dir):
  12. basic_folder_validation(dir)
  13. self.dir = dir
  14. self.file_list = get_zero_files(dir)
  15. self.num_files = len(self.file_list)
  16. assert self.num_files > 0, f'No ZeRO files found in {dir}'
  17. self.src_3d = get_model_3d_descriptor(dir)
  18. self.target_3d = model_3d_desc(pp_degree=self.src_3d.pp_degree,
  19. tp_degree=self.src_3d.tp_degree,
  20. dp_degree=self.src_3d.dp_degree)
  21. self._3d_file_map = self.src_3d.reshape(self.target_3d)
  22. def get_src_world_size(self):
  23. return self.src_3d.world_size()
  24. def get_src_tp_degree(self):
  25. return self.src_3d.tp_degree
  26. def get_src_pp_degree(self):
  27. return self.src_3d.pp_degree
  28. def get_src_dp_degree(self):
  29. return self.src_3d.dp_degree
  30. def get_file_indices_for_rank(self, pp_index, tp_index, dp_index):
  31. assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}'
  32. dp_2d_map = self._3d_file_map[dp_index]
  33. return dp_2d_map.get_data(pp_index, tp_index)
  34. def get_files_for_rank(self, pp_index, tp_index, dp_index):
  35. file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index)
  36. return [self.file_list[idx] for idx in file_idx_list]
  37. def get_state_for_rank(self,
  38. pp_index,
  39. tp_index,
  40. dp_index,
  41. keys_to_ignore=[],
  42. strip_tensor_paddings=True):
  43. state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
  44. merged_sd = None
  45. for state_file in state_file_list:
  46. sd = torch.load(state_file, map_location=torch.device('cpu'))
  47. for key in keys_to_ignore:
  48. sd.pop(key, None)
  49. if strip_tensor_paddings:
  50. self._strip_tensor_paddings(sd)
  51. if merged_sd is None:
  52. merged_sd = sd
  53. else:
  54. merged_sd = merge_state(merged_sd, sd)
  55. self._update_partition_count(merged_sd)
  56. if strip_tensor_paddings:
  57. self._clear_group_paddings(merged_sd)
  58. return merged_sd
  59. def print_3d_index_map(self, tag=None):
  60. if tag:
  61. print(f'3D index map: {tag}')
  62. for dp_index, _2d_map in enumerate(self._3d_file_map):
  63. _2d_map.print_data(f'dp = {dp_index}')
  64. def print_3d_file_map(self, tag=None):
  65. if tag:
  66. print(f'3D file map: {tag}')
  67. for dp_index, _2d_map in enumerate(self._3d_file_map):
  68. for pp_index in _2d_map.pp_degree:
  69. for tp_index in _2d_map.tp_degree:
  70. file_index_list = _2d_map.get_data(pp_index, tp_index)
  71. file_list = [self.file_list[idx] for idx in file_index_list]
  72. print(f'{pp_index}, {tp_index}, {dp_index} => {file_list}')
  73. def reshape(self, target_3d_desc: model_3d_desc):
  74. self.target_3d = target_3d_desc
  75. self._3d_file_map = self.src_3d.reshape(self.target_3d)
  76. def _strip_tensor_paddings(self, sd):
  77. param_group_states = self._get_param_group_states(sd)
  78. if param_group_states is None:
  79. return
  80. group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
  81. if group_paddings is None:
  82. return
  83. for key, group_state in param_group_states.items():
  84. if group_paddings[key] == 0:
  85. continue
  86. for state_name, state_value in group_state.items():
  87. if torch.is_tensor(state_value):
  88. raw_length = state_value.numel() - group_paddings[key]
  89. group_state[state_name] = torch.narrow(state_value,
  90. 0,
  91. 0,
  92. raw_length).clone()
  93. def _clear_group_paddings(self, sd):
  94. group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
  95. if group_paddings:
  96. num_groups = len(group_paddings)
  97. sd[OPTIMIZER_STATE_DICT][GROUP_PADDINGS] = [0] * num_groups
  98. def _get_optimizer_state(self, sd, state_key):
  99. optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None)
  100. if optimizer_state is None:
  101. return None
  102. return optimizer_state.get(state_key, None)
  103. def _get_param_group_states(self, sd):
  104. optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None)
  105. if optimizer_state is None:
  106. return None
  107. base_optimizer_state = optimizer_state.get(BASE_OPTIMIZER_STATE, None)
  108. if base_optimizer_state is None:
  109. return None
  110. return base_optimizer_state.get(GROUP_STATE_KEY, None)
  111. def _update_partition_count(self, sd):
  112. partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT)
  113. if partition_counts:
  114. num_groups = len(partition_counts)
  115. sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree
  116. ] * num_groups