123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from .constants import *
- import copy
- from ..config_utils import get_scalar_param
- # TODO: Reducing config verbosity by returning None or {} when disabled.
- # One challenge is that we still need to somehow include the default values,
- # for example the *_ENABLED has default of false.
- def get_data_efficiency_config(param_dict):
- output = {}
- output[DATA_EFFICIENCY_ENABLED] = get_data_efficiency_enabled(param_dict)
- output[DATA_EFFICIENCY_SEED] = get_data_efficiency_seed(param_dict)
- if DATA_EFFICIENCY not in param_dict.keys():
- param_dict[DATA_EFFICIENCY] = {}
- sub_param_dict = param_dict[DATA_EFFICIENCY]
- output[DATA_SAMPLING] = get_data_sampling(sub_param_dict)
- output[DATA_ROUTING] = get_data_routing(sub_param_dict)
- return output
- def get_data_efficiency_enabled(param_dict):
- if DATA_EFFICIENCY in param_dict.keys():
- return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_ENABLED, DATA_EFFICIENCY_ENABLED_DEFAULT)
- else:
- return False
- def get_data_efficiency_seed(param_dict):
- if DATA_EFFICIENCY in param_dict.keys():
- return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_SEED, DATA_EFFICIENCY_SEED_DEFAULT)
- else:
- return DATA_EFFICIENCY_SEED_DEFAULT
- def get_data_sampling(param_dict):
- output = {}
- output[DATA_SAMPLING_ENABLED] = get_data_sampling_enabled(param_dict)
- output[DATA_SAMPLING_NUM_EPOCHS] = get_data_sampling_num_epochs(param_dict)
- output[DATA_SAMPLING_NUM_WORKERS] = get_data_sampling_num_workers(param_dict)
- if DATA_SAMPLING not in param_dict.keys():
- param_dict[DATA_SAMPLING] = {}
- sub_param_dict = param_dict[DATA_SAMPLING]
- output[CURRICULUM_LEARNING] = get_curriculum_learning(sub_param_dict)
- return output
- def get_data_sampling_enabled(param_dict):
- if DATA_SAMPLING in param_dict.keys():
- return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_ENABLED, DATA_SAMPLING_ENABLED_DEFAULT)
- else:
- return False
- def get_data_sampling_num_epochs(param_dict):
- if DATA_SAMPLING in param_dict.keys():
- return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_EPOCHS, DATA_SAMPLING_NUM_EPOCHS_DEFAULT)
- else:
- return DATA_SAMPLING_NUM_EPOCHS_DEFAULT
- def get_data_sampling_num_workers(param_dict):
- if DATA_SAMPLING in param_dict.keys():
- return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_WORKERS,
- DATA_SAMPLING_NUM_WORKERS_DEFAULT)
- else:
- return DATA_SAMPLING_NUM_WORKERS_DEFAULT
- def get_curriculum_learning(param_dict):
- output = {}
- output[CURRICULUM_LEARNING_ENABLED] = get_curriculum_learning_enabled(param_dict)
- if CURRICULUM_LEARNING not in param_dict.keys():
- param_dict[CURRICULUM_LEARNING] = {}
- sub_param_dict = param_dict[CURRICULUM_LEARNING]
- if output[CURRICULUM_LEARNING_ENABLED]:
- assert CURRICULUM_LEARNING_METRICS in sub_param_dict.keys(
- ), f"Curriculum learning is enabled, {CURRICULUM_LEARNING_METRICS} must be specified"
- for key, val in get_curriculum_learning_params(param_dict).items():
- output[key] = val
- return output
- def get_curriculum_learning_enabled(param_dict):
- if CURRICULUM_LEARNING in param_dict.keys():
- return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED,
- CURRICULUM_LEARNING_ENABLED_DEFAULT)
- else:
- return False
- def get_curriculum_learning_params(param_dict):
- if CURRICULUM_LEARNING in param_dict.keys():
- curriculum_learning_params = copy.copy(param_dict[CURRICULUM_LEARNING])
- curriculum_learning_params.pop(CURRICULUM_LEARNING_ENABLED)
- return curriculum_learning_params
- else:
- return {}
- def get_curriculum_enabled_legacy(param_dict):
- if CURRICULUM_LEARNING_LEGACY in param_dict.keys():
- return get_scalar_param(param_dict[CURRICULUM_LEARNING_LEGACY], CURRICULUM_ENABLED_LEGACY,
- CURRICULUM_ENABLED_DEFAULT_LEGACY)
- else:
- return False
- def get_curriculum_params_legacy(param_dict):
- if CURRICULUM_LEARNING_LEGACY in param_dict.keys():
- curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING_LEGACY])
- curriculum_params.pop(CURRICULUM_ENABLED_LEGACY)
- return curriculum_params
- else:
- return False
- def get_data_routing(param_dict):
- output = {}
- output[DATA_ROUTING_ENABLED] = get_data_routing_enabled(param_dict)
- if DATA_ROUTING not in param_dict.keys():
- param_dict[DATA_ROUTING] = {}
- sub_param_dict = param_dict[DATA_ROUTING]
- output[RANDOM_LTD] = get_random_ltd(sub_param_dict)
- return output
- def get_data_routing_enabled(param_dict):
- if DATA_ROUTING in param_dict.keys():
- return get_scalar_param(param_dict[DATA_ROUTING], DATA_ROUTING_ENABLED, DATA_ROUTING_ENABLED_DEFAULT)
- else:
- return False
- def get_random_ltd(param_dict):
- output = {}
- output[RANDOM_LTD_ENABLED] = RANDOM_LTD_ENABLED_DEFAULT
- output[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE] = {}
- output[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][
- RANDOM_LTD_LAYER_TOKEN_LR_ENABLED] = RANDOM_LTD_LAYER_TOKEN_LR_ENABLED_DEFAULT
- if get_random_ltd_enabled(param_dict):
- output[RANDOM_LTD_ENABLED] = get_random_ltd_enabled(param_dict)
- for key, val in get_random_ltd_params(param_dict).items():
- output[key] = val
- return output
- def get_random_ltd_enabled(param_dict):
- if RANDOM_LTD in param_dict.keys():
- return get_scalar_param(param_dict[RANDOM_LTD], RANDOM_LTD_ENABLED, RANDOM_LTD_ENABLED_DEFAULT)
- else:
- return False
- def get_random_ltd_params(param_dict):
- if RANDOM_LTD in param_dict.keys():
- random_ltd_params = copy.copy(param_dict[RANDOM_LTD])
- random_ltd_params.pop(RANDOM_LTD_ENABLED)
- return random_ltd_params
- else:
- return {}
|