config.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
  5. from deepspeed.nebula.constants import *
  6. class DeepSpeedNebulaConfig(DeepSpeedConfigObject):
  7. def __init__(self, param_dict):
  8. super(DeepSpeedNebulaConfig, self).__init__()
  9. self.enabled = None
  10. self.persistent_storage_path = None
  11. self.persistent_time_interval = None
  12. self.num_of_version_in_retention = None
  13. self.enable_nebula_load = None
  14. if NEBULA in param_dict.keys():
  15. nebula_dict = param_dict[NEBULA]
  16. else:
  17. nebula_dict = {}
  18. self._initialize(nebula_dict)
  19. def _initialize(self, nebula_dict):
  20. self.enabled = get_scalar_param(nebula_dict, NEBULA_ENABLED, NEBULA_ENABLED_DEFAULT)
  21. self.load_path = get_scalar_param(nebula_dict, NEBULA_LOAD_PATH, NEBULA_LOAD_PATH_DEFAULT)
  22. self.enable_nebula_load = get_scalar_param(nebula_dict, NEBULA_ENABLE_NEBULA_LOAD,
  23. NEBULA_ENABLE_NEBULA_LOAD_DEFAULT)
  24. self.persistent_storage_path = get_scalar_param(nebula_dict, NEBULA_PERSISTENT_STORAGE_PATH,
  25. NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT)
  26. self.persistent_time_interval = get_scalar_param(nebula_dict, NEBULA_PERSISTENT_TIME_INTERVAL,
  27. NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT)
  28. self.num_of_version_in_retention = get_scalar_param(nebula_dict, NEBULA_NUM_OF_VERSION_IN_RETENTION,
  29. NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT)