tune_4x16.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """Moderate Ray Tune run (4 trials, 32 actors).
  2. This training run will start 4 Ray Tune trials, each starting 32 actors.
  3. The cluster comprises 32 nodes.
  4. Test owner: Yard1 (primary), krfricke
  5. Acceptance criteria: Should run through and report final results, as well
  6. as the Ray Tune results table. No trials should error. All trials should
  7. run in parallel.
  8. """
  9. from collections import Counter
  10. import json
  11. import os
  12. import time
  13. import ray
  14. from ray import tune
  15. from lightgbm_ray import RayParams
  16. from release_test_util import train_ray
  17. def train_wrapper(config, ray_params):
  18. train_ray(
  19. path="/data/classification.parquet",
  20. num_workers=None,
  21. num_boost_rounds=100,
  22. num_files=128,
  23. regression=False,
  24. use_gpu=False,
  25. ray_params=ray_params,
  26. lightgbm_params=config,
  27. )
  28. if __name__ == "__main__":
  29. search_space = {
  30. "eta": tune.loguniform(1e-4, 1e-1),
  31. "subsample": tune.uniform(0.5, 1.0),
  32. "max_depth": tune.randint(1, 9),
  33. }
  34. ray.init(address="auto", runtime_env={"working_dir": os.path.dirname(__file__)})
  35. ray_params = RayParams(
  36. elastic_training=False,
  37. max_actor_restarts=2,
  38. num_actors=16,
  39. cpus_per_actor=2,
  40. gpus_per_actor=0,
  41. )
  42. start = time.time()
  43. analysis = tune.run(
  44. tune.with_parameters(train_wrapper, ray_params=ray_params),
  45. config=search_space,
  46. num_samples=4,
  47. resources_per_trial=ray_params.get_tune_resources(),
  48. )
  49. taken = time.time() - start
  50. result = {
  51. "time_taken": taken,
  52. "trial_states": dict(Counter([trial.status for trial in analysis.trials])),
  53. }
  54. test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/tune_4x16.json")
  55. with open(test_output_json, "wt") as f:
  56. json.dump(result, f)
  57. print("PASSED.")