base_optimizer.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. from deepspeed.utils import logger
  7. from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
  8. from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank
  9. class DeepSpeedOptimizer(object):
  10. pass
  11. class ZeROOptimizer(DeepSpeedOptimizer):
  12. def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
  13. checkpoint_dir = os.path.join(checkpoint_dir, "zero")
  14. optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
  15. assert os.path.isfile(
  16. optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
  17. optim_sd = torch.load(optim_state_path)
  18. self._load_global_state(optim_sd)
  19. tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
  20. if self.mpu is None:
  21. logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
  22. tp_world_size = 1
  23. else:
  24. tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
  25. else self.mpu.get_tensor_model_parallel_world_size()
  26. for i, (param_group,
  27. loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
  28. # We have an assumption that all params in the same param_group have the same keys
  29. opt_keys = set()
  30. steps = []
  31. lp_groups = getattr(self, lp_groups_name)
  32. for lp in lp_groups[i]:
  33. if lp._hp_mapping is not None:
  34. #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
  35. step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
  36. tp_world_size)
  37. for key in lp._hp_mapping.get_optim_state_keys():
  38. opt_keys.add(key)
  39. steps.append(step)
  40. hp_param = param_group['params'][0]
  41. assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal"
  42. if steps[0] is not None:
  43. self.optimizer.state[hp_param]['step'] = steps[0]
  44. map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys)
  45. for key, value in loaded_param_group.items():
  46. if key == 'params':
  47. continue
  48. param_group[key] = value