rte_ray_client.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """Runtime env test on Ray Client
  2. This test installs runtime environments on a remote cluster using local
  3. pip requirements.txt files. It is intended to be run using Anyscale connect.
  4. This complements existing per-commit tests in CI, for which we don't have
  5. access to a physical remote cluster.
  6. Test owner: architkulkarni
  7. Acceptance criteria: Should run through and print "PASSED"
  8. """
  9. import argparse
  10. import json
  11. import os
  12. import tempfile
  13. import time
  14. from pathlib import Path
  15. import ray
  16. def test_pip_requirements_files(tmpdir: str):
  17. """Test requirements.txt with tasks and actors.
  18. Test specifying in @ray.remote decorator and in .options.
  19. """
  20. pip_file_18 = Path(os.path.join(tmpdir, "runtime_env_pip_18.txt"))
  21. pip_file_18.write_text("requests==2.18.0")
  22. env_18 = {"pip": str(pip_file_18)}
  23. pip_file_16 = Path(os.path.join(tmpdir, "runtime_env_pip_16.txt"))
  24. pip_file_16.write_text("requests==2.16.0")
  25. env_16 = {"pip": str(pip_file_16)}
  26. @ray.remote(runtime_env=env_16)
  27. def get_version():
  28. import requests
  29. return requests.__version__
  30. # TODO(architkulkarni): Uncomment after #19002 is fixed
  31. # assert ray.get(get_version.remote()) == "2.16.0"
  32. assert ray.get(get_version.options(runtime_env=env_18).remote()) == "2.18.0"
  33. @ray.remote(runtime_env=env_18)
  34. class VersionActor:
  35. def get_version(self):
  36. import requests
  37. return requests.__version__
  38. # TODO(architkulkarni): Uncomment after #19002 is fixed
  39. # actor_18 = VersionActor.remote()
  40. # assert ray.get(actor_18.get_version.remote()) == "2.18.0"
  41. actor_16 = VersionActor.options(runtime_env=env_16).remote()
  42. assert ray.get(actor_16.get_version.remote()) == "2.16.0"
  43. if __name__ == "__main__":
  44. parser = argparse.ArgumentParser()
  45. parser.add_argument(
  46. "--smoke-test", action="store_true", help="Finish quickly for testing."
  47. )
  48. args = parser.parse_args()
  49. start = time.time()
  50. addr = os.environ.get("RAY_ADDRESS")
  51. job_name = os.environ.get("RAY_JOB_NAME", "rte_ray_client")
  52. # Test reconnecting to the same cluster multiple times.
  53. for use_working_dir in [True, True, False, False]:
  54. with tempfile.TemporaryDirectory() as tmpdir:
  55. runtime_env = {"working_dir": tmpdir} if use_working_dir else None
  56. print("Testing with use_working_dir=" + str(use_working_dir))
  57. if addr is not None and addr.startswith("anyscale://"):
  58. ray.init(address=addr, job_name=job_name, runtime_env=runtime_env)
  59. else:
  60. ray.init(address="auto", runtime_env=runtime_env)
  61. test_pip_requirements_files(tmpdir)
  62. ray.shutdown()
  63. taken = time.time() - start
  64. result = {
  65. "time_taken": taken,
  66. }
  67. test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/rte_ray_client.json")
  68. with open(test_output_json, "wt") as f:
  69. json.dump(result, f)
  70. print("PASSED")