checkpoint_engine.py 653 B

123456789101112131415161718192021222324252627282930
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. class CheckpointEngine(object):
  6. # init checkpoint engine for save/load
  7. def __init__(self, config_params=None):
  8. pass
  9. def create(self, tag):
  10. # create checkpoint on give tag for save/load.
  11. pass
  12. def makedirs(self, path, exist_ok=False):
  13. os.makedirs(path, exist_ok=exist_ok)
  14. def save(self, state_dict, path: str):
  15. pass
  16. def load(self, path: str, map_location=None):
  17. pass
  18. def commit(self, tag):
  19. # to tell checkpoint services if all files are ready.
  20. pass