base_tuner.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import sys
  3. from deepspeed.autotuning.constants import *
  4. from deepspeed.autotuning.utils import write_experiments
  5. from deepspeed.utils import logger
  6. class BaseTuner:
  7. def __init__(self, exps, resource_manager, metric):
  8. self.all_exps = exps
  9. self.rm = resource_manager
  10. self.best_iter = 0
  11. self.best_exp = None
  12. self.best_metric_val = None
  13. self.metric = metric if metric else AUTOTUNING_METRIC_DEFAULT
  14. logger.info(f"total number of exps = {len(self.all_exps)}")
  15. def has_next(self):
  16. """Whether there exists more configurations for evaluation"""
  17. if len(self.all_exps) > 0:
  18. return True
  19. else:
  20. return False
  21. def next_batch(self, sample_size):
  22. """Select the next batch of configurations for evaluation"""
  23. raise NotImplementedError
  24. def update(self):
  25. """"Update the tuner with what configurations have been evaluated and their performance results"""
  26. def tune(self, sample_size=1, n_trials=1000, early_stopping=None):
  27. i = 0
  28. try:
  29. while i < n_trials and self.has_next():
  30. # Select the next batch of configuratiosn for evaluation
  31. sampled_exps = self.next_batch(sample_size)
  32. # Generate experiments for measurement of performance
  33. exp_paths = write_experiments(sampled_exps, self.rm.exps_dir)
  34. self.rm.schedule_experiments(exp_paths)
  35. self.rm.run()
  36. exp, metric_val = self.rm.parse_results(self.metric)
  37. if self.best_exp == None or self.best_metric_val == None or (
  38. metric_val and metric_val > self.best_metric_val):
  39. # logger.info(f"tuner finds better = {exp}")
  40. self.best_exp = exp
  41. self.best_metric_val = metric_val
  42. self.best_iter = i
  43. i += len(sampled_exps)
  44. # Update the tuner with evaluated performance results
  45. self.update()
  46. self.rm.clear()
  47. # Early stop if no more promising configurations are likely to be found
  48. if early_stopping and i >= self.best_iter + early_stopping:
  49. logger.info(
  50. f"Tuner early stopped at iteration {i}. Best iteration is {self.best_iter}. Early stopping threshold is {early_stopping}"
  51. )
  52. break
  53. return i
  54. except:
  55. logger.info("Tunner Error:", sys.exc_info()[0])
  56. return i