model_based_tuner.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import hjson
  5. from ..constants import AUTOTUNING, AUTOTUNING_METRIC_PATH
  6. from .base_tuner import BaseTuner
  7. from .cost_model import XGBoostCostModel
  8. from .utils import *
  9. from ..utils import *
  10. import numbers
  11. from ..constants import AUTOTUNING_METRIC_LATENCY
  12. INIT_NUM = 2
  13. class ModelBasedTuner(BaseTuner):
  14. """Exploring the search space with a cost model"""
  15. def __init__(self, exps: list, resource_manager, metric, tuning_space):
  16. super().__init__(exps, resource_manager, metric)
  17. self.tuning_space = tuning_space
  18. self.best_iter = 0
  19. self.all_configs = [e['ds_config'] for e in exps]
  20. self.num_all_configs = len(self.all_configs)
  21. self.dims = dict_to_dims(self.tuning_space)
  22. logger.info(f"Create config dim: {self.dims}, all configs: {self.num_all_configs}")
  23. self.visited = set([])
  24. self.trials = []
  25. self.trial_pt = 0
  26. init_num = min(INIT_NUM, self.num_all_configs)
  27. for _ in range(init_num):
  28. exp_feature = np.random.randint(self.num_all_configs)
  29. exp_feature = 0
  30. while exp_feature in self.visited:
  31. exp_feature = np.random.randint(self.num_all_configs)
  32. self.trials.append(exp_feature)
  33. self.visited.add(exp_feature)
  34. self.cost_model = XGBoostCostModel("rank")
  35. self.evaluated_configs = []
  36. self.evaluated_perf = []
  37. self.train_ct = 0
  38. self.random_exploration_ratio = 0.2 # do random exploration
  39. def find_estimated_top_configs(self):
  40. """Use the cost model to predict the estimated performance of configurations and find the top ones for the next round of evaluation"""
  41. configs = []
  42. for c in self.all_configs:
  43. flattened_ds_config = flatten(c)
  44. feature_val = []
  45. for k, v in flattened_ds_config.items():
  46. if isinstance(v, numbers.Number):
  47. feature_val.append(v)
  48. configs.append(feature_val)
  49. # print(configs)
  50. # TODO the current implementation requires that all configs have the same shape.
  51. configs = np.array(configs, dtype=np.float32)
  52. estimates = self.cost_model.predict(configs)
  53. n = len(estimates)
  54. top_idx = np.argsort(estimates)
  55. top_idx_ret = top_idx if self.metric == AUTOTUNING_METRIC_LATENCY else top_idx[::-1][:n]
  56. # top_configs = [self.all_configs[i] for i in top_idx]
  57. return top_idx_ret
  58. def next_batch(self, sample_size):
  59. sampled_batch = []
  60. counter = 0
  61. while counter < sample_size:
  62. if len(self.visited) >= self.num_all_configs:
  63. break
  64. while self.trial_pt < len(self.trials):
  65. logger.debug(f"trials: {self.trials}")
  66. # Select top promising trials
  67. index = self.trials[self.trial_pt]
  68. if index not in self.visited:
  69. break
  70. self.trial_pt += 1
  71. # To avoid over-exploitation, randomly select one that has not been explored.
  72. rand = np.random.rand()
  73. if rand < self.random_exploration_ratio:
  74. # Do normal selection
  75. feature = np.random.choice(self.trials)
  76. while index in self.visited:
  77. index = np.random.randint(self.num_all_configs)
  78. # Need to track both the sampled configs and indices
  79. sampled_batch.append(self.all_exps[index])
  80. self.visited.add(index)
  81. counter += 1
  82. return sampled_batch
  83. def has_next(self):
  84. return len(self.visited) < self.num_all_configs
  85. def update(self):
  86. for exp_id, (exp, err) in self.rm.finished_experiments.items():
  87. feature_val = []
  88. if err:
  89. logger.info(
  90. f"Skipping exp_id = {exp_id}, exp_name = {exp['name']}, the experiment did not run successfully with error = {err}, thus a metrics.txt does not exist for it. Please check the stderr.log in {exp['result_dir']}"
  91. )
  92. ds_config = exp["ds_config"]
  93. flattened_ds_config = flatten(ds_config)
  94. for k, v in flattened_ds_config.items():
  95. if isinstance(v, numbers.Number):
  96. feature_val.append(v)
  97. self.evaluated_configs.append(feature_val)
  98. self.evaluated_perf.append(0.0)
  99. continue
  100. p = exp["ds_config"][AUTOTUNING][AUTOTUNING_METRIC_PATH]
  101. with open(p, 'r') as f:
  102. results = hjson.load(f)
  103. curr_iter = results[self.metric]
  104. logger.debug(f"parsing the results for {exp_id}, Result is {curr_iter}")
  105. ds_config = exp["ds_config"]
  106. flattened_ds_config = flatten(ds_config)
  107. for k, v in flattened_ds_config.items():
  108. if isinstance(v, numbers.Number):
  109. feature_val.append(v)
  110. self.evaluated_configs.append(feature_val)
  111. self.evaluated_perf.append(curr_iter)
  112. logger.debug(f"**Evaluated configs: {len(self.evaluated_configs)}, evaluated perf: {self.evaluated_perf}")
  113. self.cost_model.fit(self.evaluated_configs, self.evaluated_perf)
  114. estimated_top_configs = self.find_estimated_top_configs()
  115. self.trials = estimated_top_configs
  116. self.trial_pt = 0
  117. self.train_ct += 1