test_segmentation_buffer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import unittest
  2. from typing import List, Union
  3. import numpy as np
  4. from rllib_dt.dt.segmentation_buffer import (
  5. MultiAgentSegmentationBuffer,
  6. SegmentationBuffer,
  7. )
  8. import ray
  9. from ray.rllib.policy.sample_batch import (
  10. DEFAULT_POLICY_ID,
  11. MultiAgentBatch,
  12. SampleBatch,
  13. concat_samples,
  14. )
  15. from ray.rllib.utils import test_utils
  16. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  17. from ray.rllib.utils.typing import PolicyID
  18. tf1, tf, tfv = try_import_tf()
  19. torch, _ = try_import_torch()
  20. def _generate_episode_batch(ep_len, eps_id, obs_dim=8, act_dim=3):
  21. """Generate a batch containing one episode."""
  22. # These values are not actually correct as usual. But using eps_id
  23. # as the values allow us to identify them in the tests.
  24. batch = SampleBatch(
  25. {
  26. SampleBatch.OBS: np.full((ep_len, obs_dim), eps_id, dtype=np.float32),
  27. SampleBatch.ACTIONS: np.full(
  28. (ep_len, act_dim), eps_id + 100, dtype=np.float32
  29. ),
  30. SampleBatch.REWARDS: np.ones((ep_len,), dtype=np.float32),
  31. SampleBatch.RETURNS_TO_GO: np.arange(
  32. ep_len, -1, -1, dtype=np.float32
  33. ).reshape((ep_len + 1, 1)),
  34. SampleBatch.EPS_ID: np.full((ep_len,), eps_id, dtype=np.int32),
  35. SampleBatch.T: np.arange(ep_len, dtype=np.int32),
  36. SampleBatch.ATTENTION_MASKS: np.ones(ep_len, dtype=np.float32),
  37. SampleBatch.TERMINATEDS: np.array([False] * (ep_len - 1) + [True]),
  38. SampleBatch.TRUNCATEDS: np.array([False] * ep_len),
  39. }
  40. )
  41. return batch
  42. def _assert_sample_batch_keys(batch: SampleBatch):
  43. """Assert sampled batch has the requisite keys."""
  44. assert SampleBatch.OBS in batch
  45. assert SampleBatch.ACTIONS in batch
  46. assert SampleBatch.RETURNS_TO_GO in batch
  47. assert SampleBatch.T in batch
  48. assert SampleBatch.ATTENTION_MASKS in batch
  49. def _assert_sample_batch_not_equal(b1: SampleBatch, b2: SampleBatch):
  50. """Assert that the two batches are not equal."""
  51. for key in b1.keys() & b2.keys():
  52. if b1[key].shape == b2[key].shape:
  53. assert not np.allclose(
  54. b1[key], b2[key]
  55. ), f"Key {key} contain the same value when they should not."
  56. def _assert_is_segment(segment: SampleBatch, episode: SampleBatch):
  57. """Assert that the sampled segment is a segment of episode."""
  58. timesteps = segment[SampleBatch.T]
  59. masks = segment[SampleBatch.ATTENTION_MASKS] > 0.5
  60. seq_len = timesteps.shape[0]
  61. episode_segment = episode.slice(timesteps[0], timesteps[-1] + 1)
  62. assert np.allclose(
  63. segment[SampleBatch.OBS][masks], episode_segment[SampleBatch.OBS]
  64. )
  65. assert np.allclose(
  66. segment[SampleBatch.ACTIONS][masks], episode_segment[SampleBatch.ACTIONS]
  67. )
  68. assert np.allclose(
  69. segment[SampleBatch.RETURNS_TO_GO][:seq_len][masks],
  70. episode_segment[SampleBatch.RETURNS_TO_GO],
  71. )
  72. def _get_internal_buffer(
  73. buffer: Union[SegmentationBuffer, MultiAgentSegmentationBuffer],
  74. policy_id: PolicyID = DEFAULT_POLICY_ID,
  75. ) -> List[SampleBatch]:
  76. """Get the internal buffer list from the buffer. If MultiAgent then return the
  77. internal buffer corresponding to the given policy_id.
  78. """
  79. if type(buffer) == SegmentationBuffer:
  80. return buffer._buffer
  81. elif type(buffer) == MultiAgentSegmentationBuffer:
  82. return buffer.buffers[policy_id]._buffer
  83. else:
  84. raise NotImplementedError
  85. def _as_sample_batch(
  86. batch: Union[SampleBatch, MultiAgentBatch],
  87. policy_id: PolicyID = DEFAULT_POLICY_ID,
  88. ) -> SampleBatch:
  89. """Returns a SampleBatch. If MultiAgentBatch then return the SampleBatch
  90. corresponding to the given policy_id.
  91. """
  92. if type(batch) == SampleBatch:
  93. return batch
  94. elif type(batch) == MultiAgentBatch:
  95. return batch.policy_batches[policy_id]
  96. else:
  97. raise NotImplementedError
  98. class TestSegmentationBuffer(unittest.TestCase):
  99. @classmethod
  100. def setUpClass(cls):
  101. ray.init()
  102. @classmethod
  103. def tearDownClass(cls):
  104. ray.shutdown()
  105. def test_add(self):
  106. """Test adding to segmentation buffer."""
  107. for buffer_cls in [SegmentationBuffer, MultiAgentSegmentationBuffer]:
  108. max_seq_len = 3
  109. max_ep_len = 10
  110. capacity = 1
  111. buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
  112. # generate batch
  113. episode_batches = []
  114. for i in range(4):
  115. episode_batches.append(_generate_episode_batch(max_ep_len, i))
  116. batch = concat_samples(episode_batches)
  117. # add to buffer and check that only last one is kept (due to replacement)
  118. buffer.add(batch)
  119. self.assertEqual(
  120. len(_get_internal_buffer(buffer)),
  121. 1,
  122. "The internal buffer should only contain one SampleBatch since"
  123. " the capacity is 1.",
  124. )
  125. test_utils.check(episode_batches[-1], _get_internal_buffer(buffer)[0])
  126. # add again
  127. buffer.add(episode_batches[0])
  128. test_utils.check(episode_batches[0], _get_internal_buffer(buffer)[0])
  129. # make buffer of enough capacity
  130. capacity = len(episode_batches)
  131. buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
  132. # add to buffer and make sure all are in
  133. buffer.add(batch)
  134. self.assertEqual(
  135. len(_get_internal_buffer(buffer)),
  136. len(episode_batches),
  137. "internal buffer doesn't have the right number of episodes.",
  138. )
  139. for i in range(len(episode_batches)):
  140. test_utils.check(episode_batches[i], _get_internal_buffer(buffer)[i])
  141. # add another one and make sure it replaced one of them
  142. new_batch = _generate_episode_batch(max_ep_len, 12345)
  143. buffer.add(new_batch)
  144. self.assertEqual(
  145. len(_get_internal_buffer(buffer)),
  146. len(episode_batches),
  147. "internal buffer doesn't have the right number of episodes.",
  148. )
  149. found = False
  150. for episode_batch in _get_internal_buffer(buffer):
  151. if episode_batch[SampleBatch.EPS_ID][0] == 12345:
  152. test_utils.check(episode_batch, new_batch)
  153. found = True
  154. break
  155. assert found, "new_batch not added to buffer."
  156. # test that adding too long an episode errors
  157. long_batch = _generate_episode_batch(max_ep_len + 1, 123)
  158. with self.assertRaises(ValueError):
  159. buffer.add(long_batch)
  160. def test_sample_basic(self):
  161. """Test sampling from a segmentation buffer."""
  162. for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer):
  163. max_seq_len = 5
  164. max_ep_len = 15
  165. capacity = 4
  166. obs_dim = 10
  167. act_dim = 2
  168. buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
  169. # generate batch and add to buffer
  170. episode_batches = []
  171. for i in range(8):
  172. episode_batches.append(
  173. _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
  174. )
  175. batch = concat_samples(episode_batches)
  176. buffer.add(batch)
  177. # sample a few times and check shape
  178. for bs in range(10, 20):
  179. batch = _as_sample_batch(buffer.sample(bs))
  180. # check the keys exist
  181. _assert_sample_batch_keys(batch)
  182. # check the shapes
  183. self.assertEquals(
  184. batch[SampleBatch.OBS].shape, (bs, max_seq_len, obs_dim)
  185. )
  186. self.assertEquals(
  187. batch[SampleBatch.ACTIONS].shape, (bs, max_seq_len, act_dim)
  188. )
  189. self.assertEquals(
  190. batch[SampleBatch.RETURNS_TO_GO].shape,
  191. (
  192. bs,
  193. max_seq_len + 1,
  194. 1,
  195. ),
  196. )
  197. self.assertEquals(batch[SampleBatch.T].shape, (bs, max_seq_len))
  198. self.assertEquals(
  199. batch[SampleBatch.ATTENTION_MASKS].shape, (bs, max_seq_len)
  200. )
  201. def test_sample_content(self):
  202. """Test that the content of the sampling are valid."""
  203. for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer):
  204. max_seq_len = 5
  205. max_ep_len = 200
  206. capacity = 1
  207. obs_dim = 11
  208. act_dim = 1
  209. buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
  210. # generate single episode and add to buffer
  211. episode = _generate_episode_batch(max_ep_len, 123, obs_dim, act_dim)
  212. buffer.add(episode)
  213. # sample twice and make sure they are not equal.
  214. # with a 200 max_ep_len and 200 samples, the probability that the two
  215. # samples are equal by chance is (1/200)**200 which is basically zero.
  216. sample1 = _as_sample_batch(buffer.sample(200))
  217. sample2 = _as_sample_batch(buffer.sample(200))
  218. _assert_sample_batch_keys(sample1)
  219. _assert_sample_batch_keys(sample2)
  220. _assert_sample_batch_not_equal(sample1, sample2)
  221. # sample and make sure the segments are actual segments of the episode
  222. batch = _as_sample_batch(buffer.sample(1000))
  223. _assert_sample_batch_keys(batch)
  224. for elem in batch.rows():
  225. _assert_is_segment(SampleBatch(elem), episode)
  226. def test_sample_capacity(self):
  227. """Test that sampling from buffer of capacity > 1 works."""
  228. for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer):
  229. max_seq_len = 3
  230. max_ep_len = 10
  231. capacity = 100
  232. obs_dim = 1
  233. act_dim = 1
  234. buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
  235. # Generate batch and add to buffer
  236. episode_batches = []
  237. for i in range(capacity):
  238. episode_batches.append(
  239. _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
  240. )
  241. buffer.add(concat_samples(episode_batches))
  242. # Sample 100 times and check that samples are from at least 2 different
  243. # episodes. The [robability of all sampling from 1 episode by chance is
  244. # (1/100)**99 which is basically zero.
  245. batch = _as_sample_batch(buffer.sample(100))
  246. eps_ids = set()
  247. for i in range(100):
  248. # obs generated by _generate_episode_batch contains eps_id
  249. # use -1 because there might be front padding
  250. eps_id = int(batch[SampleBatch.OBS][i, -1, 0])
  251. eps_ids.add(eps_id)
  252. self.assertGreater(
  253. len(eps_ids), 1, "buffer.sample is always returning the same episode."
  254. )
  255. def test_padding(self):
  256. """Test that sample will front pad segments."""
  257. for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer):
  258. max_seq_len = 10
  259. max_ep_len = 100
  260. capacity = 1
  261. obs_dim = 3
  262. act_dim = 2
  263. buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
  264. for ep_len in range(1, max_seq_len):
  265. # generate batch with episode lengths that are shorter than
  266. # max_seq_len to test padding.
  267. batch = _generate_episode_batch(ep_len, 123, obs_dim, act_dim)
  268. buffer.add(batch)
  269. samples = _as_sample_batch(buffer.sample(50))
  270. for i in range(50):
  271. # calculate number of pads based on the attention mask.
  272. num_pad = int(
  273. ep_len - samples[SampleBatch.ATTENTION_MASKS][i].sum()
  274. )
  275. for key in samples.keys():
  276. # make sure padding are added.
  277. assert np.allclose(
  278. samples[key][i, :num_pad], 0.0
  279. ), "samples were not padded correctly."
  280. def test_multi_agent(self):
  281. max_seq_len = 5
  282. max_ep_len = 20
  283. capacity = 10
  284. obs_dim = 3
  285. act_dim = 5
  286. ma_buffer = MultiAgentSegmentationBuffer(capacity, max_seq_len, max_ep_len)
  287. policy_id1 = "1"
  288. policy_id2 = "2"
  289. policy_id3 = "3"
  290. policy_ids = {policy_id1, policy_id2, policy_id3}
  291. policy1_batches = []
  292. for i in range(0, 10):
  293. policy1_batches.append(
  294. _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
  295. )
  296. policy2_batches = []
  297. for i in range(10, 20):
  298. policy2_batches.append(
  299. _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
  300. )
  301. policy3_batches = []
  302. for i in range(20, 30):
  303. policy3_batches.append(
  304. _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
  305. )
  306. batches_mapping = {
  307. policy_id1: policy1_batches,
  308. policy_id2: policy2_batches,
  309. policy_id3: policy3_batches,
  310. }
  311. ma_batch = MultiAgentBatch(
  312. {
  313. policy_id1: concat_samples(policy1_batches),
  314. policy_id2: concat_samples(policy2_batches),
  315. policy_id3: concat_samples(policy3_batches),
  316. },
  317. max_ep_len * 10,
  318. )
  319. ma_buffer.add(ma_batch)
  320. # check all are added properly
  321. for policy_id in policy_ids:
  322. assert policy_id in ma_buffer.buffers.keys()
  323. for policy_id, buffer in ma_buffer.buffers.items():
  324. assert policy_id in policy_ids
  325. for i in range(10):
  326. test_utils.check(
  327. batches_mapping[policy_id][i], _get_internal_buffer(buffer)[i]
  328. )
  329. # check that sampling are proper
  330. for _ in range(50):
  331. ma_sample = ma_buffer.sample(100)
  332. for policy_id in policy_ids:
  333. assert policy_id in ma_sample.policy_batches.keys()
  334. for policy_id, batch in ma_sample.policy_batches.items():
  335. eps_id_start = (int(policy_id) - 1) * 10
  336. eps_id_end = eps_id_start + 10
  337. _assert_sample_batch_keys(batch)
  338. for i in range(100):
  339. # Obs generated by _generate_episode_batch contains eps_id.
  340. # Use -1 index because there might be front padding
  341. eps_id = int(batch[SampleBatch.OBS][i, -1, 0])
  342. assert (
  343. eps_id_start <= eps_id < eps_id_end
  344. ), "batch within multi agent batch has the wrong agent's episode."
  345. # sample twice and make sure they are not equal (probability equal almost zero)
  346. ma_sample1 = ma_buffer.sample(200)
  347. ma_sample2 = ma_buffer.sample(200)
  348. for policy_id in policy_ids:
  349. _assert_sample_batch_not_equal(
  350. ma_sample1.policy_batches[policy_id],
  351. ma_sample2.policy_batches[policy_id],
  352. )
  353. if __name__ == "__main__":
  354. import sys
  355. import pytest
  356. sys.exit(pytest.main(["-v", __file__]))