minibatch_buffer.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from typing import Any, Tuple
  2. import queue
  3. class MinibatchBuffer:
  4. """Ring buffer of recent data batches for minibatch SGD.
  5. This is for use with AsyncSamplesOptimizer.
  6. """
  7. def __init__(self,
  8. inqueue: queue.Queue,
  9. size: int,
  10. timeout: float,
  11. num_passes: int,
  12. init_num_passes: int = 1):
  13. """Initialize a minibatch buffer.
  14. Args:
  15. inqueue: Queue to populate the internal ring buffer from.
  16. size: Max number of data items to buffer.
  17. timeout: Queue timeout
  18. num_passes: Max num times each data item should be emitted.
  19. init_num_passes: Initial max passes for each data item
  20. """
  21. self.inqueue = inqueue
  22. self.size = size
  23. self.timeout = timeout
  24. self.max_ttl = num_passes
  25. self.cur_max_ttl = init_num_passes
  26. self.buffers = [None] * size
  27. self.ttl = [0] * size
  28. self.idx = 0
  29. def get(self) -> Tuple[Any, bool]:
  30. """Get a new batch from the internal ring buffer.
  31. Returns:
  32. buf: Data item saved from inqueue.
  33. released: True if the item is now removed from the ring buffer.
  34. """
  35. if self.ttl[self.idx] <= 0:
  36. self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
  37. self.ttl[self.idx] = self.cur_max_ttl
  38. if self.cur_max_ttl < self.max_ttl:
  39. self.cur_max_ttl += 1
  40. buf = self.buffers[self.idx]
  41. self.ttl[self.idx] -= 1
  42. released = self.ttl[self.idx] <= 0
  43. if released:
  44. self.buffers[self.idx] = None
  45. self.idx = (self.idx + 1) % len(self.buffers)
  46. return buf, released