# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import os import torch from deepspeed.utils import logger from deepspeed.utils.tensor_fragment import map_to_flat_opt_states from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank class DeepSpeedOptimizer(object): pass class ZeROOptimizer(DeepSpeedOptimizer): def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: checkpoint_dir = os.path.join(checkpoint_dir, "zero") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") assert os.path.isfile( optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' optim_sd = torch.load(optim_state_path) self._load_global_state(optim_sd) tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) if self.mpu is None: logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.") tp_world_size = 1 else: tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ else self.mpu.get_tensor_model_parallel_world_size() for i, (param_group, loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])): # We have an assumption that all params in the same param_group have the same keys opt_keys = set() steps = [] lp_groups = getattr(self, lp_groups_name) for lp in lp_groups[i]: if lp._hp_mapping is not None: #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) for key in lp._hp_mapping.get_optim_state_keys(): opt_keys.add(key) steps.append(step) hp_param = param_group['params'][0] assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal" if steps[0] is not None: self.optimizer.state[hp_param]['step'] = steps[0] map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys) for key, value in loaded_param_group.items(): if key == 'params': continue param_group[key] = value