123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549 |
- #!/usr/bin/env python
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from functools import partial
- from itertools import chain
- import argparse
- import glob
- import itertools
- import math
- from concurrent.futures import ProcessPoolExecutor
- import os
- import re
- import shutil
- import torch
- import tqdm
- #from pprint import pprint
- from deepspeed.checkpoint import DeepSpeedCheckpoint
- from deepspeed.checkpoint import (
- OPTIMIZER_STATE_DICT,
- ZERO_STAGE,
- BASE_OPTIMIZER_STATE,
- SINGLE_PARTITION_OF_FP32_GROUPS,
- PARAM_GROUPS,
- PARAM_SLICE_MAPPINGS,
- PARAM_SHAPES,
- PARAM,
- CAT_DIM,
- PARAM_N_SUB_PARAMS,
- SUB_PARAM_SHAPE,
- VOCAB_TENSOR,
- UNIVERSAL_CHECKPOINT_INFO,
- UNIVERSAL_CHECKPOINT_VERSION_KEY,
- UNIVERSAL_CHECKPOINT_VERSION_VALUE,
- VOCABULARY_PARAMETER_PATTERNS,
- PIPELINE_REPLICATED_PARAMETER_PATTERNS,
- TP_REPLICATED_PARAMETER_PATTERNS,
- PARAMETER_TO_AVERAGE_PATTERNS,
- PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
- PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,
- PARAMETER_WITH_SUB_PARAMS,
- SubparamShape,
- )
- def parse_arguments():
- parser = argparse.ArgumentParser()
- parser.add_argument('--input_folder', type=str, required=True, help='Input DeepSpeed Checkpoint folder')
- parser.add_argument('--output_folder', type=str, required=True, help='Output DeepSpeed checkpoint folder')
- parser.add_argument('--num_extract_workers',
- default=4,
- type=int,
- help='How many parallel processes to extract zero shards')
- parser.add_argument(
- '--num_merge_workers',
- default=2,
- type=int,
- help=
- 'How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))'
- )
- parser.add_argument('--keep_temp_folder',
- action='store_true',
- help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.')
- parser.add_argument('--no_strict',
- dest='strict',
- action='store_false',
- help='Do not perform validity checks on converted checkpoint.')
- parser.add_argument('--inject-missing-state',
- action='store_true',
- help='Inject missing checkpoint state into the checkpoint if it is absent.')
- args = parser.parse_args()
- print(f'args = {args}')
- return args
- def atoi(text):
- return int(text) if text.isdigit() else text
- def natural_keys(text):
- '''
- alist.sort(key=natural_keys) sorts in human order
- http://nedbatchelder.com/blog/200712/human_sorting.html
- (See Toothy's implementation in the comments)
- '''
- return [atoi(c) for c in re.split(r'(\d+)', text)]
- def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
- path_list = []
- iter_folder = f'iter_{iteration:07d}'
- for i in range(0, tp_degree):
- path_list.append([])
- for j in range(0, pp_degree):
- rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}'
- ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt')
- path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path))
- return path_list
- def _save_checkpoint(file_path, chkpt_sd):
- dir, _ = os.path.split(file_path)
- os.makedirs(dir, exist_ok=True)
- torch.save(chkpt_sd, file_path)
- def extract_zero_shards(dir, ds_checkpoint, indices_3D):
- pp_index, tp_index, dp_index = indices_3D
- sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)
- # pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}")
- optim_sd = sd[OPTIMIZER_STATE_DICT]
- param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS]
- universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
- pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, [])
- # print(f'{pipeline_replicated_params=}')
- # dict
- state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
- # list
- fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
- param_groups_cnt = len(state_groups)
- for param_group_id in range(param_groups_cnt):
- flat_state = dict(
- exp_avg=state_groups[param_group_id]["exp_avg"],
- exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"],
- fp32=fp32_groups[param_group_id],
- )
- if "step" in state_groups[param_group_id]:
- flat_state["step"] = state_groups[param_group_id]["step"]
- for name, fragment_mapping in param_slice_mappings[param_group_id].items():
- if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params):
- # Skip tied weights that are replicated in first and last pp stages
- continue
- # pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}")
- for state_key in flat_state.keys():
- dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name,
- fragment_mapping.start, fragment_mapping.numel)
- def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
- state_dict = torch.load(optim_files[dp_index], map_location='cpu')
- flat_state = dict(
- exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
- exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"],
- fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0],
- )
- offset = 0
- for name, shape in param_shapes.items():
- unpartitioned_numel = shape.numel()
- partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree)
- padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel))
- for state_key in flat_state.keys():
- dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
- padding_free_numel)
- offset += partitioned_numel
- cnt = 0
- def dp_index_to_str(dp_index):
- return f"{dp_index:0>2d}"
- def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):
- global cnt # temp hack
- param_base_path = os.path.join(dir, param_name, str(tp_index))
- os.makedirs(param_base_path, exist_ok=True)
- cnt += 1
- path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}")
- #print(f"{param_name}: {offset}: {numel} => {path}")
- # State might be a python int or a tensor
- if state_name != "step" and torch.is_tensor(state_flat_tensor):
- state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone()
- _save_checkpoint(path, state_flat_tensor)
- def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
- slices = []
- for tp_index in range(tp_degree):
- prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
- paths = glob.glob(f"{prefix_path}.*")
- if len(paths) == 0:
- continue
- pattern = re.compile(f"{prefix_path}\\.([0-9]+)")
- dp_indices = set()
- for p in paths:
- m = pattern.match(p)
- if m:
- dp_indices.add(int(m.group(1)))
- else:
- raise ValueError(f"Cannot parse dp_rank from {p}")
- paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
- shards = [torch.load(p) for p in paths]
- if state == "step":
- assert all(v == shards[0] for v in shards), "All shards must have the same step value"
- slice = shards[0]
- else:
- if slice_shape is None:
- slice = torch.cat(shards, dim=0)
- else:
- slice = torch.cat(shards, dim=0).reshape(slice_shape)
- slices.append(slice)
- return slices
- def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
- name, shape = name_and_shape
- slice_base_path = os.path.join(slice_dir, name)
- param_base_path = os.path.join(dir, name)
- universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
- replicated_parameters = universal_checkpoint_info.get(TP_REPLICATED_PARAMETER_PATTERNS, [])
- parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, [])
- parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
- vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
- parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get(PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, [])
- parameter_with_sub_params = universal_checkpoint_info.get(PARAMETER_WITH_SUB_PARAMS, [])
- unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism +
- vocabulary_parameters + parameters_with_2_sub_params_cat_dim_0)
- unmatched_patterns.update(chain.from_iterable(SubparamShape(**s).patterns for s in parameter_with_sub_params))
- def get_matched_pattern(patterns_, name_):
- matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
- assert len(matched_) <= 1, f'Got more than one matching patterns={matched_} for {name_}'
- if matched_:
- pattern_ = matched_[0]
- unmatched_patterns.discard(pattern_)
- return pattern_
- return None
- def get_matched_sub_params_pattern(name_):
- for subparam_shape_dict in parameter_with_sub_params:
- subparam_shape = SubparamShape(**subparam_shape_dict)
- for pattern_ in subparam_shape.patterns:
- if re.match(pattern_, name_):
- unmatched_patterns.discard(pattern_)
- return subparam_shape
- return None
- matched_sub_params_shape = get_matched_sub_params_pattern(name)
- step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape)
- if step_merged:
- _save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0])
- for state in ("fp32", "exp_avg", "exp_avg_sq"):
- slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
- final_path = os.path.join(param_base_path, f"{state}.pt")
- #print(f"Expected shape: {shape}")
- #print(f"Fragment sizes:", list(frag.shape for frag in slices))
- ckpt_dict = {}
- if get_matched_pattern(replicated_parameters, name):
- if len(slices) > 1:
- assert all([slices[0].equal(other_slice) for other_slice in slices[1:]])
- param = slices[0]
- # print(f'replicate {name} using first slice')
- elif get_matched_pattern(parameters_to_average, name):
- param = sum(slices) / len(slices)
- # print(f'merge {name} using average')
- elif get_matched_pattern(parameters_with_2_sub_params_cat_dim_0, name):
- cat_dim = 0
- chunked_slices = [torch.chunk(s, 2, dim=cat_dim) for s in slices]
- merged_chunks_0 = torch.cat([s[0] for s in chunked_slices], dim=cat_dim)
- merged_chunks_1 = torch.cat([s[1] for s in chunked_slices], dim=cat_dim)
- param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim)
- ckpt_dict[CAT_DIM] = cat_dim
- ckpt_dict[PARAM_N_SUB_PARAMS] = 2
- elif matched_sub_params_shape:
- merged_chunks = []
- partition_dim = matched_sub_params_shape.partition_dim
- sub_dim_sizes = matched_sub_params_shape.shape[partition_dim]
- if not isinstance(sub_dim_sizes, tuple):
- sub_dim_sizes = (sub_dim_sizes, )
- partition_shape = [sum(d) if isinstance(d, tuple) else d for d in matched_sub_params_shape.shape]
- partition_shape = [d // tp_degree if i == partition_dim else d for i, d in enumerate(partition_shape)]
- slices = [s.view(partition_shape) for s in slices]
- offset = 0
- for sub_dim_size in sub_dim_sizes:
- part_sub_dim_size = sub_dim_size // tp_degree
- merged_chunks.append(
- torch.cat([s.narrow(partition_dim, offset, part_sub_dim_size) for s in slices], dim=partition_dim))
- offset += part_sub_dim_size
- param = torch.cat(merged_chunks, dim=partition_dim)
- ckpt_dict[SUB_PARAM_SHAPE] = matched_sub_params_shape
- else:
- cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
- # print(f"merge {name} with CAT DIM: {cat_dim}")
- param = torch.cat(slices, dim=cat_dim)
- ckpt_dict[CAT_DIM] = cat_dim
- if get_matched_pattern(vocabulary_parameters, name):
- #print(f"Before {param.shape=}")
- # strip padding
- original_vocab_size = universal_checkpoint_info['original_vocab_size']
- param = param[:original_vocab_size, :]
- ckpt_dict[VOCAB_TENSOR] = True
- #print(f"After {param.shape=}")
- #print(f"Final shape: {param.shape}")
- ckpt_dict[PARAM] = param
- _save_checkpoint(final_path, ckpt_dict)
- return unmatched_patterns
- def merge_zero3_slices(dp_degree, dir, slice_dir, name):
- slice_base_path = os.path.join(slice_dir, name)
- param_base_path = os.path.join(dir, name)
- for state in ("fp32", "exp_avg", "exp_avg_sq"):
- slices = _merge_zero_shards(slice_base_path, state, 1)
- final_path = os.path.join(param_base_path, f"{state}.pt")
- _save_checkpoint(final_path, slices[0])
- def _do_parallel_work(do_work, work_chunks, num_workers):
- results = []
- if num_workers > 1:
- with ProcessPoolExecutor(max_workers=num_workers) as executor:
- future_list = [executor.submit(do_work, work) for work in work_chunks]
- for f in tqdm.tqdm(future_list):
- results.append(f.result())
- else:
- # No parallel pass for unit testing
- # We can't create child processes in tests
- for work in tqdm.tqdm(work_chunks):
- results.append(do_work(work))
- return results
- def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
- _3d_range_list = list(
- itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
- range(ds_checkpoint.dp_degree)))
- #pprint(f'{_3d_range_list=}')
- do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
- _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)
- def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir):
- do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir)
- _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers)
- def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
- zero_output_folder = os.path.join(args.output_folder, "zero")
- do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
- unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers)
- # verify that all patterns were used
- # if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
- sets = [set(lst) for lst in unmatched_patterns_lists]
- unmatched_patterns = list(set.intersection(*sets))
- if args.strict:
- assert not unmatched_patterns, f'Unused patterns={unmatched_patterns} while merging tp slices'
- elif unmatched_patterns:
- print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')
- def _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir):
- zero_output_folder = os.path.join(args.output_folder, "zero")
- do_work = partial(merge_zero3_slices, dp_degree, zero_output_folder, temp_dir)
- _do_parallel_work(do_work, param_shapes.keys(), args.num_merge_workers)
- def _zero_partitioned_param_info(unpartitioned_numel, world_size):
- remainder = unpartitioned_numel % world_size
- padding_numel = (world_size - remainder) if remainder else 0
- partitioned_numel = math.ceil(unpartitioned_numel / world_size)
- return partitioned_numel, padding_numel
- def _parse_model_states_stage3(files):
- return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
- def _save_optimizer_state(args, ds_checkpoint):
- sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS]
- sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0)
- optim_sd = sd[OPTIMIZER_STATE_DICT]
- output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states}
- output_sd[PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][PARAM_GROUPS]
- zero_output_folder = os.path.join(args.output_folder, "zero")
- output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
- _save_checkpoint(output_file_path, output_sd)
- def _save_optimizer_state_stage3(args, optim_files):
- sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
- output_sd = sd[OPTIMIZER_STATE_DICT]
- output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS]
- zero_output_folder = os.path.join(args.output_folder, "zero")
- output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
- _save_checkpoint(output_file_path, output_sd)
- def _get_optim_files(checkpoint_dir):
- return _get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
- def _get_model_state_files(checkpoint_dir):
- return _get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
- def _get_checkpoint_files(checkpoint_dir, glob_pattern):
- ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
- if len(ckpt_files) == 0:
- raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
- return ckpt_files
- def _get_zero_stage(optim_files):
- state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
- optimizer_state = state_dict[OPTIMIZER_STATE_DICT]
- zero_stage = optimizer_state.get(ZERO_STAGE, 1)
- return zero_stage
- def _inject_missing_state(ds_checkpoint):
- if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
- sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
- if UNIVERSAL_CHECKPOINT_INFO not in sd:
- ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
- ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
- UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE
- def _check_for_required_state(ds_checkpoint):
- universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
- assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
- def main(args):
- print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint')
- print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}')
- optim_files = _get_optim_files(args.input_folder)
- zero_stage = _get_zero_stage(optim_files)
- if zero_stage <= 2:
- ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
- if args.inject_missing_state:
- _inject_missing_state(ds_checkpoint)
- else:
- _check_for_required_state(ds_checkpoint)
- iteration = ds_checkpoint.get_iteration()
- #_create_latest_file(args.output_folder, iteration)
- checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree,
- ds_checkpoint.pp_degree)
- slice_shapes = []
- for mp_rank_file in ds_checkpoint.mp_rank_files:
- mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
- slice_shapes += mp_sd[PARAM_SHAPES]
- # fix back to normal flat dict, merge duplicates for tp>1
- slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items())
- temp_dir = os.path.join(args.output_folder, 'tmp')
- print('*** 1. Extracting ZeRO fragments')
- _extract_zero_shard_files(args, ds_checkpoint, temp_dir)
- print('*** 2. Merging slices .....')
- _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)
- print('*** 3. Saving common optimizer states')
- _save_optimizer_state(args, ds_checkpoint)
- if not args.keep_temp_folder:
- shutil.rmtree(temp_dir, ignore_errors=True)
- # Copy mp* files into output folder
- for f in glob.glob(os.path.join(args.input_folder, 'mp*')):
- shutil.copy2(f, args.output_folder)
- else:
- model_files = _get_model_state_files(args.input_folder)
- param_shapes = _parse_model_states_stage3(model_files)
- param_shapes = {k: v for d in param_shapes for k, v in d.items()}
- dp_degree = len(model_files)
- temp_dir = os.path.join(args.output_folder, 'tmp')
- print('*** 1. Extracting ZeRO fragments')
- _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir)
- print('*** 2. Merging slices .....')
- _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir)
- print('*** 3. Saving common optimizer states')
- _save_optimizer_state_stage3(args, optim_files)
- if not args.keep_temp_folder:
- shutil.rmtree(temp_dir, ignore_errors=True)
- # Copy *model_states files into output folder
- for f in glob.glob(os.path.join(args.input_folder, '*model_states.pt')):
- shutil.copy2(f, args.output_folder)
- # Update latest to output folder
- checkpoint_root_folder, step_folder = os.path.split(args.output_folder)
- latest_file = os.path.join(checkpoint_root_folder, 'latest_universal')
- with open(latest_file, "w") as f:
- f.write(step_folder)
- print('*** Done!')
- if __name__ == "__main__":
- args = parse_arguments()
- main(args)
|