replay_ops.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from typing import List, Any, Optional
  2. import random
  3. from ray.actor import ActorHandle
  4. from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
  5. from ray.util.iter_metrics import SharedMetrics
  6. from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
  7. from ray.rllib.execution.buffers.multi_agent_replay_buffer import \
  8. MultiAgentReplayBuffer
  9. from ray.rllib.execution.common import \
  10. STEPS_SAMPLED_COUNTER, _get_shared_metrics
  11. from ray.rllib.utils.typing import SampleBatchType
  12. class StoreToReplayBuffer:
  13. """Callable that stores data into replay buffer actors.
  14. If constructed with a local replay actor, data will be stored into that
  15. buffer. If constructed with a list of replay actor handles, data will
  16. be stored randomly among those actors.
  17. This should be used with the .for_each() operator on a rollouts iterator.
  18. The batch that was stored is returned.
  19. Examples:
  20. >>> actors = [ReplayActor.remote() for _ in range(4)]
  21. >>> rollouts = ParallelRollouts(...)
  22. >>> store_op = rollouts.for_each(StoreToReplayActors(actors=actors))
  23. >>> next(store_op)
  24. SampleBatch(...)
  25. """
  26. def __init__(
  27. self,
  28. *,
  29. local_buffer: Optional[MultiAgentReplayBuffer] = None,
  30. actors: Optional[List[ActorHandle]] = None,
  31. ):
  32. """
  33. Args:
  34. local_buffer: The local replay buffer to store the data into.
  35. actors: An optional list of replay actors to use instead of
  36. `local_buffer`.
  37. """
  38. if bool(local_buffer) == bool(actors):
  39. raise ValueError(
  40. "Either `local_buffer` or `replay_actors` must be given, "
  41. "not both!")
  42. if local_buffer:
  43. self.local_actor = local_buffer
  44. self.replay_actors = None
  45. else:
  46. self.local_actor = None
  47. self.replay_actors = actors
  48. def __call__(self, batch: SampleBatchType):
  49. if self.local_actor:
  50. self.local_actor.add_batch(batch)
  51. else:
  52. actor = random.choice(self.replay_actors)
  53. actor.add_batch.remote(batch)
  54. return batch
  55. def Replay(*,
  56. local_buffer: MultiAgentReplayBuffer = None,
  57. actors: List[ActorHandle] = None,
  58. num_async: int = 4) -> LocalIterator[SampleBatchType]:
  59. """Replay experiences from the given buffer or actors.
  60. This should be combined with the StoreToReplayActors operation using the
  61. Concurrently() operator.
  62. Args:
  63. local_buffer: Local buffer to use. Only one of this and replay_actors
  64. can be specified.
  65. actors: List of replay actors. Only one of this and local_buffer
  66. can be specified.
  67. num_async: In async mode, the max number of async requests in flight
  68. per actor.
  69. Examples:
  70. >>> actors = [ReplayActor.remote() for _ in range(4)]
  71. >>> replay_op = Replay(actors=actors)
  72. >>> next(replay_op)
  73. SampleBatch(...)
  74. """
  75. if bool(local_buffer) == bool(actors):
  76. raise ValueError(
  77. "Exactly one of local_buffer and replay_actors must be given.")
  78. if actors:
  79. replay = from_actors(actors)
  80. return replay.gather_async(
  81. num_async=num_async).filter(lambda x: x is not None)
  82. def gen_replay(_):
  83. while True:
  84. item = local_buffer.replay()
  85. if item is None:
  86. yield _NextValueNotReady()
  87. else:
  88. yield item
  89. return LocalIterator(gen_replay, SharedMetrics())
  90. class WaitUntilTimestepsElapsed:
  91. """Callable that returns True once a given number of timesteps are hit."""
  92. def __init__(self, target_num_timesteps: int):
  93. self.target_num_timesteps = target_num_timesteps
  94. def __call__(self, item: Any) -> bool:
  95. metrics = _get_shared_metrics()
  96. ts = metrics.counters[STEPS_SAMPLED_COUNTER]
  97. return ts > self.target_num_timesteps
  98. # TODO(ekl) deprecate this in favor of the replay_sequence_length option.
  99. class SimpleReplayBuffer:
  100. """Simple replay buffer that operates over batches."""
  101. def __init__(self,
  102. num_slots: int,
  103. replay_proportion: Optional[float] = None):
  104. """Initialize SimpleReplayBuffer.
  105. Args:
  106. num_slots (int): Number of batches to store in total.
  107. """
  108. self.num_slots = num_slots
  109. self.replay_batches = []
  110. self.replay_index = 0
  111. def add_batch(self, sample_batch: SampleBatchType) -> None:
  112. warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
  113. if self.num_slots > 0:
  114. if len(self.replay_batches) < self.num_slots:
  115. self.replay_batches.append(sample_batch)
  116. else:
  117. self.replay_batches[self.replay_index] = sample_batch
  118. self.replay_index += 1
  119. self.replay_index %= self.num_slots
  120. def replay(self) -> SampleBatchType:
  121. return random.choice(self.replay_batches)
  122. class MixInReplay:
  123. """This operator adds replay to a stream of experiences.
  124. It takes input batches, and returns a list of batches that include replayed
  125. data as well. The number of replayed batches is determined by the
  126. configured replay proportion. The max age of a batch is determined by the
  127. number of replay slots.
  128. """
  129. def __init__(self, num_slots: int, replay_proportion: float):
  130. """Initialize MixInReplay.
  131. Args:
  132. num_slots (int): Number of batches to store in total.
  133. replay_proportion (float): The input batch will be returned
  134. and an additional number of batches proportional to this value
  135. will be added as well.
  136. Examples:
  137. # replay proportion 2:1
  138. >>> replay_op = MixInReplay(rollouts, 100, replay_proportion=2)
  139. >>> print(next(replay_op))
  140. [SampleBatch(<input>), SampleBatch(<replay>), SampleBatch(<rep.>)]
  141. # replay proportion 0:1, replay disabled
  142. >>> replay_op = MixInReplay(rollouts, 100, replay_proportion=0)
  143. >>> print(next(replay_op))
  144. [SampleBatch(<input>)]
  145. """
  146. if replay_proportion > 0 and num_slots == 0:
  147. raise ValueError(
  148. "You must set num_slots > 0 if replay_proportion > 0.")
  149. self.replay_buffer = SimpleReplayBuffer(num_slots)
  150. self.replay_proportion = replay_proportion
  151. def __call__(self, sample_batch: SampleBatchType) -> List[SampleBatchType]:
  152. # Put in replay buffer if enabled.
  153. self.replay_buffer.add_batch(sample_batch)
  154. # Proportional replay.
  155. output_batches = [sample_batch]
  156. f = self.replay_proportion
  157. while random.random() < f:
  158. f -= 1
  159. output_batches.append(self.replay_buffer.replay())
  160. return output_batches