ds_to_universal.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. #!/usr/bin/env python
  2. # Copyright (c) Microsoft Corporation.
  3. # SPDX-License-Identifier: Apache-2.0
  4. # DeepSpeed Team
  5. from functools import partial
  6. from itertools import chain
  7. import argparse
  8. import glob
  9. import itertools
  10. import math
  11. from concurrent.futures import ProcessPoolExecutor
  12. import os
  13. import re
  14. import shutil
  15. import torch
  16. import tqdm
  17. #from pprint import pprint
  18. from deepspeed.checkpoint import DeepSpeedCheckpoint
  19. from deepspeed.checkpoint import (
  20. OPTIMIZER_STATE_DICT,
  21. ZERO_STAGE,
  22. BASE_OPTIMIZER_STATE,
  23. SINGLE_PARTITION_OF_FP32_GROUPS,
  24. PARAM_GROUPS,
  25. PARAM_SLICE_MAPPINGS,
  26. PARAM_SHAPES,
  27. PARAM,
  28. CAT_DIM,
  29. PARAM_N_SUB_PARAMS,
  30. SUB_PARAM_SHAPE,
  31. VOCAB_TENSOR,
  32. UNIVERSAL_CHECKPOINT_INFO,
  33. UNIVERSAL_CHECKPOINT_VERSION_KEY,
  34. UNIVERSAL_CHECKPOINT_VERSION_VALUE,
  35. VOCABULARY_PARAMETER_PATTERNS,
  36. PIPELINE_REPLICATED_PARAMETER_PATTERNS,
  37. TP_REPLICATED_PARAMETER_PATTERNS,
  38. PARAMETER_TO_AVERAGE_PATTERNS,
  39. PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
  40. PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,
  41. PARAMETER_WITH_SUB_PARAMS,
  42. SubparamShape,
  43. )
  44. def parse_arguments():
  45. parser = argparse.ArgumentParser()
  46. parser.add_argument('--input_folder', type=str, required=True, help='Input DeepSpeed Checkpoint folder')
  47. parser.add_argument('--output_folder', type=str, required=True, help='Output DeepSpeed checkpoint folder')
  48. parser.add_argument('--num_extract_workers',
  49. default=4,
  50. type=int,
  51. help='How many parallel processes to extract zero shards')
  52. parser.add_argument(
  53. '--num_merge_workers',
  54. default=2,
  55. type=int,
  56. help=
  57. 'How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))'
  58. )
  59. parser.add_argument('--keep_temp_folder',
  60. action='store_true',
  61. help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.')
  62. parser.add_argument('--no_strict',
  63. dest='strict',
  64. action='store_false',
  65. help='Do not perform validity checks on converted checkpoint.')
  66. parser.add_argument('--inject-missing-state',
  67. action='store_true',
  68. help='Inject missing checkpoint state into the checkpoint if it is absent.')
  69. args = parser.parse_args()
  70. print(f'args = {args}')
  71. return args
  72. def atoi(text):
  73. return int(text) if text.isdigit() else text
  74. def natural_keys(text):
  75. '''
  76. alist.sort(key=natural_keys) sorts in human order
  77. http://nedbatchelder.com/blog/200712/human_sorting.html
  78. (See Toothy's implementation in the comments)
  79. '''
  80. return [atoi(c) for c in re.split(r'(\d+)', text)]
  81. def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
  82. path_list = []
  83. iter_folder = f'iter_{iteration:07d}'
  84. for i in range(0, tp_degree):
  85. path_list.append([])
  86. for j in range(0, pp_degree):
  87. rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}'
  88. ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt')
  89. path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path))
  90. return path_list
  91. def _save_checkpoint(file_path, chkpt_sd):
  92. dir, _ = os.path.split(file_path)
  93. os.makedirs(dir, exist_ok=True)
  94. torch.save(chkpt_sd, file_path)
  95. def extract_zero_shards(dir, ds_checkpoint, indices_3D):
  96. pp_index, tp_index, dp_index = indices_3D
  97. sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)
  98. # pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}")
  99. optim_sd = sd[OPTIMIZER_STATE_DICT]
  100. param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS]
  101. universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
  102. pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, [])
  103. # print(f'{pipeline_replicated_params=}')
  104. # dict
  105. state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
  106. # list
  107. fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
  108. param_groups_cnt = len(state_groups)
  109. for param_group_id in range(param_groups_cnt):
  110. flat_state = dict(
  111. exp_avg=state_groups[param_group_id]["exp_avg"],
  112. exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"],
  113. fp32=fp32_groups[param_group_id],
  114. )
  115. if "step" in state_groups[param_group_id]:
  116. flat_state["step"] = state_groups[param_group_id]["step"]
  117. for name, fragment_mapping in param_slice_mappings[param_group_id].items():
  118. if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params):
  119. # Skip tied weights that are replicated in first and last pp stages
  120. continue
  121. # pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}")
  122. for state_key in flat_state.keys():
  123. dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name,
  124. fragment_mapping.start, fragment_mapping.numel)
  125. def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
  126. state_dict = torch.load(optim_files[dp_index], map_location='cpu')
  127. flat_state = dict(
  128. exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
  129. exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"],
  130. fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0],
  131. )
  132. offset = 0
  133. for name, shape in param_shapes.items():
  134. unpartitioned_numel = shape.numel()
  135. partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree)
  136. padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel))
  137. for state_key in flat_state.keys():
  138. dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
  139. padding_free_numel)
  140. offset += partitioned_numel
  141. cnt = 0
  142. def dp_index_to_str(dp_index):
  143. return f"{dp_index:0>2d}"
  144. def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):
  145. global cnt # temp hack
  146. param_base_path = os.path.join(dir, param_name, str(tp_index))
  147. os.makedirs(param_base_path, exist_ok=True)
  148. cnt += 1
  149. path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}")
  150. #print(f"{param_name}: {offset}: {numel} => {path}")
  151. # State might be a python int or a tensor
  152. if state_name != "step" and torch.is_tensor(state_flat_tensor):
  153. state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone()
  154. _save_checkpoint(path, state_flat_tensor)
  155. def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
  156. slices = []
  157. for tp_index in range(tp_degree):
  158. prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
  159. paths = glob.glob(f"{prefix_path}.*")
  160. if len(paths) == 0:
  161. continue
  162. pattern = re.compile(f"{prefix_path}\\.([0-9]+)")
  163. dp_indices = set()
  164. for p in paths:
  165. m = pattern.match(p)
  166. if m:
  167. dp_indices.add(int(m.group(1)))
  168. else:
  169. raise ValueError(f"Cannot parse dp_rank from {p}")
  170. paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
  171. shards = [torch.load(p) for p in paths]
  172. if state == "step":
  173. assert all(v == shards[0] for v in shards), "All shards must have the same step value"
  174. slice = shards[0]
  175. else:
  176. if slice_shape is None:
  177. slice = torch.cat(shards, dim=0)
  178. else:
  179. slice = torch.cat(shards, dim=0).reshape(slice_shape)
  180. slices.append(slice)
  181. return slices
  182. def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
  183. name, shape = name_and_shape
  184. slice_base_path = os.path.join(slice_dir, name)
  185. param_base_path = os.path.join(dir, name)
  186. universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
  187. replicated_parameters = universal_checkpoint_info.get(TP_REPLICATED_PARAMETER_PATTERNS, [])
  188. parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, [])
  189. parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
  190. vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
  191. parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get(PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, [])
  192. parameter_with_sub_params = universal_checkpoint_info.get(PARAMETER_WITH_SUB_PARAMS, [])
  193. unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism +
  194. vocabulary_parameters + parameters_with_2_sub_params_cat_dim_0)
  195. unmatched_patterns.update(chain.from_iterable(SubparamShape(**s).patterns for s in parameter_with_sub_params))
  196. def get_matched_pattern(patterns_, name_):
  197. matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
  198. assert len(matched_) <= 1, f'Got more than one matching patterns={matched_} for {name_}'
  199. if matched_:
  200. pattern_ = matched_[0]
  201. unmatched_patterns.discard(pattern_)
  202. return pattern_
  203. return None
  204. def get_matched_sub_params_pattern(name_):
  205. for subparam_shape_dict in parameter_with_sub_params:
  206. subparam_shape = SubparamShape(**subparam_shape_dict)
  207. for pattern_ in subparam_shape.patterns:
  208. if re.match(pattern_, name_):
  209. unmatched_patterns.discard(pattern_)
  210. return subparam_shape
  211. return None
  212. matched_sub_params_shape = get_matched_sub_params_pattern(name)
  213. step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape)
  214. if step_merged:
  215. _save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0])
  216. for state in ("fp32", "exp_avg", "exp_avg_sq"):
  217. slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
  218. final_path = os.path.join(param_base_path, f"{state}.pt")
  219. #print(f"Expected shape: {shape}")
  220. #print(f"Fragment sizes:", list(frag.shape for frag in slices))
  221. ckpt_dict = {}
  222. if get_matched_pattern(replicated_parameters, name):
  223. if len(slices) > 1:
  224. assert all([slices[0].equal(other_slice) for other_slice in slices[1:]])
  225. param = slices[0]
  226. # print(f'replicate {name} using first slice')
  227. elif get_matched_pattern(parameters_to_average, name):
  228. param = sum(slices) / len(slices)
  229. # print(f'merge {name} using average')
  230. elif get_matched_pattern(parameters_with_2_sub_params_cat_dim_0, name):
  231. cat_dim = 0
  232. chunked_slices = [torch.chunk(s, 2, dim=cat_dim) for s in slices]
  233. merged_chunks_0 = torch.cat([s[0] for s in chunked_slices], dim=cat_dim)
  234. merged_chunks_1 = torch.cat([s[1] for s in chunked_slices], dim=cat_dim)
  235. param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim)
  236. ckpt_dict[CAT_DIM] = cat_dim
  237. ckpt_dict[PARAM_N_SUB_PARAMS] = 2
  238. elif matched_sub_params_shape:
  239. merged_chunks = []
  240. partition_dim = matched_sub_params_shape.partition_dim
  241. sub_dim_sizes = matched_sub_params_shape.shape[partition_dim]
  242. if not isinstance(sub_dim_sizes, tuple):
  243. sub_dim_sizes = (sub_dim_sizes, )
  244. partition_shape = [sum(d) if isinstance(d, tuple) else d for d in matched_sub_params_shape.shape]
  245. partition_shape = [d // tp_degree if i == partition_dim else d for i, d in enumerate(partition_shape)]
  246. slices = [s.view(partition_shape) for s in slices]
  247. offset = 0
  248. for sub_dim_size in sub_dim_sizes:
  249. part_sub_dim_size = sub_dim_size // tp_degree
  250. merged_chunks.append(
  251. torch.cat([s.narrow(partition_dim, offset, part_sub_dim_size) for s in slices], dim=partition_dim))
  252. offset += part_sub_dim_size
  253. param = torch.cat(merged_chunks, dim=partition_dim)
  254. ckpt_dict[SUB_PARAM_SHAPE] = matched_sub_params_shape
  255. else:
  256. cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
  257. # print(f"merge {name} with CAT DIM: {cat_dim}")
  258. param = torch.cat(slices, dim=cat_dim)
  259. ckpt_dict[CAT_DIM] = cat_dim
  260. if get_matched_pattern(vocabulary_parameters, name):
  261. #print(f"Before {param.shape=}")
  262. # strip padding
  263. original_vocab_size = universal_checkpoint_info['original_vocab_size']
  264. param = param[:original_vocab_size, :]
  265. ckpt_dict[VOCAB_TENSOR] = True
  266. #print(f"After {param.shape=}")
  267. #print(f"Final shape: {param.shape}")
  268. ckpt_dict[PARAM] = param
  269. _save_checkpoint(final_path, ckpt_dict)
  270. return unmatched_patterns
  271. def merge_zero3_slices(dp_degree, dir, slice_dir, name):
  272. slice_base_path = os.path.join(slice_dir, name)
  273. param_base_path = os.path.join(dir, name)
  274. for state in ("fp32", "exp_avg", "exp_avg_sq"):
  275. slices = _merge_zero_shards(slice_base_path, state, 1)
  276. final_path = os.path.join(param_base_path, f"{state}.pt")
  277. _save_checkpoint(final_path, slices[0])
  278. def _do_parallel_work(do_work, work_chunks, num_workers):
  279. results = []
  280. if num_workers > 1:
  281. with ProcessPoolExecutor(max_workers=num_workers) as executor:
  282. future_list = [executor.submit(do_work, work) for work in work_chunks]
  283. for f in tqdm.tqdm(future_list):
  284. results.append(f.result())
  285. else:
  286. # No parallel pass for unit testing
  287. # We can't create child processes in tests
  288. for work in tqdm.tqdm(work_chunks):
  289. results.append(do_work(work))
  290. return results
  291. def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
  292. _3d_range_list = list(
  293. itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
  294. range(ds_checkpoint.dp_degree)))
  295. #pprint(f'{_3d_range_list=}')
  296. do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
  297. _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)
  298. def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir):
  299. do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir)
  300. _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers)
  301. def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
  302. zero_output_folder = os.path.join(args.output_folder, "zero")
  303. do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
  304. unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers)
  305. # verify that all patterns were used
  306. # if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
  307. sets = [set(lst) for lst in unmatched_patterns_lists]
  308. unmatched_patterns = list(set.intersection(*sets))
  309. if args.strict:
  310. assert not unmatched_patterns, f'Unused patterns={unmatched_patterns} while merging tp slices'
  311. elif unmatched_patterns:
  312. print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')
  313. def _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir):
  314. zero_output_folder = os.path.join(args.output_folder, "zero")
  315. do_work = partial(merge_zero3_slices, dp_degree, zero_output_folder, temp_dir)
  316. _do_parallel_work(do_work, param_shapes.keys(), args.num_merge_workers)
  317. def _zero_partitioned_param_info(unpartitioned_numel, world_size):
  318. remainder = unpartitioned_numel % world_size
  319. padding_numel = (world_size - remainder) if remainder else 0
  320. partitioned_numel = math.ceil(unpartitioned_numel / world_size)
  321. return partitioned_numel, padding_numel
  322. def _parse_model_states_stage3(files):
  323. return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
  324. def _save_optimizer_state(args, ds_checkpoint):
  325. sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS]
  326. sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0)
  327. optim_sd = sd[OPTIMIZER_STATE_DICT]
  328. output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states}
  329. output_sd[PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][PARAM_GROUPS]
  330. zero_output_folder = os.path.join(args.output_folder, "zero")
  331. output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
  332. _save_checkpoint(output_file_path, output_sd)
  333. def _save_optimizer_state_stage3(args, optim_files):
  334. sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
  335. output_sd = sd[OPTIMIZER_STATE_DICT]
  336. output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS]
  337. zero_output_folder = os.path.join(args.output_folder, "zero")
  338. output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
  339. _save_checkpoint(output_file_path, output_sd)
  340. def _get_optim_files(checkpoint_dir):
  341. return _get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
  342. def _get_model_state_files(checkpoint_dir):
  343. return _get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
  344. def _get_checkpoint_files(checkpoint_dir, glob_pattern):
  345. ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
  346. if len(ckpt_files) == 0:
  347. raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
  348. return ckpt_files
  349. def _get_zero_stage(optim_files):
  350. state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
  351. optimizer_state = state_dict[OPTIMIZER_STATE_DICT]
  352. zero_stage = optimizer_state.get(ZERO_STAGE, 1)
  353. return zero_stage
  354. def _inject_missing_state(ds_checkpoint):
  355. if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
  356. sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
  357. if UNIVERSAL_CHECKPOINT_INFO not in sd:
  358. ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
  359. ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
  360. UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE
  361. def _check_for_required_state(ds_checkpoint):
  362. universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
  363. assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
  364. def main(args):
  365. print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint')
  366. print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}')
  367. optim_files = _get_optim_files(args.input_folder)
  368. zero_stage = _get_zero_stage(optim_files)
  369. if zero_stage <= 2:
  370. ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
  371. if args.inject_missing_state:
  372. _inject_missing_state(ds_checkpoint)
  373. else:
  374. _check_for_required_state(ds_checkpoint)
  375. iteration = ds_checkpoint.get_iteration()
  376. #_create_latest_file(args.output_folder, iteration)
  377. checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree,
  378. ds_checkpoint.pp_degree)
  379. slice_shapes = []
  380. for mp_rank_file in ds_checkpoint.mp_rank_files:
  381. mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
  382. slice_shapes += mp_sd[PARAM_SHAPES]
  383. # fix back to normal flat dict, merge duplicates for tp>1
  384. slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items())
  385. temp_dir = os.path.join(args.output_folder, 'tmp')
  386. print('*** 1. Extracting ZeRO fragments')
  387. _extract_zero_shard_files(args, ds_checkpoint, temp_dir)
  388. print('*** 2. Merging slices .....')
  389. _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)
  390. print('*** 3. Saving common optimizer states')
  391. _save_optimizer_state(args, ds_checkpoint)
  392. if not args.keep_temp_folder:
  393. shutil.rmtree(temp_dir, ignore_errors=True)
  394. # Copy mp* files into output folder
  395. for f in glob.glob(os.path.join(args.input_folder, 'mp*')):
  396. shutil.copy2(f, args.output_folder)
  397. else:
  398. model_files = _get_model_state_files(args.input_folder)
  399. param_shapes = _parse_model_states_stage3(model_files)
  400. param_shapes = {k: v for d in param_shapes for k, v in d.items()}
  401. dp_degree = len(model_files)
  402. temp_dir = os.path.join(args.output_folder, 'tmp')
  403. print('*** 1. Extracting ZeRO fragments')
  404. _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir)
  405. print('*** 2. Merging slices .....')
  406. _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir)
  407. print('*** 3. Saving common optimizer states')
  408. _save_optimizer_state_stage3(args, optim_files)
  409. if not args.keep_temp_folder:
  410. shutil.rmtree(temp_dir, ignore_errors=True)
  411. # Copy *model_states files into output folder
  412. for f in glob.glob(os.path.join(args.input_folder, '*model_states.pt')):
  413. shutil.copy2(f, args.output_folder)
  414. # Update latest to output folder
  415. checkpoint_root_folder, step_folder = os.path.split(args.output_folder)
  416. latest_file = os.path.join(checkpoint_root_folder, 'latest_universal')
  417. with open(latest_file, "w") as f:
  418. f.write(step_folder)
  419. print('*** Done!')
  420. if __name__ == "__main__":
  421. args = parse_arguments()
  422. main(args)