shuffled_input.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import logging
  2. import random
  3. from ray.rllib.offline.input_reader import InputReader
  4. from ray.rllib.utils.annotations import override, DeveloperAPI
  5. from ray.rllib.utils.typing import SampleBatchType
  6. logger = logging.getLogger(__name__)
  7. @DeveloperAPI
  8. class ShuffledInput(InputReader):
  9. """Randomizes data over a sliding window buffer of N batches.
  10. This increases the randomization of the data, which is useful if the
  11. batches were not in random order to start with.
  12. """
  13. @DeveloperAPI
  14. def __init__(self, child: InputReader, n: int = 0):
  15. """Initializes a ShuffledInput instance.
  16. Args:
  17. child: child input reader to shuffle.
  18. n: If positive, shuffle input over this many batches.
  19. """
  20. self.n = n
  21. self.child = child
  22. self.buffer = []
  23. @override(InputReader)
  24. def next(self) -> SampleBatchType:
  25. if self.n <= 1:
  26. return self.child.next()
  27. if len(self.buffer) < self.n:
  28. logger.info("Filling shuffle buffer to {} batches".format(self.n))
  29. while len(self.buffer) < self.n:
  30. self.buffer.append(self.child.next())
  31. logger.info("Shuffle buffer filled")
  32. i = random.randint(0, len(self.buffer) - 1)
  33. self.buffer[i] = self.child.next()
  34. return random.choice(self.buffer)