from typing import List, Optional, Any import queue from ray.util.iter import LocalIterator, _NextValueNotReady from ray.util.iter_metrics import SharedMetrics from ray.rllib.utils.typing import SampleBatchType def Concurrently(ops: List[LocalIterator], *, mode: str = "round_robin", output_indexes: Optional[List[int]] = None, round_robin_weights: Optional[List[int]] = None ) -> LocalIterator[SampleBatchType]: """Operator that runs the given parent iterators concurrently. Args: mode (str): One of 'round_robin', 'async'. In 'round_robin' mode, we alternate between pulling items from each parent iterator in order deterministically. In 'async' mode, we pull from each parent iterator as fast as they are produced. This is non-deterministic. output_indexes (list): If specified, only output results from the given ops. For example, if ``output_indexes=[0]``, only results from the first op in ops will be returned. round_robin_weights (list): List of weights to use for round robin mode. For example, ``[2, 1]`` will cause the iterator to pull twice as many items from the first iterator as the second. ``[2, 1, *]`` will cause as many items to be pulled as possible from the third iterator without blocking. This is only allowed in round robin mode. Examples: >>> sim_op = ParallelRollouts(...).for_each(...) >>> replay_op = LocalReplay(...).for_each(...) >>> combined_op = Concurrently([sim_op, replay_op], mode="async") """ if len(ops) < 2: raise ValueError("Should specify at least 2 ops.") if mode == "round_robin": deterministic = True elif mode == "async": deterministic = False if round_robin_weights: raise ValueError( "round_robin_weights cannot be specified in async mode") else: raise ValueError("Unknown mode {}".format(mode)) if round_robin_weights and all(r == "*" for r in round_robin_weights): raise ValueError("Cannot specify all round robin weights = *") if output_indexes: for i in output_indexes: assert i in range(len(ops)), ("Index out of range", i) def tag(op, i): return op.for_each(lambda x: (i, x)) ops = [tag(op, i) for i, op in enumerate(ops)] output = ops[0].union( *ops[1:], deterministic=deterministic, round_robin_weights=round_robin_weights) if output_indexes: output = (output.filter(lambda tup: tup[0] in output_indexes).for_each( lambda tup: tup[1])) return output class Enqueue: """Enqueue data items into a queue.Queue instance. Returns the input item as output. The enqueue is non-blocking, so Enqueue operations can executed with Dequeue via the Concurrently() operator. Examples: >>> queue = queue.Queue(100) >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue)) >>> read_op = Dequeue(queue) >>> combined_op = Concurrently([write_op, read_op], mode="async") >>> next(combined_op) SampleBatch(...) """ def __init__(self, output_queue: queue.Queue): if not isinstance(output_queue, queue.Queue): raise ValueError("Expected queue.Queue, got {}".format( type(output_queue))) self.queue = output_queue def __call__(self, x: Any) -> Any: try: self.queue.put(x, timeout=0.001) except queue.Full: return _NextValueNotReady() return x def Dequeue(input_queue: queue.Queue, check=lambda: True) -> LocalIterator[SampleBatchType]: """Dequeue data items from a queue.Queue instance. The dequeue is non-blocking, so Dequeue operations can execute with Enqueue via the Concurrently() operator. Args: input_queue (Queue): queue to pull items from. check (fn): liveness check. When this function returns false, Dequeue() will raise an error to halt execution. Examples: >>> queue = queue.Queue(100) >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue)) >>> read_op = Dequeue(queue) >>> combined_op = Concurrently([write_op, read_op], mode="async") >>> next(combined_op) SampleBatch(...) """ if not isinstance(input_queue, queue.Queue): raise ValueError("Expected queue.Queue, got {}".format( type(input_queue))) def base_iterator(timeout=None): while check(): try: item = input_queue.get(timeout=0.001) yield item except queue.Empty: yield _NextValueNotReady() raise RuntimeError("Dequeue `check()` returned False! " "Exiting with Exception from Dequeue iterator.") return LocalIterator(base_iterator, SharedMetrics())