job_manager.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import os
  2. import time
  3. from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
  4. from ray_release.cluster_manager.cluster_manager import ClusterManager
  5. from ray_release.exception import CommandTimeout
  6. from ray_release.logger import logger
  7. from ray_release.util import (
  8. ANYSCALE_HOST,
  9. exponential_backoff_retry,
  10. )
  11. if TYPE_CHECKING:
  12. from ray.job_submission import JobSubmissionClient # noqa: F401
  13. class JobManager:
  14. def __init__(self, cluster_manager: ClusterManager):
  15. self.job_id_pool: Dict[int, str] = dict()
  16. self.start_time: Dict[int, float] = dict()
  17. self.counter = 0
  18. self.cluster_manager = cluster_manager
  19. self.job_client = None
  20. self.last_job_id = None
  21. def _get_job_client(self) -> "JobSubmissionClient":
  22. from ray.job_submission import JobSubmissionClient # noqa: F811
  23. if not self.job_client:
  24. self.job_client = JobSubmissionClient(
  25. self.cluster_manager.get_cluster_address()
  26. )
  27. return self.job_client
  28. def _run_job(
  29. self,
  30. cmd_to_run: str,
  31. env_vars: Dict[str, Any],
  32. working_dir: Optional[str] = None,
  33. ) -> int:
  34. self.counter += 1
  35. command_id = self.counter
  36. env = os.environ.copy()
  37. env["RAY_ADDRESS"] = self.cluster_manager.get_cluster_address()
  38. env.setdefault("ANYSCALE_HOST", str(ANYSCALE_HOST))
  39. full_cmd = " ".join(f"{k}={v}" for k, v in env_vars.items()) + " " + cmd_to_run
  40. logger.info(f"Executing {cmd_to_run} with {env_vars} via ray job submit")
  41. job_client = self._get_job_client()
  42. runtime_env = None
  43. if working_dir:
  44. runtime_env = {"working_dir": working_dir}
  45. job_id = job_client.submit_job(
  46. # Entrypoint shell command to execute
  47. entrypoint=full_cmd,
  48. runtime_env=runtime_env,
  49. )
  50. self.last_job_id = job_id
  51. self.job_id_pool[command_id] = job_id
  52. self.start_time[command_id] = time.time()
  53. return command_id
  54. def _get_job_status_with_retry(self, command_id):
  55. job_client = self._get_job_client()
  56. return exponential_backoff_retry(
  57. lambda: job_client.get_job_status(self.job_id_pool[command_id]),
  58. retry_exceptions=Exception,
  59. initial_retry_delay_s=1,
  60. max_retries=3,
  61. )
  62. def _wait_job(self, command_id: int, timeout: int):
  63. from ray.job_submission import JobStatus # noqa: F811
  64. start_time = time.monotonic()
  65. timeout_at = start_time + timeout
  66. next_status = start_time + 30
  67. while True:
  68. now = time.monotonic()
  69. if now >= timeout_at:
  70. raise CommandTimeout(
  71. f"Cluster command timed out after {timeout} seconds."
  72. )
  73. if now >= next_status:
  74. logger.info(
  75. f"... command still running ..."
  76. f"({int(now - start_time)} seconds) ..."
  77. )
  78. next_status += 30
  79. status = self._get_job_status_with_retry(command_id)
  80. if status in {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}:
  81. break
  82. time.sleep(1)
  83. status = self._get_job_status_with_retry(command_id)
  84. # TODO(sang): Propagate JobInfo.error_type
  85. if status == JobStatus.SUCCEEDED:
  86. retcode = 0
  87. else:
  88. retcode = -1
  89. duration = time.time() - self.start_time[command_id]
  90. return retcode, duration
  91. def run_and_wait(
  92. self,
  93. cmd_to_run,
  94. env_vars,
  95. working_dir: Optional[str] = None,
  96. timeout: int = 120,
  97. ) -> Tuple[int, float]:
  98. cid = self._run_job(cmd_to_run, env_vars, working_dir=working_dir)
  99. return self._wait_job(cid, timeout)
  100. def get_last_logs(self):
  101. # return None
  102. job_client = self._get_job_client()
  103. return job_client.get_job_logs(self.last_job_id)