12345678910111213141516171819202122232425262728293031323334 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from deepspeed.utils import logger, log_dist
- from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
- CheckpointEngine
- class TorchCheckpointEngine(CheckpointEngine):
- def __init__(self, config_params=None):
- super().__init__(config_params)
- def create(self, tag):
- log_dist(f"[Torch] Checkpoint {tag} is about to be saved!", ranks=[0])
- def save(self, state_dict, path: str):
- logger.info(f"[Torch] Saving {path}...")
- torch.save(state_dict, path)
- logger.info(f"[Torch] Saved {path}.")
- return None
- def load(self, path: str, map_location=None):
- logger.info(f"[Torch] Loading checkpoint from {path}...")
- partition = torch.load(path, map_location=map_location)
- logger.info(f"[Torch] Loaded checkpoint from {path}.")
- return partition
- def commit(self, tag):
- logger.info(f"[Torch] Checkpoint {tag} is ready now!")
- return True
|