replay_buffer.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. import collections
  2. import logging
  3. import numpy as np
  4. import platform
  5. import random
  6. from typing import Any, Dict, List, Optional
  7. # Import ray before psutil will make sure we use psutil's bundled version
  8. import ray # noqa F401
  9. import psutil # noqa E402
  10. from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
  11. from ray.rllib.policy.rnn_sequencing import \
  12. timeslice_along_seq_lens_with_overlap
  13. from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
  14. DEFAULT_POLICY_ID
  15. from ray.rllib.utils.annotations import DeveloperAPI, override
  16. from ray.util.iter import ParallelIteratorWorker
  17. from ray.util.debug import log_once
  18. from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE, \
  19. deprecation_warning
  20. from ray.rllib.utils.timer import TimerStat
  21. from ray.rllib.utils.metrics.window_stat import WindowStat
  22. from ray.rllib.utils.typing import SampleBatchType
  23. # Constant that represents all policies in lockstep replay mode.
  24. _ALL_POLICIES = "__all__"
  25. logger = logging.getLogger(__name__)
  26. def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None:
  27. """Warn if the configured replay buffer capacity is too large."""
  28. if log_once("replay_capacity"):
  29. item_size = item.size_bytes()
  30. psutil_mem = psutil.virtual_memory()
  31. total_gb = psutil_mem.total / 1e9
  32. mem_size = num_items * item_size / 1e9
  33. msg = ("Estimated max memory usage for replay buffer is {} GB "
  34. "({} batches of size {}, {} bytes each), "
  35. "available system memory is {} GB".format(
  36. mem_size, num_items, item.count, item_size, total_gb))
  37. if mem_size > total_gb:
  38. raise ValueError(msg)
  39. elif mem_size > 0.2 * total_gb:
  40. logger.warning(msg)
  41. else:
  42. logger.info(msg)
  43. @Deprecated(new="warn_replay_capacity", error=False)
  44. def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None:
  45. return warn_replay_capacity(item=item, num_items=num_items)
  46. @DeveloperAPI
  47. class ReplayBuffer:
  48. @DeveloperAPI
  49. def __init__(self,
  50. capacity: int = 10000,
  51. size: Optional[int] = DEPRECATED_VALUE):
  52. """Initializes a Replaybuffer instance.
  53. Args:
  54. capacity (int): Max number of timesteps to store in the FIFO
  55. buffer. After reaching this number, older samples will be
  56. dropped to make space for new ones.
  57. """
  58. # Deprecated args.
  59. if size != DEPRECATED_VALUE:
  60. deprecation_warning(
  61. "ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False)
  62. capacity = size
  63. # The actual storage (list of SampleBatches).
  64. self._storage = []
  65. self.capacity = capacity
  66. # The next index to override in the buffer.
  67. self._next_idx = 0
  68. self._hit_count = np.zeros(self.capacity)
  69. # Whether we have already hit our capacity (and have therefore
  70. # started to evict older samples).
  71. self._eviction_started = False
  72. self._num_timesteps_added = 0
  73. self._num_timesteps_added_wrap = 0
  74. self._num_timesteps_sampled = 0
  75. self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
  76. self._est_size_bytes = 0
  77. def __len__(self) -> int:
  78. return len(self._storage)
  79. @DeveloperAPI
  80. def add(self, item: SampleBatchType, weight: float) -> None:
  81. assert item.count > 0, item
  82. warn_replay_capacity(item=item, num_items=self.capacity / item.count)
  83. self._num_timesteps_added += item.count
  84. self._num_timesteps_added_wrap += item.count
  85. if self._next_idx >= len(self._storage):
  86. self._storage.append(item)
  87. self._est_size_bytes += item.size_bytes()
  88. else:
  89. self._storage[self._next_idx] = item
  90. # Wrap around storage as a circular buffer once we hit capacity.
  91. if self._num_timesteps_added_wrap >= self.capacity:
  92. self._eviction_started = True
  93. self._num_timesteps_added_wrap = 0
  94. self._next_idx = 0
  95. else:
  96. self._next_idx += 1
  97. if self._eviction_started:
  98. self._evicted_hit_stats.push(self._hit_count[self._next_idx])
  99. self._hit_count[self._next_idx] = 0
  100. def _encode_sample(self, idxes: List[int]) -> SampleBatchType:
  101. out = SampleBatch.concat_samples([self._storage[i] for i in idxes])
  102. out.decompress_if_needed()
  103. return out
  104. @DeveloperAPI
  105. def sample(self, num_items: int) -> SampleBatchType:
  106. """Sample a batch of experiences.
  107. Args:
  108. num_items (int): Number of items to sample from this buffer.
  109. Returns:
  110. SampleBatchType: concatenated batch of items.
  111. """
  112. idxes = [
  113. random.randint(0,
  114. len(self._storage) - 1) for _ in range(num_items)
  115. ]
  116. self._num_sampled += num_items
  117. return self._encode_sample(idxes)
  118. @DeveloperAPI
  119. def stats(self, debug=False) -> dict:
  120. data = {
  121. "added_count": self._num_timesteps_added,
  122. "added_count_wrapped": self._num_timesteps_added_wrap,
  123. "eviction_started": self._eviction_started,
  124. "sampled_count": self._num_timesteps_sampled,
  125. "est_size_bytes": self._est_size_bytes,
  126. "num_entries": len(self._storage),
  127. }
  128. if debug:
  129. data.update(self._evicted_hit_stats.stats())
  130. return data
  131. @DeveloperAPI
  132. def get_state(self) -> Dict[str, Any]:
  133. """Returns all local state.
  134. Returns:
  135. Dict[str, Any]: The serializable local state.
  136. """
  137. state = {"_storage": self._storage, "_next_idx": self._next_idx}
  138. state.update(self.stats(debug=False))
  139. return state
  140. @DeveloperAPI
  141. def set_state(self, state: Dict[str, Any]) -> None:
  142. """Restores all local state to the provided `state`.
  143. Args:
  144. state (Dict[str, Any]): The new state to set this buffer. Can be
  145. obtained by calling `self.get_state()`.
  146. """
  147. # The actual storage.
  148. self._storage = state["_storage"]
  149. self._next_idx = state["_next_idx"]
  150. # Stats and counts.
  151. self._num_timesteps_added = state["added_count"]
  152. self._num_timesteps_added_wrap = state["added_count_wrapped"]
  153. self._eviction_started = state["eviction_started"]
  154. self._num_timesteps_sampled = state["sampled_count"]
  155. self._est_size_bytes = state["est_size_bytes"]
  156. @DeveloperAPI
  157. class PrioritizedReplayBuffer(ReplayBuffer):
  158. @DeveloperAPI
  159. def __init__(self,
  160. capacity: int = 10000,
  161. alpha: float = 1.0,
  162. size: Optional[int] = DEPRECATED_VALUE):
  163. """Initializes a PrioritizedReplayBuffer instance.
  164. Args:
  165. capacity (int): Max number of timesteps to store in the FIFO
  166. buffer. After reaching this number, older samples will be
  167. dropped to make space for new ones.
  168. alpha (float): How much prioritization is used
  169. (0.0=no prioritization, 1.0=full prioritization).
  170. """
  171. super(PrioritizedReplayBuffer, self).__init__(capacity, size)
  172. assert alpha > 0
  173. self._alpha = alpha
  174. it_capacity = 1
  175. while it_capacity < self.capacity:
  176. it_capacity *= 2
  177. self._it_sum = SumSegmentTree(it_capacity)
  178. self._it_min = MinSegmentTree(it_capacity)
  179. self._max_priority = 1.0
  180. self._prio_change_stats = WindowStat("reprio", 1000)
  181. @DeveloperAPI
  182. @override(ReplayBuffer)
  183. def add(self, item: SampleBatchType, weight: float) -> None:
  184. idx = self._next_idx
  185. super(PrioritizedReplayBuffer, self).add(item, weight)
  186. if weight is None:
  187. weight = self._max_priority
  188. self._it_sum[idx] = weight**self._alpha
  189. self._it_min[idx] = weight**self._alpha
  190. def _sample_proportional(self, num_items: int) -> List[int]:
  191. res = []
  192. for _ in range(num_items):
  193. # TODO(szymon): should we ensure no repeats?
  194. mass = random.random() * self._it_sum.sum(0, len(self._storage))
  195. idx = self._it_sum.find_prefixsum_idx(mass)
  196. res.append(idx)
  197. return res
  198. @DeveloperAPI
  199. @override(ReplayBuffer)
  200. def sample(self, num_items: int, beta: float) -> SampleBatchType:
  201. """Sample a batch of experiences and return priority weights, indices.
  202. Args:
  203. num_items (int): Number of items to sample from this buffer.
  204. beta (float): To what degree to use importance weights
  205. (0 - no corrections, 1 - full correction).
  206. Returns:
  207. SampleBatchType: Concatenated batch of items including "weights"
  208. and "batch_indexes" fields denoting IS of each sampled
  209. transition and original idxes in buffer of sampled experiences.
  210. """
  211. assert beta >= 0.0
  212. idxes = self._sample_proportional(num_items)
  213. weights = []
  214. batch_indexes = []
  215. p_min = self._it_min.min() / self._it_sum.sum()
  216. max_weight = (p_min * len(self._storage))**(-beta)
  217. for idx in idxes:
  218. p_sample = self._it_sum[idx] / self._it_sum.sum()
  219. weight = (p_sample * len(self._storage))**(-beta)
  220. count = self._storage[idx].count
  221. # If zero-padded, count will not be the actual batch size of the
  222. # data.
  223. if isinstance(self._storage[idx], SampleBatch) and \
  224. self._storage[idx].zero_padded:
  225. actual_size = self._storage[idx].max_seq_len
  226. else:
  227. actual_size = count
  228. weights.extend([weight / max_weight] * actual_size)
  229. batch_indexes.extend([idx] * actual_size)
  230. self._num_timesteps_sampled += count
  231. batch = self._encode_sample(idxes)
  232. # Note: prioritization is not supported in lockstep replay mode.
  233. if isinstance(batch, SampleBatch):
  234. batch["weights"] = np.array(weights)
  235. batch["batch_indexes"] = np.array(batch_indexes)
  236. return batch
  237. @DeveloperAPI
  238. def update_priorities(self, idxes: List[int],
  239. priorities: List[float]) -> None:
  240. """Update priorities of sampled transitions.
  241. sets priority of transition at index idxes[i] in buffer
  242. to priorities[i].
  243. Parameters
  244. ----------
  245. idxes: [int]
  246. List of idxes of sampled transitions
  247. priorities: [float]
  248. List of updated priorities corresponding to
  249. transitions at the sampled idxes denoted by
  250. variable `idxes`.
  251. """
  252. # Making sure we don't pass in e.g. a torch tensor.
  253. assert isinstance(idxes, (list, np.ndarray)), \
  254. "ERROR: `idxes` is not a list or np.ndarray, but " \
  255. "{}!".format(type(idxes).__name__)
  256. assert len(idxes) == len(priorities)
  257. for idx, priority in zip(idxes, priorities):
  258. assert priority > 0
  259. assert 0 <= idx < len(self._storage)
  260. delta = priority**self._alpha - self._it_sum[idx]
  261. self._prio_change_stats.push(delta)
  262. self._it_sum[idx] = priority**self._alpha
  263. self._it_min[idx] = priority**self._alpha
  264. self._max_priority = max(self._max_priority, priority)
  265. @DeveloperAPI
  266. @override(ReplayBuffer)
  267. def stats(self, debug: bool = False) -> Dict:
  268. parent = ReplayBuffer.stats(self, debug)
  269. if debug:
  270. parent.update(self._prio_change_stats.stats())
  271. return parent
  272. @DeveloperAPI
  273. @override(ReplayBuffer)
  274. def get_state(self) -> Dict[str, Any]:
  275. """Returns all local state.
  276. Returns:
  277. Dict[str, Any]: The serializable local state.
  278. """
  279. # Get parent state.
  280. state = super().get_state()
  281. # Add prio weights.
  282. state.update({
  283. "sum_segment_tree": self._it_sum.get_state(),
  284. "min_segment_tree": self._it_min.get_state(),
  285. "max_priority": self._max_priority,
  286. })
  287. return state
  288. @DeveloperAPI
  289. @override(ReplayBuffer)
  290. def set_state(self, state: Dict[str, Any]) -> None:
  291. """Restores all local state to the provided `state`.
  292. Args:
  293. state (Dict[str, Any]): The new state to set this buffer. Can be
  294. obtained by calling `self.get_state()`.
  295. """
  296. super().set_state(state)
  297. self._it_sum.set_state(state["sum_segment_tree"])
  298. self._it_min.set_state(state["min_segment_tree"])
  299. self._max_priority = state["max_priority"]
  300. # Visible for testing.
  301. _local_replay_buffer = None
  302. class LocalReplayBuffer(ParallelIteratorWorker):
  303. """A replay buffer shard storing data for all policies (in multiagent setup).
  304. Ray actors are single-threaded, so for scalability, multiple replay actors
  305. may be created to increase parallelism."""
  306. def __init__(
  307. self,
  308. num_shards: int = 1,
  309. learning_starts: int = 1000,
  310. capacity: int = 10000,
  311. replay_batch_size: int = 1,
  312. prioritized_replay_alpha: float = 0.6,
  313. prioritized_replay_beta: float = 0.4,
  314. prioritized_replay_eps: float = 1e-6,
  315. replay_mode: str = "independent",
  316. replay_sequence_length: int = 1,
  317. replay_burn_in: int = 0,
  318. replay_zero_init_states: bool = True,
  319. buffer_size=DEPRECATED_VALUE,
  320. ):
  321. """Initializes a LocalReplayBuffer instance.
  322. Args:
  323. num_shards (int): The number of buffer shards that exist in total
  324. (including this one).
  325. learning_starts (int): Number of timesteps after which a call to
  326. `replay()` will yield samples (before that, `replay()` will
  327. return None).
  328. capacity (int): The capacity of the buffer. Note that when
  329. `replay_sequence_length` > 1, this is the number of sequences
  330. (not single timesteps) stored.
  331. replay_batch_size (int): The batch size to be sampled (in
  332. timesteps). Note that if `replay_sequence_length` > 1,
  333. `self.replay_batch_size` will be set to the number of
  334. sequences sampled (B).
  335. prioritized_replay_alpha (float): Alpha parameter for a prioritized
  336. replay buffer.
  337. prioritized_replay_beta (float): Beta parameter for a prioritized
  338. replay buffer.
  339. prioritized_replay_eps (float): Epsilon parameter for a prioritized
  340. replay buffer.
  341. replay_mode (str): One of "independent" or "lockstep". Determined,
  342. whether in the multiagent case, sampling is done across all
  343. agents/policies equally.
  344. replay_sequence_length (int): The sequence length (T) of a single
  345. sample. If > 1, we will sample B x T from this buffer.
  346. replay_burn_in (int): The burn-in length in case
  347. `replay_sequence_length` > 0. This is the number of timesteps
  348. each sequence overlaps with the previous one to generate a
  349. better internal state (=state after the burn-in), instead of
  350. starting from 0.0 each RNN rollout.
  351. replay_zero_init_states (bool): Whether the initial states in the
  352. buffer (if replay_sequence_length > 0) are alwayas 0.0 or
  353. should be updated with the previous train_batch state outputs.
  354. """
  355. # Deprecated args.
  356. if buffer_size != DEPRECATED_VALUE:
  357. deprecation_warning(
  358. "ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False)
  359. capacity = buffer_size
  360. self.replay_starts = learning_starts // num_shards
  361. self.capacity = capacity // num_shards
  362. self.replay_batch_size = replay_batch_size
  363. self.prioritized_replay_beta = prioritized_replay_beta
  364. self.prioritized_replay_eps = prioritized_replay_eps
  365. self.replay_mode = replay_mode
  366. self.replay_sequence_length = replay_sequence_length
  367. self.replay_burn_in = replay_burn_in
  368. self.replay_zero_init_states = replay_zero_init_states
  369. if replay_sequence_length > 1:
  370. self.replay_batch_size = int(
  371. max(1, replay_batch_size // replay_sequence_length))
  372. logger.info(
  373. "Since replay_sequence_length={} and replay_batch_size={}, "
  374. "we will replay {} sequences at a time.".format(
  375. replay_sequence_length, replay_batch_size,
  376. self.replay_batch_size))
  377. if replay_mode not in ["lockstep", "independent"]:
  378. raise ValueError("Unsupported replay mode: {}".format(replay_mode))
  379. def gen_replay():
  380. while True:
  381. yield self.replay()
  382. ParallelIteratorWorker.__init__(self, gen_replay, False)
  383. def new_buffer():
  384. return PrioritizedReplayBuffer(
  385. self.capacity, alpha=prioritized_replay_alpha)
  386. self.replay_buffers = collections.defaultdict(new_buffer)
  387. # Metrics.
  388. self.add_batch_timer = TimerStat()
  389. self.replay_timer = TimerStat()
  390. self.update_priorities_timer = TimerStat()
  391. self.num_added = 0
  392. # Make externally accessible for testing.
  393. global _local_replay_buffer
  394. _local_replay_buffer = self
  395. # If set, return this instead of the usual data for testing.
  396. self._fake_batch = None
  397. @staticmethod
  398. def get_instance_for_testing():
  399. global _local_replay_buffer
  400. return _local_replay_buffer
  401. def get_host(self) -> str:
  402. return platform.node()
  403. def add_batch(self, batch: SampleBatchType) -> None:
  404. # Make a copy so the replay buffer doesn't pin plasma memory.
  405. batch = batch.copy()
  406. # Handle everything as if multiagent
  407. if isinstance(batch, SampleBatch):
  408. batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
  409. with self.add_batch_timer:
  410. # Lockstep mode: Store under _ALL_POLICIES key (we will always
  411. # only sample from all policies at the same time).
  412. if self.replay_mode == "lockstep":
  413. # Note that prioritization is not supported in this mode.
  414. for s in batch.timeslices(self.replay_sequence_length):
  415. self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
  416. else:
  417. for policy_id, sample_batch in batch.policy_batches.items():
  418. if self.replay_sequence_length == 1:
  419. timeslices = sample_batch.timeslices(1)
  420. else:
  421. timeslices = timeslice_along_seq_lens_with_overlap(
  422. sample_batch=sample_batch,
  423. zero_pad_max_seq_len=self.replay_sequence_length,
  424. pre_overlap=self.replay_burn_in,
  425. zero_init_states=self.replay_zero_init_states,
  426. )
  427. for time_slice in timeslices:
  428. # If SampleBatch has prio-replay weights, average
  429. # over these to use as a weight for the entire
  430. # sequence.
  431. if "weights" in time_slice and \
  432. len(time_slice["weights"]):
  433. weight = np.mean(time_slice["weights"])
  434. else:
  435. weight = None
  436. self.replay_buffers[policy_id].add(
  437. time_slice, weight=weight)
  438. self.num_added += batch.count
  439. def replay(self) -> SampleBatchType:
  440. if self._fake_batch:
  441. fake_batch = SampleBatch(self._fake_batch)
  442. return MultiAgentBatch({
  443. DEFAULT_POLICY_ID: fake_batch
  444. }, fake_batch.count)
  445. if self.num_added < self.replay_starts:
  446. return None
  447. with self.replay_timer:
  448. # Lockstep mode: Sample from all policies at the same time an
  449. # equal amount of steps.
  450. if self.replay_mode == "lockstep":
  451. return self.replay_buffers[_ALL_POLICIES].sample(
  452. self.replay_batch_size, beta=self.prioritized_replay_beta)
  453. else:
  454. samples = {}
  455. for policy_id, replay_buffer in self.replay_buffers.items():
  456. samples[policy_id] = replay_buffer.sample(
  457. self.replay_batch_size,
  458. beta=self.prioritized_replay_beta)
  459. return MultiAgentBatch(samples, self.replay_batch_size)
  460. def update_priorities(self, prio_dict: Dict) -> None:
  461. with self.update_priorities_timer:
  462. for policy_id, (batch_indexes, td_errors) in prio_dict.items():
  463. new_priorities = (
  464. np.abs(td_errors) + self.prioritized_replay_eps)
  465. self.replay_buffers[policy_id].update_priorities(
  466. batch_indexes, new_priorities)
  467. def stats(self, debug: bool = False) -> Dict:
  468. stat = {
  469. "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
  470. "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
  471. "update_priorities_time_ms": round(
  472. 1000 * self.update_priorities_timer.mean, 3),
  473. }
  474. for policy_id, replay_buffer in self.replay_buffers.items():
  475. stat.update({
  476. "policy_{}".format(policy_id): replay_buffer.stats(debug=debug)
  477. })
  478. return stat
  479. def get_state(self) -> Dict[str, Any]:
  480. state = {"num_added": self.num_added, "replay_buffers": {}}
  481. for policy_id, replay_buffer in self.replay_buffers.items():
  482. state["replay_buffers"][policy_id] = replay_buffer.get_state()
  483. return state
  484. def set_state(self, state: Dict[str, Any]) -> None:
  485. self.num_added = state["num_added"]
  486. buffer_states = state["replay_buffers"]
  487. for policy_id in buffer_states.keys():
  488. self.replay_buffers[policy_id].set_state(buffer_states[policy_id])
  489. ReplayActor = ray.remote(num_cpus=0)(LocalReplayBuffer)