train_small_connect.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. """Small cluster training
  2. This training run will start 4 workers on 4 nodes (including head node).
  3. Test owner: Yard1 (primary), krfricke
  4. Acceptance criteria: Should run through and report final results.
  5. """
  6. import json
  7. import os
  8. import time
  9. import ray
  10. from lightgbm_ray import RayParams
  11. from ray.util.lightgbm.release_test_util import train_ray
  12. if __name__ == "__main__":
  13. addr = os.environ.get("RAY_ADDRESS")
  14. job_name = os.environ.get("RAY_JOB_NAME", "train_small")
  15. if addr.startswith("anyscale://"):
  16. ray.init(address=addr, job_name=job_name)
  17. else:
  18. ray.init(address="auto")
  19. ray_params = RayParams(
  20. elastic_training=False,
  21. max_actor_restarts=2,
  22. num_actors=4,
  23. cpus_per_actor=4,
  24. gpus_per_actor=0,
  25. )
  26. @ray.remote
  27. def train():
  28. train_ray(
  29. path="/data/classification.parquet",
  30. num_workers=None,
  31. num_boost_rounds=100,
  32. num_files=25,
  33. regression=False,
  34. use_gpu=False,
  35. ray_params=ray_params,
  36. lightgbm_params=None,
  37. )
  38. start = time.time()
  39. ray.get(train.remote())
  40. taken = time.time() - start
  41. result = {
  42. "time_taken": taken,
  43. }
  44. test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/train_small.json")
  45. with open(test_output_json, "wt") as f:
  46. json.dump(result, f)
  47. print("PASSED.")