job_manager.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. import time
  3. from typing import Dict, Tuple, TYPE_CHECKING
  4. if TYPE_CHECKING:
  5. from ray.job_submission import JobSubmissionClient, JobStatus # noqa: F401
  6. from ray_release.logger import logger
  7. from ray_release.util import ANYSCALE_HOST
  8. from ray_release.cluster_manager.cluster_manager import ClusterManager
  9. from ray_release.exception import CommandTimeout
  10. from ray_release.util import exponential_backoff_retry
  11. class JobManager:
  12. def __init__(self, cluster_manager: ClusterManager):
  13. self.job_id_pool: Dict[int, str] = dict()
  14. self.start_time: Dict[int, float] = dict()
  15. self.counter = 0
  16. self.cluster_manager = cluster_manager
  17. self.job_client = None
  18. self.last_job_id = None
  19. def _get_job_client(self) -> "JobSubmissionClient":
  20. from ray.job_submission import JobSubmissionClient # noqa: F811
  21. if not self.job_client:
  22. self.job_client = JobSubmissionClient(
  23. self.cluster_manager.get_cluster_address()
  24. )
  25. return self.job_client
  26. def _run_job(self, cmd_to_run, env_vars) -> int:
  27. self.counter += 1
  28. command_id = self.counter
  29. env = os.environ.copy()
  30. env["RAY_ADDRESS"] = self.cluster_manager.get_cluster_address()
  31. env.setdefault("ANYSCALE_HOST", ANYSCALE_HOST)
  32. full_cmd = " ".join(f"{k}={v}" for k, v in env_vars.items()) + " " + cmd_to_run
  33. logger.info(f"Executing {cmd_to_run} with {env_vars} via ray job submit")
  34. job_client = self._get_job_client()
  35. job_id = job_client.submit_job(
  36. # Entrypoint shell command to execute
  37. entrypoint=full_cmd,
  38. )
  39. self.last_job_id = job_id
  40. self.job_id_pool[command_id] = job_id
  41. self.start_time[command_id] = time.time()
  42. return command_id
  43. def _get_job_status_with_retry(self, command_id):
  44. job_client = self._get_job_client()
  45. return exponential_backoff_retry(
  46. lambda: job_client.get_job_status(self.job_id_pool[command_id]),
  47. retry_exceptions=Exception,
  48. initial_retry_delay_s=1,
  49. max_retries=3,
  50. )
  51. def _wait_job(self, command_id: int, timeout: int):
  52. from ray.job_submission import JobStatus # noqa: F811
  53. start_time = time.monotonic()
  54. timeout_at = start_time + timeout
  55. next_status = start_time + 30
  56. while True:
  57. now = time.monotonic()
  58. if now >= timeout_at:
  59. raise CommandTimeout(
  60. f"Cluster command timed out after {timeout} seconds."
  61. )
  62. if now >= next_status:
  63. logger.info(
  64. f"... command still running ..."
  65. f"({int(now - start_time)} seconds) ..."
  66. )
  67. next_status += 30
  68. status = self._get_job_status_with_retry(command_id)
  69. if status in {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}:
  70. break
  71. time.sleep(1)
  72. status = self._get_job_status_with_retry(command_id)
  73. # TODO(sang): Propagate JobInfo.error_type
  74. if status == JobStatus.SUCCEEDED:
  75. retcode = 0
  76. else:
  77. retcode = -1
  78. duration = time.time() - self.start_time[command_id]
  79. return retcode, duration
  80. def run_and_wait(
  81. self, cmd_to_run, env_vars, timeout: int = 120
  82. ) -> Tuple[int, float]:
  83. cid = self._run_job(cmd_to_run, env_vars)
  84. return self._wait_job(cid, timeout)
  85. def get_last_logs(self):
  86. # return None
  87. job_client = self._get_job_client()
  88. return job_client.get_job_logs(self.last_job_id)