12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
- from deepspeed.nebula.constants import *
- class DeepSpeedNebulaConfig(DeepSpeedConfigObject):
- def __init__(self, param_dict):
- super(DeepSpeedNebulaConfig, self).__init__()
- self.enabled = None
- self.persistent_storage_path = None
- self.persistent_time_interval = None
- self.num_of_version_in_retention = None
- self.enable_nebula_load = None
- if NEBULA in param_dict.keys():
- nebula_dict = param_dict[NEBULA]
- else:
- nebula_dict = {}
- self._initialize(nebula_dict)
- def _initialize(self, nebula_dict):
- self.enabled = get_scalar_param(nebula_dict, NEBULA_ENABLED, NEBULA_ENABLED_DEFAULT)
- self.load_path = get_scalar_param(nebula_dict, NEBULA_LOAD_PATH, NEBULA_LOAD_PATH_DEFAULT)
- self.enable_nebula_load = get_scalar_param(nebula_dict, NEBULA_ENABLE_NEBULA_LOAD,
- NEBULA_ENABLE_NEBULA_LOAD_DEFAULT)
- self.persistent_storage_path = get_scalar_param(nebula_dict, NEBULA_PERSISTENT_STORAGE_PATH,
- NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT)
- self.persistent_time_interval = get_scalar_param(nebula_dict, NEBULA_PERSISTENT_TIME_INTERVAL,
- NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT)
- self.num_of_version_in_retention = get_scalar_param(nebula_dict, NEBULA_NUM_OF_VERSION_IN_RETENTION,
- NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT)
|