torch_checkpoint_engine.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.utils import logger, log_dist
  6. from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
  7. CheckpointEngine
  8. class TorchCheckpointEngine(CheckpointEngine):
  9. def __init__(self, config_params=None):
  10. super().__init__(config_params)
  11. def create(self, tag):
  12. log_dist(f"[Torch] Checkpoint {tag} is about to be saved!", ranks=[0])
  13. def save(self, state_dict, path: str):
  14. logger.info(f"[Torch] Saving {path}...")
  15. torch.save(state_dict, path)
  16. logger.info(f"[Torch] Saved {path}.")
  17. return None
  18. def load(self, path: str, map_location=None):
  19. logger.info(f"[Torch] Loading checkpoint from {path}...")
  20. partition = torch.load(path, map_location=map_location)
  21. logger.info(f"[Torch] Loaded checkpoint from {path}.")
  22. return partition
  23. def commit(self, tag):
  24. logger.info(f"[Torch] Checkpoint {tag} is ready now!")
  25. return True