anyscale_job_manager.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import time
  2. from contextlib import contextmanager
  3. from typing import Any, Dict, Optional, Tuple, List
  4. import anyscale
  5. from anyscale.sdk.anyscale_client.models import (
  6. CreateProductionJob,
  7. HaJobStates,
  8. )
  9. from ray_release.anyscale_util import get_cluster_name
  10. from ray_release.cluster_manager.cluster_manager import ClusterManager
  11. from ray_release.exception import (
  12. CommandTimeout,
  13. JobStartupTimeout,
  14. JobStartupFailed,
  15. )
  16. from ray_release.logger import logger
  17. from ray_release.signal_handling import register_handler, unregister_handler
  18. from ray_release.util import (
  19. exponential_backoff_retry,
  20. anyscale_job_url,
  21. format_link,
  22. )
  23. job_status_to_return_code = {
  24. HaJobStates.SUCCESS: 0,
  25. HaJobStates.OUT_OF_RETRIES: -1,
  26. HaJobStates.BROKEN: -2,
  27. HaJobStates.TERMINATED: -3,
  28. }
  29. terminal_state = set(job_status_to_return_code.keys())
  30. class AnyscaleJobManager:
  31. def __init__(self, cluster_manager: ClusterManager):
  32. self.start_time = None
  33. self.counter = 0
  34. self.cluster_manager = cluster_manager
  35. self._last_job_result = None
  36. self._last_logs = None
  37. self.cluster_startup_timeout = 600
  38. self._duration = None
  39. def _run_job(
  40. self,
  41. cmd_to_run: str,
  42. env_vars: Dict[str, Any],
  43. working_dir: Optional[str] = None,
  44. upload_path: Optional[str] = None,
  45. pip: Optional[List[str]] = None,
  46. ) -> None:
  47. env_vars_for_job = env_vars.copy()
  48. env_vars_for_job[
  49. "ANYSCALE_JOB_CLUSTER_ENV_NAME"
  50. ] = self.cluster_manager.cluster_env_name
  51. logger.info(
  52. f"Executing {cmd_to_run} with {env_vars_for_job} via Anyscale job submit"
  53. )
  54. anyscale_client = self.sdk
  55. runtime_env = {
  56. "env_vars": env_vars_for_job,
  57. "pip": pip or [],
  58. }
  59. if working_dir:
  60. runtime_env["working_dir"] = working_dir
  61. if upload_path:
  62. runtime_env["upload_path"] = upload_path
  63. try:
  64. job_response = anyscale_client.create_job(
  65. CreateProductionJob(
  66. name=self.cluster_manager.cluster_name,
  67. description=f"Smoke test: {self.cluster_manager.smoke_test}",
  68. project_id=self.cluster_manager.project_id,
  69. config=dict(
  70. entrypoint=cmd_to_run,
  71. runtime_env=runtime_env,
  72. build_id=self.cluster_manager.cluster_env_build_id,
  73. compute_config_id=self.cluster_manager.cluster_compute_id,
  74. max_retries=0,
  75. ),
  76. ),
  77. )
  78. except Exception as e:
  79. raise JobStartupFailed(
  80. "Error starting job with name "
  81. f"{self.cluster_manager.cluster_name}: "
  82. f"{e}"
  83. ) from e
  84. self.last_job_result = job_response.result
  85. self.start_time = time.time()
  86. logger.info(f"Link to job: " f"{format_link(self.job_url)}")
  87. return
  88. @property
  89. def sdk(self):
  90. return self.cluster_manager.sdk
  91. @property
  92. def last_job_result(self):
  93. return self._last_job_result
  94. @last_job_result.setter
  95. def last_job_result(self, value):
  96. cluster_id = value.state.cluster_id
  97. # Set this only once.
  98. if self.cluster_manager.cluster_id is None and cluster_id:
  99. self.cluster_manager.cluster_id = value.state.cluster_id
  100. self.cluster_manager.cluster_name = get_cluster_name(
  101. value.state.cluster_id, self.sdk
  102. )
  103. self._last_job_result = value
  104. @property
  105. def job_id(self) -> Optional[str]:
  106. if not self.last_job_result:
  107. return None
  108. return self.last_job_result.id
  109. @property
  110. def job_url(self) -> Optional[str]:
  111. if not self.job_id:
  112. return None
  113. return anyscale_job_url(self.job_id)
  114. @property
  115. def last_job_status(self) -> Optional[HaJobStates]:
  116. if not self.last_job_result:
  117. return None
  118. return self.last_job_result.state.current_state
  119. @property
  120. def in_progress(self) -> bool:
  121. return self.last_job_result and self.last_job_status not in terminal_state
  122. def _get_job_status_with_retry(self):
  123. anyscale_client = self.cluster_manager.sdk
  124. return exponential_backoff_retry(
  125. lambda: anyscale_client.get_production_job(self.job_id),
  126. retry_exceptions=Exception,
  127. initial_retry_delay_s=1,
  128. max_retries=3,
  129. ).result
  130. def _terminate_job(self, raise_exceptions: bool = False):
  131. if not self.in_progress:
  132. return
  133. logger.info(f"Terminating job {self.job_id}...")
  134. try:
  135. self.sdk.terminate_job(self.job_id)
  136. logger.info(f"Job {self.job_id} terminated!")
  137. except Exception:
  138. msg = f"Couldn't terminate job {self.job_id}!"
  139. if raise_exceptions:
  140. logger.error(msg)
  141. raise
  142. else:
  143. logger.exception(msg)
  144. @contextmanager
  145. def _terminate_job_context(self):
  146. """
  147. Context to ensure the job is terminated.
  148. Aside from running _terminate_job at exit, it also registers
  149. a signal handler to terminate the job if the program is interrupted
  150. or terminated. It restores the original handlers on exit.
  151. """
  152. def terminate_handler(signum, frame):
  153. self._terminate_job()
  154. register_handler(terminate_handler)
  155. yield
  156. self._terminate_job()
  157. unregister_handler(terminate_handler)
  158. def _wait_job(self, timeout: int):
  159. # The context ensures the job always either finishes normally
  160. # or is terminated.
  161. with self._terminate_job_context():
  162. assert self.job_id, "Job must have been started"
  163. start_time = time.monotonic()
  164. # Waiting for cluster needs to be a part of the whole
  165. # run.
  166. timeout_at = start_time + self.cluster_startup_timeout
  167. next_status = start_time + 30
  168. job_running = False
  169. while True:
  170. now = time.monotonic()
  171. if now >= timeout_at:
  172. self._terminate_job()
  173. if not job_running:
  174. raise JobStartupTimeout(
  175. "Cluster did not start within "
  176. f"{self.cluster_startup_timeout} seconds."
  177. )
  178. raise CommandTimeout(f"Job timed out after {timeout} seconds.")
  179. if now >= next_status:
  180. if job_running:
  181. msg = "... job still running ..."
  182. else:
  183. msg = "... job not yet running ..."
  184. logger.info(
  185. f"{msg}({int(now - start_time)} seconds, "
  186. f"{int(timeout_at - now)} seconds to job timeout) ..."
  187. )
  188. next_status += 30
  189. result = self._get_job_status_with_retry()
  190. self.last_job_result = result
  191. status = self.last_job_status
  192. if not job_running and status in {
  193. HaJobStates.RUNNING,
  194. HaJobStates.ERRORED,
  195. }:
  196. logger.info(
  197. f"... job started ...({int(now - start_time)} seconds) ..."
  198. )
  199. job_running = True
  200. # If job has started, we switch from waiting for cluster
  201. # to the actual command (incl. prepare commands) timeout.
  202. timeout_at = now + timeout
  203. if status in terminal_state:
  204. logger.info(f"Job entered terminal state {status}.")
  205. break
  206. time.sleep(1)
  207. result = self._get_job_status_with_retry()
  208. self.last_job_result = result
  209. status = self.last_job_status
  210. assert status in terminal_state
  211. if status == HaJobStates.TERMINATED and not job_running:
  212. # Soft infra error
  213. retcode = -4
  214. else:
  215. retcode = job_status_to_return_code[status]
  216. self._duration = time.time() - self.start_time
  217. return retcode, self._duration
  218. def run_and_wait(
  219. self,
  220. cmd_to_run,
  221. env_vars,
  222. working_dir: Optional[str] = None,
  223. timeout: int = 120,
  224. upload_path: Optional[str] = None,
  225. pip: Optional[List[str]] = None,
  226. ) -> Tuple[int, float]:
  227. self._run_job(
  228. cmd_to_run,
  229. env_vars,
  230. working_dir=working_dir,
  231. upload_path=upload_path,
  232. pip=pip,
  233. )
  234. return self._wait_job(timeout)
  235. def _get_ray_logs(self) -> str:
  236. """
  237. Obtain the last few logs
  238. """
  239. if self.cluster_manager.log_streaming_limit == -1:
  240. return anyscale.job.get_logs(id=self.job_id)
  241. return anyscale.job.get_logs(
  242. id=self.job_id, max_lines=self.cluster_manager.log_streaming_limit
  243. )
  244. def get_last_logs(self):
  245. if not self.job_id:
  246. raise RuntimeError(
  247. "Job has not been started, therefore there are no logs to obtain."
  248. )
  249. if self._last_logs:
  250. return self._last_logs
  251. # Skip loading logs when the job ran for too long and collected too much logs.
  252. if self._duration is not None and self._duration > 4 * 3_600:
  253. return None
  254. ret = exponential_backoff_retry(
  255. self._get_ray_logs,
  256. retry_exceptions=Exception,
  257. initial_retry_delay_s=30,
  258. max_retries=3,
  259. )
  260. if ret and not self.in_progress:
  261. self._last_logs = ret
  262. return ret