config.py 4.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.runtime.config_utils import get_scalar_param, get_dict_param, DeepSpeedConfigObject
  5. from deepspeed.autotuning.constants import *
  6. class DeepSpeedAutotuningConfig(DeepSpeedConfigObject):
  7. def __init__(self, param_dict):
  8. super(DeepSpeedAutotuningConfig, self).__init__()
  9. self.enabled = None
  10. self.start_step = None
  11. self.end_step = None
  12. self.metric_path = None
  13. self.arg_mappings = None
  14. self.metric = None
  15. self.model_info = None
  16. self.results_dir = None
  17. self.exps_dir = None
  18. self.overwrite = None
  19. if param_dict and AUTOTUNING in param_dict.keys():
  20. autotuning_dict = param_dict[AUTOTUNING]
  21. else:
  22. autotuning_dict = {}
  23. self._initialize(autotuning_dict)
  24. def _initialize(self, autotuning_dict):
  25. self.enabled = get_scalar_param(autotuning_dict, AUTOTUNING_ENABLED, AUTOTUNING_ENABLED_DEFAULT)
  26. self.fast = get_scalar_param(autotuning_dict, AUTOTUNING_FAST, AUTOTUNING_FAST_DEFAULT)
  27. self.results_dir = get_scalar_param(autotuning_dict, AUTOTUNING_RESULTS_DIR, AUTOTUNING_RESULTS_DIR_DEFAULT)
  28. assert self.results_dir, "results_dir cannot be empty"
  29. self.exps_dir = get_scalar_param(autotuning_dict, AUTOTUNING_EXPS_DIR, AUTOTUNING_EXPS_DIR_DEFAULT)
  30. assert self.exps_dir, "exps_dir cannot be empty"
  31. self.overwrite = get_scalar_param(autotuning_dict, AUTOTUNING_OVERWRITE, AUTOTUNING_OVERWRITE_DEFAULT)
  32. self.start_profile_step = get_scalar_param(autotuning_dict, AUTOTUNING_START_PROFILE_STEP,
  33. AUTOTUNING_START_PROFILE_STEP_DEFAULT)
  34. self.end_profile_step = get_scalar_param(autotuning_dict, AUTOTUNING_END_PROFILE_STEP,
  35. AUTOTUNING_END_PROFILE_STEP_DEFAULT)
  36. self.metric = get_scalar_param(autotuning_dict, AUTOTUNING_METRIC, AUTOTUNING_METRIC_DEFAULT)
  37. self.metric_path = get_scalar_param(autotuning_dict, AUTOTUNING_METRIC_PATH, AUTOTUNING_METRIC_PATH_DEFAULT)
  38. self.tuner_type = get_scalar_param(autotuning_dict, AUTOTUNING_TUNER_TYPE, AUTOTUNING_TUNER_TYPE_DEFAULT)
  39. self.tuner_early_stopping = get_scalar_param(autotuning_dict, AUTOTUNING_TUNER_EARLY_STOPPING,
  40. AUTOTUNING_TUNER_EARLY_STOPPING_DEFAULT)
  41. self.tuner_num_trials = get_scalar_param(autotuning_dict, AUTOTUNING_TUNER_NUM_TRIALS,
  42. AUTOTUNING_TUNER_NUM_TRIALS_DEFAULT)
  43. self.arg_mappings = get_dict_param(autotuning_dict, AUTOTUNING_ARG_MAPPINGS, AUTOTUNING_ARG_MAPPINGS_DEFAULT)
  44. self.model_info = get_model_info_config(autotuning_dict)
  45. self.model_info_path = get_scalar_param(autotuning_dict, AUTOTUNING_MODEL_INFO_PATH,
  46. AUTOTUNING_MODEL_INFO_PATH_DEFAULT)
  47. self.mp_size = get_scalar_param(autotuning_dict, AUTOTUNING_MP_SIZE, AUTOTUNING_MP_SIZE_DEFAULT)
  48. self.max_train_batch_size = get_dict_param(autotuning_dict, AUTOTUNING_MAX_TRAIN_BATCH_SIZE,
  49. AUTOTUNING_MAX_TRAIN_BATCH_SIZE_DEFAULT)
  50. self.min_train_batch_size = get_dict_param(autotuning_dict, AUTOTUNING_MIN_TRAIN_BATCH_SIZE,
  51. AUTOTUNING_MIN_TRAIN_BATCH_SIZE_DEFAULT)
  52. self.max_train_micro_batch_size_per_gpu = get_dict_param(
  53. autotuning_dict, AUTOTUNING_MAX_TRAIN_MICRO_BATCH_SIZE_PER_GPU,
  54. AUTOTUNING_MAX_TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT)
  55. self.min_train_micro_batch_size_per_gpu = get_dict_param(
  56. autotuning_dict, AUTOTUNING_MIN_TRAIN_MICRO_BATCH_SIZE_PER_GPU,
  57. AUTOTUNING_MIN_TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT)
  58. self.num_tuning_micro_batch_sizes = get_dict_param(autotuning_dict, AUTOTUNING_NUM_TUNING_MICRO_BATCH_SIZES,
  59. AUTOTUNING_NUM_TUNING_MICRO_BATCH_SIZES_DEFAULT)
  60. def get_model_info_config(param_dict):
  61. if MODEL_INFO in param_dict and param_dict[MODEL_INFO] is not None:
  62. model_info_config = {}
  63. for key, default_value in MODEL_INFO_KEY_DEFAULT_DICT.items():
  64. model_info_config[key] = get_scalar_param(param_dict[MODEL_INFO], key, default_value)
  65. return model_info_config
  66. return None
  67. def get_default_model_info_config():
  68. return MODEL_INFO_KEY_DEFAULT_DICT