123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import os
- import torch
- import torch_nebula
- from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
- CheckpointEngine
- from deepspeed.utils import logger, log_dist
- from deepspeed.nebula.constants import *
- def _get_tag_from_path(path):
- return os.path.basename(os.path.dirname(path))
- class NebulaCheckpointEngine(CheckpointEngine):
- def __init__(self, config_params=None):
- super().__init__(config_params)
- self.checkpoint = None
- self.tag_flag = None
- self.enable_nebula_load = config_params.enable_nebula_load
- self.nebula_load_path = config_params.load_path
- if self.nebula_load_path is None:
- self.nebula_load_path = config_params.persistent_storage_path
- nebula_config_params = {
- NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path,
- NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval,
- NEBULA_NUM_OF_VERSION_IN_RETENTION: config_params.num_of_version_in_retention,
- }
- torch_nebula.init(**nebula_config_params)
- def create(self, tag):
- log_dist(f"[Nebula] Start Checkpoint for tag:{tag}", ranks=[0])
- # -2 means: customer needs to explicitly tell nebula
- # current checkpoint is complete by commit method.
- self.checkpoint = torch_nebula.Checkpoint(tag, -2)
- def save(self, state_dict, path: str):
- log_dist(f"[Nebula] Create dummy files for loading.")
- torch.save("", path)
- tag = _get_tag_from_path(path)
- partition_name = os.path.basename(path)
- logger.info(f"[Nebula] Saving {partition_name} under tag {tag}...")
- self.checkpoint.save(partition_name, state_dict)
- logger.info(f"[Nebula] Saved {partition_name} under tag {tag}.")
- return None
- def load(self, path: str, map_location=None):
- tag = _get_tag_from_path(path)
- first_load_flag = self.tag_flag is None or self.tag_flag == tag
- if not self.enable_nebula_load and first_load_flag:
- self.tag_flag = tag
- logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
- partition = torch.load(path, map_location=map_location)
- logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
- return partition
- partition_name = os.path.basename(path)
- logger.info(f"[Nebula] Loading {path} under tag {tag} from nebula path {self.nebula_load_path}...")
- checkpoint = None
- if tag in (None, 'latest', 'latest_universal'):
- # In some cases, there is the inconsistent tag between deepspeed metadata (latest file)
- # and nebula metadata, will lead to the failure on loading with deepspeed tag. Then we
- # will try to load the valid latest checkpoint from nebula(tier3 > tier1). So, in summary
- # when met failure loading for given tag, the loading priority would be like:
- # nebula tier3 latest > nebula tier1 latest.
- checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
- else:
- checkpoint = torch_nebula.get_checkpoint(tag=tag, persist_path=self.nebula_load_path)
- if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
- logger.info(
- f"Unable to find valid checkpoint tag:{tag} from Nebula, try to get latest checkpoint again from nebula {self.nebula_load_path} path!"
- )
- # nebula tier3 latest
- checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
- if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
- logger.info(
- f"Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!"
- )
- # nebula tier1 latest
- checkpoint = torch_nebula.get_latest_checkpoint()
- logger.warning(f"Unable to find valid checkpoint from Nebula under tag:{tag}.")
- return None
- tag = checkpoint.tag
- self.tag_flag = -1
- partition = checkpoint.load(partition_name, map_location=map_location)
- logger.info(f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.")
- return partition
- def commit(self, tag):
- # nebula commit will be call when all files under give tag are ready to be persisted in the async way.
- logger.info(f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting")
- commit_rls = self.checkpoint.commit()
- if not commit_rls:
- logger.error(f"[Nebula] failed to commit the checkpoint, please check the log.")
- return False
- return commit_rls
|