_anyscale_job_wrapper.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. """
  2. This script provides extra functionality for Anyscale Jobs tests.
  3. It will be ran on the cluster.
  4. We need to reimplement some utility functions here as it will not
  5. have access to the ray_release package.
  6. """
  7. import argparse
  8. import time
  9. import os
  10. from pathlib import Path
  11. import subprocess
  12. import multiprocessing
  13. import json
  14. import sys
  15. import logging
  16. from urllib.parse import urlparse
  17. from typing import Optional, List, Tuple
  18. OUTPUT_JSON_FILENAME = "output.json"
  19. AWS_CP_TIMEOUT = 300
  20. TIMEOUT_RETURN_CODE = 124 # same as bash timeout
  21. installed_pips = []
  22. logger = logging.getLogger(__name__)
  23. logger.setLevel(logging.INFO)
  24. handler = logging.StreamHandler(stream=sys.stderr)
  25. formatter = logging.Formatter(
  26. fmt="[%(levelname)s %(asctime)s] %(filename)s: %(lineno)d %(message)s"
  27. )
  28. handler.setFormatter(formatter)
  29. logger.addHandler(handler)
  30. def exponential_backoff_retry(
  31. f, retry_exceptions, initial_retry_delay_s, max_retries
  32. ) -> None:
  33. retry_cnt = 0
  34. retry_delay_s = initial_retry_delay_s
  35. while True:
  36. try:
  37. return f()
  38. except retry_exceptions as e:
  39. retry_cnt += 1
  40. if retry_cnt > max_retries:
  41. raise
  42. logger.warning(
  43. f"Retry function call failed due to {e} "
  44. f"in {retry_delay_s} seconds..."
  45. )
  46. time.sleep(retry_delay_s)
  47. retry_delay_s *= 2
  48. def install_pip(pip: str):
  49. if pip in installed_pips:
  50. return
  51. subprocess.run(["pip", "install", "-q", pip], check=True)
  52. installed_pips.append(pip)
  53. def run_storage_cp(source: str, target: str):
  54. if not source or not target:
  55. return False
  56. if not Path(source).exists():
  57. logger.warning(f"Couldn't upload to cloud storage: '{source}' does not exist.")
  58. return False
  59. storage_service = urlparse(target).scheme
  60. cp_cmd_args = []
  61. if storage_service == "s3":
  62. cp_cmd_args = [
  63. "aws",
  64. "s3",
  65. "cp",
  66. source,
  67. target,
  68. "--acl",
  69. "bucket-owner-full-control",
  70. ]
  71. elif storage_service == "gs":
  72. cp_cmd_args = [
  73. "gcloud",
  74. "storage",
  75. "cp",
  76. source,
  77. target,
  78. ]
  79. else:
  80. raise Exception(f"Not supporting storage service: {storage_service}")
  81. try:
  82. exponential_backoff_retry(
  83. lambda: subprocess.run(
  84. cp_cmd_args,
  85. timeout=AWS_CP_TIMEOUT,
  86. check=True,
  87. ),
  88. subprocess.SubprocessError,
  89. initial_retry_delay_s=10,
  90. max_retries=3,
  91. )
  92. return True
  93. except subprocess.SubprocessError:
  94. logger.exception("Couldn't upload to cloud storage.")
  95. return False
  96. def collect_metrics(time_taken: float) -> bool:
  97. if "METRICS_OUTPUT_JSON" not in os.environ:
  98. return False
  99. # Timeout is the time the test took divided by 200
  100. # (~7 minutes for a 24h test) but no less than 90s
  101. # and no more than 900s
  102. metrics_timeout = max(90, min((time.time() - time_taken) / 200, 900))
  103. try:
  104. subprocess.run(
  105. [
  106. "python",
  107. "prometheus_metrics.py",
  108. str(time_taken),
  109. "--path",
  110. os.environ["METRICS_OUTPUT_JSON"],
  111. ],
  112. timeout=metrics_timeout,
  113. check=True,
  114. )
  115. return True
  116. except subprocess.SubprocessError:
  117. logger.exception("Couldn't collect metrics.")
  118. return False
  119. # Has to be here so it can be pickled
  120. def _run_bash_command_subprocess(command: str, timeout: float):
  121. """Ran in a multiprocessing process."""
  122. try:
  123. subprocess.run(command, check=True, timeout=timeout)
  124. return_code = 0
  125. except subprocess.TimeoutExpired:
  126. return_code = TIMEOUT_RETURN_CODE
  127. except subprocess.CalledProcessError as e:
  128. return_code = e.returncode
  129. print(f"Subprocess return code: {return_code}", file=sys.stderr)
  130. # Exit so the return code is propagated to the outer process
  131. sys.exit(return_code)
  132. def run_bash_command(workload: str, timeout: float):
  133. timeout = timeout if timeout > 0 else None
  134. cwd = Path.cwd()
  135. workload_path = cwd / "workload.sh"
  136. workload_path = workload_path.resolve()
  137. with open(workload_path, "w") as fp:
  138. fp.write(workload)
  139. command = ["bash", "-x", str(workload_path)]
  140. logger.info(f"Running command {workload}")
  141. # Pop job's runtime env to allow workload's runtime env to take precedence
  142. # TODO: Confirm this is safe
  143. os.environ.pop("RAY_JOB_CONFIG_JSON_ENV_VAR", None)
  144. # We use multiprocessing with 'spawn' context to avoid
  145. # forking (as happens when using subprocess directly).
  146. # Forking messes up Ray interactions and causes deadlocks.
  147. return_code = None
  148. try:
  149. ctx = multiprocessing.get_context("spawn")
  150. p = ctx.Process(target=_run_bash_command_subprocess, args=(command, timeout))
  151. p.start()
  152. logger.info(f"Starting process {p.pid}.")
  153. # Add a little extra to the timeout as _run_bash_command_subprocess
  154. # also has a timeout internally and it's cleaner to use that
  155. p.join(timeout=timeout + 10)
  156. except multiprocessing.TimeoutError:
  157. return_code = TIMEOUT_RETURN_CODE
  158. except multiprocessing.ProcessError:
  159. pass
  160. finally:
  161. if p.is_alive():
  162. logger.warning(f"Terminating process {p.pid} forcefully.")
  163. p.terminate()
  164. if return_code is None:
  165. return_code = p.exitcode
  166. os.remove(str(workload_path))
  167. logger.info(f"Process {p.pid} exited with return code {return_code}.")
  168. assert return_code is not None
  169. return return_code
  170. def run_prepare_commands(
  171. prepare_commands: List[str], prepare_commands_timeouts: List[float]
  172. ) -> Tuple[bool, List[int], float]:
  173. """Run prepare commands. All commands must pass. Fails fast."""
  174. prepare_return_codes = []
  175. prepare_passed = True
  176. prepare_time_taken = None
  177. if not prepare_commands:
  178. return prepare_passed, prepare_return_codes, prepare_time_taken
  179. logger.info("### Starting prepare commands ###")
  180. for prepare_command, timeout in zip(prepare_commands, prepare_commands_timeouts):
  181. command_start_time = time.monotonic()
  182. prepare_return_codes.append(run_bash_command(prepare_command, timeout))
  183. prepare_time_taken = time.monotonic() - command_start_time
  184. return_code = prepare_return_codes[-1]
  185. if return_code == 0:
  186. continue
  187. timed_out = return_code == TIMEOUT_RETURN_CODE
  188. if timed_out:
  189. logger.error(
  190. "Prepare command timed out. " f"Time taken: {prepare_time_taken}"
  191. )
  192. else:
  193. logger.info(
  194. f"Prepare command finished with return code {return_code}. "
  195. f"Time taken: {prepare_time_taken}"
  196. )
  197. logger.error("Prepare command failed.")
  198. prepare_passed = False
  199. break
  200. return prepare_passed, prepare_return_codes, prepare_time_taken
  201. def main(
  202. test_workload: str,
  203. test_workload_timeout: float,
  204. test_no_raise_on_timeout: bool,
  205. results_cloud_storage_uri: Optional[str],
  206. metrics_cloud_storage_uri: Optional[str],
  207. output_cloud_storage_uri: Optional[str],
  208. upload_cloud_storage_uri: Optional[str],
  209. artifact_path: Optional[str],
  210. prepare_commands: List[str],
  211. prepare_commands_timeouts: List[str],
  212. ):
  213. """
  214. This function provides extra functionality for an Anyscale Job.
  215. 1. Runs prepare commands and handles their timeouts
  216. 2. Runs the actual test workload and handles its timeout
  217. 3. Uploads test results.json
  218. 4. Gathers prometheus metrics
  219. 5. Uploads prometheus metrics.json
  220. 6. Uploads output.json
  221. """
  222. logger.info("### Starting ###")
  223. start_time = time.monotonic()
  224. if len(prepare_commands) != len(prepare_commands_timeouts):
  225. raise ValueError(
  226. "`prepare_commands` and `prepare_commands_timeouts` must "
  227. "have the same length."
  228. )
  229. # Run prepare commands. All prepare commands must pass.
  230. (
  231. prepare_passed,
  232. prepare_return_codes,
  233. last_prepare_time_taken,
  234. ) = run_prepare_commands(prepare_commands, prepare_commands_timeouts)
  235. uploaded_results = False
  236. collected_metrics = False
  237. uploaded_metrics = False
  238. uploaded_artifact = artifact_path is not None
  239. workload_time_taken = None
  240. # If all prepare commands passed, run actual test workload.
  241. if prepare_passed:
  242. logger.info("### Starting entrypoint ###")
  243. command_start_time = time.monotonic()
  244. return_code = run_bash_command(test_workload, test_workload_timeout)
  245. workload_time_taken = time.monotonic() - command_start_time
  246. timed_out = return_code == TIMEOUT_RETURN_CODE
  247. if timed_out:
  248. msg = f"Timed out. Time taken: {workload_time_taken}"
  249. if test_no_raise_on_timeout:
  250. logger.info(msg)
  251. else:
  252. logger.error(msg)
  253. else:
  254. logger.info(
  255. f"Finished with return code {return_code}. "
  256. f"Time taken: {workload_time_taken}"
  257. )
  258. # Upload results.json
  259. uploaded_results = run_storage_cp(
  260. os.environ.get("TEST_OUTPUT_JSON", None), results_cloud_storage_uri
  261. )
  262. # Collect prometheus metrics
  263. collected_metrics = collect_metrics(workload_time_taken)
  264. if collected_metrics:
  265. # Upload prometheus metrics
  266. uploaded_metrics = run_storage_cp(
  267. os.environ.get("METRICS_OUTPUT_JSON", None), metrics_cloud_storage_uri
  268. )
  269. uploaded_artifact = run_storage_cp(
  270. artifact_path,
  271. os.path.join(
  272. upload_cloud_storage_uri, os.environ["USER_GENERATED_ARTIFACT"]
  273. )
  274. if "USER_GENERATED_ARTIFACT" in os.environ
  275. else None,
  276. )
  277. else:
  278. return_code = None
  279. total_time_taken = time.monotonic() - start_time
  280. output_json = {
  281. "return_code": return_code,
  282. "prepare_return_codes": prepare_return_codes,
  283. "last_prepare_time_taken": last_prepare_time_taken,
  284. "workload_time_taken": workload_time_taken,
  285. "total_time_taken": total_time_taken,
  286. "uploaded_results": uploaded_results,
  287. "collected_metrics": collected_metrics,
  288. "uploaded_metrics": uploaded_metrics,
  289. "uploaded_artifact": uploaded_artifact,
  290. }
  291. output_json = json.dumps(
  292. output_json, ensure_ascii=True, sort_keys=True, separators=(",", ":")
  293. )
  294. output_json_file = (Path.cwd() / OUTPUT_JSON_FILENAME).resolve()
  295. with open(output_json_file, "w") as fp:
  296. fp.write(output_json)
  297. # Upload output.json
  298. run_storage_cp(str(output_json_file), output_cloud_storage_uri)
  299. logger.info("### Finished ###")
  300. # This will be read by the AnyscaleJobRunner on the buildkite runner
  301. # if output.json cannot be obtained from cloud storage
  302. logger.info(f"### JSON |{output_json}| ###")
  303. # Flush buffers
  304. logging.shutdown()
  305. print("", flush=True)
  306. print("", file=sys.stderr, flush=True)
  307. if return_code == TIMEOUT_RETURN_CODE and test_no_raise_on_timeout:
  308. return_code = 0
  309. elif return_code is None:
  310. return_code = 1
  311. time.sleep(1)
  312. return return_code
  313. if __name__ == "__main__":
  314. parser = argparse.ArgumentParser()
  315. parser.add_argument(
  316. "test_workload", type=str, help="test workload, eg. python workloads/script.py"
  317. )
  318. parser.add_argument(
  319. "--test-workload-timeout",
  320. default=3600,
  321. type=float,
  322. help="test workload timeout (set to <0 for infinite)",
  323. )
  324. parser.add_argument(
  325. "--test-no-raise-on-timeout",
  326. action="store_true",
  327. help="don't fail on timeout",
  328. )
  329. parser.add_argument(
  330. "--results-cloud-storage-uri",
  331. type=str,
  332. help="bucket address to upload results.json to",
  333. required=False,
  334. )
  335. parser.add_argument(
  336. "--metrics-cloud-storage-uri",
  337. type=str,
  338. help="bucket address to upload metrics.json to",
  339. required=False,
  340. )
  341. parser.add_argument(
  342. "--output-cloud-storage-uri",
  343. type=str,
  344. help="bucket address to upload output.json to",
  345. required=False,
  346. )
  347. parser.add_argument(
  348. "--upload-cloud-storage-uri",
  349. type=str,
  350. help="root cloud-storage bucket address to upload stuff",
  351. required=False,
  352. )
  353. parser.add_argument(
  354. "--artifact-path",
  355. type=str,
  356. help="user provided artifact path (on head node), must be a single file path",
  357. required=False,
  358. )
  359. parser.add_argument(
  360. "--prepare-commands", type=str, nargs="*", help="prepare commands to run"
  361. )
  362. parser.add_argument(
  363. "--prepare-commands-timeouts",
  364. default=3600,
  365. type=float,
  366. nargs="*",
  367. help="timeout for prepare commands (set to <0 for infinite)",
  368. )
  369. args = parser.parse_args()
  370. sys.exit(main(**args.__dict__))