cluster_manager.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import abc
  2. import time
  3. from typing import TYPE_CHECKING, Any, Dict, Optional
  4. from ray_release.aws import (
  5. add_tags_to_aws_config,
  6. RELEASE_AWS_RESOURCE_TYPES_TO_TRACK_FOR_BILLING,
  7. )
  8. from ray_release.anyscale_util import get_project_name, LAST_LOGS_LENGTH
  9. from ray_release.config import DEFAULT_AUTOSUSPEND_MINS, DEFAULT_MAXIMUM_UPTIME_MINS
  10. from ray_release.test import Test
  11. from ray_release.exception import CloudInfoError
  12. from ray_release.util import anyscale_cluster_url, dict_hash, get_anyscale_sdk
  13. from ray_release.logger import logger
  14. if TYPE_CHECKING:
  15. from anyscale.sdk.anyscale_client.sdk import AnyscaleSDK
  16. class ClusterManager(abc.ABC):
  17. def __init__(
  18. self,
  19. test: Test,
  20. project_id: str,
  21. sdk: Optional["AnyscaleSDK"] = None,
  22. smoke_test: bool = False,
  23. log_streaming_limit: int = LAST_LOGS_LENGTH,
  24. ):
  25. self.sdk = sdk or get_anyscale_sdk()
  26. self.test = test
  27. self.smoke_test = smoke_test
  28. self.project_id = project_id
  29. self.project_name = get_project_name(self.project_id, self.sdk)
  30. self.log_streaming_limit = log_streaming_limit
  31. self.cluster_name = (
  32. f"{test.get_name()}{'-smoke-test' if smoke_test else ''}_{int(time.time())}"
  33. )
  34. self.cluster_id = None
  35. self.cluster_env = None
  36. self.cluster_env_name = None
  37. self.cluster_env_id = None
  38. self.cluster_env_build_id = None
  39. self.cluster_compute = None
  40. self.cluster_compute_name = None
  41. self.cluster_compute_id = None
  42. self.cloud_provider = None
  43. self.autosuspend_minutes = DEFAULT_AUTOSUSPEND_MINS
  44. self.maximum_uptime_minutes = DEFAULT_MAXIMUM_UPTIME_MINS
  45. def set_cluster_env(self):
  46. byod_image_name_normalized = (
  47. self.test.get_anyscale_byod_image()
  48. .replace("/", "_")
  49. .replace(":", "_")
  50. .replace(".", "_")
  51. )
  52. self.cluster_env_name = (
  53. f"{byod_image_name_normalized}"
  54. f"__env__{dict_hash(self.test.get_byod_runtime_env())}"
  55. )
  56. def set_cluster_compute(
  57. self,
  58. cluster_compute: Dict[str, Any],
  59. extra_tags: Optional[Dict[str, str]] = None,
  60. ):
  61. extra_tags = extra_tags or {}
  62. self.cluster_compute = cluster_compute
  63. self.cluster_compute.setdefault(
  64. "idle_termination_minutes", self.autosuspend_minutes
  65. )
  66. self.cluster_compute.setdefault(
  67. "maximum_uptime_minutes", self.maximum_uptime_minutes
  68. )
  69. self.cloud_provider = self._get_cloud_provider(cluster_compute)
  70. self.cluster_compute = self._annotate_cluster_compute(
  71. self.cluster_compute,
  72. cloud_provider=self.cloud_provider,
  73. extra_tags=extra_tags,
  74. )
  75. self.cluster_compute_name = (
  76. f"{self.project_name}_{self.project_id[4:8]}"
  77. f"__compute__{self.test.get_name()}__"
  78. f"{dict_hash(self.cluster_compute)}"
  79. )
  80. def _get_cloud_provider(self, cluster_compute: Dict[str, Any]) -> Optional[str]:
  81. if not cluster_compute or "cloud_id" not in cluster_compute:
  82. return None
  83. try:
  84. return self.sdk.get_cloud(cluster_compute["cloud_id"]).result.provider
  85. except Exception as e:
  86. raise CloudInfoError(f"Could not obtain cloud information: {e}") from e
  87. def _annotate_cluster_compute(
  88. self,
  89. cluster_compute: Dict[str, Any],
  90. cloud_provider: str,
  91. extra_tags: Dict[str, str],
  92. ) -> Dict[str, Any]:
  93. if not extra_tags or cloud_provider != "AWS":
  94. return cluster_compute
  95. cluster_compute = cluster_compute.copy()
  96. aws = cluster_compute.get("aws", {})
  97. cluster_compute["aws"] = add_tags_to_aws_config(
  98. aws, extra_tags, RELEASE_AWS_RESOURCE_TYPES_TO_TRACK_FOR_BILLING
  99. )
  100. return cluster_compute
  101. def build_configs(self, timeout: float = 30.0):
  102. raise NotImplementedError
  103. def delete_configs(self):
  104. raise NotImplementedError
  105. def start_cluster(self, timeout: float = 600.0):
  106. raise NotImplementedError
  107. def terminate_cluster(self, wait: bool = False):
  108. try:
  109. self.terminate_cluster_ex(wait=False)
  110. except Exception as e:
  111. logger.exception(f"Could not terminate cluster: {e}")
  112. def terminate_cluster_ex(self, wait: bool = False):
  113. raise NotImplementedError
  114. def get_cluster_address(self) -> str:
  115. raise NotImplementedError
  116. def get_cluster_url(self) -> Optional[str]:
  117. if not self.project_id or not self.cluster_id:
  118. return None
  119. return anyscale_cluster_url(self.project_id, self.cluster_id)