anyscale_job_manager.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. import os
  2. import time
  3. import subprocess
  4. import tempfile
  5. from collections import deque
  6. from contextlib import contextmanager
  7. from typing import Any, Dict, Optional, Tuple
  8. from anyscale.sdk.anyscale_client.models import (
  9. CreateProductionJob,
  10. HaJobStates,
  11. )
  12. from ray_release.anyscale_util import LAST_LOGS_LENGTH, get_cluster_name
  13. from ray_release.cluster_manager.cluster_manager import ClusterManager
  14. from ray_release.exception import (
  15. CommandTimeout,
  16. JobStartupTimeout,
  17. JobStartupFailed,
  18. )
  19. from ray_release.logger import logger
  20. from ray_release.signal_handling import register_handler, unregister_handler
  21. from ray_release.util import (
  22. ANYSCALE_HOST,
  23. ERROR_LOG_PATTERNS,
  24. exponential_backoff_retry,
  25. anyscale_job_url,
  26. format_link,
  27. )
  28. job_status_to_return_code = {
  29. HaJobStates.SUCCESS: 0,
  30. HaJobStates.OUT_OF_RETRIES: -1,
  31. HaJobStates.BROKEN: -2,
  32. HaJobStates.TERMINATED: -3,
  33. }
  34. terminal_state = set(job_status_to_return_code.keys())
  35. class AnyscaleJobManager:
  36. def __init__(self, cluster_manager: ClusterManager):
  37. self.start_time = None
  38. self.counter = 0
  39. self.cluster_manager = cluster_manager
  40. self._last_job_result = None
  41. self._last_logs = None
  42. self.cluster_startup_timeout = 600
  43. def _run_job(
  44. self,
  45. cmd_to_run: str,
  46. env_vars: Dict[str, Any],
  47. working_dir: Optional[str] = None,
  48. upload_path: Optional[str] = None,
  49. ) -> None:
  50. env = os.environ.copy()
  51. env.setdefault("ANYSCALE_HOST", str(ANYSCALE_HOST))
  52. logger.info(f"Executing {cmd_to_run} with {env_vars} via Anyscale job submit")
  53. anyscale_client = self.sdk
  54. runtime_env = {"env_vars": env_vars}
  55. if working_dir:
  56. runtime_env["working_dir"] = working_dir
  57. if upload_path:
  58. runtime_env["upload_path"] = upload_path
  59. try:
  60. job_response = anyscale_client.create_job(
  61. CreateProductionJob(
  62. name=self.cluster_manager.cluster_name,
  63. description=f"Smoke test: {self.cluster_manager.smoke_test}",
  64. project_id=self.cluster_manager.project_id,
  65. config=dict(
  66. entrypoint=cmd_to_run,
  67. runtime_env=runtime_env,
  68. build_id=self.cluster_manager.cluster_env_build_id,
  69. compute_config_id=self.cluster_manager.cluster_compute_id,
  70. max_retries=0,
  71. ),
  72. ),
  73. )
  74. except Exception as e:
  75. raise JobStartupFailed(
  76. "Error starting job with name "
  77. f"{self.cluster_manager.cluster_name}: "
  78. f"{e}"
  79. ) from e
  80. self.last_job_result = job_response.result
  81. self.start_time = time.time()
  82. logger.info(f"Link to job: " f"{format_link(self.job_url)}")
  83. return
  84. @property
  85. def sdk(self):
  86. return self.cluster_manager.sdk
  87. @property
  88. def last_job_result(self):
  89. return self._last_job_result
  90. @last_job_result.setter
  91. def last_job_result(self, value):
  92. cluster_id = value.state.cluster_id
  93. # Set this only once.
  94. if self.cluster_manager.cluster_id is None and cluster_id:
  95. self.cluster_manager.cluster_id = value.state.cluster_id
  96. self.cluster_manager.cluster_name = get_cluster_name(
  97. value.state.cluster_id, self.sdk
  98. )
  99. self._last_job_result = value
  100. @property
  101. def job_id(self) -> Optional[str]:
  102. if not self.last_job_result:
  103. return None
  104. return self.last_job_result.id
  105. @property
  106. def job_url(self) -> Optional[str]:
  107. if not self.job_id:
  108. return None
  109. return anyscale_job_url(self.job_id)
  110. @property
  111. def last_job_status(self) -> Optional[HaJobStates]:
  112. if not self.last_job_result:
  113. return None
  114. return self.last_job_result.state.current_state
  115. @property
  116. def in_progress(self) -> bool:
  117. return self.last_job_result and self.last_job_status not in terminal_state
  118. def _get_job_status_with_retry(self):
  119. anyscale_client = self.cluster_manager.sdk
  120. return exponential_backoff_retry(
  121. lambda: anyscale_client.get_production_job(self.job_id),
  122. retry_exceptions=Exception,
  123. initial_retry_delay_s=1,
  124. max_retries=3,
  125. ).result
  126. def _terminate_job(self, raise_exceptions: bool = False):
  127. if not self.in_progress:
  128. return
  129. logger.info(f"Terminating job {self.job_id}...")
  130. try:
  131. self.sdk.terminate_job(self.job_id)
  132. logger.info(f"Job {self.job_id} terminated!")
  133. except Exception:
  134. msg = f"Couldn't terminate job {self.job_id}!"
  135. if raise_exceptions:
  136. logger.error(msg)
  137. raise
  138. else:
  139. logger.exception(msg)
  140. @contextmanager
  141. def _terminate_job_context(self):
  142. """
  143. Context to ensure the job is terminated.
  144. Aside from running _terminate_job at exit, it also registers
  145. a signal handler to terminate the job if the program is interrupted
  146. or terminated. It restores the original handlers on exit.
  147. """
  148. def terminate_handler(signum, frame):
  149. self._terminate_job()
  150. register_handler(terminate_handler)
  151. yield
  152. self._terminate_job()
  153. unregister_handler(terminate_handler)
  154. def _wait_job(self, timeout: int):
  155. # The context ensures the job always either finishes normally
  156. # or is terminated.
  157. with self._terminate_job_context():
  158. assert self.job_id, "Job must have been started"
  159. start_time = time.monotonic()
  160. # Waiting for cluster needs to be a part of the whole
  161. # run.
  162. timeout_at = start_time + self.cluster_startup_timeout
  163. next_status = start_time + 30
  164. job_running = False
  165. while True:
  166. now = time.monotonic()
  167. if now >= timeout_at:
  168. self._terminate_job()
  169. if not job_running:
  170. raise JobStartupTimeout(
  171. "Cluster did not start within "
  172. f"{self.cluster_startup_timeout} seconds."
  173. )
  174. raise CommandTimeout(f"Job timed out after {timeout} seconds.")
  175. if now >= next_status:
  176. if job_running:
  177. msg = "... job still running ..."
  178. else:
  179. msg = "... job not yet running ..."
  180. logger.info(
  181. f"{msg}({int(now - start_time)} seconds, "
  182. f"{int(timeout_at - now)} seconds to job timeout) ..."
  183. )
  184. next_status += 30
  185. result = self._get_job_status_with_retry()
  186. self.last_job_result = result
  187. status = self.last_job_status
  188. if not job_running and status in {
  189. HaJobStates.RUNNING,
  190. HaJobStates.ERRORED,
  191. }:
  192. logger.info(
  193. f"... job started ...({int(now - start_time)} seconds) ..."
  194. )
  195. job_running = True
  196. # If job has started, we switch from waiting for cluster
  197. # to the actual command (incl. prepare commands) timeout.
  198. timeout_at = now + timeout
  199. if status in terminal_state:
  200. logger.info(f"Job entered terminal state {status}.")
  201. break
  202. time.sleep(1)
  203. result = self._get_job_status_with_retry()
  204. self.last_job_result = result
  205. status = self.last_job_status
  206. assert status in terminal_state
  207. if status == HaJobStates.TERMINATED and not job_running:
  208. # Soft infra error
  209. retcode = -4
  210. else:
  211. retcode = job_status_to_return_code[status]
  212. duration = time.time() - self.start_time
  213. return retcode, duration
  214. def run_and_wait(
  215. self,
  216. cmd_to_run,
  217. env_vars,
  218. working_dir: Optional[str] = None,
  219. timeout: int = 120,
  220. upload_path: Optional[str] = None,
  221. ) -> Tuple[int, float]:
  222. self._run_job(
  223. cmd_to_run, env_vars, working_dir=working_dir, upload_path=upload_path
  224. )
  225. return self._wait_job(timeout)
  226. def _get_ray_logs(self) -> Tuple[Optional[str], Optional[str]]:
  227. """
  228. Obtain any ray logs that contain keywords that indicate a crash, such as
  229. ERROR or Traceback
  230. """
  231. tmpdir = tempfile.mktemp()
  232. try:
  233. subprocess.check_output(
  234. [
  235. "anyscale",
  236. "logs",
  237. "cluster",
  238. "--id",
  239. self.cluster_manager.cluster_id,
  240. "--head-only",
  241. "--download",
  242. "--download-dir",
  243. tmpdir,
  244. ]
  245. )
  246. except Exception as e:
  247. logger.log(f"Failed to download logs from anyscale {e}")
  248. return None
  249. return AnyscaleJobManager._find_job_driver_and_ray_error_logs(tmpdir)
  250. @staticmethod
  251. def _find_job_driver_and_ray_error_logs(
  252. tmpdir: str,
  253. ) -> Tuple[Optional[str], Optional[str]]:
  254. # Ignored some ray files that do not crash ray despite having exceptions
  255. ignored_ray_files = [
  256. "monitor.log",
  257. "event_AUTOSCALER.log",
  258. "event_JOBS.log",
  259. ]
  260. error_output = None
  261. job_driver_output = None
  262. matched_pattern_count = 0
  263. for root, _, files in os.walk(tmpdir):
  264. for file in files:
  265. if file in ignored_ray_files:
  266. continue
  267. with open(os.path.join(root, file)) as lines:
  268. output = "".join(deque(lines, maxlen=3 * LAST_LOGS_LENGTH))
  269. # job-driver logs
  270. if file.startswith("job-driver-"):
  271. job_driver_output = output
  272. continue
  273. # ray error logs, favor those that match with the most number of
  274. # error patterns
  275. if (
  276. len([error for error in ERROR_LOG_PATTERNS if error in output])
  277. > matched_pattern_count
  278. ):
  279. error_output = output
  280. return job_driver_output, error_output
  281. def get_last_logs(self):
  282. if not self.job_id:
  283. raise RuntimeError(
  284. "Job has not been started, therefore there are no logs to obtain."
  285. )
  286. if self._last_logs:
  287. return self._last_logs
  288. def _get_logs():
  289. job_driver_log, ray_error_log = self._get_ray_logs()
  290. assert job_driver_log or ray_error_log, "No logs fetched"
  291. if job_driver_log:
  292. return job_driver_log
  293. else:
  294. return ray_error_log
  295. ret = exponential_backoff_retry(
  296. _get_logs,
  297. retry_exceptions=Exception,
  298. initial_retry_delay_s=30,
  299. max_retries=3,
  300. )
  301. if ret and not self.in_progress:
  302. self._last_logs = ret
  303. return ret