from typing import List, Any, Optional import random from ray.actor import ActorHandle from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady from ray.util.iter_metrics import SharedMetrics from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity from ray.rllib.execution.buffers.multi_agent_replay_buffer import \ MultiAgentReplayBuffer from ray.rllib.execution.common import \ STEPS_SAMPLED_COUNTER, _get_shared_metrics from ray.rllib.utils.typing import SampleBatchType class StoreToReplayBuffer: """Callable that stores data into replay buffer actors. If constructed with a local replay actor, data will be stored into that buffer. If constructed with a list of replay actor handles, data will be stored randomly among those actors. This should be used with the .for_each() operator on a rollouts iterator. The batch that was stored is returned. Examples: >>> actors = [ReplayActor.remote() for _ in range(4)] >>> rollouts = ParallelRollouts(...) >>> store_op = rollouts.for_each(StoreToReplayActors(actors=actors)) >>> next(store_op) SampleBatch(...) """ def __init__( self, *, local_buffer: Optional[MultiAgentReplayBuffer] = None, actors: Optional[List[ActorHandle]] = None, ): """ Args: local_buffer: The local replay buffer to store the data into. actors: An optional list of replay actors to use instead of `local_buffer`. """ if bool(local_buffer) == bool(actors): raise ValueError( "Either `local_buffer` or `replay_actors` must be given, " "not both!") if local_buffer: self.local_actor = local_buffer self.replay_actors = None else: self.local_actor = None self.replay_actors = actors def __call__(self, batch: SampleBatchType): if self.local_actor: self.local_actor.add_batch(batch) else: actor = random.choice(self.replay_actors) actor.add_batch.remote(batch) return batch def Replay(*, local_buffer: MultiAgentReplayBuffer = None, actors: List[ActorHandle] = None, num_async: int = 4) -> LocalIterator[SampleBatchType]: """Replay experiences from the given buffer or actors. This should be combined with the StoreToReplayActors operation using the Concurrently() operator. Args: local_buffer: Local buffer to use. Only one of this and replay_actors can be specified. actors: List of replay actors. Only one of this and local_buffer can be specified. num_async: In async mode, the max number of async requests in flight per actor. Examples: >>> actors = [ReplayActor.remote() for _ in range(4)] >>> replay_op = Replay(actors=actors) >>> next(replay_op) SampleBatch(...) """ if bool(local_buffer) == bool(actors): raise ValueError( "Exactly one of local_buffer and replay_actors must be given.") if actors: replay = from_actors(actors) return replay.gather_async( num_async=num_async).filter(lambda x: x is not None) def gen_replay(_): while True: item = local_buffer.replay() if item is None: yield _NextValueNotReady() else: yield item return LocalIterator(gen_replay, SharedMetrics()) class WaitUntilTimestepsElapsed: """Callable that returns True once a given number of timesteps are hit.""" def __init__(self, target_num_timesteps: int): self.target_num_timesteps = target_num_timesteps def __call__(self, item: Any) -> bool: metrics = _get_shared_metrics() ts = metrics.counters[STEPS_SAMPLED_COUNTER] return ts > self.target_num_timesteps # TODO(ekl) deprecate this in favor of the replay_sequence_length option. class SimpleReplayBuffer: """Simple replay buffer that operates over batches.""" def __init__(self, num_slots: int, replay_proportion: Optional[float] = None): """Initialize SimpleReplayBuffer. Args: num_slots (int): Number of batches to store in total. """ self.num_slots = num_slots self.replay_batches = [] self.replay_index = 0 def add_batch(self, sample_batch: SampleBatchType) -> None: warn_replay_capacity(item=sample_batch, num_items=self.num_slots) if self.num_slots > 0: if len(self.replay_batches) < self.num_slots: self.replay_batches.append(sample_batch) else: self.replay_batches[self.replay_index] = sample_batch self.replay_index += 1 self.replay_index %= self.num_slots def replay(self) -> SampleBatchType: return random.choice(self.replay_batches) class MixInReplay: """This operator adds replay to a stream of experiences. It takes input batches, and returns a list of batches that include replayed data as well. The number of replayed batches is determined by the configured replay proportion. The max age of a batch is determined by the number of replay slots. """ def __init__(self, num_slots: int, replay_proportion: float): """Initialize MixInReplay. Args: num_slots (int): Number of batches to store in total. replay_proportion (float): The input batch will be returned and an additional number of batches proportional to this value will be added as well. Examples: # replay proportion 2:1 >>> replay_op = MixInReplay(rollouts, 100, replay_proportion=2) >>> print(next(replay_op)) [SampleBatch(), SampleBatch(), SampleBatch()] # replay proportion 0:1, replay disabled >>> replay_op = MixInReplay(rollouts, 100, replay_proportion=0) >>> print(next(replay_op)) [SampleBatch()] """ if replay_proportion > 0 and num_slots == 0: raise ValueError( "You must set num_slots > 0 if replay_proportion > 0.") self.replay_buffer = SimpleReplayBuffer(num_slots) self.replay_proportion = replay_proportion def __call__(self, sample_batch: SampleBatchType) -> List[SampleBatchType]: # Put in replay buffer if enabled. self.replay_buffer.add_batch(sample_batch) # Proportional replay. output_batches = [sample_batch] f = self.replay_proportion while random.random() < f: f -= 1 output_batches.append(self.replay_buffer.replay()) return output_batches