base_tuner.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import atexit
  2. import sys
  3. from deepspeed.autotuning.constants import *
  4. from deepspeed.autotuning.utils import write_experiments
  5. from deepspeed.utils import logger
  6. import json
  7. class BaseTuner:
  8. def __init__(self, exps, resource_manager, metric):
  9. self.all_exps = exps
  10. self.rm = resource_manager
  11. self.best_iter = 0
  12. self.best_exp = None
  13. self.best_metric_val = None
  14. self.metric = metric if metric else AUTOTUNING_METRIC_DEFAULT
  15. logger.info(f"total number of exps = {len(self.all_exps)}")
  16. def has_next(self):
  17. """Whether there exists more configurations for evaluation"""
  18. if len(self.all_exps) > 0:
  19. return True
  20. else:
  21. return False
  22. def next_batch(self, sample_size):
  23. """Select the next batch of configurations for evaluation"""
  24. raise NotImplementedError
  25. def update(self):
  26. """"Update the tuner with what configurations have been evaluated and their performance results"""
  27. def tune(self, sample_size=1, n_trials=1000, early_stopping=None):
  28. i = 0
  29. try:
  30. while i < n_trials and self.has_next():
  31. # Select the next batch of configuratiosn for evaluation
  32. sampled_exps = self.next_batch(sample_size)
  33. # Generate experiments for measurement of performance
  34. exp_paths = write_experiments(sampled_exps, self.rm.exps_dir)
  35. self.rm.schedule_experiments(exp_paths)
  36. self.rm.run()
  37. exp, metric_val = self.rm.parse_results(self.metric)
  38. if self.best_exp == None or (metric_val
  39. and metric_val > self.best_metric_val):
  40. # logger.info(f"tuner finds better = {exp}")
  41. self.best_exp = exp
  42. self.best_metric_val = metric_val
  43. self.best_iter = i
  44. i += len(sampled_exps)
  45. # Update the tuner with evaluated performance results
  46. self.update()
  47. self.rm.clear()
  48. # Early stop if no more promising configurations are likely to be found
  49. if early_stopping and i >= self.best_iter + early_stopping:
  50. logger.info(
  51. f"Tuner early stopped at iteration {i}. Best iteration is {self.best_iter}. Early stopping threshold is {early_stopping}"
  52. )
  53. break
  54. return i
  55. except:
  56. logger.info("Tunner Error:", sys.exc_info()[0])
  57. return i