test_glue.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. import os
  2. import pytest
  3. import shutil
  4. import sys
  5. import tempfile
  6. import time
  7. from typing import Type, Callable, Optional
  8. import unittest
  9. from unittest.mock import patch
  10. from ray_release.alerts.handle import result_to_handle_map
  11. from ray_release.cluster_manager.cluster_manager import ClusterManager
  12. from ray_release.cluster_manager.full import FullClusterManager
  13. from ray_release.command_runner.command_runner import CommandRunner
  14. from ray_release.test import Test
  15. from ray_release.exception import (
  16. ReleaseTestConfigError,
  17. ClusterCreationError,
  18. ClusterStartupError,
  19. ClusterStartupTimeout,
  20. RemoteEnvSetupError,
  21. CommandError,
  22. PrepareCommandError,
  23. CommandTimeout,
  24. PrepareCommandTimeout,
  25. TestCommandError,
  26. TestCommandTimeout,
  27. FetchResultError,
  28. LogsError,
  29. ResultsAlert,
  30. ClusterNodesWaitTimeout,
  31. )
  32. from ray_release.file_manager.file_manager import FileManager
  33. from ray_release.glue import (
  34. run_release_test,
  35. type_str_to_command_runner,
  36. command_runner_to_cluster_manager,
  37. )
  38. from ray_release.logger import logger
  39. from ray_release.reporter.reporter import Reporter
  40. from ray_release.result import Result, ExitCode
  41. from ray_release.tests.utils import MockSDK, APIDict
  42. def _fail_on_call(error_type: Type[Exception] = RuntimeError, message: str = "Fail"):
  43. def _fail(*args, **kwargs):
  44. raise error_type(message)
  45. return _fail
  46. class MockReturn:
  47. return_dict = {}
  48. def __getattribute__(self, item):
  49. return_dict = object.__getattribute__(self, "return_dict")
  50. if item in return_dict:
  51. mocked = return_dict[item]
  52. if isinstance(mocked, Callable):
  53. return mocked()
  54. else:
  55. return lambda *a, **kw: mocked
  56. return object.__getattribute__(self, item)
  57. class MockTest(Test):
  58. def get_anyscale_byod_image(self) -> str:
  59. return ""
  60. class GlueTest(unittest.TestCase):
  61. def writeClusterEnv(self, content: str):
  62. with open(os.path.join(self.tempdir, "cluster_env.yaml"), "wt") as fp:
  63. fp.write(content)
  64. def writeClusterCompute(self, content: str):
  65. with open(os.path.join(self.tempdir, "cluster_compute.yaml"), "wt") as fp:
  66. fp.write(content)
  67. def setUp(self) -> None:
  68. self.tempdir = tempfile.mkdtemp()
  69. self.sdk = MockSDK()
  70. self.sdk.returns["get_project"] = APIDict(
  71. result=APIDict(name="unit_test_project")
  72. )
  73. self.sdk.returns["get_cloud"] = APIDict(result=APIDict(provider="AWS"))
  74. self.writeClusterEnv("{'env': true}")
  75. self.writeClusterCompute("{'compute': true}")
  76. with open(os.path.join(self.tempdir, "driver_fail.sh"), "wt") as f:
  77. f.write("exit 1\n")
  78. with open(os.path.join(self.tempdir, "driver_succeed.sh"), "wt") as f:
  79. f.write("exit 0\n")
  80. this_sdk = self.sdk
  81. this_tempdir = self.tempdir
  82. self.instances = {}
  83. self.cluster_manager_return = {}
  84. self.command_runner_return = {}
  85. self.file_manager_return = {}
  86. this_instances = self.instances
  87. this_cluster_manager_return = self.cluster_manager_return
  88. this_command_runner_return = self.command_runner_return
  89. this_file_manager_return = self.file_manager_return
  90. class MockClusterManager(MockReturn, FullClusterManager):
  91. def __init__(
  92. self,
  93. test_name: str,
  94. project_id: str,
  95. sdk=None,
  96. smoke_test: bool = False,
  97. log_streaming_limit: int = 100,
  98. ):
  99. super(MockClusterManager, self).__init__(
  100. test_name,
  101. project_id,
  102. this_sdk,
  103. smoke_test=smoke_test,
  104. log_streaming_limit=log_streaming_limit,
  105. )
  106. self.return_dict = this_cluster_manager_return
  107. this_instances["cluster_manager"] = self
  108. class MockCommandRunner(MockReturn, CommandRunner):
  109. return_dict = self.cluster_manager_return
  110. def __init__(
  111. self,
  112. cluster_manager: ClusterManager,
  113. file_manager: FileManager,
  114. working_dir,
  115. sdk=None,
  116. artifact_path: Optional[str] = None,
  117. ):
  118. super(MockCommandRunner, self).__init__(
  119. cluster_manager, file_manager, this_tempdir
  120. )
  121. self.return_dict = this_command_runner_return
  122. class MockFileManager(MockReturn, FileManager):
  123. def __init__(self, cluster_manager: ClusterManager):
  124. super(MockFileManager, self).__init__(cluster_manager)
  125. self.return_dict = this_file_manager_return
  126. self.mock_alert_return = None
  127. def mock_alerter(test: Test, result: Result):
  128. return self.mock_alert_return
  129. result_to_handle_map["unit_test_alerter"] = (mock_alerter, False)
  130. type_str_to_command_runner["unit_test"] = MockCommandRunner
  131. command_runner_to_cluster_manager[MockCommandRunner] = MockClusterManager
  132. self.test = MockTest(
  133. name="unit_test_end_to_end",
  134. run=dict(
  135. type="unit_test",
  136. prepare="prepare_cmd",
  137. script="test_cmd",
  138. wait_for_nodes=dict(num_nodes=4, timeout=40),
  139. ),
  140. working_dir=self.tempdir,
  141. cluster=dict(
  142. cluster_env="cluster_env.yaml",
  143. cluster_compute="cluster_compute.yaml",
  144. byod={},
  145. ),
  146. alert="unit_test_alerter",
  147. )
  148. self.anyscale_project = "prj_unit12345678"
  149. def tearDown(self) -> None:
  150. shutil.rmtree(self.tempdir)
  151. def _succeed_until(self, until: str):
  152. # These commands should succeed
  153. self.cluster_manager_return["cluster_compute_id"] = "valid"
  154. self.cluster_manager_return["create_cluster_compute"] = None
  155. if until == "cluster_compute":
  156. return
  157. self.cluster_manager_return["cluster_env_id"] = "valid"
  158. self.cluster_manager_return["create_cluster_env"] = None
  159. self.cluster_manager_return["cluster_env_build_id"] = "valid"
  160. self.cluster_manager_return["build_cluster_env"] = None
  161. if until == "cluster_env":
  162. return
  163. self.cluster_manager_return["cluster_id"] = "valid"
  164. self.cluster_manager_return["start_cluster"] = None
  165. if until == "cluster_start":
  166. return
  167. self.command_runner_return["prepare_remote_env"] = None
  168. if until == "remote_env":
  169. return
  170. self.command_runner_return["wait_for_nodes"] = None
  171. if until == "wait_for_nodes":
  172. return
  173. self.command_runner_return["run_prepare_command"] = None
  174. if until == "prepare_command":
  175. return
  176. self.command_runner_return["run_command"] = None
  177. if until == "test_command":
  178. return
  179. self.command_runner_return["fetch_results"] = {
  180. "time_taken": 50,
  181. "last_update": time.time() - 60,
  182. }
  183. if until == "fetch_results":
  184. return
  185. self.command_runner_return["get_last_logs_ex"] = "Lorem ipsum"
  186. if until == "get_last_logs":
  187. return
  188. self.mock_alert_return = None
  189. def _run(self, result: Result, **kwargs):
  190. run_release_test(
  191. test=self.test,
  192. anyscale_project=self.anyscale_project,
  193. result=result,
  194. log_streaming_limit=1000,
  195. **kwargs
  196. )
  197. def testInvalidClusterCompute(self):
  198. result = Result()
  199. with patch(
  200. "ray_release.glue.load_test_cluster_compute",
  201. _fail_on_call(ReleaseTestConfigError),
  202. ), self.assertRaises(ReleaseTestConfigError):
  203. self._run(result)
  204. self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
  205. # Fails because file not found
  206. os.unlink(os.path.join(self.tempdir, "cluster_compute.yaml"))
  207. with self.assertRaisesRegex(ReleaseTestConfigError, "Path not found"):
  208. self._run(result)
  209. self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
  210. # Fails because invalid jinja template
  211. self.writeClusterCompute("{{ INVALID")
  212. with self.assertRaisesRegex(ReleaseTestConfigError, "yaml template"):
  213. self._run(result)
  214. self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
  215. # Fails because invalid json
  216. self.writeClusterCompute("{'test': true, 'fail}")
  217. with self.assertRaisesRegex(ReleaseTestConfigError, "quoted scalar"):
  218. self._run(result)
  219. self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
  220. def testStartClusterFails(self):
  221. result = Result()
  222. self._succeed_until("cluster_env")
  223. # Fails because API response faulty
  224. with self.assertRaises(ClusterCreationError):
  225. self._run(result)
  226. self.assertEqual(result.return_code, ExitCode.CLUSTER_RESOURCE_ERROR.value)
  227. self.cluster_manager_return["cluster_id"] = "valid"
  228. # Fail for random cluster startup reason
  229. self.cluster_manager_return["start_cluster"] = _fail_on_call(
  230. ClusterStartupError
  231. )
  232. with self.assertRaises(ClusterStartupError):
  233. self._run(result)
  234. self.assertEqual(result.return_code, ExitCode.CLUSTER_STARTUP_ERROR.value)
  235. # Ensure cluster was terminated
  236. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  237. # Fail for cluster startup timeout
  238. self.cluster_manager_return["start_cluster"] = _fail_on_call(
  239. ClusterStartupTimeout
  240. )
  241. with self.assertRaises(ClusterStartupTimeout):
  242. self._run(result)
  243. self.assertEqual(result.return_code, ExitCode.CLUSTER_STARTUP_TIMEOUT.value)
  244. # Ensure cluster was terminated
  245. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  246. def testPrepareRemoteEnvFails(self):
  247. result = Result()
  248. self._succeed_until("cluster_start")
  249. self.command_runner_return["prepare_remote_env"] = _fail_on_call(
  250. RemoteEnvSetupError
  251. )
  252. with self.assertRaises(RemoteEnvSetupError):
  253. self._run(result)
  254. self.assertEqual(result.return_code, ExitCode.REMOTE_ENV_SETUP_ERROR.value)
  255. # Ensure cluster was terminated
  256. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  257. def testWaitForNodesFails(self):
  258. result = Result()
  259. self._succeed_until("remote_env")
  260. # Wait for nodes command fails
  261. self.command_runner_return["wait_for_nodes"] = _fail_on_call(
  262. ClusterNodesWaitTimeout
  263. )
  264. with self.assertRaises(ClusterNodesWaitTimeout):
  265. self._run(result)
  266. self.assertEqual(result.return_code, ExitCode.CLUSTER_WAIT_TIMEOUT.value)
  267. # Ensure cluster was terminated
  268. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  269. def testPrepareCommandFails(self):
  270. result = Result()
  271. self._succeed_until("wait_for_nodes")
  272. # Prepare command fails
  273. self.command_runner_return["run_prepare_command"] = _fail_on_call(CommandError)
  274. with self.assertRaises(PrepareCommandError):
  275. self._run(result)
  276. self.assertEqual(result.return_code, ExitCode.PREPARE_ERROR.value)
  277. # Prepare command times out
  278. self.command_runner_return["run_prepare_command"] = _fail_on_call(
  279. CommandTimeout
  280. )
  281. with self.assertRaises(PrepareCommandTimeout):
  282. self._run(result)
  283. # Special case: Prepare commands are usually waiting for nodes
  284. # (this may change in the future!)
  285. self.assertEqual(result.return_code, ExitCode.CLUSTER_WAIT_TIMEOUT.value)
  286. # Ensure cluster was terminated
  287. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  288. def testTestCommandFails(self):
  289. result = Result()
  290. self._succeed_until("prepare_command")
  291. # Test command fails
  292. self.command_runner_return["run_command"] = _fail_on_call(CommandError)
  293. with self.assertRaises(TestCommandError):
  294. self._run(result)
  295. self.assertEqual(result.return_code, ExitCode.COMMAND_ERROR.value)
  296. # Test command times out
  297. self.command_runner_return["run_command"] = _fail_on_call(CommandTimeout)
  298. with self.assertRaises(TestCommandTimeout):
  299. self._run(result)
  300. self.assertEqual(result.return_code, ExitCode.COMMAND_TIMEOUT.value)
  301. # Ensure cluster was terminated
  302. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  303. def testTestCommandTimeoutLongRunning(self):
  304. result = Result()
  305. self._succeed_until("fetch_results")
  306. # Test command times out
  307. self.command_runner_return["run_command"] = _fail_on_call(CommandTimeout)
  308. with self.assertRaises(TestCommandTimeout):
  309. self._run(result)
  310. self.assertEqual(result.return_code, ExitCode.COMMAND_TIMEOUT.value)
  311. # But now set test to long running
  312. self.test["run"]["long_running"] = True
  313. self._run(result) # Will not fail this time
  314. self.assertGreaterEqual(result.results["last_update_diff"], 60.0)
  315. # Ensure cluster was terminated
  316. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  317. def testSmokeUnstableTest(self):
  318. result = Result()
  319. self._succeed_until("complete")
  320. self.test["stable"] = False
  321. self._run(result, smoke_test=True)
  322. # Ensure stable and smoke_test are set correctly.
  323. assert not result.stable
  324. assert result.smoke_test
  325. def testFetchResultFails(self):
  326. result = Result()
  327. self._succeed_until("test_command")
  328. self.command_runner_return["fetch_results"] = _fail_on_call(FetchResultError)
  329. with self.assertLogs(logger, "ERROR") as cm:
  330. self._run(result)
  331. self.assertTrue(any("Could not fetch results" in o for o in cm.output))
  332. self.assertEqual(result.return_code, ExitCode.SUCCESS.value)
  333. self.assertEqual(result.status, "success")
  334. # Ensure cluster was terminated
  335. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  336. def testFetchResultFailsReqNonEmptyResult(self):
  337. # set `require_result` bit.
  338. new_handler = (result_to_handle_map["unit_test_alerter"], True)
  339. result_to_handle_map["unit_test_alerter"] = new_handler
  340. result = Result()
  341. self._succeed_until("test_command")
  342. self.command_runner_return["fetch_results"] = _fail_on_call(FetchResultError)
  343. with self.assertRaisesRegex(FetchResultError, "Fail"):
  344. with self.assertLogs(logger, "ERROR") as cm:
  345. self._run(result)
  346. self.assertTrue(any("Could not fetch results" in o for o in cm.output))
  347. self.assertEqual(result.return_code, ExitCode.FETCH_RESULT_ERROR.value)
  348. self.assertEqual(result.status, "infra_error")
  349. # Ensure cluster was terminated, no matter what
  350. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  351. def testLastLogsFails(self):
  352. result = Result()
  353. self._succeed_until("fetch_results")
  354. self.command_runner_return["get_last_logs_ex"] = _fail_on_call(LogsError)
  355. with self.assertLogs(logger, "ERROR") as cm:
  356. self._run(result)
  357. self.assertTrue(any("Error fetching logs" in o for o in cm.output))
  358. self.assertEqual(result.return_code, ExitCode.SUCCESS.value)
  359. self.assertEqual(result.status, "success")
  360. # Ensure cluster was terminated
  361. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  362. def testAlertFails(self):
  363. result = Result()
  364. self._succeed_until("get_last_logs")
  365. self.mock_alert_return = "Alert raised"
  366. with self.assertRaises(ResultsAlert):
  367. self._run(result)
  368. self.assertEqual(result.return_code, ExitCode.COMMAND_ALERT.value)
  369. self.assertEqual(result.status, "error")
  370. self.assertEqual(self.instances["cluster_manager"].log_streaming_limit, 1000)
  371. # Ensure cluster was terminated
  372. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  373. def testReportFails(self):
  374. result = Result()
  375. self._succeed_until("complete")
  376. class FailReporter(Reporter):
  377. def report_result_ex(self, test: Test, result: Result):
  378. raise RuntimeError
  379. with self.assertLogs(logger, "ERROR") as cm:
  380. self._run(result, reporters=[FailReporter()])
  381. self.assertTrue(any("Error reporting results" in o for o in cm.output))
  382. self.assertEqual(result.return_code, ExitCode.SUCCESS.value)
  383. self.assertEqual(result.status, "success")
  384. # Ensure cluster was terminated
  385. self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
  386. if __name__ == "__main__":
  387. sys.exit(pytest.main(["-v", __file__]))