deepspeed_checkpoint.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import re
  6. from typing import Dict
  7. import torch
  8. from .reshape_3d_utils import model_3d_desc
  9. from .reshape_utils import (basic_folder_validation, merge_state, partition_data, get_files, get_files_with_prefix)
  10. from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
  11. from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map
  12. from .zero_checkpoint import ZeROCheckpoint
  13. from .constants import *
  14. EMBEDDING_LAYER_INDEX = 0
  15. FINAL_LAYER_NORM_INDEX = -1
  16. ARGS_KEY = 'args'
  17. CHECKPOINT_INFO_KEY = 'checkpoint_info'
  18. ITERATION_KEY = 'iteration'
  19. LAYER_FILE_PREFIX_PATTERN = r'layer_(\d+)-model_.*'
  20. SEQUENTIAL_LAYERS = [
  21. 'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight',
  22. 'post_attention_layernorm.bias', 'mlp.dense_4h_to_h.bias', 'position_embeddings.weight'
  23. ]
  24. LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1}
  25. class DeepSpeedCheckpoint(object):
  26. def __init__(self,
  27. dir,
  28. tp_degree=None,
  29. pp_degree=None,
  30. dp_degree=None,
  31. final_layer_norm_idx=FINAL_LAYER_NORM_INDEX):
  32. self.final_layer_norm_idx = final_layer_norm_idx
  33. self.dir = dir
  34. pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0
  35. self._validate_folder(dir, pipeline_parallel)
  36. self.zero_checkpoint = ZeROCheckpoint(dir)
  37. self.file_list = get_files(dir)
  38. self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX)
  39. self.mp_rank_files = get_files_with_prefix(self.file_list, MODEL_FILE_PREFIX)
  40. self.layer_keys = self._get_layer_keys()
  41. self.layer_count = len(self.layer_keys)
  42. self.tp_degree = self.zero_checkpoint.get_src_tp_degree() if tp_degree is None else tp_degree
  43. self.pp_degree = self.zero_checkpoint.get_src_pp_degree() if pp_degree is None else pp_degree
  44. self.dp_degree = self.zero_checkpoint.get_src_dp_degree() if dp_degree is None else dp_degree
  45. self.original_world_size = self.zero_checkpoint.get_src_tp_degree() * self.zero_checkpoint.get_src_pp_degree(
  46. ) * self.zero_checkpoint.get_src_dp_degree()
  47. self.world_size = self.tp_degree * self.pp_degree * self.dp_degree
  48. self.old_2d_map = meg_2d_parallel_map(self.zero_checkpoint.get_src_pp_degree(),
  49. self.zero_checkpoint.get_src_tp_degree())
  50. self.old_2d_map.simple_init()
  51. self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.zero_checkpoint.get_src_pp_degree(),
  52. old_tp_degree=self.zero_checkpoint.get_src_tp_degree(),
  53. new_pp_degree=self.pp_degree,
  54. new_tp_degree=self.tp_degree)
  55. if self.is_change_pp_degree() or self.is_change_tp_degree() or self.is_change_dp_degree():
  56. self.zero_checkpoint.reshape(model_3d_desc(self.pp_degree, self.tp_degree, self.dp_degree))
  57. self.global_state = {}
  58. self._sanity_check()
  59. self.pp_to_transformer_map = self._build_pp_transformer_map()
  60. self.transformer_file_map = self._build_transformer_file_map()
  61. self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
  62. self.tp_to_final_norm_map = self._build_tp_other_layer_map(self.final_layer_norm_idx)
  63. self._build_global_state()
  64. def is_change_tp_degree(self):
  65. return self.tp_degree != self.zero_checkpoint.get_src_tp_degree()
  66. def is_change_pp_degree(self):
  67. return self.pp_degree != self.zero_checkpoint.get_src_pp_degree()
  68. def is_change_dp_degree(self):
  69. return self.dp_degree != self.zero_checkpoint.get_src_dp_degree()
  70. def show_2d_mapping(self):
  71. print(f'reshaped 2d map ---- begin')
  72. for i in range(self.pp_degree):
  73. for j in range(self.tp_degree):
  74. file_list = self.get_2d_parallel_files(pp_index=i, tp_index=j)
  75. print(f'[{i}, {j}] = {file_list}')
  76. print(f'reshaped 2d map ---- end')
  77. def show_tp_embedding_map(self):
  78. self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers')
  79. def show_tp_final_norm_map(self):
  80. self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers')
  81. def show_pp_transformer_map(self):
  82. self._dump_mapping(self.pp_to_transformer_map, 'pp_to_transformer_layers')
  83. def show_transformer_file_map(self):
  84. self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files')
  85. def _build_global_state(self):
  86. sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
  87. self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
  88. self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)
  89. def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict:
  90. return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index,
  91. tp_index=tp_index,
  92. dp_index=dp_index,
  93. keys_to_ignore=[PARAM_SHAPES])
  94. def get_zero_files(self, pp_index, tp_index, dp_index) -> list:
  95. return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)
  96. def get_embedding_layer_id(self):
  97. return self.layer_keys[EMBEDDING_LAYER_INDEX]
  98. def get_final_norm_layer_id(self):
  99. return self.layer_keys[self.final_layer_norm_idx]
  100. def get_iteration(self):
  101. if not ITERATION_KEY in self.global_state:
  102. sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
  103. self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
  104. return self.global_state[ITERATION_KEY]
  105. def get_embedding_state(self, tp_index: int) -> Dict:
  106. assert tp_index in self.tp_to_embedding_map.keys()
  107. sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
  108. sd = self._merge_state_dicts(sd_list)
  109. return sd
  110. def get_embedding_files(self, tp_index: int) -> list:
  111. assert tp_index in self.tp_to_embedding_map.keys()
  112. return self.tp_to_embedding_map[tp_index]
  113. def _get_checkpoint_value(self, key):
  114. if not key in self.global_state:
  115. sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
  116. self.global_state[key] = sd.get(key, None)
  117. return self.global_state[key]
  118. def get_args(self):
  119. return self._get_checkpoint_value(ARGS_KEY)
  120. def get_checkpoint_info(self, info_key=CHECKPOINT_INFO_KEY):
  121. return self._get_checkpoint_value(info_key)
  122. def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict:
  123. assert tp_index < self.tp_degree
  124. assert pp_index < self.pp_degree
  125. fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
  126. sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
  127. merged_sd = None
  128. for sd in sd_list:
  129. if merged_sd is None:
  130. merged_sd = sd
  131. else:
  132. merged_sd = merge_state(merged_sd, sd)
  133. return merged_sd
  134. def get_transformer_state(self, tp_index: int, pp_index: int) -> list:
  135. assert tp_index < self.tp_degree
  136. assert pp_index < self.pp_degree
  137. t_list = []
  138. for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
  139. sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
  140. sd = self._merge_state_dicts(sd_list)
  141. t_list.append(sd)
  142. return t_list
  143. def get_pp_transformer_map(self, pp_index: int) -> list:
  144. assert pp_index < self.pp_degree
  145. return self.pp_to_transformer_map[pp_index]
  146. def get_final_norm_state(self, tp_index: int) -> Dict:
  147. assert tp_index in self.tp_to_final_norm_map.keys()
  148. sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
  149. return sd
  150. def get_final_norm_files(self, tp_index: int) -> list:
  151. assert tp_index in self.tp_to_final_norm_map.keys()
  152. return self.tp_to_final_norm_map[tp_index]
  153. def _build_tp_other_layer_map(self, layer_index: int):
  154. data_map = {}
  155. if len(self.layer_files) < 1:
  156. return data_map
  157. assert layer_index <= len(self.layer_files)
  158. layer_files = get_files_with_prefix(self.layer_files, self.layer_keys[layer_index])
  159. layer_file_partitions = partition_data(layer_files, self.tp_degree)
  160. data_map = {i: flist for i, flist in enumerate(layer_file_partitions)}
  161. return data_map
  162. def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list:
  163. assert tp_index < self.tp_degree
  164. assert pp_index < self.pp_degree
  165. file_indices = self.new_2d_map.get_data(pp_index=pp_index, tp_index=tp_index)
  166. return [self.mp_rank_files[i] for i in file_indices]
  167. def _build_pp_transformer_map(self):
  168. data_map = {}
  169. if self.pp_degree > 0:
  170. transformer_layers = self.layer_keys[1:self.final_layer_norm_idx]
  171. layers_per_pp = len(transformer_layers) // self.pp_degree
  172. data_map = {
  173. i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
  174. for i in range(0, self.pp_degree)
  175. }
  176. return data_map
  177. def _dump_mapping(self, data_map, map_tag=None):
  178. if map_tag is not None:
  179. print(f'Dump mapping: {map_tag}')
  180. for k, v in data_map.items():
  181. print(f'{k} = {v}')
  182. def _build_transformer_file_map(self):
  183. transformer_layer_keys = self.layer_keys[1:self.final_layer_norm_idx]
  184. file_map = {}
  185. # XXX: this is not guaranteed
  186. layers_per_pp = 1
  187. if self.pp_degree > 0:
  188. layers_per_pp = len(transformer_layer_keys) // self.pp_degree
  189. #print(f"{transformer_layer_keys} {layers_per_pp}")
  190. for key_index, layer_key in enumerate(transformer_layer_keys):
  191. pp_index = key_index // layers_per_pp
  192. layer_files = get_files_with_prefix(self.layer_files, layer_key + '-')
  193. layer_file_partitions = partition_data(layer_files, self.tp_degree)
  194. for tp_index in range(self.tp_degree):
  195. map_key = (tp_index, pp_index)
  196. if not map_key in file_map.keys():
  197. file_map[map_key] = []
  198. file_map[map_key].append(layer_file_partitions[tp_index])
  199. return file_map
  200. def _sanity_check(self):
  201. assert len(self.mp_rank_files) % self.tp_degree == 0
  202. assert self.zero_checkpoint.num_files % (self.pp_degree * self.tp_degree) == 0
  203. assert self.zero_checkpoint.num_files % (self.tp_degree) == 0
  204. # XXX: fix me - isn't always the case
  205. # only true with --pp-partition-method 'type:transformer|embedding' \
  206. # assert (len(self.layer_keys) - 2) % self.pp_degree == 0
  207. def validate_files(self):
  208. for file in self.file_list:
  209. if not os.path.isfile(file):
  210. print(f'Error: {file} is not existent')
  211. def _get_layer_keys(self):
  212. key_set = set()
  213. for file_path in self.layer_files:
  214. _, fname = os.path.split(file_path)
  215. layer_id = re.search(LAYER_FILE_PREFIX_PATTERN, fname).group(1)
  216. key_set.add(layer_id)
  217. sorted_ids = sorted(list(key_set), key=int)
  218. layer_keys = [LAYER_FILE_PREFIX + str(layer_id) for layer_id in sorted_ids]
  219. return layer_keys
  220. def _merge_state_dicts(self, sd_list):
  221. merged_sd = {}
  222. for key in sd_list[0].keys():
  223. if not key in SEQUENTIAL_LAYERS:
  224. cat_dim = LAYER_CONCAT_DIM.get(key, 0)
  225. merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
  226. else:
  227. merged_sd[key] = sd_list[0][key]
  228. return merged_sd
  229. def _validate_folder(self, dir, pipeline_parallel):
  230. basic_folder_validation(dir)
  231. file_list = get_files(dir)
  232. file_prefix_list = [MODEL_FILE_PREFIX]
  233. if pipeline_parallel:
  234. file_prefix_list.extend([LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01'])
  235. for file_prefix in file_prefix_list:
  236. ckpt_files = get_files_with_prefix(file_list, file_prefix)
  237. assert len(
  238. ckpt_files
  239. ) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.'