base_tuner.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import sys
  5. from deepspeed.autotuning.constants import *
  6. from deepspeed.autotuning.utils import write_experiments
  7. from deepspeed.utils import logger
  8. class BaseTuner:
  9. def __init__(self, exps, resource_manager, metric):
  10. self.all_exps = exps
  11. self.rm = resource_manager
  12. self.best_iter = 0
  13. self.best_exp = None
  14. self.best_metric_val = None
  15. self.metric = metric if metric else AUTOTUNING_METRIC_DEFAULT
  16. logger.info(f"total number of exps = {len(self.all_exps)}")
  17. def has_next(self):
  18. """Whether there exists more configurations for evaluation"""
  19. if len(self.all_exps) > 0:
  20. return True
  21. else:
  22. return False
  23. def next_batch(self, sample_size):
  24. """Select the next batch of configurations for evaluation"""
  25. raise NotImplementedError
  26. def update(self):
  27. """"Update the tuner with what configurations have been evaluated and their performance results"""
  28. def tune(self, sample_size=1, n_trials=1000, early_stopping=None):
  29. i = 0
  30. try:
  31. while i < n_trials and self.has_next():
  32. # Select the next batch of configuration for evaluation
  33. sampled_exps = self.next_batch(sample_size)
  34. # Generate experiments for measurement of performance
  35. exp_paths = write_experiments(sampled_exps, self.rm.exps_dir)
  36. self.rm.schedule_experiments(exp_paths)
  37. self.rm.run()
  38. exp, metric_val = self.rm.parse_results(self.metric)
  39. if self.best_exp is None or self.best_metric_val is None or (metric_val
  40. and metric_val > self.best_metric_val):
  41. # logger.info(f"tuner finds better = {exp}")
  42. self.best_exp = exp
  43. self.best_metric_val = metric_val
  44. self.best_iter = i
  45. i += len(sampled_exps)
  46. # Update the tuner with evaluated performance results
  47. self.update()
  48. self.rm.clear()
  49. # Early stop if no more promising configurations are likely to be found
  50. if early_stopping and i >= self.best_iter + early_stopping:
  51. logger.info(
  52. f"Tuner early stopped at iteration {i}. Best iteration is {self.best_iter}. Early stopping threshold is {early_stopping}"
  53. )
  54. break
  55. return i
  56. except:
  57. logger.info("Tuner Error:", sys.exc_info()[0])
  58. return i