123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622 |
- import copy
- import gym
- from gym.spaces import Box, Discrete
- import numpy as np
- import unittest
- import ray
- from ray.rllib.agents.callbacks import DefaultCallbacks
- import ray.rllib.agents.dqn as dqn
- import ray.rllib.agents.ppo as ppo
- from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
- from ray.rllib.examples.env.multi_agent import MultiAgentPendulum
- from ray.rllib.evaluation.rollout_worker import RolloutWorker
- from ray.rllib.examples.policy.episode_env_aware_policy import \
- EpisodeEnvAwareAttentionPolicy, EpisodeEnvAwareLSTMPolicy
- from ray.rllib.models.tf.attention_net import GTrXLNet
- from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
- from ray.rllib.policy.view_requirement import ViewRequirement
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.test_utils import framework_iterator, check
- class MyCallbacks(DefaultCallbacks):
- @override(DefaultCallbacks)
- def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
- assert train_batch.count == 201
- assert sum(train_batch[SampleBatch.SEQ_LENS]) == 201
- for k, v in train_batch.items():
- if k in ["state_in_0", SampleBatch.SEQ_LENS]:
- assert len(v) == len(train_batch[SampleBatch.SEQ_LENS])
- else:
- assert len(v) == 201
- current = None
- for o in train_batch[SampleBatch.OBS]:
- if current:
- assert o == current + 1
- current = o
- if o == 15:
- current = None
- class TestTrajectoryViewAPI(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- ray.init()
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_traj_view_normal_case(self):
- """Tests, whether Model and Policy return the correct ViewRequirements.
- """
- config = dqn.DEFAULT_CONFIG.copy()
- config["num_envs_per_worker"] = 10
- config["rollout_fragment_length"] = 4
- for _ in framework_iterator(config):
- trainer = dqn.DQNTrainer(
- config,
- env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv")
- policy = trainer.get_policy()
- view_req_model = policy.model.view_requirements
- view_req_policy = policy.view_requirements
- assert len(view_req_model) == 1, view_req_model
- assert len(view_req_policy) == 10, view_req_policy
- for key in [
- SampleBatch.OBS,
- SampleBatch.ACTIONS,
- SampleBatch.REWARDS,
- SampleBatch.DONES,
- SampleBatch.NEXT_OBS,
- SampleBatch.EPS_ID,
- SampleBatch.AGENT_INDEX,
- "weights",
- ]:
- assert key in view_req_policy
- # None of the view cols has a special underlying data_col,
- # except next-obs.
- if key != SampleBatch.NEXT_OBS:
- assert view_req_policy[key].data_col is None
- else:
- assert view_req_policy[key].data_col == SampleBatch.OBS
- assert view_req_policy[key].shift == 1
- rollout_worker = trainer.workers.local_worker()
- sample_batch = rollout_worker.sample()
- expected_count = \
- config["num_envs_per_worker"] * \
- config["rollout_fragment_length"]
- assert sample_batch.count == expected_count
- for v in sample_batch.values():
- assert len(v) == expected_count
- trainer.stop()
- def test_traj_view_lstm_prev_actions_and_rewards(self):
- """Tests, whether Policy/Model return correct LSTM ViewRequirements.
- """
- config = ppo.DEFAULT_CONFIG.copy()
- config["model"] = config["model"].copy()
- # Activate LSTM + prev-action + rewards.
- config["model"]["use_lstm"] = True
- config["model"]["lstm_use_prev_action"] = True
- config["model"]["lstm_use_prev_reward"] = True
- for _ in framework_iterator(config):
- trainer = ppo.PPOTrainer(config, env="CartPole-v0")
- policy = trainer.get_policy()
- view_req_model = policy.model.view_requirements
- view_req_policy = policy.view_requirements
- # 7=obs, prev-a + r, 2x state-in, 2x state-out.
- assert len(view_req_model) == 7, view_req_model
- assert len(view_req_policy) == 20,\
- (len(view_req_policy), view_req_policy)
- for key in [
- SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
- SampleBatch.DONES, SampleBatch.NEXT_OBS,
- SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS,
- SampleBatch.PREV_REWARDS, "advantages", "value_targets",
- SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP
- ]:
- assert key in view_req_policy
- if key == SampleBatch.PREV_ACTIONS:
- assert view_req_policy[key].data_col == SampleBatch.ACTIONS
- assert view_req_policy[key].shift == -1
- elif key == SampleBatch.PREV_REWARDS:
- assert view_req_policy[key].data_col == SampleBatch.REWARDS
- assert view_req_policy[key].shift == -1
- elif key not in [
- SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
- SampleBatch.PREV_REWARDS
- ]:
- assert view_req_policy[key].data_col is None
- else:
- assert view_req_policy[key].data_col == SampleBatch.OBS
- assert view_req_policy[key].shift == 1
- trainer.stop()
- def test_traj_view_attention_net(self):
- config = ppo.DEFAULT_CONFIG.copy()
- # Setup attention net.
- config["model"] = config["model"].copy()
- config["model"]["max_seq_len"] = 50
- config["model"]["custom_model"] = GTrXLNet
- config["model"]["custom_model_config"] = {
- "num_transformer_units": 1,
- "attention_dim": 64,
- "num_heads": 2,
- "memory_inference": 50,
- "memory_training": 50,
- "head_dim": 32,
- "ff_hidden_dim": 32,
- }
- # Test with odd batch numbers.
- config["train_batch_size"] = 1031
- config["sgd_minibatch_size"] = 201
- config["num_sgd_iter"] = 5
- config["num_workers"] = 0
- config["callbacks"] = MyCallbacks
- config["env_config"] = {
- "config": {
- "start_at_t": 1
- }
- } # first obs is [1.0]
- for _ in framework_iterator(config, frameworks="tf2"):
- trainer = ppo.PPOTrainer(
- config,
- env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv",
- )
- rw = trainer.workers.local_worker()
- sample = rw.sample()
- assert sample.count == trainer.config["rollout_fragment_length"]
- results = trainer.train()
- assert results["timesteps_total"] == config["train_batch_size"]
- trainer.stop()
- def test_traj_view_next_action(self):
- action_space = Discrete(2)
- rollout_worker_w_api = RolloutWorker(
- env_creator=lambda _: gym.make("CartPole-v0"),
- policy_config=ppo.DEFAULT_CONFIG,
- rollout_fragment_length=200,
- policy_spec=ppo.PPOTorchPolicy,
- policy_mapping_fn=None,
- num_envs=1,
- )
- # Add the next action (a') and 2nd next action (a'') to the view
- # requirements of the policy.
- # This should be visible then in postprocessing and train batches.
- # Switch off for action computations (can't be there as we don't know
- # the next actions already at action computation time).
- rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
- "next_actions"] = ViewRequirement(
- SampleBatch.ACTIONS,
- shift=1,
- space=action_space,
- used_for_compute_actions=False)
- rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
- "2nd_next_actions"] = ViewRequirement(
- SampleBatch.ACTIONS,
- shift=2,
- space=action_space,
- used_for_compute_actions=False)
- # Make sure, we have DONEs as well.
- rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
- "dones"] = ViewRequirement()
- batch = rollout_worker_w_api.sample()
- self.assertTrue("next_actions" in batch)
- self.assertTrue("2nd_next_actions" in batch)
- expected_a_ = None # expected next action
- expected_a__ = None # expected 2nd next action
- for i in range(len(batch["actions"])):
- a, d, a_, a__ = \
- batch["actions"][i], batch["dones"][i], \
- batch["next_actions"][i], batch["2nd_next_actions"][i]
- # Episode done: next action and 2nd next action should be 0.
- if d:
- check(a_, 0)
- check(a__, 0)
- expected_a_ = None
- expected_a__ = None
- continue
- # Episode is not done and we have an expected next-a.
- if expected_a_ is not None:
- check(a, expected_a_)
- if expected_a__ is not None:
- check(a_, expected_a__)
- expected_a__ = a__
- expected_a_ = a_
- def test_traj_view_lstm_functionality(self):
- action_space = Box(float("-inf"), float("inf"), shape=(3, ))
- obs_space = Box(float("-inf"), float("inf"), (4, ))
- max_seq_len = 50
- rollout_fragment_length = 200
- assert rollout_fragment_length % max_seq_len == 0
- policies = {
- "pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
- }
- def policy_fn(agent_id, episode, **kwargs):
- return "pol0"
- config = {
- "multiagent": {
- "policies": policies,
- "policy_mapping_fn": policy_fn,
- },
- "model": {
- "use_lstm": True,
- "max_seq_len": max_seq_len,
- },
- }
- rw = RolloutWorker(
- env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
- policy_config=config,
- rollout_fragment_length=rollout_fragment_length,
- policy_spec=policies,
- policy_mapping_fn=policy_fn,
- normalize_actions=False,
- num_envs=1,
- )
- for iteration in range(20):
- result = rw.sample()
- check(result.count, rollout_fragment_length)
- pol_batch_w = result.policy_batches["pol0"]
- assert pol_batch_w.count >= rollout_fragment_length
- analyze_rnn_batch(
- pol_batch_w,
- max_seq_len,
- view_requirements=rw.policy_map["pol0"].view_requirements)
- def test_traj_view_attention_functionality(self):
- action_space = Box(float("-inf"), float("inf"), shape=(3, ))
- obs_space = Box(float("-inf"), float("inf"), (4, ))
- max_seq_len = 50
- rollout_fragment_length = 201
- policies = {
- "pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space,
- {}),
- }
- def policy_fn(agent_id, episode, **kwargs):
- return "pol0"
- config = {
- "multiagent": {
- "policies": policies,
- "policy_mapping_fn": policy_fn,
- },
- "model": {
- "max_seq_len": max_seq_len,
- },
- }
- rollout_worker_w_api = RolloutWorker(
- env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
- policy_config=config,
- rollout_fragment_length=rollout_fragment_length,
- policy_spec=policies,
- policy_mapping_fn=policy_fn,
- normalize_actions=False,
- num_envs=1,
- )
- batch = rollout_worker_w_api.sample()
- print(batch)
- def test_counting_by_agent_steps(self):
- """Test whether a PPOTrainer can be built with all frameworks."""
- config = copy.deepcopy(ppo.DEFAULT_CONFIG)
- num_agents = 3
- config["num_workers"] = 2
- config["num_sgd_iter"] = 2
- config["framework"] = "torch"
- config["rollout_fragment_length"] = 21
- config["train_batch_size"] = 147
- config["multiagent"] = {
- "policies": {f"p{i}"
- for i in range(num_agents)},
- "policy_mapping_fn": lambda aid, **kwargs: "p{}".format(aid),
- "count_steps_by": "agent_steps",
- }
- # Env setup.
- config["env"] = MultiAgentPendulum
- config["env_config"] = {"num_agents": num_agents}
- num_iterations = 2
- trainer = ppo.PPOTrainer(config=config)
- results = None
- for i in range(num_iterations):
- results = trainer.train()
- self.assertEqual(results["agent_timesteps_total"],
- results["timesteps_total"] * num_agents)
- self.assertGreaterEqual(results["agent_timesteps_total"],
- num_iterations * config["train_batch_size"])
- self.assertLessEqual(results["agent_timesteps_total"],
- (num_iterations + 1) * config["train_batch_size"])
- trainer.stop()
- def test_get_single_step_input_dict_batch_repeat_value_larger_1(self):
- """Test whether a SampleBatch produces the correct 1-step input dict.
- """
- space = Box(-1.0, 1.0, ())
- # With batch-repeat-value > 1: state_in_0 is only built every n
- # timesteps.
- view_reqs = {
- "state_in_0": ViewRequirement(
- data_col="state_out_0",
- shift="-5:-1",
- space=space,
- batch_repeat_value=5,
- ),
- "state_out_0": ViewRequirement(
- space=space, used_for_compute_actions=False),
- }
- # Trajectory of 1 ts (0) (we would like to compute the 1st).
- batch = SampleBatch({
- "state_in_0": np.array([
- [0, 0, 0, 0, 0], # ts=0
- ]),
- "state_out_0": np.array([1]),
- })
- input_dict = batch.get_single_step_input_dict(
- view_requirements=view_reqs, index="last")
- check(
- input_dict,
- {
- "state_in_0": [[0, 0, 0, 0, 1]], # ts=1
- "seq_lens": [1],
- })
- # Trajectory of 6 ts (0-5) (we would like to compute the 6th).
- batch = SampleBatch({
- "state_in_0": np.array([
- [0, 0, 0, 0, 0], # ts=0
- [1, 2, 3, 4, 5], # ts=5
- ]),
- "state_out_0": np.array([1, 2, 3, 4, 5, 6]),
- })
- input_dict = batch.get_single_step_input_dict(
- view_requirements=view_reqs, index="last")
- check(
- input_dict,
- {
- "state_in_0": [[2, 3, 4, 5, 6]], # ts=6
- "seq_lens": [1],
- })
- # Trajectory of 12 ts (0-11) (we would like to compute the 12th).
- batch = SampleBatch({
- "state_in_0": np.array([
- [0, 0, 0, 0, 0], # ts=0
- [1, 2, 3, 4, 5], # ts=5
- [6, 7, 8, 9, 10], # ts=10
- ]),
- "state_out_0": np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
- })
- input_dict = batch.get_single_step_input_dict(
- view_requirements=view_reqs, index="last")
- check(
- input_dict,
- {
- "state_in_0": [[8, 9, 10, 11, 12]], # ts=12
- "seq_lens": [1],
- })
- def test_get_single_step_input_dict_batch_repeat_value_1(self):
- """Test whether a SampleBatch produces the correct 1-step input dict.
- """
- space = Box(-1.0, 1.0, ())
- # With batch-repeat-value==1: state_in_0 is built each timestep.
- view_reqs = {
- "state_in_0": ViewRequirement(
- data_col="state_out_0",
- shift="-5:-1",
- space=space,
- batch_repeat_value=1,
- ),
- "state_out_0": ViewRequirement(
- space=space, used_for_compute_actions=False),
- }
- # Trajectory of 1 ts (0) (we would like to compute the 1st).
- batch = SampleBatch({
- "state_in_0": np.array([
- [0, 0, 0, 0, 0], # ts=0
- ]),
- "state_out_0": np.array([1]),
- })
- input_dict = batch.get_single_step_input_dict(
- view_requirements=view_reqs, index="last")
- check(
- input_dict,
- {
- "state_in_0": [[0, 0, 0, 0, 1]], # ts=1
- "seq_lens": [1],
- })
- # Trajectory of 6 ts (0-5) (we would like to compute the 6th).
- batch = SampleBatch({
- "state_in_0": np.array([
- [0, 0, 0, 0, 0], # ts=0
- [0, 0, 0, 0, 1], # ts=1
- [0, 0, 0, 1, 2], # ts=2
- [0, 0, 1, 2, 3], # ts=3
- [0, 1, 2, 3, 4], # ts=4
- [1, 2, 3, 4, 5], # ts=5
- ]),
- "state_out_0": np.array([1, 2, 3, 4, 5, 6]),
- })
- input_dict = batch.get_single_step_input_dict(
- view_requirements=view_reqs, index="last")
- check(
- input_dict,
- {
- "state_in_0": [[2, 3, 4, 5, 6]], # ts=6
- "seq_lens": [1],
- })
- # Trajectory of 12 ts (0-11) (we would like to compute the 12th).
- batch = SampleBatch({
- "state_in_0": np.array([
- [0, 0, 0, 0, 0], # ts=0
- [0, 0, 0, 0, 1], # ts=1
- [0, 0, 0, 1, 2], # ts=2
- [0, 0, 1, 2, 3], # ts=3
- [0, 1, 2, 3, 4], # ts=4
- [1, 2, 3, 4, 5], # ts=5
- [2, 3, 4, 5, 6], # ts=6
- [3, 4, 5, 6, 7], # ts=7
- [4, 5, 6, 7, 8], # ts=8
- [5, 6, 7, 8, 9], # ts=9
- [6, 7, 8, 9, 10], # ts=10
- [7, 8, 9, 10, 11], # ts=11
- ]),
- "state_out_0": np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
- })
- input_dict = batch.get_single_step_input_dict(
- view_requirements=view_reqs, index="last")
- check(
- input_dict,
- {
- "state_in_0": [[8, 9, 10, 11, 12]], # ts=12
- "seq_lens": [1],
- })
- def analyze_rnn_batch(batch, max_seq_len, view_requirements):
- count = batch.count
- # Check prev_reward/action, next_obs consistency.
- for idx in range(count):
- # If timestep tracked by batch, good.
- if "t" in batch:
- ts = batch["t"][idx]
- # Else, ts
- else:
- ts = batch["obs"][idx][3]
- obs_t = batch["obs"][idx]
- a_t = batch["actions"][idx]
- r_t = batch["rewards"][idx]
- state_in_0 = batch["state_in_0"][idx]
- state_in_1 = batch["state_in_1"][idx]
- # Check postprocessing outputs.
- if "2xobs" in batch:
- postprocessed_col_t = batch["2xobs"][idx]
- assert (obs_t == postprocessed_col_t / 2.0).all()
- # Check state-in/out and next-obs values.
- if idx > 0:
- next_obs_t_m_1 = batch["new_obs"][idx - 1]
- state_out_0_t_m_1 = batch["state_out_0"][idx - 1]
- state_out_1_t_m_1 = batch["state_out_1"][idx - 1]
- # Same trajectory as for t-1 -> Should be able to match.
- if (batch[SampleBatch.AGENT_INDEX][idx] ==
- batch[SampleBatch.AGENT_INDEX][idx - 1]
- and batch[SampleBatch.EPS_ID][idx] ==
- batch[SampleBatch.EPS_ID][idx - 1]):
- assert batch["unroll_id"][idx - 1] == batch["unroll_id"][idx]
- assert (obs_t == next_obs_t_m_1).all()
- assert (state_in_0 == state_out_0_t_m_1).all()
- assert (state_in_1 == state_out_1_t_m_1).all()
- # Different trajectory.
- else:
- assert batch["unroll_id"][idx - 1] != batch["unroll_id"][idx]
- assert not (obs_t == next_obs_t_m_1).all()
- assert not (state_in_0 == state_out_0_t_m_1).all()
- assert not (state_in_1 == state_out_1_t_m_1).all()
- # Check initial 0-internal states.
- if ts == 0:
- assert (state_in_0 == 0.0).all()
- assert (state_in_1 == 0.0).all()
- # Check initial 0-internal states (at ts=0).
- if ts == 0:
- assert (state_in_0 == 0.0).all()
- assert (state_in_1 == 0.0).all()
- # Check prev. a/r values.
- if idx < count - 1:
- prev_actions_t_p_1 = batch["prev_actions"][idx + 1]
- prev_rewards_t_p_1 = batch["prev_rewards"][idx + 1]
- # Same trajectory as for t+1 -> Should be able to match.
- if batch[SampleBatch.AGENT_INDEX][idx] == \
- batch[SampleBatch.AGENT_INDEX][idx + 1] and \
- batch[SampleBatch.EPS_ID][idx] == \
- batch[SampleBatch.EPS_ID][idx + 1]:
- assert (a_t == prev_actions_t_p_1).all()
- assert r_t == prev_rewards_t_p_1
- # Different (new) trajectory. Assume t-1 (prev-a/r) to be
- # always 0.0s. [3]=ts
- elif ts == 0:
- assert (prev_actions_t_p_1 == 0).all()
- assert prev_rewards_t_p_1 == 0.0
- pad_batch_to_sequences_of_same_size(
- batch,
- max_seq_len=max_seq_len,
- shuffle=False,
- batch_divisibility_req=1,
- view_requirements=view_requirements,
- )
- # Check after seq-len 0-padding.
- cursor = 0
- for i, seq_len in enumerate(batch[SampleBatch.SEQ_LENS]):
- state_in_0 = batch["state_in_0"][i]
- state_in_1 = batch["state_in_1"][i]
- for j in range(seq_len):
- k = cursor + j
- ts = batch["t"][k]
- obs_t = batch["obs"][k]
- a_t = batch["actions"][k]
- r_t = batch["rewards"][k]
- # Check postprocessing outputs.
- if "2xobs" in batch:
- postprocessed_col_t = batch["2xobs"][k]
- assert (obs_t == postprocessed_col_t / 2.0).all()
- # Check state-in/out and next-obs values.
- if j > 0:
- next_obs_t_m_1 = batch["new_obs"][k - 1]
- # state_out_0_t_m_1 = batch["state_out_0"][k - 1]
- # state_out_1_t_m_1 = batch["state_out_1"][k - 1]
- # Always same trajectory as for t-1.
- assert batch["unroll_id"][k - 1] == batch["unroll_id"][k]
- assert (obs_t == next_obs_t_m_1).all()
- # assert (state_in_0 == state_out_0_t_m_1).all())
- # assert (state_in_1 == state_out_1_t_m_1).all())
- # Check initial 0-internal states.
- elif ts == 0:
- assert (state_in_0 == 0.0).all()
- assert (state_in_1 == 0.0).all()
- for j in range(seq_len, max_seq_len):
- k = cursor + j
- obs_t = batch["obs"][k]
- a_t = batch["actions"][k]
- r_t = batch["rewards"][k]
- assert (obs_t == 0.0).all()
- assert (a_t == 0.0).all()
- assert (r_t == 0.0).all()
- cursor += max_seq_len
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|