123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445 |
- import functools
- import os
- from pathlib import Path
- import pickle
- import sys
- import time
- import unittest
- import ray
- from ray.util.state import list_actors
- from ray.rllib.utils.actor_manager import FaultAwareApply, FaultTolerantActorManager
- def load_random_numbers():
- """Loads deterministic random numbers from data file."""
- rllib_dir = Path(__file__).parent.parent.parent
- pkl_file = os.path.join(
- rllib_dir,
- "utils",
- "tests",
- "random_numbers.pkl",
- )
- return pickle.load(open(pkl_file, "rb"))
- RANDOM_NUMS = load_random_numbers()
- @ray.remote(max_restarts=-1)
- class Actor(FaultAwareApply):
- def __init__(self, i, maybe_crash=True):
- self.random_numbers = RANDOM_NUMS[i]
- self.count = 0
- self.maybe_crash = maybe_crash
- self.config = {
- "recreate_failed_workers": True,
- }
- def _maybe_crash(self):
- if not self.maybe_crash:
- return
- r = self.random_numbers[self.count]
- # 10% chance of crashing.
- if r < 0.1:
- sys.exit(1)
- # Another 10% chance of throwing errors.
- elif r < 0.2:
- raise AttributeError("sorry")
- def call(self):
- self.count += 1
- self._maybe_crash()
- # Otherwise, return good result.
- return self.count
- def ping(self):
- self._maybe_crash()
- return "pong"
- def wait_for_restore():
- """Wait for Ray actor fault tolerence to restore all failed actors."""
- while True:
- states = [
- # Wait till all actors are either "ALIVE" (retored),
- # or "DEAD" (cancelled. these actors are from other
- # finished test cases).
- a["state"] == "ALIVE" or a["state"] == "DEAD"
- for a in list_actors(filters=[("class_name", "=", "Actor")])
- ]
- print("waiting ... ", states)
- if all(states):
- break
- # Otherwise, wait a bit.
- time.sleep(0.5)
- class TestActorManager(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- ray.init()
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_sync_call_healthy_only(self):
- """Test synchronous remote calls to only healthy actors."""
- actors = [Actor.remote(i) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- results = []
- for _ in range(10):
- results.extend(manager.foreach_actor(lambda w: w.call()).ignore_errors())
- # Wait for actors to recover.
- wait_for_restore()
- # Notice that since we only fire calls against healthy actors,
- # we wouldn't be aware that the actors have been recovered.
- # So once an actor is taken out of the lineup (10% chance),
- # it will not go back in, and we should have few results here.
- # Basically takes us 7 calls to kill all the actors.
- # Note that we can hardcode 10 here because we are using deterministic
- # sequences of random numbers.
- self.assertEqual(len(results), 7)
- manager.clear()
- def test_sync_call_all_actors(self):
- """Test synchronous remote calls to all actors, regardless of their states."""
- actors = [Actor.remote(i) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- results = []
- for _ in range(10):
- # Make sure we have latest states of all actors.
- results.extend(
- manager.foreach_actor(lambda w: w.call(), healthy_only=False)
- )
- # Wait for actors to recover.
- wait_for_restore()
- # We fired against all actors regardless of their status.
- # So we should get 40 results back.
- self.assertEqual(len(results), 40)
- # Since the actors are always restored before next round of calls,
- # we should get more results back.
- # Some of these calls still failed, but 15 good results in total.
- # Note that we can hardcode 15 here because we are using deterministic
- # sequences of random numbers.
- self.assertEqual(len([r for r in results if r.ok]), 15)
- manager.clear()
- def test_sync_call_return_obj_refs(self):
- """Test synchronous remote calls to all actors asking for raw ObjectRefs."""
- actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- results = list(
- manager.foreach_actor(
- lambda w: w.call(),
- healthy_only=False,
- return_obj_refs=True,
- )
- )
- # We fired against all actors regardless of their status.
- # So we should get 40 results back.
- self.assertEqual(len(results), 4)
- for r in results:
- # Each result is an ObjectRef.
- self.assertTrue(r.ok)
- self.assertTrue(isinstance(r.get(), ray.ObjectRef))
- manager.clear()
- def test_sync_call_fire_and_forget(self):
- """Test synchronous remote calls with 0 timeout_seconds."""
- actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- results1 = []
- for _ in range(10):
- manager.probe_unhealthy_actors(mark_healthy=True)
- results1.extend(
- manager.foreach_actor(lambda w: w.call(), timeout_seconds=0)
- )
- # Wait for actors to recover.
- wait_for_restore()
- # Timeout is 0, so we returned immediately.
- # We may get a couple of results back if the calls are fast,
- # but that is not important.
- results2 = [
- r.get()
- for r in manager.foreach_actor(
- lambda w: w.call(), healthy_only=False
- ).ignore_errors()
- ]
- # Results from blocking calls show the # of calls happend on
- # each remote actor. 11 calls to each actor in total.
- self.assertEqual(results2, [11, 11, 11, 11])
- manager.clear()
- def test_sync_call_same_actor_multiple_times(self):
- """Test multiple synchronous remote calls to the same actor."""
- actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- # 2 synchronous call to actor 0.
- results = manager.foreach_actor(
- lambda w: w.call(),
- remote_actor_ids=[0, 0],
- )
- # Returns 1 and 2, representing the first and second calls to actor 0.
- self.assertEqual([r.get() for r in results.ignore_errors()], [1, 2])
- manager.clear()
- def test_async_call_same_actor_multiple_times(self):
- """Test multiple asynchronous remote calls to the same actor."""
- actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- # 2 asynchronous call to actor 0.
- num_of_calls = manager.foreach_actor_async(
- lambda w: w.call(),
- remote_actor_ids=[0, 0],
- )
- self.assertEqual(num_of_calls, 2)
- # Now, let's actually fetch the results.
- results = manager.fetch_ready_async_reqs(timeout_seconds=None)
- # Returns 1 and 2, representing the first and second calls to actor 0.
- self.assertEqual([r.get() for r in results.ignore_errors()], [1, 2])
- manager.clear()
- def test_sync_call_not_ignore_error(self):
- """Test synchronous remote calls that returns errors."""
- actors = [Actor.remote(i) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- results = []
- for _ in range(10):
- manager.probe_unhealthy_actors(mark_healthy=True)
- results.extend(manager.foreach_actor(lambda w: w.call()))
- # Wait for actors to recover.
- wait_for_restore()
- # Some calls did error out.
- self.assertTrue(any([not r.ok for r in results]))
- manager.clear()
- def test_sync_call_not_bringing_back_actors(self):
- """Test successful remote calls will not bring back actors unless told to."""
- actors = [Actor.remote(i) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- results = manager.foreach_actor(lambda w: w.call())
- # Some calls did error out.
- self.assertTrue(any([not r.ok for r in results]))
- # Wait for actors to recover.
- wait_for_restore()
- manager.probe_unhealthy_actors()
- # Restored actors are not marked healthy if we just do probing.
- # Only 2 healthy actors.
- self.assertEqual(manager.num_healthy_actors(), 2)
- manager.clear()
- def test_async_call(self):
- """Test asynchronous remote calls work."""
- actors = [Actor.remote(i) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- results = []
- for _ in range(10):
- manager.foreach_actor_async(lambda w: w.call())
- results.extend(manager.fetch_ready_async_reqs(timeout_seconds=None))
- # Wait for actors to recover.
- wait_for_restore()
- # Note that we can hardcode the numbers here because of the deterministic
- # lists of random numbers we use.
- # 7 calls succeeded, 4 failed.
- # The number of results back is much lower than 40, because we do not probe
- # the actors with this test. As soon as an actor errors out, it will get
- # taken out of the lineup forever.
- self.assertEqual(len([r for r in results if r.ok]), 7)
- self.assertEqual(len([r for r in results if not r.ok]), 4)
- manager.clear()
- def test_async_calls_get_dropped_if_inflight_requests_over_limit(self):
- """Test asynchronous remote calls get dropped if too many in-flight calls."""
- actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
- manager = FaultTolerantActorManager(
- actors=actors,
- max_remote_requests_in_flight_per_actor=2,
- )
- # 2 asynchronous call to actor 1.
- num_of_calls = manager.foreach_actor_async(
- lambda w: w.call(),
- remote_actor_ids=[0, 0],
- )
- self.assertEqual(num_of_calls, 2)
- # Now, let's try to make another async call to actor 1.
- num_of_calls = manager.foreach_actor_async(
- lambda w: w.call(),
- healthy_only=False,
- remote_actor_ids=[0],
- )
- # We actually made 0 calls.
- self.assertEqual(num_of_calls, 0)
- manager.clear()
- def test_healthy_only_works_for_list_of_functions(self):
- """Test healthy only mode works when a list of funcs are provided."""
- actors = [Actor.remote(i) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- # Mark first and second actor as unhealthy.
- manager.set_actor_state(1, False)
- manager.set_actor_state(2, False)
- def f(id, _):
- return id
- func = [functools.partial(f, i) for i in range(4)]
- manager.foreach_actor_async(func, healthy_only=True)
- results = manager.fetch_ready_async_reqs(timeout_seconds=None)
- # Should get results back from calling actor 0 and 3.
- self.assertEqual([r.get() for r in results], [0, 3])
- manager.clear()
- def test_len_of_func_not_match_len_of_actors(self):
- """Test healthy only mode works when a list of funcs are provided."""
- actors = [Actor.remote(i) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- def f(id, _):
- return id
- func = [functools.partial(f, i) for i in range(3)]
- with self.assertRaisesRegexp(AssertionError, "same number of callables") as _:
- manager.foreach_actor_async(func, healthy_only=True),
- manager.clear()
- def test_probe_unhealthy_actors(self):
- """Test probe brings back unhealthy actors."""
- actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- # Mark first and second actor as unhealthy.
- manager.set_actor_state(1, False)
- manager.set_actor_state(2, False)
- # These actors are actually healthy.
- manager.probe_unhealthy_actors(mark_healthy=True)
- # Both actors are now healthy.
- self.assertEqual(len(manager.healthy_actor_ids()), 4)
- def test_tags(self):
- """Test that tags work for async calls."""
- actors = [Actor.remote(i, maybe_crash=False) for i in range(4)]
- manager = FaultTolerantActorManager(actors=actors)
- manager.foreach_actor_async(lambda w: w.ping(), tag="pingpong")
- manager.foreach_actor_async(lambda w: w.call(), tag="call")
- time.sleep(1)
- results_ping_pong = manager.fetch_ready_async_reqs(
- tags="pingpong", timeout_seconds=5
- )
- results_call = manager.fetch_ready_async_reqs(tags="call", timeout_seconds=5)
- self.assertEquals(len(list(results_ping_pong)), 4)
- self.assertEquals(len(list(results_call)), 4)
- for result in results_ping_pong:
- data = result.get()
- self.assertEqual(data, "pong")
- self.assertEqual(result.tag, "pingpong")
- for result in results_call:
- data = result.get()
- self.assertEqual(data, 1)
- self.assertEqual(result.tag, "call")
- # test with default tag
- manager.foreach_actor_async(lambda w: w.ping())
- manager.foreach_actor_async(lambda w: w.call())
- time.sleep(1)
- results = manager.fetch_ready_async_reqs(timeout_seconds=5)
- self.assertEquals(len(list(results)), 8)
- for result in results:
- data = result.get()
- self.assertEqual(result.tag, None)
- if isinstance(data, str):
- self.assertEqual(data, "pong")
- elif isinstance(data, int):
- self.assertEqual(data, 2)
- else:
- raise ValueError("data is not str or int")
- # test with custom tags
- manager.foreach_actor_async(lambda w: w.ping(), tag="pingpong")
- manager.foreach_actor_async(lambda w: w.call(), tag="call")
- time.sleep(1)
- results = manager.fetch_ready_async_reqs(
- timeout_seconds=5, tags=["pingpong", "call"]
- )
- self.assertEquals(len(list(results)), 8)
- for result in results:
- data = result.get()
- if isinstance(data, str):
- self.assertEqual(data, "pong")
- self.assertEqual(result.tag, "pingpong")
- elif isinstance(data, int):
- self.assertEqual(data, 3)
- self.assertEqual(result.tag, "call")
- else:
- raise ValueError("data is not str or int")
- # test with incorrect tags
- manager.foreach_actor_async(lambda w: w.ping(), tag="pingpong")
- manager.foreach_actor_async(lambda w: w.call(), tag="call")
- time.sleep(1)
- results = manager.fetch_ready_async_reqs(timeout_seconds=5, tags=["incorrect"])
- self.assertEquals(len(list(results)), 0)
- # now test that passing no tags still gives back all of the results
- results = manager.fetch_ready_async_reqs(timeout_seconds=5)
- self.assertEquals(len(list(results)), 8)
- for result in results:
- data = result.get()
- if isinstance(data, str):
- self.assertEqual(data, "pong")
- self.assertEqual(result.tag, "pingpong")
- elif isinstance(data, int):
- self.assertEqual(data, 4)
- self.assertEqual(result.tag, "call")
- else:
- raise ValueError("result is not str or int")
- if __name__ == "__main__":
- import pytest
- sys.exit(pytest.main(["-v", __file__]))
|