nebula_checkpoint_engine.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import torch_nebula
  7. from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
  8. CheckpointEngine
  9. from deepspeed.utils import logger, log_dist
  10. from deepspeed.nebula.constants import *
  11. def _get_tag_from_path(path):
  12. return os.path.basename(os.path.dirname(path))
  13. class NebulaCheckpointEngine(CheckpointEngine):
  14. def __init__(self, config_params=None):
  15. super().__init__(config_params)
  16. self.checkpoint = None
  17. self.tag_flag = None
  18. self.enable_nebula_load = config_params.enable_nebula_load
  19. self.nebula_load_path = config_params.load_path
  20. if self.nebula_load_path is None:
  21. self.nebula_load_path = config_params.persistent_storage_path
  22. nebula_config_params = {
  23. NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path,
  24. NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval,
  25. NEBULA_NUM_OF_VERSION_IN_RETENTION: config_params.num_of_version_in_retention,
  26. }
  27. torch_nebula.init(**nebula_config_params)
  28. def create(self, tag):
  29. log_dist(f"[Nebula] Start Checkpoint for tag:{tag}", ranks=[0])
  30. # -2 means: customer needs to explicitly tell nebula
  31. # current checkpoint is complete by commit method.
  32. self.checkpoint = torch_nebula.Checkpoint(tag, -2)
  33. def save(self, state_dict, path: str):
  34. log_dist(f"[Nebula] Create dummy files for loading.")
  35. torch.save("", path)
  36. tag = _get_tag_from_path(path)
  37. partition_name = os.path.basename(path)
  38. logger.info(f"[Nebula] Saving {partition_name} under tag {tag}...")
  39. self.checkpoint.save(partition_name, state_dict)
  40. logger.info(f"[Nebula] Saved {partition_name} under tag {tag}.")
  41. return None
  42. def load(self, path: str, map_location=None):
  43. tag = _get_tag_from_path(path)
  44. first_load_flag = self.tag_flag is None or self.tag_flag == tag
  45. if not self.enable_nebula_load and first_load_flag:
  46. self.tag_flag = tag
  47. logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
  48. partition = torch.load(path, map_location=map_location)
  49. logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
  50. return partition
  51. partition_name = os.path.basename(path)
  52. logger.info(f"[Nebula] Loading {path} under tag {tag} from nebula path {self.nebula_load_path}...")
  53. checkpoint = None
  54. if tag in (None, 'latest', 'latest_universal'):
  55. # In some cases, there is the inconsistent tag between deepspeed metadata (latest file)
  56. # and nebula metadata, will lead to the failure on loading with deepspeed tag. Then we
  57. # will try to load the valid latest checkpoint from nebula(tier3 > tier1). So, in summary
  58. # when met failure loading for given tag, the loading priority would be like:
  59. # nebula tier3 latest > nebula tier1 latest.
  60. checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
  61. else:
  62. checkpoint = torch_nebula.get_checkpoint(tag=tag, persist_path=self.nebula_load_path)
  63. if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
  64. logger.info(
  65. f"Unable to find valid checkpoint tag:{tag} from Nebula, try to get latest checkpoint again from nebula {self.nebula_load_path} path!"
  66. )
  67. # nebula tier3 latest
  68. checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
  69. if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
  70. logger.info(
  71. f"Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!"
  72. )
  73. # nebula tier1 latest
  74. checkpoint = torch_nebula.get_latest_checkpoint()
  75. logger.warning(f"Unable to find valid checkpoint from Nebula under tag:{tag}.")
  76. return None
  77. tag = checkpoint.tag
  78. self.tag_flag = -1
  79. partition = checkpoint.load(partition_name, map_location=map_location)
  80. logger.info(f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.")
  81. return partition
  82. def commit(self, tag):
  83. # nebula commit will be call when all files under give tag are ready to be persisted in the async way.
  84. logger.info(f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting")
  85. commit_rls = self.checkpoint.commit()
  86. if not commit_rls:
  87. logger.error(f"[Nebula] failed to commit the checkpoint, please check the log.")
  88. return False
  89. return commit_rls