ds_to_universal.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. #!/usr/bin/env python
  2. # Copyright (c) Microsoft Corporation.
  3. # SPDX-License-Identifier: Apache-2.0
  4. # DeepSpeed Team
  5. from functools import partial
  6. import argparse
  7. import glob
  8. import itertools
  9. import multiprocessing
  10. import os
  11. import re
  12. import shutil
  13. import torch
  14. import tqdm
  15. #from pprint import pprint
  16. from deepspeed.checkpoint import DeepSpeedCheckpoint
  17. from deepspeed.checkpoint import (
  18. OPTIMIZER_STATE_DICT,
  19. BASE_OPTIMIZER_STATE,
  20. SINGLE_PARTITION_OF_FP32_GROUPS,
  21. PARAM_SLICE_MAPPINGS,
  22. PARAM_SHAPES,
  23. PARAM,
  24. CAT_DIM,
  25. PARAM_N_SUB_PARAMS,
  26. VOCAB_TENSOR,
  27. UNIVERSAL_CHECKPOINT_INFO,
  28. VOCABULARY_PARAMETER_PATTERNS,
  29. PIPELINE_REPLICATED_PARAMETER_PATTERNS,
  30. TP_REPLICATED_PARAMETER_PATTERNS,
  31. PARAMETER_TO_AVERAGE_PATTERNS,
  32. PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
  33. PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,
  34. )
  35. def parse_arguments():
  36. parser = argparse.ArgumentParser()
  37. parser.add_argument('--input_folder', type=str, required=True, help='Input DeepSpeed Checkpoint folder')
  38. parser.add_argument('--output_folder', type=str, required=True, help='Output DeepSpeed checkpoint folder')
  39. parser.add_argument('--num_extract_workers',
  40. default=4,
  41. type=int,
  42. help='How many parallel processes to extract zero shards')
  43. parser.add_argument(
  44. '--num_merge_workers',
  45. default=2,
  46. type=int,
  47. help=
  48. 'How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))'
  49. )
  50. parser.add_argument('--keep_temp_folder',
  51. action='store_true',
  52. help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.')
  53. parser.add_argument('--no_strict',
  54. dest='strict',
  55. action='store_false',
  56. help='Do not perform validity checks on converted checkpoint.')
  57. args = parser.parse_args()
  58. print(f'args = {args}')
  59. return args
  60. def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
  61. path_list = []
  62. iter_folder = f'iter_{iteration:07d}'
  63. for i in range(0, tp_degree):
  64. path_list.append([])
  65. for j in range(0, pp_degree):
  66. rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}'
  67. ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt')
  68. path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path))
  69. return path_list
  70. def _save_checkpoint(file_path, chkpt_sd):
  71. dir, _ = os.path.split(file_path)
  72. os.makedirs(dir, exist_ok=True)
  73. torch.save(chkpt_sd, file_path)
  74. def extract_zero_shards(dir, ds_checkpoint, indices_3D):
  75. pp_index, tp_index, dp_index = indices_3D
  76. sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)
  77. # pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}")
  78. optim_sd = sd[OPTIMIZER_STATE_DICT]
  79. param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS]
  80. universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
  81. pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, [])
  82. # print(f'{pipeline_replicated_params=}')
  83. # dict
  84. state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
  85. # list
  86. fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
  87. param_groups_cnt = len(state_groups)
  88. for param_group_id in range(param_groups_cnt):
  89. flat_state = dict(
  90. exp_avg=state_groups[param_group_id]["exp_avg"],
  91. exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"],
  92. fp32=fp32_groups[param_group_id],
  93. )
  94. for name, fragment_mapping in param_slice_mappings[param_group_id].items():
  95. if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params):
  96. # Skip tied weights that are replicated in first and last pp stages
  97. continue
  98. # pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}")
  99. for state_key in flat_state.keys():
  100. dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name,
  101. fragment_mapping.start, fragment_mapping.numel)
  102. cnt = 0
  103. def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):
  104. global cnt # temp hack
  105. param_base_path = os.path.join(dir, param_name, str(tp_index))
  106. os.makedirs(param_base_path, exist_ok=True)
  107. cnt += 1
  108. counter = f"{dp_index:0>2d}"
  109. path = os.path.join(param_base_path, f"{state_name}.{counter}")
  110. #print(f"{param_name}: {offset}: {numel} => {path}")
  111. t = state_flat_tensor.narrow(0, offset, numel).clone()
  112. _save_checkpoint(path, t)
  113. def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
  114. slices = []
  115. for tp_index in range(tp_degree):
  116. prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
  117. paths = sorted(list(glob.glob(f"{prefix_path}.*")))
  118. shards = [torch.load(p) for p in paths]
  119. slice = torch.cat(shards, dim=0).reshape(slice_shape)
  120. slices.append(slice)
  121. return slices
  122. def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
  123. name, shape = name_and_shape
  124. slice_base_path = os.path.join(slice_dir, name)
  125. param_base_path = os.path.join(dir, name)
  126. universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
  127. replicated_parameters = universal_checkpoint_info.get(TP_REPLICATED_PARAMETER_PATTERNS, [])
  128. parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, [])
  129. parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
  130. vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
  131. parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get(PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, [])
  132. unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism +
  133. vocabulary_parameters + parameters_with_2_sub_params_cat_dim_0)
  134. def get_matched_pattern(patterns_, name_):
  135. matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
  136. assert len(matched_) <= 1, f'Got more than one matching patterns={matched_} for {name_}'
  137. if matched_:
  138. pattern_ = matched_[0]
  139. unmatched_patterns.discard(pattern_)
  140. return pattern_
  141. return None
  142. for state in ("fp32", "exp_avg", "exp_avg_sq"):
  143. slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
  144. final_path = os.path.join(param_base_path, f"{state}.pt")
  145. #print(f"Expected shape: {shape}")
  146. #print(f"Fragment sizes:", list(frag.shape for frag in slices))
  147. ckpt_dict = {}
  148. if get_matched_pattern(replicated_parameters, name):
  149. if len(slices) > 1:
  150. assert all([slices[0].equal(other_slice) for other_slice in slices[1:]])
  151. param = slices[0]
  152. # print(f'replicate {name} using first slice')
  153. elif get_matched_pattern(parameters_to_average, name):
  154. param = sum(slices) / len(slices)
  155. # print(f'merge {name} using average')
  156. elif get_matched_pattern(parameters_with_2_sub_params_cat_dim_0, name):
  157. cat_dim = 0
  158. chunked_slices = [torch.chunk(s, 2, dim=cat_dim) for s in slices]
  159. merged_chunks_0 = torch.cat([s[0] for s in chunked_slices], dim=cat_dim)
  160. merged_chunks_1 = torch.cat([s[1] for s in chunked_slices], dim=cat_dim)
  161. param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim)
  162. ckpt_dict[CAT_DIM] = cat_dim
  163. ckpt_dict[PARAM_N_SUB_PARAMS] = 2
  164. else:
  165. cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
  166. # print(f"merge {name} with CAT DIM: {cat_dim}")
  167. param = torch.cat(slices, dim=cat_dim)
  168. ckpt_dict[CAT_DIM] = cat_dim
  169. if get_matched_pattern(vocabulary_parameters, name):
  170. #print(f"Before {param.shape=}")
  171. # strip padding
  172. original_vocab_size = universal_checkpoint_info['original_vocab_size']
  173. param = param[:original_vocab_size, :]
  174. ckpt_dict[VOCAB_TENSOR] = True
  175. #print(f"After {param.shape=}")
  176. #print(f"Final shape: {param.shape}")
  177. ckpt_dict[PARAM] = param
  178. _save_checkpoint(final_path, ckpt_dict)
  179. return unmatched_patterns
  180. def _get_chunks(l, n):
  181. for i in range(0, len(l), n):
  182. yield l[i:i + n]
  183. def _do_parallel_work(do_work, work_chunks, num_workers):
  184. pool = multiprocessing.Pool(num_workers)
  185. results = []
  186. for batch in tqdm.tqdm(work_chunks):
  187. res = pool.map(do_work, batch)
  188. results.extend(res)
  189. pool.close()
  190. pool.join()
  191. return results
  192. def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
  193. _3d_range_list = list(
  194. itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
  195. range(ds_checkpoint.dp_degree)))
  196. #pprint(f'{_3d_range_list=}')
  197. work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
  198. #pprint(f'{work_chunks=}')
  199. # extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
  200. do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
  201. _do_parallel_work(do_work, work_chunks, args.num_extract_workers)
  202. def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
  203. work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers))
  204. #pprint(work_chunks)
  205. zero_output_folder = os.path.join(args.output_folder, "zero")
  206. do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
  207. unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)
  208. # verify that all patterns were used
  209. # if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
  210. sets = [set(lst) for lst in unmatched_patterns_lists]
  211. unmatched_patterns = list(set.intersection(*sets))
  212. if args.strict:
  213. assert not unmatched_patterns, f'Unused patterns={unmatched_patterns} while merging tp slices'
  214. elif unmatched_patterns:
  215. print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')
  216. def _save_optimizer_state(args, ds_checkpoint):
  217. sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS]
  218. sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0)
  219. optim_sd = sd[OPTIMIZER_STATE_DICT]
  220. output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states}
  221. zero_output_folder = os.path.join(args.output_folder, "zero")
  222. output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
  223. _save_checkpoint(output_file_path, output_sd)
  224. def _check_for_required_state(ds_checkpoint):
  225. universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
  226. assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
  227. def main():
  228. print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint')
  229. args = parse_arguments()
  230. print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}')
  231. ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
  232. _check_for_required_state(ds_checkpoint)
  233. iteration = ds_checkpoint.get_iteration()
  234. #_create_latest_file(args.output_folder, iteration)
  235. checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree,
  236. ds_checkpoint.pp_degree)
  237. slice_shapes = []
  238. for mp_rank_file in ds_checkpoint.mp_rank_files:
  239. mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
  240. slice_shapes += mp_sd[PARAM_SHAPES]
  241. # fix back to normal flat dict, merge duplicates for tp>1
  242. slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items())
  243. temp_dir = os.path.join(args.output_folder, 'tmp')
  244. print('*** 1. Extracting ZeRO fragments')
  245. _extract_zero_shard_files(args, ds_checkpoint, temp_dir)
  246. print('*** 2. Merging slices .....')
  247. _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)
  248. print('*** 3. Saving common optimizer states')
  249. _save_optimizer_state(args, ds_checkpoint)
  250. if not args.keep_temp_folder:
  251. shutil.rmtree(temp_dir, ignore_errors=True)
  252. # Copy mp* files into output folder
  253. for f in glob.glob(os.path.join(args.input_folder, 'mp*')):
  254. shutil.copy2(f, args.output_folder)
  255. # Update latest to output folder
  256. checkpoint_root_folder, step_folder = os.path.split(args.output_folder)
  257. latest_file = os.path.join(checkpoint_root_folder, 'latest_universal')
  258. with open(latest_file, "w") as f:
  259. f.write(step_folder)
  260. print('*** Done!')
  261. if __name__ == "__main__":
  262. main()