index_based_tuner.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import random
  5. from .base_tuner import BaseTuner
  6. class RandomTuner(BaseTuner):
  7. """Explore the search space in random order"""
  8. def __init__(self, exps: list, resource_manager, metric):
  9. super().__init__(exps, resource_manager, metric)
  10. def next_batch(self, sample_size=1):
  11. if sample_size > len(self.all_exps):
  12. sample_size = len(self.all_exps)
  13. sampled_batch = random.sample(self.all_exps, sample_size)
  14. self.all_exps = [x for x in self.all_exps if x not in sampled_batch]
  15. return sampled_batch
  16. class GridSearchTuner(BaseTuner):
  17. """Explore the search space in sequential order"""
  18. def __init__(self, exps: list, resource_manager, metric):
  19. super().__init__(exps, resource_manager, metric)
  20. def next_batch(self, sample_size=1):
  21. if sample_size > len(self.all_exps):
  22. sample_size = len(self.all_exps)
  23. sampled_batch = self.all_exps[0:sample_size]
  24. self.all_exps = [x for x in self.all_exps if x not in sampled_batch]
  25. return sampled_batch