multi_agent_replay_buffer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import collections
  2. import platform
  3. from typing import Any, Dict
  4. import numpy as np
  5. import ray
  6. from ray.rllib import SampleBatch
  7. from ray.rllib.execution import PrioritizedReplayBuffer, ReplayBuffer
  8. from ray.rllib.execution.buffers.replay_buffer import logger, _ALL_POLICIES
  9. from ray.rllib.policy.rnn_sequencing import \
  10. timeslice_along_seq_lens_with_overlap
  11. from ray.rllib.policy.sample_batch import MultiAgentBatch
  12. from ray.rllib.utils import deprecation_warning
  13. from ray.rllib.utils.deprecation import DEPRECATED_VALUE
  14. from ray.rllib.utils.timer import TimerStat
  15. from ray.rllib.utils.typing import SampleBatchType
  16. from ray.util.iter import ParallelIteratorWorker
  17. class MultiAgentReplayBuffer(ParallelIteratorWorker):
  18. """A replay buffer shard storing data for all policies (in multiagent setup).
  19. Ray actors are single-threaded, so for scalability, multiple replay actors
  20. may be created to increase parallelism."""
  21. def __init__(
  22. self,
  23. num_shards: int = 1,
  24. learning_starts: int = 1000,
  25. capacity: int = 10000,
  26. replay_batch_size: int = 1,
  27. prioritized_replay_alpha: float = 0.6,
  28. prioritized_replay_beta: float = 0.4,
  29. prioritized_replay_eps: float = 1e-6,
  30. replay_mode: str = "independent",
  31. replay_sequence_length: int = 1,
  32. replay_burn_in: int = 0,
  33. replay_zero_init_states: bool = True,
  34. buffer_size=DEPRECATED_VALUE,
  35. ):
  36. """Initializes a MultiAgentReplayBuffer instance.
  37. Args:
  38. num_shards: The number of buffer shards that exist in total
  39. (including this one).
  40. learning_starts: Number of timesteps after which a call to
  41. `replay()` will yield samples (before that, `replay()` will
  42. return None).
  43. capacity: The capacity of the buffer. Note that when
  44. `replay_sequence_length` > 1, this is the number of sequences
  45. (not single timesteps) stored.
  46. replay_batch_size: The batch size to be sampled (in timesteps).
  47. Note that if `replay_sequence_length` > 1,
  48. `self.replay_batch_size` will be set to the number of
  49. sequences sampled (B).
  50. prioritized_replay_alpha: Alpha parameter for a prioritized
  51. replay buffer. Use 0.0 for no prioritization.
  52. prioritized_replay_beta: Beta parameter for a prioritized
  53. replay buffer.
  54. prioritized_replay_eps: Epsilon parameter for a prioritized
  55. replay buffer.
  56. replay_mode: One of "independent" or "lockstep". Determined,
  57. whether in the multiagent case, sampling is done across all
  58. agents/policies equally.
  59. replay_sequence_length: The sequence length (T) of a single
  60. sample. If > 1, we will sample B x T from this buffer.
  61. replay_burn_in: The burn-in length in case
  62. `replay_sequence_length` > 0. This is the number of timesteps
  63. each sequence overlaps with the previous one to generate a
  64. better internal state (=state after the burn-in), instead of
  65. starting from 0.0 each RNN rollout.
  66. replay_zero_init_states: Whether the initial states in the
  67. buffer (if replay_sequence_length > 0) are alwayas 0.0 or
  68. should be updated with the previous train_batch state outputs.
  69. """
  70. # Deprecated args.
  71. if buffer_size != DEPRECATED_VALUE:
  72. deprecation_warning(
  73. "ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False)
  74. capacity = buffer_size
  75. self.replay_starts = learning_starts // num_shards
  76. self.capacity = capacity // num_shards
  77. self.replay_batch_size = replay_batch_size
  78. self.prioritized_replay_beta = prioritized_replay_beta
  79. self.prioritized_replay_eps = prioritized_replay_eps
  80. self.replay_mode = replay_mode
  81. self.replay_sequence_length = replay_sequence_length
  82. self.replay_burn_in = replay_burn_in
  83. self.replay_zero_init_states = replay_zero_init_states
  84. if replay_sequence_length > 1:
  85. self.replay_batch_size = int(
  86. max(1, replay_batch_size // replay_sequence_length))
  87. logger.info(
  88. "Since replay_sequence_length={} and replay_batch_size={}, "
  89. "we will replay {} sequences at a time.".format(
  90. replay_sequence_length, replay_batch_size,
  91. self.replay_batch_size))
  92. if replay_mode not in ["lockstep", "independent"]:
  93. raise ValueError("Unsupported replay mode: {}".format(replay_mode))
  94. def gen_replay():
  95. while True:
  96. yield self.replay()
  97. ParallelIteratorWorker.__init__(self, gen_replay, False)
  98. def new_buffer():
  99. if prioritized_replay_alpha == 0.0:
  100. return ReplayBuffer(self.capacity)
  101. else:
  102. return PrioritizedReplayBuffer(
  103. self.capacity, alpha=prioritized_replay_alpha)
  104. self.replay_buffers = collections.defaultdict(new_buffer)
  105. # Metrics.
  106. self.add_batch_timer = TimerStat()
  107. self.replay_timer = TimerStat()
  108. self.update_priorities_timer = TimerStat()
  109. self.num_added = 0
  110. # Make externally accessible for testing.
  111. global _local_replay_buffer
  112. _local_replay_buffer = self
  113. # If set, return this instead of the usual data for testing.
  114. self._fake_batch = None
  115. @staticmethod
  116. def get_instance_for_testing():
  117. """Return a MultiAgentReplayBuffer instance that has been previously
  118. instantiated.
  119. Returns:
  120. _local_replay_buffer: The lastly instantiated
  121. MultiAgentReplayBuffer.
  122. """
  123. global _local_replay_buffer
  124. return _local_replay_buffer
  125. def get_host(self) -> str:
  126. """Returns the computer's network name.
  127. Returns:
  128. The computer's networks name or an empty string, if the network
  129. name could not be determined.
  130. """
  131. return platform.node()
  132. def add_batch(self, batch: SampleBatchType) -> None:
  133. """Adds a batch to the appropriate policy's replay buffer.
  134. Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
  135. it is not a MultiAgentBatch. Subsequently adds the batch to
  136. Args:
  137. batch (SampleBatchType): The batch to be added.
  138. """
  139. # Make a copy so the replay buffer doesn't pin plasma memory.
  140. batch = batch.copy()
  141. # Handle everything as if multi-agent.
  142. batch = batch.as_multi_agent()
  143. with self.add_batch_timer:
  144. # Lockstep mode: Store under _ALL_POLICIES key (we will always
  145. # only sample from all policies at the same time).
  146. if self.replay_mode == "lockstep":
  147. # Note that prioritization is not supported in this mode.
  148. for s in batch.timeslices(self.replay_sequence_length):
  149. self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
  150. else:
  151. for policy_id, sample_batch in batch.policy_batches.items():
  152. if self.replay_sequence_length == 1:
  153. timeslices = sample_batch.timeslices(1)
  154. else:
  155. timeslices = timeslice_along_seq_lens_with_overlap(
  156. sample_batch=sample_batch,
  157. zero_pad_max_seq_len=self.replay_sequence_length,
  158. pre_overlap=self.replay_burn_in,
  159. zero_init_states=self.replay_zero_init_states,
  160. )
  161. for time_slice in timeslices:
  162. # If SampleBatch has prio-replay weights, average
  163. # over these to use as a weight for the entire
  164. # sequence.
  165. if "weights" in time_slice and \
  166. len(time_slice["weights"]):
  167. weight = np.mean(time_slice["weights"])
  168. else:
  169. weight = None
  170. self.replay_buffers[policy_id].add(
  171. time_slice, weight=weight)
  172. self.num_added += batch.count
  173. def replay(self) -> SampleBatchType:
  174. """If this buffer was given a fake batch, return it, otherwise return
  175. a MultiAgentBatch with samples.
  176. """
  177. if self._fake_batch:
  178. if not isinstance(self._fake_batch, MultiAgentBatch):
  179. self._fake_batch = SampleBatch(
  180. self._fake_batch).as_multi_agent()
  181. return self._fake_batch
  182. if self.num_added < self.replay_starts:
  183. return None
  184. with self.replay_timer:
  185. # Lockstep mode: Sample from all policies at the same time an
  186. # equal amount of steps.
  187. if self.replay_mode == "lockstep":
  188. return self.replay_buffers[_ALL_POLICIES].sample(
  189. self.replay_batch_size, beta=self.prioritized_replay_beta)
  190. else:
  191. samples = {}
  192. for policy_id, replay_buffer in self.replay_buffers.items():
  193. samples[policy_id] = replay_buffer.sample(
  194. self.replay_batch_size,
  195. beta=self.prioritized_replay_beta)
  196. return MultiAgentBatch(samples, self.replay_batch_size)
  197. def update_priorities(self, prio_dict: Dict) -> None:
  198. """Updates the priorities of underlying replay buffers.
  199. Computes new priorities from td_errors and prioritized_replay_eps.
  200. These priorities are used to update underlying replay buffers per
  201. policy_id.
  202. Args:
  203. prio_dict (Dict): A dictionary containing td_errors for
  204. batches saved in underlying replay buffers.
  205. """
  206. with self.update_priorities_timer:
  207. for policy_id, (batch_indexes, td_errors) in prio_dict.items():
  208. new_priorities = (
  209. np.abs(td_errors) + self.prioritized_replay_eps)
  210. self.replay_buffers[policy_id].update_priorities(
  211. batch_indexes, new_priorities)
  212. def stats(self, debug: bool = False) -> Dict:
  213. """Returns the stats of this buffer and all underlying buffers.
  214. Args:
  215. debug (bool): If True, stats of underlying replay buffers will
  216. be fetched with debug=True.
  217. Returns:
  218. stat: Dictionary of buffer stats.
  219. """
  220. stat = {
  221. "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
  222. "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
  223. "update_priorities_time_ms": round(
  224. 1000 * self.update_priorities_timer.mean, 3),
  225. }
  226. for policy_id, replay_buffer in self.replay_buffers.items():
  227. stat.update({
  228. "policy_{}".format(policy_id): replay_buffer.stats(debug=debug)
  229. })
  230. return stat
  231. def get_state(self) -> Dict[str, Any]:
  232. state = {"num_added": self.num_added, "replay_buffers": {}}
  233. for policy_id, replay_buffer in self.replay_buffers.items():
  234. state["replay_buffers"][policy_id] = replay_buffer.get_state()
  235. return state
  236. def set_state(self, state: Dict[str, Any]) -> None:
  237. self.num_added = state["num_added"]
  238. buffer_states = state["replay_buffers"]
  239. for policy_id in buffer_states.keys():
  240. self.replay_buffers[policy_id].set_state(buffer_states[policy_id])
  241. ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer)