train_tensorflow_mnist_test.py 819 B

1234567891011121314151617181920212223242526272829303132
  1. import json
  2. import os
  3. import time
  4. import ray
  5. from ray.train.examples.tensorflow_mnist_example import train_tensorflow_mnist
  6. if __name__ == "__main__":
  7. start = time.time()
  8. addr = os.environ.get("RAY_ADDRESS")
  9. job_name = os.environ.get("RAY_JOB_NAME", "train_tensorflow_mnist_test")
  10. if addr is not None and addr.startswith("anyscale://"):
  11. ray.init(address=addr, job_name=job_name)
  12. else:
  13. ray.init(address="auto")
  14. train_tensorflow_mnist(num_workers=6, use_gpu=True, epochs=20)
  15. taken = time.time() - start
  16. result = {
  17. "time_taken": taken,
  18. }
  19. test_output_json = os.environ.get(
  20. "TEST_OUTPUT_JSON", "/tmp/train_tensorflow_mnist_test.json"
  21. )
  22. with open(test_output_json, "wt") as f:
  23. json.dump(result, f)
  24. print("Test Successful!")