1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- from typing import Any, Tuple
- import queue
- from ray.rllib.utils.deprecation import deprecation_warning
- from ray.util import log_once
- class MinibatchBuffer:
- """Ring buffer of recent data batches for minibatch SGD.
- This is for use with AsyncSamplesOptimizer.
- """
- def __init__(
- self,
- inqueue: queue.Queue,
- size: int,
- timeout: float,
- num_passes: int,
- init_num_passes: int = 1,
- ):
- """Initialize a minibatch buffer.
- Args:
- inqueue (queue.Queue): Queue to populate the internal ring buffer
- from.
- size: Max number of data items to buffer.
- timeout: Queue timeout
- num_passes: Max num times each data item should be emitted.
- init_num_passes: Initial passes for each data item.
- Maxiumum number of passes per item are increased to num_passes over
- time.
- """
- self.inqueue = inqueue
- self.size = size
- self.timeout = timeout
- self.max_initial_ttl = num_passes
- self.cur_initial_ttl = init_num_passes
- self.buffers = [None] * size
- self.ttl = [0] * size
- self.idx = 0
- if log_once("minibatch-buffer-deprecation-warning"):
- deprecation_warning(
- old="ray.rllib.execution.minibatch_buffer.MinibatchBuffer"
- )
- def get(self) -> Tuple[Any, bool]:
- """Get a new batch from the internal ring buffer.
- Returns:
- buf: Data item saved from inqueue.
- released: True if the item is now removed from the ring buffer.
- """
- if self.ttl[self.idx] <= 0:
- self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
- self.ttl[self.idx] = self.cur_initial_ttl
- if self.cur_initial_ttl < self.max_initial_ttl:
- self.cur_initial_ttl += 1
- buf = self.buffers[self.idx]
- self.ttl[self.idx] -= 1
- released = self.ttl[self.idx] <= 0
- if released:
- self.buffers[self.idx] = None
- self.idx = (self.idx + 1) % len(self.buffers)
- return buf, released
|