123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- """
- Progress Bar for Ray Actors (tqdm)
- ==================================
- Tracking progress of distributed tasks can be tricky.
- This script will demonstrate how to implement a simple
- progress bar for a Ray actor to track progress across various
- different distributed components.
- Original source: `Link <https://github.com/votingworks/arlo-e2e>`_
- Setup: Dependencies
- -------------------
- First, import some dependencies.
- """
- # Inspiration: https://github.com/honnibal/spacy-ray/pull/
- # 1/files#diff-7ede881ddc3e8456b320afb958362b2aR12-R45
- from asyncio import Event
- from typing import Tuple
- from time import sleep
- import ray
- # For typing purposes
- from ray.actor import ActorHandle
- from tqdm import tqdm
- ############################################################
- # This is the Ray "actor" that can be called from anywhere to update
- # our progress. You'll be using the `update` method. Don't
- # instantiate this class yourself. Instead,
- # it's something that you'll get from a `ProgressBar`.
- @ray.remote
- class ProgressBarActor:
- counter: int
- delta: int
- event: Event
- def __init__(self) -> None:
- self.counter = 0
- self.delta = 0
- self.event = Event()
- def update(self, num_items_completed: int) -> None:
- """Updates the ProgressBar with the incremental
- number of items that were just completed.
- """
- self.counter += num_items_completed
- self.delta += num_items_completed
- self.event.set()
- async def wait_for_update(self) -> Tuple[int, int]:
- """Blocking call.
- Waits until somebody calls `update`, then returns a tuple of
- the number of updates since the last call to
- `wait_for_update`, and the total number of completed items.
- """
- await self.event.wait()
- self.event.clear()
- saved_delta = self.delta
- self.delta = 0
- return saved_delta, self.counter
- def get_counter(self) -> int:
- """
- Returns the total number of complete items.
- """
- return self.counter
- ######################################################################
- # This is where the progress bar starts. You create one of these
- # on the head node, passing in the expected total number of items,
- # and an optional string description.
- # Pass along the `actor` reference to any remote task,
- # and if they complete ten
- # tasks, they'll call `actor.update.remote(10)`.
- # Back on the local node, once you launch your remote Ray tasks, call
- # `print_until_done`, which will feed everything back into a `tqdm` counter.
- class ProgressBar:
- progress_actor: ActorHandle
- total: int
- description: str
- pbar: tqdm
- def __init__(self, total: int, description: str = ""):
- # Ray actors don't seem to play nice with mypy, generating
- # a spurious warning for the following line,
- # which we need to suppress. The code is fine.
- self.progress_actor = ProgressBarActor.remote() # type: ignore
- self.total = total
- self.description = description
- @property
- def actor(self) -> ActorHandle:
- """Returns a reference to the remote `ProgressBarActor`.
- When you complete tasks, call `update` on the actor.
- """
- return self.progress_actor
- def print_until_done(self) -> None:
- """Blocking call.
- Do this after starting a series of remote Ray tasks, to which you've
- passed the actor handle. Each of them calls `update` on the actor.
- When the progress meter reaches 100%, this method returns.
- """
- pbar = tqdm(desc=self.description, total=self.total)
- while True:
- delta, counter = ray.get(self.actor.wait_for_update.remote())
- pbar.update(delta)
- if counter >= self.total:
- pbar.close()
- return
- #################################################################
- # This is an example of a task that increments the progress bar.
- # Note that this is a Ray Task, but it could very well
- # be any generic Ray Actor.
- #
- @ray.remote
- def sleep_then_increment(i: int, pba: ActorHandle) -> int:
- sleep(i / 2.0)
- pba.update.remote(1)
- return i
- #################################################################
- # Now you can run it and see what happens!
- #
- def run():
- ray.init()
- num_ticks = 6
- pb = ProgressBar(num_ticks)
- actor = pb.actor
- # You can replace this with any arbitrary Ray task/actor.
- tasks_pre_launch = [
- sleep_then_increment.remote(i, actor) for i in range(0, num_ticks)
- ]
- pb.print_until_done()
- tasks = ray.get(tasks_pre_launch)
- tasks == list(range(num_ticks))
- num_ticks == ray.get(actor.get_counter.remote())
- run()
|