123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421 |
- import unittest
- from typing import List, Union
- import numpy as np
- from rllib_dt.dt.segmentation_buffer import (
- MultiAgentSegmentationBuffer,
- SegmentationBuffer,
- )
- import ray
- from ray.rllib.policy.sample_batch import (
- DEFAULT_POLICY_ID,
- MultiAgentBatch,
- SampleBatch,
- concat_samples,
- )
- from ray.rllib.utils import test_utils
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- from ray.rllib.utils.typing import PolicyID
- tf1, tf, tfv = try_import_tf()
- torch, _ = try_import_torch()
- def _generate_episode_batch(ep_len, eps_id, obs_dim=8, act_dim=3):
- """Generate a batch containing one episode."""
- # These values are not actually correct as usual. But using eps_id
- # as the values allow us to identify them in the tests.
- batch = SampleBatch(
- {
- SampleBatch.OBS: np.full((ep_len, obs_dim), eps_id, dtype=np.float32),
- SampleBatch.ACTIONS: np.full(
- (ep_len, act_dim), eps_id + 100, dtype=np.float32
- ),
- SampleBatch.REWARDS: np.ones((ep_len,), dtype=np.float32),
- SampleBatch.RETURNS_TO_GO: np.arange(
- ep_len, -1, -1, dtype=np.float32
- ).reshape((ep_len + 1, 1)),
- SampleBatch.EPS_ID: np.full((ep_len,), eps_id, dtype=np.int32),
- SampleBatch.T: np.arange(ep_len, dtype=np.int32),
- SampleBatch.ATTENTION_MASKS: np.ones(ep_len, dtype=np.float32),
- SampleBatch.TERMINATEDS: np.array([False] * (ep_len - 1) + [True]),
- SampleBatch.TRUNCATEDS: np.array([False] * ep_len),
- }
- )
- return batch
- def _assert_sample_batch_keys(batch: SampleBatch):
- """Assert sampled batch has the requisite keys."""
- assert SampleBatch.OBS in batch
- assert SampleBatch.ACTIONS in batch
- assert SampleBatch.RETURNS_TO_GO in batch
- assert SampleBatch.T in batch
- assert SampleBatch.ATTENTION_MASKS in batch
- def _assert_sample_batch_not_equal(b1: SampleBatch, b2: SampleBatch):
- """Assert that the two batches are not equal."""
- for key in b1.keys() & b2.keys():
- if b1[key].shape == b2[key].shape:
- assert not np.allclose(
- b1[key], b2[key]
- ), f"Key {key} contain the same value when they should not."
- def _assert_is_segment(segment: SampleBatch, episode: SampleBatch):
- """Assert that the sampled segment is a segment of episode."""
- timesteps = segment[SampleBatch.T]
- masks = segment[SampleBatch.ATTENTION_MASKS] > 0.5
- seq_len = timesteps.shape[0]
- episode_segment = episode.slice(timesteps[0], timesteps[-1] + 1)
- assert np.allclose(
- segment[SampleBatch.OBS][masks], episode_segment[SampleBatch.OBS]
- )
- assert np.allclose(
- segment[SampleBatch.ACTIONS][masks], episode_segment[SampleBatch.ACTIONS]
- )
- assert np.allclose(
- segment[SampleBatch.RETURNS_TO_GO][:seq_len][masks],
- episode_segment[SampleBatch.RETURNS_TO_GO],
- )
- def _get_internal_buffer(
- buffer: Union[SegmentationBuffer, MultiAgentSegmentationBuffer],
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- ) -> List[SampleBatch]:
- """Get the internal buffer list from the buffer. If MultiAgent then return the
- internal buffer corresponding to the given policy_id.
- """
- if type(buffer) == SegmentationBuffer:
- return buffer._buffer
- elif type(buffer) == MultiAgentSegmentationBuffer:
- return buffer.buffers[policy_id]._buffer
- else:
- raise NotImplementedError
- def _as_sample_batch(
- batch: Union[SampleBatch, MultiAgentBatch],
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- ) -> SampleBatch:
- """Returns a SampleBatch. If MultiAgentBatch then return the SampleBatch
- corresponding to the given policy_id.
- """
- if type(batch) == SampleBatch:
- return batch
- elif type(batch) == MultiAgentBatch:
- return batch.policy_batches[policy_id]
- else:
- raise NotImplementedError
- class TestSegmentationBuffer(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- ray.init()
- @classmethod
- def tearDownClass(cls):
- ray.shutdown()
- def test_add(self):
- """Test adding to segmentation buffer."""
- for buffer_cls in [SegmentationBuffer, MultiAgentSegmentationBuffer]:
- max_seq_len = 3
- max_ep_len = 10
- capacity = 1
- buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
- # generate batch
- episode_batches = []
- for i in range(4):
- episode_batches.append(_generate_episode_batch(max_ep_len, i))
- batch = concat_samples(episode_batches)
- # add to buffer and check that only last one is kept (due to replacement)
- buffer.add(batch)
- self.assertEqual(
- len(_get_internal_buffer(buffer)),
- 1,
- "The internal buffer should only contain one SampleBatch since"
- " the capacity is 1.",
- )
- test_utils.check(episode_batches[-1], _get_internal_buffer(buffer)[0])
- # add again
- buffer.add(episode_batches[0])
- test_utils.check(episode_batches[0], _get_internal_buffer(buffer)[0])
- # make buffer of enough capacity
- capacity = len(episode_batches)
- buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
- # add to buffer and make sure all are in
- buffer.add(batch)
- self.assertEqual(
- len(_get_internal_buffer(buffer)),
- len(episode_batches),
- "internal buffer doesn't have the right number of episodes.",
- )
- for i in range(len(episode_batches)):
- test_utils.check(episode_batches[i], _get_internal_buffer(buffer)[i])
- # add another one and make sure it replaced one of them
- new_batch = _generate_episode_batch(max_ep_len, 12345)
- buffer.add(new_batch)
- self.assertEqual(
- len(_get_internal_buffer(buffer)),
- len(episode_batches),
- "internal buffer doesn't have the right number of episodes.",
- )
- found = False
- for episode_batch in _get_internal_buffer(buffer):
- if episode_batch[SampleBatch.EPS_ID][0] == 12345:
- test_utils.check(episode_batch, new_batch)
- found = True
- break
- assert found, "new_batch not added to buffer."
- # test that adding too long an episode errors
- long_batch = _generate_episode_batch(max_ep_len + 1, 123)
- with self.assertRaises(ValueError):
- buffer.add(long_batch)
- def test_sample_basic(self):
- """Test sampling from a segmentation buffer."""
- for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer):
- max_seq_len = 5
- max_ep_len = 15
- capacity = 4
- obs_dim = 10
- act_dim = 2
- buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
- # generate batch and add to buffer
- episode_batches = []
- for i in range(8):
- episode_batches.append(
- _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
- )
- batch = concat_samples(episode_batches)
- buffer.add(batch)
- # sample a few times and check shape
- for bs in range(10, 20):
- batch = _as_sample_batch(buffer.sample(bs))
- # check the keys exist
- _assert_sample_batch_keys(batch)
- # check the shapes
- self.assertEquals(
- batch[SampleBatch.OBS].shape, (bs, max_seq_len, obs_dim)
- )
- self.assertEquals(
- batch[SampleBatch.ACTIONS].shape, (bs, max_seq_len, act_dim)
- )
- self.assertEquals(
- batch[SampleBatch.RETURNS_TO_GO].shape,
- (
- bs,
- max_seq_len + 1,
- 1,
- ),
- )
- self.assertEquals(batch[SampleBatch.T].shape, (bs, max_seq_len))
- self.assertEquals(
- batch[SampleBatch.ATTENTION_MASKS].shape, (bs, max_seq_len)
- )
- def test_sample_content(self):
- """Test that the content of the sampling are valid."""
- for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer):
- max_seq_len = 5
- max_ep_len = 200
- capacity = 1
- obs_dim = 11
- act_dim = 1
- buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
- # generate single episode and add to buffer
- episode = _generate_episode_batch(max_ep_len, 123, obs_dim, act_dim)
- buffer.add(episode)
- # sample twice and make sure they are not equal.
- # with a 200 max_ep_len and 200 samples, the probability that the two
- # samples are equal by chance is (1/200)**200 which is basically zero.
- sample1 = _as_sample_batch(buffer.sample(200))
- sample2 = _as_sample_batch(buffer.sample(200))
- _assert_sample_batch_keys(sample1)
- _assert_sample_batch_keys(sample2)
- _assert_sample_batch_not_equal(sample1, sample2)
- # sample and make sure the segments are actual segments of the episode
- batch = _as_sample_batch(buffer.sample(1000))
- _assert_sample_batch_keys(batch)
- for elem in batch.rows():
- _assert_is_segment(SampleBatch(elem), episode)
- def test_sample_capacity(self):
- """Test that sampling from buffer of capacity > 1 works."""
- for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer):
- max_seq_len = 3
- max_ep_len = 10
- capacity = 100
- obs_dim = 1
- act_dim = 1
- buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
- # Generate batch and add to buffer
- episode_batches = []
- for i in range(capacity):
- episode_batches.append(
- _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
- )
- buffer.add(concat_samples(episode_batches))
- # Sample 100 times and check that samples are from at least 2 different
- # episodes. The [robability of all sampling from 1 episode by chance is
- # (1/100)**99 which is basically zero.
- batch = _as_sample_batch(buffer.sample(100))
- eps_ids = set()
- for i in range(100):
- # obs generated by _generate_episode_batch contains eps_id
- # use -1 because there might be front padding
- eps_id = int(batch[SampleBatch.OBS][i, -1, 0])
- eps_ids.add(eps_id)
- self.assertGreater(
- len(eps_ids), 1, "buffer.sample is always returning the same episode."
- )
- def test_padding(self):
- """Test that sample will front pad segments."""
- for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer):
- max_seq_len = 10
- max_ep_len = 100
- capacity = 1
- obs_dim = 3
- act_dim = 2
- buffer = buffer_cls(capacity, max_seq_len, max_ep_len)
- for ep_len in range(1, max_seq_len):
- # generate batch with episode lengths that are shorter than
- # max_seq_len to test padding.
- batch = _generate_episode_batch(ep_len, 123, obs_dim, act_dim)
- buffer.add(batch)
- samples = _as_sample_batch(buffer.sample(50))
- for i in range(50):
- # calculate number of pads based on the attention mask.
- num_pad = int(
- ep_len - samples[SampleBatch.ATTENTION_MASKS][i].sum()
- )
- for key in samples.keys():
- # make sure padding are added.
- assert np.allclose(
- samples[key][i, :num_pad], 0.0
- ), "samples were not padded correctly."
- def test_multi_agent(self):
- max_seq_len = 5
- max_ep_len = 20
- capacity = 10
- obs_dim = 3
- act_dim = 5
- ma_buffer = MultiAgentSegmentationBuffer(capacity, max_seq_len, max_ep_len)
- policy_id1 = "1"
- policy_id2 = "2"
- policy_id3 = "3"
- policy_ids = {policy_id1, policy_id2, policy_id3}
- policy1_batches = []
- for i in range(0, 10):
- policy1_batches.append(
- _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
- )
- policy2_batches = []
- for i in range(10, 20):
- policy2_batches.append(
- _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
- )
- policy3_batches = []
- for i in range(20, 30):
- policy3_batches.append(
- _generate_episode_batch(max_ep_len, i, obs_dim, act_dim)
- )
- batches_mapping = {
- policy_id1: policy1_batches,
- policy_id2: policy2_batches,
- policy_id3: policy3_batches,
- }
- ma_batch = MultiAgentBatch(
- {
- policy_id1: concat_samples(policy1_batches),
- policy_id2: concat_samples(policy2_batches),
- policy_id3: concat_samples(policy3_batches),
- },
- max_ep_len * 10,
- )
- ma_buffer.add(ma_batch)
- # check all are added properly
- for policy_id in policy_ids:
- assert policy_id in ma_buffer.buffers.keys()
- for policy_id, buffer in ma_buffer.buffers.items():
- assert policy_id in policy_ids
- for i in range(10):
- test_utils.check(
- batches_mapping[policy_id][i], _get_internal_buffer(buffer)[i]
- )
- # check that sampling are proper
- for _ in range(50):
- ma_sample = ma_buffer.sample(100)
- for policy_id in policy_ids:
- assert policy_id in ma_sample.policy_batches.keys()
- for policy_id, batch in ma_sample.policy_batches.items():
- eps_id_start = (int(policy_id) - 1) * 10
- eps_id_end = eps_id_start + 10
- _assert_sample_batch_keys(batch)
- for i in range(100):
- # Obs generated by _generate_episode_batch contains eps_id.
- # Use -1 index because there might be front padding
- eps_id = int(batch[SampleBatch.OBS][i, -1, 0])
- assert (
- eps_id_start <= eps_id < eps_id_end
- ), "batch within multi agent batch has the wrong agent's episode."
- # sample twice and make sure they are not equal (probability equal almost zero)
- ma_sample1 = ma_buffer.sample(200)
- ma_sample2 = ma_buffer.sample(200)
- for policy_id in policy_ids:
- _assert_sample_batch_not_equal(
- ma_sample1.policy_batches[policy_id],
- ma_sample2.policy_batches[policy_id],
- )
- if __name__ == "__main__":
- import sys
- import pytest
- sys.exit(pytest.main(["-v", __file__]))
|