concurrency_ops.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from typing import List, Optional, Any
  2. import queue
  3. from ray.util.iter import LocalIterator, _NextValueNotReady
  4. from ray.util.iter_metrics import SharedMetrics
  5. from ray.rllib.utils.typing import SampleBatchType
  6. def Concurrently(ops: List[LocalIterator],
  7. *,
  8. mode: str = "round_robin",
  9. output_indexes: Optional[List[int]] = None,
  10. round_robin_weights: Optional[List[int]] = None
  11. ) -> LocalIterator[SampleBatchType]:
  12. """Operator that runs the given parent iterators concurrently.
  13. Args:
  14. mode (str): One of 'round_robin', 'async'. In 'round_robin' mode,
  15. we alternate between pulling items from each parent iterator in
  16. order deterministically. In 'async' mode, we pull from each parent
  17. iterator as fast as they are produced. This is non-deterministic.
  18. output_indexes (list): If specified, only output results from the
  19. given ops. For example, if ``output_indexes=[0]``, only results
  20. from the first op in ops will be returned.
  21. round_robin_weights (list): List of weights to use for round robin
  22. mode. For example, ``[2, 1]`` will cause the iterator to pull twice
  23. as many items from the first iterator as the second. ``[2, 1, *]``
  24. will cause as many items to be pulled as possible from the third
  25. iterator without blocking. This is only allowed in round robin
  26. mode.
  27. Examples:
  28. >>> sim_op = ParallelRollouts(...).for_each(...)
  29. >>> replay_op = LocalReplay(...).for_each(...)
  30. >>> combined_op = Concurrently([sim_op, replay_op], mode="async")
  31. """
  32. if len(ops) < 2:
  33. raise ValueError("Should specify at least 2 ops.")
  34. if mode == "round_robin":
  35. deterministic = True
  36. elif mode == "async":
  37. deterministic = False
  38. if round_robin_weights:
  39. raise ValueError(
  40. "round_robin_weights cannot be specified in async mode")
  41. else:
  42. raise ValueError("Unknown mode {}".format(mode))
  43. if round_robin_weights and all(r == "*" for r in round_robin_weights):
  44. raise ValueError("Cannot specify all round robin weights = *")
  45. if output_indexes:
  46. for i in output_indexes:
  47. assert i in range(len(ops)), ("Index out of range", i)
  48. def tag(op, i):
  49. return op.for_each(lambda x: (i, x))
  50. ops = [tag(op, i) for i, op in enumerate(ops)]
  51. output = ops[0].union(
  52. *ops[1:],
  53. deterministic=deterministic,
  54. round_robin_weights=round_robin_weights)
  55. if output_indexes:
  56. output = (output.filter(lambda tup: tup[0] in output_indexes).for_each(
  57. lambda tup: tup[1]))
  58. return output
  59. class Enqueue:
  60. """Enqueue data items into a queue.Queue instance.
  61. Returns the input item as output.
  62. The enqueue is non-blocking, so Enqueue operations can executed with
  63. Dequeue via the Concurrently() operator.
  64. Examples:
  65. >>> queue = queue.Queue(100)
  66. >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
  67. >>> read_op = Dequeue(queue)
  68. >>> combined_op = Concurrently([write_op, read_op], mode="async")
  69. >>> next(combined_op)
  70. SampleBatch(...)
  71. """
  72. def __init__(self, output_queue: queue.Queue):
  73. if not isinstance(output_queue, queue.Queue):
  74. raise ValueError("Expected queue.Queue, got {}".format(
  75. type(output_queue)))
  76. self.queue = output_queue
  77. def __call__(self, x: Any) -> Any:
  78. try:
  79. self.queue.put(x, timeout=0.001)
  80. except queue.Full:
  81. return _NextValueNotReady()
  82. return x
  83. def Dequeue(input_queue: queue.Queue,
  84. check=lambda: True) -> LocalIterator[SampleBatchType]:
  85. """Dequeue data items from a queue.Queue instance.
  86. The dequeue is non-blocking, so Dequeue operations can execute with
  87. Enqueue via the Concurrently() operator.
  88. Args:
  89. input_queue (Queue): queue to pull items from.
  90. check (fn): liveness check. When this function returns false,
  91. Dequeue() will raise an error to halt execution.
  92. Examples:
  93. >>> queue = queue.Queue(100)
  94. >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
  95. >>> read_op = Dequeue(queue)
  96. >>> combined_op = Concurrently([write_op, read_op], mode="async")
  97. >>> next(combined_op)
  98. SampleBatch(...)
  99. """
  100. if not isinstance(input_queue, queue.Queue):
  101. raise ValueError("Expected queue.Queue, got {}".format(
  102. type(input_queue)))
  103. def base_iterator(timeout=None):
  104. while check():
  105. try:
  106. item = input_queue.get(timeout=0.001)
  107. yield item
  108. except queue.Empty:
  109. yield _NextValueNotReady()
  110. raise RuntimeError("Dequeue `check()` returned False! "
  111. "Exiting with Exception from Dequeue iterator.")
  112. return LocalIterator(base_iterator, SharedMetrics())