test_taskpool.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import unittest
  2. from unittest.mock import patch
  3. import ray
  4. from ray.rllib.utils.actors import TaskPool
  5. def createMockWorkerAndObjectRef(obj_ref):
  6. return ({obj_ref: 1}, obj_ref)
  7. class TaskPoolTest(unittest.TestCase):
  8. @patch("ray.wait")
  9. def test_completed_prefetch_yieldsAllComplete(self, rayWaitMock):
  10. task1 = createMockWorkerAndObjectRef(1)
  11. task2 = createMockWorkerAndObjectRef(2)
  12. # Return the second task as complete and the first as pending
  13. rayWaitMock.return_value = ([2], [1])
  14. pool = TaskPool()
  15. pool.add(*task1)
  16. pool.add(*task2)
  17. fetched = list(pool.completed_prefetch())
  18. self.assertListEqual(fetched, [task2])
  19. @patch("ray.wait")
  20. def test_completed_prefetch_yieldsAllCompleteUpToDefaultLimit(self, rayWaitMock):
  21. # Load the pool with 1000 tasks, mock them all as complete and then
  22. # check that the first call to completed_prefetch only yields 999
  23. # items and the second call yields the final one
  24. pool = TaskPool()
  25. for i in range(1000):
  26. task = createMockWorkerAndObjectRef(i)
  27. pool.add(*task)
  28. rayWaitMock.return_value = (list(range(1000)), [])
  29. # For this test, we're only checking the object refs
  30. fetched = [pair[1] for pair in pool.completed_prefetch()]
  31. self.assertListEqual(fetched, list(range(999)))
  32. # Finally, check the next iteration returns the final taks
  33. fetched = [pair[1] for pair in pool.completed_prefetch()]
  34. self.assertListEqual(fetched, [999])
  35. @patch("ray.wait")
  36. def test_completed_prefetch_yieldsAllCompleteUpToSpecifiedLimit(self, rayWaitMock):
  37. # Load the pool with 1000 tasks, mock them all as complete and then
  38. # check that the first call to completed_prefetch only yield 999 items
  39. # and the second call yields the final one
  40. pool = TaskPool()
  41. for i in range(1000):
  42. task = createMockWorkerAndObjectRef(i)
  43. pool.add(*task)
  44. rayWaitMock.return_value = (list(range(1000)), [])
  45. # Verify that only the first 500 tasks are returned, this should leave
  46. # some tasks in the _fetching deque for later
  47. fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=500)]
  48. self.assertListEqual(fetched, list(range(500)))
  49. # Finally, check the next iteration returns the remaining tasks
  50. fetched = [pair[1] for pair in pool.completed_prefetch()]
  51. self.assertListEqual(fetched, list(range(500, 1000)))
  52. @patch("ray.wait")
  53. def test_completed_prefetch_yieldsRemainingIfIterationStops(self, rayWaitMock):
  54. # Test for issue #7106
  55. # In versions of Ray up to 0.8.1, if the pre-fetch generator failed to
  56. # run to completion, then the TaskPool would fail to clear up already
  57. # fetched tasks resulting in stale object refs being returned
  58. pool = TaskPool()
  59. for i in range(10):
  60. task = createMockWorkerAndObjectRef(i)
  61. pool.add(*task)
  62. rayWaitMock.return_value = (list(range(10)), [])
  63. # This should fetch just the first item in the list
  64. try:
  65. for _ in pool.completed_prefetch():
  66. # Simulate a worker failure returned by ray.get()
  67. raise ray.exceptions.RayError
  68. except ray.exceptions.RayError:
  69. pass
  70. # This fetch should return the remaining pre-fetched tasks
  71. fetched = [pair[1] for pair in pool.completed_prefetch()]
  72. self.assertListEqual(fetched, list(range(1, 10)))
  73. @patch("ray.wait")
  74. def test_reset_workers_pendingFetchesFromFailedWorkersRemoved(self, rayWaitMock):
  75. pool = TaskPool()
  76. # We need to hold onto the tasks for this test so that we can fail a
  77. # specific worker
  78. tasks = []
  79. for i in range(10):
  80. task = createMockWorkerAndObjectRef(i)
  81. pool.add(*task)
  82. tasks.append(task)
  83. # Simulate only some of the work being complete and fetch a couple of
  84. # tasks in order to fill the fetching queue
  85. rayWaitMock.return_value = ([0, 1, 2, 3, 4, 5], [6, 7, 8, 9])
  86. fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=2)]
  87. # As we still have some pending tasks, we need to update the
  88. # completion states to remove the completed tasks
  89. rayWaitMock.return_value = ([], [6, 7, 8, 9])
  90. pool.reset_workers(
  91. [
  92. tasks[0][0],
  93. tasks[1][0],
  94. tasks[2][0],
  95. tasks[3][0],
  96. # OH NO! WORKER 4 HAS CRASHED!
  97. tasks[5][0],
  98. tasks[6][0],
  99. tasks[7][0],
  100. tasks[8][0],
  101. tasks[9][0],
  102. ]
  103. )
  104. # Fetch the remaining tasks which should already be in the _fetching
  105. # queue
  106. fetched = [pair[1] for pair in pool.completed_prefetch()]
  107. self.assertListEqual(fetched, [2, 3, 5])
  108. if __name__ == "__main__":
  109. import pytest
  110. import sys
  111. sys.exit(pytest.main(["-v", __file__]))