123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- from typing import List, Optional, Any
- import queue
- from ray.util.iter import LocalIterator, _NextValueNotReady
- from ray.util.iter_metrics import SharedMetrics
- from ray.rllib.utils.typing import SampleBatchType
- def Concurrently(ops: List[LocalIterator],
- *,
- mode: str = "round_robin",
- output_indexes: Optional[List[int]] = None,
- round_robin_weights: Optional[List[int]] = None
- ) -> LocalIterator[SampleBatchType]:
- """Operator that runs the given parent iterators concurrently.
- Args:
- mode (str): One of 'round_robin', 'async'. In 'round_robin' mode,
- we alternate between pulling items from each parent iterator in
- order deterministically. In 'async' mode, we pull from each parent
- iterator as fast as they are produced. This is non-deterministic.
- output_indexes (list): If specified, only output results from the
- given ops. For example, if ``output_indexes=[0]``, only results
- from the first op in ops will be returned.
- round_robin_weights (list): List of weights to use for round robin
- mode. For example, ``[2, 1]`` will cause the iterator to pull twice
- as many items from the first iterator as the second. ``[2, 1, *]``
- will cause as many items to be pulled as possible from the third
- iterator without blocking. This is only allowed in round robin
- mode.
- Examples:
- >>> sim_op = ParallelRollouts(...).for_each(...)
- >>> replay_op = LocalReplay(...).for_each(...)
- >>> combined_op = Concurrently([sim_op, replay_op], mode="async")
- """
- if len(ops) < 2:
- raise ValueError("Should specify at least 2 ops.")
- if mode == "round_robin":
- deterministic = True
- elif mode == "async":
- deterministic = False
- if round_robin_weights:
- raise ValueError(
- "round_robin_weights cannot be specified in async mode")
- else:
- raise ValueError("Unknown mode {}".format(mode))
- if round_robin_weights and all(r == "*" for r in round_robin_weights):
- raise ValueError("Cannot specify all round robin weights = *")
- if output_indexes:
- for i in output_indexes:
- assert i in range(len(ops)), ("Index out of range", i)
- def tag(op, i):
- return op.for_each(lambda x: (i, x))
- ops = [tag(op, i) for i, op in enumerate(ops)]
- output = ops[0].union(
- *ops[1:],
- deterministic=deterministic,
- round_robin_weights=round_robin_weights)
- if output_indexes:
- output = (output.filter(lambda tup: tup[0] in output_indexes).for_each(
- lambda tup: tup[1]))
- return output
- class Enqueue:
- """Enqueue data items into a queue.Queue instance.
- Returns the input item as output.
- The enqueue is non-blocking, so Enqueue operations can executed with
- Dequeue via the Concurrently() operator.
- Examples:
- >>> queue = queue.Queue(100)
- >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
- >>> read_op = Dequeue(queue)
- >>> combined_op = Concurrently([write_op, read_op], mode="async")
- >>> next(combined_op)
- SampleBatch(...)
- """
- def __init__(self, output_queue: queue.Queue):
- if not isinstance(output_queue, queue.Queue):
- raise ValueError("Expected queue.Queue, got {}".format(
- type(output_queue)))
- self.queue = output_queue
- def __call__(self, x: Any) -> Any:
- try:
- self.queue.put(x, timeout=0.001)
- except queue.Full:
- return _NextValueNotReady()
- return x
- def Dequeue(input_queue: queue.Queue,
- check=lambda: True) -> LocalIterator[SampleBatchType]:
- """Dequeue data items from a queue.Queue instance.
- The dequeue is non-blocking, so Dequeue operations can execute with
- Enqueue via the Concurrently() operator.
- Args:
- input_queue (Queue): queue to pull items from.
- check (fn): liveness check. When this function returns false,
- Dequeue() will raise an error to halt execution.
- Examples:
- >>> queue = queue.Queue(100)
- >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
- >>> read_op = Dequeue(queue)
- >>> combined_op = Concurrently([write_op, read_op], mode="async")
- >>> next(combined_op)
- SampleBatch(...)
- """
- if not isinstance(input_queue, queue.Queue):
- raise ValueError("Expected queue.Queue, got {}".format(
- type(input_queue)))
- def base_iterator(timeout=None):
- while check():
- try:
- item = input_queue.get(timeout=0.001)
- yield item
- except queue.Empty:
- yield _NextValueNotReady()
- raise RuntimeError("Dequeue `check()` returned False! "
- "Exiting with Exception from Dequeue iterator.")
- return LocalIterator(base_iterator, SharedMetrics())
|