minibatch_buffer.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from typing import Any, Tuple
  2. import queue
  3. class MinibatchBuffer:
  4. """Ring buffer of recent data batches for minibatch SGD.
  5. This is for use with AsyncSamplesOptimizer.
  6. """
  7. def __init__(self,
  8. inqueue: queue.Queue,
  9. size: int,
  10. timeout: float,
  11. num_passes: int,
  12. init_num_passes: int = 1):
  13. """Initialize a minibatch buffer.
  14. Args:
  15. inqueue (queue.Queue): Queue to populate the internal ring buffer
  16. from.
  17. size (int): Max number of data items to buffer.
  18. timeout (float): Queue timeout
  19. num_passes (int): Max num times each data item should be emitted.
  20. init_num_passes (int): Initial passes for each data item.
  21. Maxiumum number of passes per item are increased to num_passes over
  22. time.
  23. """
  24. self.inqueue = inqueue
  25. self.size = size
  26. self.timeout = timeout
  27. self.max_initial_ttl = num_passes
  28. self.cur_initial_ttl = init_num_passes
  29. self.buffers = [None] * size
  30. self.ttl = [0] * size
  31. self.idx = 0
  32. def get(self) -> Tuple[Any, bool]:
  33. """Get a new batch from the internal ring buffer.
  34. Returns:
  35. buf: Data item saved from inqueue.
  36. released: True if the item is now removed from the ring buffer.
  37. """
  38. if self.ttl[self.idx] <= 0:
  39. self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
  40. self.ttl[self.idx] = self.cur_initial_ttl
  41. if self.cur_initial_ttl < self.max_initial_ttl:
  42. self.cur_initial_ttl += 1
  43. buf = self.buffers[self.idx]
  44. self.ttl[self.idx] -= 1
  45. released = self.ttl[self.idx] <= 0
  46. if released:
  47. self.buffers[self.idx] = None
  48. self.idx = (self.idx + 1) % len(self.buffers)
  49. return buf, released