config.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. """
  3. Copyright (c) Microsoft Corporation
  4. Licensed under the MIT license.
  5. """
  6. from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
  7. from deepspeed.nebula.constants import *
  8. class DeepSpeedNebulaConfig(DeepSpeedConfigObject):
  9. def __init__(self, param_dict):
  10. super(DeepSpeedNebulaConfig, self).__init__()
  11. self.enabled = None
  12. self.persistent_storage_path = None
  13. self.persistent_time_interval = None
  14. self.num_of_version_in_retention = None
  15. self.enable_nebula_load = None
  16. if NEBULA in param_dict.keys():
  17. nebula_dict = param_dict[NEBULA]
  18. else:
  19. nebula_dict = {}
  20. self._initialize(nebula_dict)
  21. def _initialize(self, nebula_dict):
  22. self.enabled = get_scalar_param(nebula_dict,
  23. NEBULA_ENABLED,
  24. NEBULA_ENABLED_DEFAULT)
  25. self.load_path = get_scalar_param(nebula_dict,
  26. NEBULA_LOAD_PATH,
  27. NEBULA_LOAD_PATH_DEFAULT)
  28. self.enable_nebula_load = get_scalar_param(nebula_dict,
  29. NEBULA_ENABLE_NEBULA_LOAD,
  30. NEBULA_ENABLE_NEBULA_LOAD_DEFAULT)
  31. self.persistent_storage_path = get_scalar_param(
  32. nebula_dict,
  33. NEBULA_PERSISTENT_STORAGE_PATH,
  34. NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT)
  35. self.persistent_time_interval = get_scalar_param(
  36. nebula_dict,
  37. NEBULA_PERSISTENT_TIME_INTERVAL,
  38. NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT)
  39. self.num_of_version_in_retention = get_scalar_param(
  40. nebula_dict,
  41. NEBULA_NUM_OF_VERSION_IN_RETENTION,
  42. NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT)