progress_bar.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """
  2. Progress Bar for Ray Actors (tqdm)
  3. ==================================
  4. Tracking progress of distributed tasks can be tricky.
  5. This script will demonstrate how to implement a simple
  6. progress bar for a Ray actor to track progress across various
  7. different distributed components.
  8. Original source: `Link <https://github.com/votingworks/arlo-e2e>`_
  9. Setup: Dependencies
  10. -------------------
  11. First, import some dependencies.
  12. """
  13. # Inspiration: https://github.com/honnibal/spacy-ray/pull/
  14. # 1/files#diff-7ede881ddc3e8456b320afb958362b2aR12-R45
  15. from asyncio import Event
  16. from typing import Tuple
  17. from time import sleep
  18. import ray
  19. # For typing purposes
  20. from ray.actor import ActorHandle
  21. from tqdm import tqdm
  22. ############################################################
  23. # This is the Ray "actor" that can be called from anywhere to update
  24. # our progress. You'll be using the `update` method. Don't
  25. # instantiate this class yourself. Instead,
  26. # it's something that you'll get from a `ProgressBar`.
  27. @ray.remote
  28. class ProgressBarActor:
  29. counter: int
  30. delta: int
  31. event: Event
  32. def __init__(self) -> None:
  33. self.counter = 0
  34. self.delta = 0
  35. self.event = Event()
  36. def update(self, num_items_completed: int) -> None:
  37. """Updates the ProgressBar with the incremental
  38. number of items that were just completed.
  39. """
  40. self.counter += num_items_completed
  41. self.delta += num_items_completed
  42. self.event.set()
  43. async def wait_for_update(self) -> Tuple[int, int]:
  44. """Blocking call.
  45. Waits until somebody calls `update`, then returns a tuple of
  46. the number of updates since the last call to
  47. `wait_for_update`, and the total number of completed items.
  48. """
  49. await self.event.wait()
  50. self.event.clear()
  51. saved_delta = self.delta
  52. self.delta = 0
  53. return saved_delta, self.counter
  54. def get_counter(self) -> int:
  55. """
  56. Returns the total number of complete items.
  57. """
  58. return self.counter
  59. ######################################################################
  60. # This is where the progress bar starts. You create one of these
  61. # on the head node, passing in the expected total number of items,
  62. # and an optional string description.
  63. # Pass along the `actor` reference to any remote task,
  64. # and if they complete ten
  65. # tasks, they'll call `actor.update.remote(10)`.
  66. # Back on the local node, once you launch your remote Ray tasks, call
  67. # `print_until_done`, which will feed everything back into a `tqdm` counter.
  68. class ProgressBar:
  69. progress_actor: ActorHandle
  70. total: int
  71. description: str
  72. pbar: tqdm
  73. def __init__(self, total: int, description: str = ""):
  74. # Ray actors don't seem to play nice with mypy, generating
  75. # a spurious warning for the following line,
  76. # which we need to suppress. The code is fine.
  77. self.progress_actor = ProgressBarActor.remote() # type: ignore
  78. self.total = total
  79. self.description = description
  80. @property
  81. def actor(self) -> ActorHandle:
  82. """Returns a reference to the remote `ProgressBarActor`.
  83. When you complete tasks, call `update` on the actor.
  84. """
  85. return self.progress_actor
  86. def print_until_done(self) -> None:
  87. """Blocking call.
  88. Do this after starting a series of remote Ray tasks, to which you've
  89. passed the actor handle. Each of them calls `update` on the actor.
  90. When the progress meter reaches 100%, this method returns.
  91. """
  92. pbar = tqdm(desc=self.description, total=self.total)
  93. while True:
  94. delta, counter = ray.get(self.actor.wait_for_update.remote())
  95. pbar.update(delta)
  96. if counter >= self.total:
  97. pbar.close()
  98. return
  99. #################################################################
  100. # This is an example of a task that increments the progress bar.
  101. # Note that this is a Ray Task, but it could very well
  102. # be any generic Ray Actor.
  103. #
  104. @ray.remote
  105. def sleep_then_increment(i: int, pba: ActorHandle) -> int:
  106. sleep(i / 2.0)
  107. pba.update.remote(1)
  108. return i
  109. #################################################################
  110. # Now you can run it and see what happens!
  111. #
  112. def run():
  113. ray.init()
  114. num_ticks = 6
  115. pb = ProgressBar(num_ticks)
  116. actor = pb.actor
  117. # You can replace this with any arbitrary Ray task/actor.
  118. tasks_pre_launch = [
  119. sleep_then_increment.remote(i, actor) for i in range(0, num_ticks)
  120. ]
  121. pb.print_until_done()
  122. tasks = ray.get(tasks_pre_launch)
  123. tasks == list(range(num_ticks))
  124. num_ticks == ray.get(actor.get_counter.remote())
  125. run()