job_runner.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import json
  2. import os
  3. import shutil
  4. import tempfile
  5. from typing import TYPE_CHECKING, Any, Dict, Optional
  6. from ray_release.cluster_manager.cluster_manager import ClusterManager
  7. from ray_release.command_runner.command_runner import CommandRunner
  8. from ray_release.exception import (
  9. ClusterNodesWaitTimeout,
  10. CommandError,
  11. CommandTimeout,
  12. LocalEnvSetupError,
  13. LogsError,
  14. FetchResultError,
  15. )
  16. from ray_release.file_manager.file_manager import FileManager
  17. from ray_release.job_manager import JobManager
  18. from ray_release.logger import logger
  19. from ray_release.util import format_link, get_anyscale_sdk
  20. from ray_release.wheels import install_matching_ray_locally
  21. if TYPE_CHECKING:
  22. from anyscale.sdk.anyscale_client.sdk import AnyscaleSDK
  23. class JobRunner(CommandRunner):
  24. def __init__(
  25. self,
  26. cluster_manager: ClusterManager,
  27. file_manager: FileManager,
  28. working_dir: str,
  29. sdk: Optional["AnyscaleSDK"] = None,
  30. artifact_path: Optional[str] = None,
  31. ):
  32. super(JobRunner, self).__init__(
  33. cluster_manager=cluster_manager,
  34. file_manager=file_manager,
  35. working_dir=working_dir,
  36. )
  37. self.sdk = sdk or get_anyscale_sdk()
  38. self.job_manager = JobManager(cluster_manager)
  39. self.last_command_scd_id = None
  40. def prepare_local_env(self, ray_wheels_url: Optional[str] = None):
  41. if not os.environ.get("BUILDKITE"):
  42. return
  43. # Install matching Ray for job submission
  44. try:
  45. install_matching_ray_locally(
  46. ray_wheels_url or os.environ.get("RAY_WHEELS", None)
  47. )
  48. except Exception as e:
  49. raise LocalEnvSetupError(f"Error setting up local environment: {e}") from e
  50. def _copy_script_to_working_dir(self, script_name):
  51. script = os.path.join(os.path.dirname(__file__), f"_{script_name}")
  52. shutil.copy(script, script_name)
  53. def prepare_remote_env(self):
  54. self._copy_script_to_working_dir("wait_cluster.py")
  55. self._copy_script_to_working_dir("prometheus_metrics.py")
  56. # Do not upload the files here. Instead, we use the job runtime environment
  57. # to automatically upload the local working dir.
  58. def wait_for_nodes(self, num_nodes: int, timeout: float = 900):
  59. # Wait script should be uploaded already. Kick off command
  60. try:
  61. # Give 30 seconds more to acount for communication
  62. self.run_prepare_command(
  63. f"python wait_cluster.py {num_nodes} {timeout}", timeout=timeout + 30
  64. )
  65. except (CommandError, CommandTimeout) as e:
  66. raise ClusterNodesWaitTimeout(
  67. f"Not all {num_nodes} nodes came up within {timeout} seconds."
  68. ) from e
  69. def save_metrics(self, start_time: float, timeout: float = 900):
  70. self.run_prepare_command(
  71. f"python prometheus_metrics.py {start_time}", timeout=timeout
  72. )
  73. def run_command(
  74. self,
  75. command: str,
  76. env: Optional[Dict] = None,
  77. timeout: float = 3600.0,
  78. raise_on_timeout: bool = True,
  79. ) -> float:
  80. full_env = self.get_full_command_env(env)
  81. if full_env:
  82. env_str = " ".join(f"{k}={v}" for k, v in full_env.items()) + " "
  83. else:
  84. env_str = ""
  85. full_command = f"{env_str}{command}"
  86. logger.info(
  87. f"Running command in cluster {self.cluster_manager.cluster_name}: "
  88. f"{full_command}"
  89. )
  90. logger.info(
  91. f"Link to cluster: "
  92. f"{format_link(self.cluster_manager.get_cluster_url())}"
  93. )
  94. status_code, time_taken = self.job_manager.run_and_wait(
  95. full_command, full_env, working_dir=".", timeout=int(timeout)
  96. )
  97. if status_code != 0:
  98. raise CommandError(f"Command returned non-success status: {status_code}")
  99. return time_taken
  100. def get_last_logs_ex(self, scd_id: Optional[str] = None):
  101. try:
  102. return self.job_manager.get_last_logs()
  103. except Exception as e:
  104. raise LogsError(f"Could not get last logs: {e}") from e
  105. def _fetch_json(self, path: str) -> Dict[str, Any]:
  106. try:
  107. tmpfile = tempfile.mkstemp(suffix=".json")[1]
  108. logger.info(tmpfile)
  109. self.file_manager.download(path, tmpfile)
  110. with open(tmpfile, "rt") as f:
  111. data = json.load(f)
  112. os.unlink(tmpfile)
  113. return data
  114. except Exception as e:
  115. raise FetchResultError(f"Could not fetch results from session: {e}") from e
  116. def fetch_results(self) -> Dict[str, Any]:
  117. return self._fetch_json(self._RESULT_OUTPUT_JSON)
  118. def fetch_metrics(self) -> Dict[str, Any]:
  119. return self._fetch_json(self._METRICS_OUTPUT_JSON)
  120. def fetch_artifact(self):
  121. raise NotImplementedError