test_actor_manager.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import functools
  2. import os
  3. from pathlib import Path
  4. import pickle
  5. import sys
  6. import time
  7. import unittest
  8. import ray
  9. from ray.util.state import list_actors
  10. from ray.rllib.utils.actor_manager import FaultAwareApply, FaultTolerantActorManager
  11. def load_random_numbers():
  12. """Loads deterministic random numbers from data file."""
  13. rllib_dir = Path(__file__).parent.parent.parent
  14. pkl_file = os.path.join(
  15. rllib_dir,
  16. "utils",
  17. "tests",
  18. "random_numbers.pkl",
  19. )
  20. return pickle.load(open(pkl_file, "rb"))
  21. RANDOM_NUMS = load_random_numbers()
  22. @ray.remote(max_restarts=-1)
  23. class Actor(FaultAwareApply):
  24. def __init__(self, i, maybe_crash=True):
  25. self.random_numbers = RANDOM_NUMS[i]
  26. self.count = 0
  27. self.maybe_crash = maybe_crash
  28. self.config = {
  29. "recreate_failed_workers": True,
  30. }
  31. def _maybe_crash(self):
  32. if not self.maybe_crash:
  33. return
  34. r = self.random_numbers[self.count]
  35. # 10% chance of crashing.
  36. if r < 0.1:
  37. sys.exit(1)
  38. # Another 10% chance of throwing errors.
  39. elif r < 0.2:
  40. raise AttributeError("sorry")
  41. def call(self):
  42. self.count += 1
  43. self._maybe_crash()
  44. # Otherwise, return good result.
  45. return self.count
  46. def ping(self):
  47. self._maybe_crash()
  48. return "pong"
  49. def wait_for_restore():
  50. """Wait for Ray actor fault tolerence to restore all failed actors."""
  51. while True:
  52. states = [
  53. # Wait till all actors are either "ALIVE" (retored),
  54. # or "DEAD" (cancelled. these actors are from other
  55. # finished test cases).
  56. a["state"] == "ALIVE" or a["state"] == "DEAD"
  57. for a in list_actors(filters=[("class_name", "=", "Actor")])
  58. ]
  59. print("waiting ... ", states)
  60. if all(states):
  61. break
  62. # Otherwise, wait a bit.
  63. time.sleep(0.5)
  64. class TestActorManager(unittest.TestCase):
  65. @classmethod
  66. def setUpClass(cls) -> None:
  67. ray.init()
  68. @classmethod
  69. def tearDownClass(cls) -> None:
  70. ray.shutdown()
  71. def test_sync_call_healthy_only(self):
  72. """Test synchronous remote calls to only healthy actors."""
  73. actors = [Actor.remote(i) for i in range(4)]
  74. manager = FaultTolerantActorManager(actors=actors)
  75. results = []
  76. for _ in range(10):
  77. results.extend(manager.foreach_actor(lambda w: w.call()).ignore_errors())
  78. # Wait for actors to recover.
  79. wait_for_restore()
  80. # Notice that since we only fire calls against healthy actors,
  81. # we wouldn't be aware that the actors have been recovered.
  82. # So once an actor is taken out of the lineup (10% chance),
  83. # it will not go back in, and we should have few results here.
  84. # Basically takes us 7 calls to kill all the actors.
  85. # Note that we can hardcode 10 here because we are using deterministic
  86. # sequences of random numbers.
  87. self.assertEqual(len(results), 7)
  88. manager.clear()
  89. def test_sync_call_all_actors(self):
  90. """Test synchronous remote calls to all actors, regardless of their states."""
  91. actors = [Actor.remote(i) for i in range(4)]
  92. manager = FaultTolerantActorManager(actors=actors)
  93. results = []
  94. for _ in range(10):
  95. # Make sure we have latest states of all actors.
  96. results.extend(
  97. manager.foreach_actor(lambda w: w.call(), healthy_only=False)
  98. )
  99. # Wait for actors to recover.
  100. wait_for_restore()
  101. # We fired against all actors regardless of their status.
  102. # So we should get 40 results back.
  103. self.assertEqual(len(results), 40)
  104. # Since the actors are always restored before next round of calls,
  105. # we should get more results back.
  106. # Some of these calls still failed, but 15 good results in total.
  107. # Note that we can hardcode 15 here because we are using deterministic
  108. # sequences of random numbers.
  109. self.assertEqual(len([r for r in results if r.ok]), 15)
  110. manager.clear()
  111. def test_sync_call_return_obj_refs(self):
  112. """Test synchronous remote calls to all actors asking for raw ObjectRefs."""
  113. actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
  114. manager = FaultTolerantActorManager(actors=actors)
  115. results = list(
  116. manager.foreach_actor(
  117. lambda w: w.call(),
  118. healthy_only=False,
  119. return_obj_refs=True,
  120. )
  121. )
  122. # We fired against all actors regardless of their status.
  123. # So we should get 40 results back.
  124. self.assertEqual(len(results), 4)
  125. for r in results:
  126. # Each result is an ObjectRef.
  127. self.assertTrue(r.ok)
  128. self.assertTrue(isinstance(r.get(), ray.ObjectRef))
  129. manager.clear()
  130. def test_sync_call_fire_and_forget(self):
  131. """Test synchronous remote calls with 0 timeout_seconds."""
  132. actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
  133. manager = FaultTolerantActorManager(actors=actors)
  134. results1 = []
  135. for _ in range(10):
  136. manager.probe_unhealthy_actors(mark_healthy=True)
  137. results1.extend(
  138. manager.foreach_actor(lambda w: w.call(), timeout_seconds=0)
  139. )
  140. # Wait for actors to recover.
  141. wait_for_restore()
  142. # Timeout is 0, so we returned immediately.
  143. # We may get a couple of results back if the calls are fast,
  144. # but that is not important.
  145. results2 = [
  146. r.get()
  147. for r in manager.foreach_actor(
  148. lambda w: w.call(), healthy_only=False
  149. ).ignore_errors()
  150. ]
  151. # Results from blocking calls show the # of calls happend on
  152. # each remote actor. 11 calls to each actor in total.
  153. self.assertEqual(results2, [11, 11, 11, 11])
  154. manager.clear()
  155. def test_sync_call_same_actor_multiple_times(self):
  156. """Test multiple synchronous remote calls to the same actor."""
  157. actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
  158. manager = FaultTolerantActorManager(actors=actors)
  159. # 2 synchronous call to actor 0.
  160. results = manager.foreach_actor(
  161. lambda w: w.call(),
  162. remote_actor_ids=[0, 0],
  163. )
  164. # Returns 1 and 2, representing the first and second calls to actor 0.
  165. self.assertEqual([r.get() for r in results.ignore_errors()], [1, 2])
  166. manager.clear()
  167. def test_async_call_same_actor_multiple_times(self):
  168. """Test multiple asynchronous remote calls to the same actor."""
  169. actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
  170. manager = FaultTolerantActorManager(actors=actors)
  171. # 2 asynchronous call to actor 0.
  172. num_of_calls = manager.foreach_actor_async(
  173. lambda w: w.call(),
  174. remote_actor_ids=[0, 0],
  175. )
  176. self.assertEqual(num_of_calls, 2)
  177. # Now, let's actually fetch the results.
  178. results = manager.fetch_ready_async_reqs(timeout_seconds=None)
  179. # Returns 1 and 2, representing the first and second calls to actor 0.
  180. self.assertEqual([r.get() for r in results.ignore_errors()], [1, 2])
  181. manager.clear()
  182. def test_sync_call_not_ignore_error(self):
  183. """Test synchronous remote calls that returns errors."""
  184. actors = [Actor.remote(i) for i in range(4)]
  185. manager = FaultTolerantActorManager(actors=actors)
  186. results = []
  187. for _ in range(10):
  188. manager.probe_unhealthy_actors(mark_healthy=True)
  189. results.extend(manager.foreach_actor(lambda w: w.call()))
  190. # Wait for actors to recover.
  191. wait_for_restore()
  192. # Some calls did error out.
  193. self.assertTrue(any([not r.ok for r in results]))
  194. manager.clear()
  195. def test_sync_call_not_bringing_back_actors(self):
  196. """Test successful remote calls will not bring back actors unless told to."""
  197. actors = [Actor.remote(i) for i in range(4)]
  198. manager = FaultTolerantActorManager(actors=actors)
  199. results = manager.foreach_actor(lambda w: w.call())
  200. # Some calls did error out.
  201. self.assertTrue(any([not r.ok for r in results]))
  202. # Wait for actors to recover.
  203. wait_for_restore()
  204. manager.probe_unhealthy_actors()
  205. # Restored actors are not marked healthy if we just do probing.
  206. # Only 2 healthy actors.
  207. self.assertEqual(manager.num_healthy_actors(), 2)
  208. manager.clear()
  209. def test_async_call(self):
  210. """Test asynchronous remote calls work."""
  211. actors = [Actor.remote(i) for i in range(4)]
  212. manager = FaultTolerantActorManager(actors=actors)
  213. results = []
  214. for _ in range(10):
  215. manager.foreach_actor_async(lambda w: w.call())
  216. results.extend(manager.fetch_ready_async_reqs(timeout_seconds=None))
  217. # Wait for actors to recover.
  218. wait_for_restore()
  219. # Note that we can hardcode the numbers here because of the deterministic
  220. # lists of random numbers we use.
  221. # 7 calls succeeded, 4 failed.
  222. # The number of results back is much lower than 40, because we do not probe
  223. # the actors with this test. As soon as an actor errors out, it will get
  224. # taken out of the lineup forever.
  225. self.assertEqual(len([r for r in results if r.ok]), 7)
  226. self.assertEqual(len([r for r in results if not r.ok]), 4)
  227. manager.clear()
  228. def test_async_calls_get_dropped_if_inflight_requests_over_limit(self):
  229. """Test asynchronous remote calls get dropped if too many in-flight calls."""
  230. actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
  231. manager = FaultTolerantActorManager(
  232. actors=actors,
  233. max_remote_requests_in_flight_per_actor=2,
  234. )
  235. # 2 asynchronous call to actor 1.
  236. num_of_calls = manager.foreach_actor_async(
  237. lambda w: w.call(),
  238. remote_actor_ids=[0, 0],
  239. )
  240. self.assertEqual(num_of_calls, 2)
  241. # Now, let's try to make another async call to actor 1.
  242. num_of_calls = manager.foreach_actor_async(
  243. lambda w: w.call(),
  244. healthy_only=False,
  245. remote_actor_ids=[0],
  246. )
  247. # We actually made 0 calls.
  248. self.assertEqual(num_of_calls, 0)
  249. manager.clear()
  250. def test_healthy_only_works_for_list_of_functions(self):
  251. """Test healthy only mode works when a list of funcs are provided."""
  252. actors = [Actor.remote(i) for i in range(4)]
  253. manager = FaultTolerantActorManager(actors=actors)
  254. # Mark first and second actor as unhealthy.
  255. manager.set_actor_state(1, False)
  256. manager.set_actor_state(2, False)
  257. def f(id, _):
  258. return id
  259. func = [functools.partial(f, i) for i in range(4)]
  260. manager.foreach_actor_async(func, healthy_only=True)
  261. results = manager.fetch_ready_async_reqs(timeout_seconds=None)
  262. # Should get results back from calling actor 0 and 3.
  263. self.assertEqual([r.get() for r in results], [0, 3])
  264. manager.clear()
  265. def test_len_of_func_not_match_len_of_actors(self):
  266. """Test healthy only mode works when a list of funcs are provided."""
  267. actors = [Actor.remote(i) for i in range(4)]
  268. manager = FaultTolerantActorManager(actors=actors)
  269. def f(id, _):
  270. return id
  271. func = [functools.partial(f, i) for i in range(3)]
  272. with self.assertRaisesRegexp(AssertionError, "same number of callables") as _:
  273. manager.foreach_actor_async(func, healthy_only=True),
  274. manager.clear()
  275. def test_probe_unhealthy_actors(self):
  276. """Test probe brings back unhealthy actors."""
  277. actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
  278. manager = FaultTolerantActorManager(actors=actors)
  279. # Mark first and second actor as unhealthy.
  280. manager.set_actor_state(1, False)
  281. manager.set_actor_state(2, False)
  282. # These actors are actually healthy.
  283. manager.probe_unhealthy_actors(mark_healthy=True)
  284. # Both actors are now healthy.
  285. self.assertEqual(len(manager.healthy_actor_ids()), 4)
  286. def test_tags(self):
  287. """Test that tags work for async calls."""
  288. actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
  289. manager = FaultTolerantActorManager(actors=actors)
  290. manager.foreach_actor_async(lambda w: w.ping(), tag="pingpong")
  291. manager.foreach_actor_async(lambda w: w.call(), tag="call")
  292. time.sleep(1)
  293. results_ping_pong = manager.fetch_ready_async_reqs(
  294. tags="pingpong", timeout_seconds=5
  295. )
  296. results_call = manager.fetch_ready_async_reqs(tags="call", timeout_seconds=5)
  297. self.assertEquals(len(list(results_ping_pong)), 4)
  298. self.assertEquals(len(list(results_call)), 4)
  299. for result in results_ping_pong:
  300. data = result.get()
  301. self.assertEqual(data, "pong")
  302. self.assertEqual(result.tag, "pingpong")
  303. for result in results_call:
  304. data = result.get()
  305. self.assertEqual(data, 1)
  306. self.assertEqual(result.tag, "call")
  307. # test with default tag
  308. manager.foreach_actor_async(lambda w: w.ping())
  309. manager.foreach_actor_async(lambda w: w.call())
  310. time.sleep(1)
  311. results = manager.fetch_ready_async_reqs(timeout_seconds=5)
  312. self.assertEquals(len(list(results)), 8)
  313. for result in results:
  314. data = result.get()
  315. self.assertEqual(result.tag, None)
  316. if isinstance(data, str):
  317. self.assertEqual(data, "pong")
  318. elif isinstance(data, int):
  319. self.assertEqual(data, 2)
  320. else:
  321. raise ValueError("data is not str or int")
  322. # test with custom tags
  323. manager.foreach_actor_async(lambda w: w.ping(), tag="pingpong")
  324. manager.foreach_actor_async(lambda w: w.call(), tag="call")
  325. time.sleep(1)
  326. results = manager.fetch_ready_async_reqs(
  327. timeout_seconds=5, tags=["pingpong", "call"]
  328. )
  329. self.assertEquals(len(list(results)), 8)
  330. for result in results:
  331. data = result.get()
  332. if isinstance(data, str):
  333. self.assertEqual(data, "pong")
  334. self.assertEqual(result.tag, "pingpong")
  335. elif isinstance(data, int):
  336. self.assertEqual(data, 3)
  337. self.assertEqual(result.tag, "call")
  338. else:
  339. raise ValueError("data is not str or int")
  340. # test with incorrect tags
  341. manager.foreach_actor_async(lambda w: w.ping(), tag="pingpong")
  342. manager.foreach_actor_async(lambda w: w.call(), tag="call")
  343. time.sleep(1)
  344. results = manager.fetch_ready_async_reqs(timeout_seconds=5, tags=["incorrect"])
  345. self.assertEquals(len(list(results)), 0)
  346. # now test that passing no tags still gives back all of the results
  347. results = manager.fetch_ready_async_reqs(timeout_seconds=5)
  348. self.assertEquals(len(list(results)), 8)
  349. for result in results:
  350. data = result.get()
  351. if isinstance(data, str):
  352. self.assertEqual(data, "pong")
  353. self.assertEqual(result.tag, "pingpong")
  354. elif isinstance(data, int):
  355. self.assertEqual(data, 4)
  356. self.assertEqual(result.tag, "call")
  357. else:
  358. raise ValueError("result is not str or int")
  359. if __name__ == "__main__":
  360. import pytest
  361. sys.exit(pytest.main(["-v", __file__]))