job_file_manager.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import os
  2. import shutil
  3. import sys
  4. import tempfile
  5. from typing import Optional
  6. import boto3
  7. from google.cloud import storage
  8. from ray_release.aws import RELEASE_AWS_BUCKET
  9. from ray_release.cluster_manager.cluster_manager import ClusterManager
  10. from ray_release.exception import FileDownloadError, FileUploadError
  11. from ray_release.file_manager.file_manager import FileManager
  12. from ray_release.job_manager import JobManager
  13. from ray_release.logger import logger
  14. from ray_release.util import (
  15. exponential_backoff_retry,
  16. generate_tmp_cloud_storage_path,
  17. S3_CLOUD_STORAGE,
  18. GS_CLOUD_STORAGE,
  19. GS_BUCKET,
  20. )
  21. class JobFileManager(FileManager):
  22. def __init__(self, cluster_manager: ClusterManager):
  23. import anyscale
  24. super(JobFileManager, self).__init__(cluster_manager=cluster_manager)
  25. self.sdk = self.cluster_manager.sdk
  26. self.s3_client = boto3.client(S3_CLOUD_STORAGE)
  27. self.cloud_storage_provider = os.environ.get(
  28. "ANYSCALE_CLOUD_STORAGE_PROVIDER", S3_CLOUD_STORAGE
  29. )
  30. if self.cloud_storage_provider == S3_CLOUD_STORAGE:
  31. self.bucket = str(RELEASE_AWS_BUCKET)
  32. elif self.cloud_storage_provider == GS_CLOUD_STORAGE:
  33. self.bucket = GS_BUCKET
  34. self.gs_client = storage.Client()
  35. else:
  36. raise RuntimeError(
  37. f"Non supported anyscale service provider: "
  38. f"{self.cloud_storage_provider}"
  39. )
  40. self.job_manager = JobManager(cluster_manager)
  41. # Backward compatible
  42. if "ANYSCALE_RAY_DIR" in anyscale.__dict__:
  43. sys.path.insert(0, f"{anyscale.ANYSCALE_RAY_DIR}/bin")
  44. def _run_with_retry(self, f, initial_retry_delay_s: int = 10):
  45. assert callable(f)
  46. return exponential_backoff_retry(
  47. f,
  48. retry_exceptions=Exception,
  49. initial_retry_delay_s=initial_retry_delay_s,
  50. max_retries=3,
  51. )
  52. def _generate_tmp_cloud_storage_path(self):
  53. location = f"tmp/{generate_tmp_cloud_storage_path()}"
  54. return location
  55. def download_from_cloud(
  56. self, key: str, target: str, delete_after_download: bool = False
  57. ):
  58. if self.cloud_storage_provider == S3_CLOUD_STORAGE:
  59. self._run_with_retry(
  60. lambda: self.s3_client.download_file(
  61. Bucket=self.bucket,
  62. Key=key,
  63. Filename=target,
  64. )
  65. )
  66. if self.cloud_storage_provider == GS_CLOUD_STORAGE:
  67. bucket = self.gs_client.bucket(self.bucket)
  68. blob = bucket.blob(key)
  69. self._run_with_retry(lambda: blob.download_to_filename(target))
  70. if delete_after_download:
  71. self.delete(key)
  72. def download(self, source: str, target: str):
  73. # Attention: Only works for single files at the moment
  74. remote_upload_to = self._generate_tmp_cloud_storage_path()
  75. # remote source -> s3
  76. bucket_address = f"s3://{self.bucket}/{remote_upload_to}"
  77. retcode, _ = self._run_with_retry(
  78. lambda: self.job_manager.run_and_wait(
  79. (
  80. f"pip install -q awscli && "
  81. f"aws s3 cp {source} {bucket_address} "
  82. "--acl bucket-owner-full-control"
  83. ),
  84. {},
  85. )
  86. )
  87. if retcode != 0:
  88. raise FileDownloadError(f"Error downloading file {source} to {target}")
  89. self.download_from_cloud(remote_upload_to, target, delete_after_download=True)
  90. def _push_local_dir(self):
  91. remote_upload_to = self._generate_tmp_cloud_storage_path()
  92. # pack local dir
  93. _, local_path = tempfile.mkstemp()
  94. shutil.make_archive(local_path, "gztar", os.getcwd())
  95. # local source -> s3
  96. self._run_with_retry(
  97. lambda: self.s3_client.upload_file(
  98. Filename=local_path + ".tar.gz",
  99. Bucket=self.bucket,
  100. Key=remote_upload_to,
  101. )
  102. )
  103. # remove local archive
  104. os.unlink(local_path)
  105. bucket_address = f"s3://{self.bucket}/{remote_upload_to}"
  106. # s3 -> remote target
  107. retcode, _ = self.job_manager.run_and_wait(
  108. f"pip install -q awscli && "
  109. f"aws s3 cp {bucket_address} archive.tar.gz && "
  110. f"tar xf archive.tar.gz ",
  111. {},
  112. )
  113. if retcode != 0:
  114. raise FileUploadError(
  115. f"Error uploading local dir to session "
  116. f"{self.cluster_manager.cluster_name}."
  117. )
  118. try:
  119. self._run_with_retry(
  120. lambda: self.s3_client.delete_object(
  121. Bucket=self.bucket, Key=remote_upload_to
  122. ),
  123. initial_retry_delay_s=2,
  124. )
  125. except RuntimeError as e:
  126. logger.warning(f"Could not remove temporary S3 object: {e}")
  127. def upload(self, source: Optional[str] = None, target: Optional[str] = None):
  128. if source is None and target is None:
  129. self._push_local_dir()
  130. return
  131. assert isinstance(source, str)
  132. assert isinstance(target, str)
  133. remote_upload_to = self._generate_tmp_cloud_storage_path()
  134. # local source -> s3
  135. self._run_with_retry(
  136. lambda: self.s3_client.upload_file(
  137. Filename=source,
  138. Bucket=self.bucket,
  139. Key=remote_upload_to,
  140. )
  141. )
  142. # s3 -> remote target
  143. bucket_address = f"{S3_CLOUD_STORAGE}://{self.bucket}/{remote_upload_to}"
  144. retcode, _ = self.job_manager.run_and_wait(
  145. "pip install -q awscli && " f"aws s3 cp {bucket_address} {target}",
  146. {},
  147. )
  148. if retcode != 0:
  149. raise FileUploadError(f"Error uploading file {source} to {target}")
  150. self.delete(remote_upload_to)
  151. def _delete_gs_fn(self, key: str, recursive: bool = False):
  152. if recursive:
  153. blobs = self.gs_client.list_blobs(
  154. self.bucket,
  155. prefix=key,
  156. )
  157. for blob in blobs:
  158. blob.delete()
  159. else:
  160. blob = self.gs_client.bucket(self.bucket).blob(key)
  161. blob.delete()
  162. def _delete_s3_fn(self, key: str, recursive: bool = False):
  163. if recursive:
  164. response = self.s3_client.list_objects_v2(Bucket=self.bucket, Prefix=key)
  165. for object in response["Contents"]:
  166. self.s3_client.delete_object(Bucket=self.bucket, Key=object["Key"])
  167. else:
  168. self.s3_client.delete_object(Bucket=self.bucket, Key=key)
  169. def delete(self, key: str, recursive: bool = False):
  170. def delete_fn():
  171. if self.cloud_storage_provider == S3_CLOUD_STORAGE:
  172. self._delete_s3_fn(key, recursive)
  173. return
  174. if self.cloud_storage_provider == GS_CLOUD_STORAGE:
  175. self._delete_gs_fn(key, recursive)
  176. return
  177. try:
  178. self._run_with_retry(
  179. delete_fn,
  180. initial_retry_delay_s=2,
  181. )
  182. except Exception as e:
  183. logger.warning(f"Could not remove temporary cloud object: {e}")