horovod_user_test.py 921 B

12345678910111213141516171819202122232425262728293031323334353637
  1. import json
  2. import os
  3. import time
  4. import ray
  5. from horovod_example import main
  6. if __name__ == "__main__":
  7. start = time.time()
  8. addr = os.environ.get("RAY_ADDRESS")
  9. job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test")
  10. runtime_env = {"working_dir": os.path.dirname(__file__)}
  11. if addr.startswith("anyscale://"):
  12. ray.init(address=addr, job_name=job_name, runtime_env=runtime_env)
  13. else:
  14. ray.init(address="auto", runtime_env=runtime_env)
  15. main(
  16. num_workers=6,
  17. use_gpu=True,
  18. placement_group_timeout_s=2000,
  19. timeout_s=120,
  20. kwargs={"num_epochs": 20},
  21. )
  22. taken = time.time() - start
  23. result = {
  24. "time_taken": taken,
  25. }
  26. test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/horovod_user_test.json")
  27. with open(test_output_json, "wt") as f:
  28. json.dump(result, f)
  29. print("Test Successful!")