minibatch_buffer.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from typing import Any, Tuple
  2. import queue
  3. from ray.rllib.utils.deprecation import deprecation_warning
  4. from ray.util import log_once
  5. class MinibatchBuffer:
  6. """Ring buffer of recent data batches for minibatch SGD.
  7. This is for use with AsyncSamplesOptimizer.
  8. """
  9. def __init__(
  10. self,
  11. inqueue: queue.Queue,
  12. size: int,
  13. timeout: float,
  14. num_passes: int,
  15. init_num_passes: int = 1,
  16. ):
  17. """Initialize a minibatch buffer.
  18. Args:
  19. inqueue (queue.Queue): Queue to populate the internal ring buffer
  20. from.
  21. size: Max number of data items to buffer.
  22. timeout: Queue timeout
  23. num_passes: Max num times each data item should be emitted.
  24. init_num_passes: Initial passes for each data item.
  25. Maxiumum number of passes per item are increased to num_passes over
  26. time.
  27. """
  28. self.inqueue = inqueue
  29. self.size = size
  30. self.timeout = timeout
  31. self.max_initial_ttl = num_passes
  32. self.cur_initial_ttl = init_num_passes
  33. self.buffers = [None] * size
  34. self.ttl = [0] * size
  35. self.idx = 0
  36. if log_once("minibatch-buffer-deprecation-warning"):
  37. deprecation_warning(
  38. old="ray.rllib.execution.minibatch_buffer.MinibatchBuffer"
  39. )
  40. def get(self) -> Tuple[Any, bool]:
  41. """Get a new batch from the internal ring buffer.
  42. Returns:
  43. buf: Data item saved from inqueue.
  44. released: True if the item is now removed from the ring buffer.
  45. """
  46. if self.ttl[self.idx] <= 0:
  47. self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
  48. self.ttl[self.idx] = self.cur_initial_ttl
  49. if self.cur_initial_ttl < self.max_initial_ttl:
  50. self.cur_initial_ttl += 1
  51. buf = self.buffers[self.idx]
  52. self.ttl[self.idx] -= 1
  53. released = self.ttl[self.idx] <= 0
  54. if released:
  55. self.buffers[self.idx] = None
  56. self.idx = (self.idx + 1) % len(self.buffers)
  57. return buf, released