replay_ops.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from typing import Optional
  2. import random
  3. from ray.rllib.utils.replay_buffers.replay_buffer import warn_replay_capacity
  4. from ray.rllib.utils.typing import SampleBatchType
  5. from ray.rllib.utils.deprecation import deprecation_warning
  6. from ray.util import log_once
  7. # TODO(sven) deprecate this class.
  8. class SimpleReplayBuffer:
  9. """Simple replay buffer that operates over batches."""
  10. def __init__(self, num_slots: int, replay_proportion: Optional[float] = None):
  11. """Initialize SimpleReplayBuffer.
  12. Args:
  13. num_slots: Number of batches to store in total.
  14. """
  15. self.num_slots = num_slots
  16. self.replay_batches = []
  17. self.replay_index = 0
  18. if log_once("simple_replay_buffer_deprecation_warning"):
  19. deprecation_warning(old="ray.rllib.execution.replay_ops.SimpleReplayBuffer")
  20. def add_batch(self, sample_batch: SampleBatchType) -> None:
  21. warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
  22. if self.num_slots > 0:
  23. if len(self.replay_batches) < self.num_slots:
  24. self.replay_batches.append(sample_batch)
  25. else:
  26. self.replay_batches[self.replay_index] = sample_batch
  27. self.replay_index += 1
  28. self.replay_index %= self.num_slots
  29. def replay(self) -> SampleBatchType:
  30. return random.choice(self.replay_batches)
  31. def __len__(self):
  32. return len(self.replay_batches)