123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from .constants import (BASE_OPTIMIZER_STATE, GROUP_PADDINGS, OPTIMIZER_STATE_DICT, PARTITION_COUNT)
- from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state)
- from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)
- GROUP_STATE_KEY = 'state'
- class ZeROCheckpoint(object):
- def __init__(self, dir):
- basic_folder_validation(dir)
- self.dir = dir
- self.file_list = get_zero_files(dir)
- self.num_files = len(self.file_list)
- assert self.num_files > 0, f'No ZeRO files found in {dir}'
- self.src_3d = get_model_3d_descriptor(dir)
- self.target_3d = model_3d_desc(pp_degree=self.src_3d.pp_degree,
- tp_degree=self.src_3d.tp_degree,
- dp_degree=self.src_3d.dp_degree)
- self._3d_file_map = self.src_3d.reshape(self.target_3d)
- def get_src_world_size(self):
- return self.src_3d.world_size()
- def get_src_tp_degree(self):
- return self.src_3d.tp_degree
- def get_src_pp_degree(self):
- return self.src_3d.pp_degree
- def get_src_dp_degree(self):
- return self.src_3d.dp_degree
- def get_file_indices_for_rank(self, pp_index, tp_index, dp_index):
- assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}'
- dp_2d_map = self._3d_file_map[dp_index]
- return dp_2d_map.get_data(pp_index, tp_index)
- def get_files_for_rank(self, pp_index, tp_index, dp_index):
- file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index)
- return [self.file_list[idx] for idx in file_idx_list]
- def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], strip_tensor_paddings=True):
- state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
- merged_sd = None
- for state_file in state_file_list:
- sd = torch.load(state_file, map_location=torch.device('cpu'))
- for key in keys_to_ignore:
- sd.pop(key, None)
- if strip_tensor_paddings:
- self._strip_tensor_paddings(sd)
- if merged_sd is None:
- merged_sd = sd
- else:
- merged_sd = merge_state(merged_sd, sd)
- self._update_partition_count(merged_sd)
- if strip_tensor_paddings:
- self._clear_group_paddings(merged_sd)
- return merged_sd
- def print_3d_index_map(self, tag=None):
- if tag:
- print(f'3D index map: {tag}')
- for dp_index, _2d_map in enumerate(self._3d_file_map):
- _2d_map.print_data(f'dp = {dp_index}')
- def print_3d_file_map(self, tag=None):
- if tag:
- print(f'3D file map: {tag}')
- for dp_index, _2d_map in enumerate(self._3d_file_map):
- for pp_index in _2d_map.pp_degree:
- for tp_index in _2d_map.tp_degree:
- file_index_list = _2d_map.get_data(pp_index, tp_index)
- file_list = [self.file_list[idx] for idx in file_index_list]
- print(f'{pp_index}, {tp_index}, {dp_index} => {file_list}')
- def reshape(self, target_3d_desc: model_3d_desc):
- self.target_3d = target_3d_desc
- self._3d_file_map = self.src_3d.reshape(self.target_3d)
- def _strip_tensor_paddings(self, sd):
- param_group_states = self._get_param_group_states(sd)
- if param_group_states is None:
- return
- group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
- if group_paddings is None:
- return
- for key, group_state in param_group_states.items():
- if group_paddings[key] == 0:
- continue
- for state_name, state_value in group_state.items():
- if torch.is_tensor(state_value):
- raw_length = state_value.numel() - group_paddings[key]
- group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
- def _clear_group_paddings(self, sd):
- group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
- if group_paddings:
- num_groups = len(group_paddings)
- sd[OPTIMIZER_STATE_DICT][GROUP_PADDINGS] = [0] * num_groups
- def _get_optimizer_state(self, sd, state_key):
- optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None)
- if optimizer_state is None:
- return None
- return optimizer_state.get(state_key, None)
- def _get_param_group_states(self, sd):
- optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None)
- if optimizer_state is None:
- return None
- base_optimizer_state = optimizer_state.get(BASE_OPTIMIZER_STATE, None)
- if base_optimizer_state is None:
- return None
- return base_optimizer_state.get(GROUP_STATE_KEY, None)
- def _update_partition_count(self, sd):
- partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT)
- if partition_counts:
- num_groups = len(partition_counts)
- sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree] * num_groups
|