123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- import unittest
- from unittest.mock import patch
- import ray
- from ray.rllib.utils.actors import TaskPool
- def createMockWorkerAndObjectRef(obj_ref):
- return ({obj_ref: 1}, obj_ref)
- class TaskPoolTest(unittest.TestCase):
- @patch("ray.wait")
- def test_completed_prefetch_yieldsAllComplete(self, rayWaitMock):
- task1 = createMockWorkerAndObjectRef(1)
- task2 = createMockWorkerAndObjectRef(2)
- # Return the second task as complete and the first as pending
- rayWaitMock.return_value = ([2], [1])
- pool = TaskPool()
- pool.add(*task1)
- pool.add(*task2)
- fetched = list(pool.completed_prefetch())
- self.assertListEqual(fetched, [task2])
- @patch("ray.wait")
- def test_completed_prefetch_yieldsAllCompleteUpToDefaultLimit(self, rayWaitMock):
- # Load the pool with 1000 tasks, mock them all as complete and then
- # check that the first call to completed_prefetch only yields 999
- # items and the second call yields the final one
- pool = TaskPool()
- for i in range(1000):
- task = createMockWorkerAndObjectRef(i)
- pool.add(*task)
- rayWaitMock.return_value = (list(range(1000)), [])
- # For this test, we're only checking the object refs
- fetched = [pair[1] for pair in pool.completed_prefetch()]
- self.assertListEqual(fetched, list(range(999)))
- # Finally, check the next iteration returns the final taks
- fetched = [pair[1] for pair in pool.completed_prefetch()]
- self.assertListEqual(fetched, [999])
- @patch("ray.wait")
- def test_completed_prefetch_yieldsAllCompleteUpToSpecifiedLimit(self, rayWaitMock):
- # Load the pool with 1000 tasks, mock them all as complete and then
- # check that the first call to completed_prefetch only yield 999 items
- # and the second call yields the final one
- pool = TaskPool()
- for i in range(1000):
- task = createMockWorkerAndObjectRef(i)
- pool.add(*task)
- rayWaitMock.return_value = (list(range(1000)), [])
- # Verify that only the first 500 tasks are returned, this should leave
- # some tasks in the _fetching deque for later
- fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=500)]
- self.assertListEqual(fetched, list(range(500)))
- # Finally, check the next iteration returns the remaining tasks
- fetched = [pair[1] for pair in pool.completed_prefetch()]
- self.assertListEqual(fetched, list(range(500, 1000)))
- @patch("ray.wait")
- def test_completed_prefetch_yieldsRemainingIfIterationStops(self, rayWaitMock):
- # Test for issue #7106
- # In versions of Ray up to 0.8.1, if the pre-fetch generator failed to
- # run to completion, then the TaskPool would fail to clear up already
- # fetched tasks resulting in stale object refs being returned
- pool = TaskPool()
- for i in range(10):
- task = createMockWorkerAndObjectRef(i)
- pool.add(*task)
- rayWaitMock.return_value = (list(range(10)), [])
- # This should fetch just the first item in the list
- try:
- for _ in pool.completed_prefetch():
- # Simulate a worker failure returned by ray.get()
- raise ray.exceptions.RayError
- except ray.exceptions.RayError:
- pass
- # This fetch should return the remaining pre-fetched tasks
- fetched = [pair[1] for pair in pool.completed_prefetch()]
- self.assertListEqual(fetched, list(range(1, 10)))
- @patch("ray.wait")
- def test_reset_workers_pendingFetchesFromFailedWorkersRemoved(self, rayWaitMock):
- pool = TaskPool()
- # We need to hold onto the tasks for this test so that we can fail a
- # specific worker
- tasks = []
- for i in range(10):
- task = createMockWorkerAndObjectRef(i)
- pool.add(*task)
- tasks.append(task)
- # Simulate only some of the work being complete and fetch a couple of
- # tasks in order to fill the fetching queue
- rayWaitMock.return_value = ([0, 1, 2, 3, 4, 5], [6, 7, 8, 9])
- fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=2)]
- # As we still have some pending tasks, we need to update the
- # completion states to remove the completed tasks
- rayWaitMock.return_value = ([], [6, 7, 8, 9])
- pool.reset_workers(
- [
- tasks[0][0],
- tasks[1][0],
- tasks[2][0],
- tasks[3][0],
- # OH NO! WORKER 4 HAS CRASHED!
- tasks[5][0],
- tasks[6][0],
- tasks[7][0],
- tasks[8][0],
- tasks[9][0],
- ]
- )
- # Fetch the remaining tasks which should already be in the _fetching
- # queue
- fetched = [pair[1] for pair in pool.completed_prefetch()]
- self.assertListEqual(fetched, [2, 3, 5])
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|