config.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .constants import *
  5. import copy
  6. from ..config_utils import get_scalar_param
  7. # TODO: Reducing config verbosity by returning None or {} when disabled.
  8. # One challenge is that we still need to somehow include the default values,
  9. # for example the *_ENABLED has default of false.
  10. def get_data_efficiency_config(param_dict):
  11. output = {}
  12. output[DATA_EFFICIENCY_ENABLED] = get_data_efficiency_enabled(param_dict)
  13. output[DATA_EFFICIENCY_SEED] = get_data_efficiency_seed(param_dict)
  14. if DATA_EFFICIENCY not in param_dict.keys():
  15. param_dict[DATA_EFFICIENCY] = {}
  16. sub_param_dict = param_dict[DATA_EFFICIENCY]
  17. output[DATA_SAMPLING] = get_data_sampling(sub_param_dict)
  18. output[DATA_ROUTING] = get_data_routing(sub_param_dict)
  19. return output
  20. def get_data_efficiency_enabled(param_dict):
  21. if DATA_EFFICIENCY in param_dict.keys():
  22. return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_ENABLED, DATA_EFFICIENCY_ENABLED_DEFAULT)
  23. else:
  24. return False
  25. def get_data_efficiency_seed(param_dict):
  26. if DATA_EFFICIENCY in param_dict.keys():
  27. return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_SEED, DATA_EFFICIENCY_SEED_DEFAULT)
  28. else:
  29. return DATA_EFFICIENCY_SEED_DEFAULT
  30. def get_data_sampling(param_dict):
  31. output = {}
  32. output[DATA_SAMPLING_ENABLED] = get_data_sampling_enabled(param_dict)
  33. output[DATA_SAMPLING_NUM_EPOCHS] = get_data_sampling_num_epochs(param_dict)
  34. output[DATA_SAMPLING_NUM_WORKERS] = get_data_sampling_num_workers(param_dict)
  35. if DATA_SAMPLING not in param_dict.keys():
  36. param_dict[DATA_SAMPLING] = {}
  37. sub_param_dict = param_dict[DATA_SAMPLING]
  38. output[CURRICULUM_LEARNING] = get_curriculum_learning(sub_param_dict)
  39. return output
  40. def get_data_sampling_enabled(param_dict):
  41. if DATA_SAMPLING in param_dict.keys():
  42. return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_ENABLED, DATA_SAMPLING_ENABLED_DEFAULT)
  43. else:
  44. return False
  45. def get_data_sampling_num_epochs(param_dict):
  46. if DATA_SAMPLING in param_dict.keys():
  47. return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_EPOCHS, DATA_SAMPLING_NUM_EPOCHS_DEFAULT)
  48. else:
  49. return DATA_SAMPLING_NUM_EPOCHS_DEFAULT
  50. def get_data_sampling_num_workers(param_dict):
  51. if DATA_SAMPLING in param_dict.keys():
  52. return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_WORKERS,
  53. DATA_SAMPLING_NUM_WORKERS_DEFAULT)
  54. else:
  55. return DATA_SAMPLING_NUM_WORKERS_DEFAULT
  56. def get_curriculum_learning(param_dict):
  57. output = {}
  58. output[CURRICULUM_LEARNING_ENABLED] = get_curriculum_learning_enabled(param_dict)
  59. if CURRICULUM_LEARNING not in param_dict.keys():
  60. param_dict[CURRICULUM_LEARNING] = {}
  61. sub_param_dict = param_dict[CURRICULUM_LEARNING]
  62. if output[CURRICULUM_LEARNING_ENABLED]:
  63. assert CURRICULUM_LEARNING_METRICS in sub_param_dict.keys(
  64. ), f"Curriculum learning is enabled, {CURRICULUM_LEARNING_METRICS} must be specified"
  65. for key, val in get_curriculum_learning_params(param_dict).items():
  66. output[key] = val
  67. return output
  68. def get_curriculum_learning_enabled(param_dict):
  69. if CURRICULUM_LEARNING in param_dict.keys():
  70. return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED,
  71. CURRICULUM_LEARNING_ENABLED_DEFAULT)
  72. else:
  73. return False
  74. def get_curriculum_learning_params(param_dict):
  75. if CURRICULUM_LEARNING in param_dict.keys():
  76. curriculum_learning_params = copy.copy(param_dict[CURRICULUM_LEARNING])
  77. curriculum_learning_params.pop(CURRICULUM_LEARNING_ENABLED)
  78. return curriculum_learning_params
  79. else:
  80. return {}
  81. def get_curriculum_enabled_legacy(param_dict):
  82. if CURRICULUM_LEARNING_LEGACY in param_dict.keys():
  83. return get_scalar_param(param_dict[CURRICULUM_LEARNING_LEGACY], CURRICULUM_ENABLED_LEGACY,
  84. CURRICULUM_ENABLED_DEFAULT_LEGACY)
  85. else:
  86. return False
  87. def get_curriculum_params_legacy(param_dict):
  88. if CURRICULUM_LEARNING_LEGACY in param_dict.keys():
  89. curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING_LEGACY])
  90. curriculum_params.pop(CURRICULUM_ENABLED_LEGACY)
  91. return curriculum_params
  92. else:
  93. return False
  94. def get_data_routing(param_dict):
  95. output = {}
  96. output[DATA_ROUTING_ENABLED] = get_data_routing_enabled(param_dict)
  97. if DATA_ROUTING not in param_dict.keys():
  98. param_dict[DATA_ROUTING] = {}
  99. sub_param_dict = param_dict[DATA_ROUTING]
  100. output[RANDOM_LTD] = get_random_ltd(sub_param_dict)
  101. return output
  102. def get_data_routing_enabled(param_dict):
  103. if DATA_ROUTING in param_dict.keys():
  104. return get_scalar_param(param_dict[DATA_ROUTING], DATA_ROUTING_ENABLED, DATA_ROUTING_ENABLED_DEFAULT)
  105. else:
  106. return False
  107. def get_random_ltd(param_dict):
  108. output = {}
  109. output[RANDOM_LTD_ENABLED] = RANDOM_LTD_ENABLED_DEFAULT
  110. output[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE] = {}
  111. output[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][
  112. RANDOM_LTD_LAYER_TOKEN_LR_ENABLED] = RANDOM_LTD_LAYER_TOKEN_LR_ENABLED_DEFAULT
  113. if get_random_ltd_enabled(param_dict):
  114. output[RANDOM_LTD_ENABLED] = get_random_ltd_enabled(param_dict)
  115. for key, val in get_random_ltd_params(param_dict).items():
  116. output[key] = val
  117. return output
  118. def get_random_ltd_enabled(param_dict):
  119. if RANDOM_LTD in param_dict.keys():
  120. return get_scalar_param(param_dict[RANDOM_LTD], RANDOM_LTD_ENABLED, RANDOM_LTD_ENABLED_DEFAULT)
  121. else:
  122. return False
  123. def get_random_ltd_params(param_dict):
  124. if RANDOM_LTD in param_dict.keys():
  125. random_ltd_params = copy.copy(param_dict[RANDOM_LTD])
  126. random_ltd_params.pop(RANDOM_LTD_ENABLED)
  127. return random_ltd_params
  128. else:
  129. return {}