123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450 |
- import gym
- import numpy as np
- import random
- import unittest
- import ray
- from ray.tune.registry import register_env
- from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
- from ray.rllib.agents.pg import PGTrainer
- from ray.rllib.evaluation.episode import Episode
- from ray.rllib.evaluation.rollout_worker import get_global_worker
- from ray.rllib.examples.policy.random_policy import RandomPolicy
- from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
- BasicMultiAgent, EarlyDoneMultiAgent, FlexAgentsMultiAgent, \
- RoundRobinMultiAgent
- from ray.rllib.evaluation.rollout_worker import RolloutWorker
- from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
- from ray.rllib.env.multi_agent_env import MultiAgentEnvWrapper
- from ray.rllib.policy.policy import PolicySpec
- from ray.rllib.utils.numpy import one_hot
- from ray.rllib.utils.test_utils import check
- class TestMultiAgentEnv(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- ray.init(num_cpus=4)
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_basic_mock(self):
- env = BasicMultiAgent(4)
- obs = env.reset()
- self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
- for _ in range(24):
- obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
- self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
- self.assertEqual(rew, {0: 1, 1: 1, 2: 1, 3: 1})
- self.assertEqual(done, {
- 0: False,
- 1: False,
- 2: False,
- 3: False,
- "__all__": False
- })
- obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
- self.assertEqual(done, {
- 0: True,
- 1: True,
- 2: True,
- 3: True,
- "__all__": True
- })
- def test_round_robin_mock(self):
- env = RoundRobinMultiAgent(2)
- obs = env.reset()
- self.assertEqual(obs, {0: 0})
- for _ in range(5):
- obs, rew, done, info = env.step({0: 0})
- self.assertEqual(obs, {1: 0})
- self.assertEqual(done["__all__"], False)
- obs, rew, done, info = env.step({1: 0})
- self.assertEqual(obs, {0: 0})
- self.assertEqual(done["__all__"], False)
- obs, rew, done, info = env.step({0: 0})
- self.assertEqual(done["__all__"], True)
- def test_no_reset_until_poll(self):
- env = MultiAgentEnvWrapper(lambda v: BasicMultiAgent(2), [], 1)
- self.assertFalse(env.get_sub_environments()[0].resetted)
- env.poll()
- self.assertTrue(env.get_sub_environments()[0].resetted)
- def test_vectorize_basic(self):
- env = MultiAgentEnvWrapper(lambda v: BasicMultiAgent(2), [], 2)
- obs, rew, dones, _, _ = env.poll()
- self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
- self.assertEqual(rew, {0: {}, 1: {}})
- self.assertEqual(dones, {
- 0: {
- "__all__": False
- },
- 1: {
- "__all__": False
- },
- })
- for _ in range(24):
- env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
- obs, rew, dones, _, _ = env.poll()
- self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
- self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
- self.assertEqual(
- dones, {
- 0: {
- 0: False,
- 1: False,
- "__all__": False
- },
- 1: {
- 0: False,
- 1: False,
- "__all__": False
- }
- })
- env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
- obs, rew, dones, _, _ = env.poll()
- self.assertEqual(
- dones, {
- 0: {
- 0: True,
- 1: True,
- "__all__": True
- },
- 1: {
- 0: True,
- 1: True,
- "__all__": True
- }
- })
- # Reset processing
- self.assertRaises(
- ValueError, lambda: env.send_actions({
- 0: {
- 0: 0,
- 1: 0
- },
- 1: {
- 0: 0,
- 1: 0
- }
- }))
- self.assertEqual(env.try_reset(0), {0: {0: 0, 1: 0}})
- self.assertEqual(env.try_reset(1), {1: {0: 0, 1: 0}})
- env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
- obs, rew, dones, _, _ = env.poll()
- self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
- self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
- self.assertEqual(
- dones, {
- 0: {
- 0: False,
- 1: False,
- "__all__": False
- },
- 1: {
- 0: False,
- 1: False,
- "__all__": False
- }
- })
- def test_vectorize_round_robin(self):
- env = MultiAgentEnvWrapper(lambda v: RoundRobinMultiAgent(2), [], 2)
- obs, rew, dones, _, _ = env.poll()
- self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
- self.assertEqual(rew, {0: {}, 1: {}})
- env.send_actions({0: {0: 0}, 1: {0: 0}})
- obs, rew, dones, _, _ = env.poll()
- self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}})
- env.send_actions({0: {1: 0}, 1: {1: 0}})
- obs, rew, dones, _, _ = env.poll()
- self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
- def test_multi_agent_sample(self):
- def policy_mapping_fn(agent_id, episode, worker, **kwargs):
- return "p{}".format(agent_id % 2)
- ev = RolloutWorker(
- env_creator=lambda _: BasicMultiAgent(5),
- policy_spec={
- "p0": PolicySpec(policy_class=MockPolicy),
- "p1": PolicySpec(policy_class=MockPolicy),
- },
- policy_mapping_fn=policy_mapping_fn,
- rollout_fragment_length=50)
- batch = ev.sample()
- self.assertEqual(batch.count, 50)
- self.assertEqual(batch.policy_batches["p0"].count, 150)
- self.assertEqual(batch.policy_batches["p1"].count, 100)
- self.assertEqual(batch.policy_batches["p0"]["t"].tolist(),
- list(range(25)) * 6)
- def test_multi_agent_sample_sync_remote(self):
- ev = RolloutWorker(
- env_creator=lambda _: BasicMultiAgent(5),
- policy_spec={
- "p0": PolicySpec(policy_class=MockPolicy),
- "p1": PolicySpec(policy_class=MockPolicy),
- },
- # This signature will raise a soft-deprecation warning due
- # to the new signature we are using (agent_id, episode, **kwargs),
- # but should not break this test.
- policy_mapping_fn=(lambda agent_id: "p{}".format(agent_id % 2)),
- rollout_fragment_length=50,
- num_envs=4,
- remote_worker_envs=True,
- remote_env_batch_wait_ms=99999999)
- batch = ev.sample()
- self.assertEqual(batch.count, 200)
- def test_multi_agent_sample_async_remote(self):
- ev = RolloutWorker(
- env_creator=lambda _: BasicMultiAgent(5),
- policy_spec={
- "p0": PolicySpec(policy_class=MockPolicy),
- "p1": PolicySpec(policy_class=MockPolicy),
- },
- policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
- rollout_fragment_length=50,
- num_envs=4,
- remote_worker_envs=True)
- batch = ev.sample()
- self.assertEqual(batch.count, 200)
- def test_multi_agent_sample_with_horizon(self):
- ev = RolloutWorker(
- env_creator=lambda _: BasicMultiAgent(5),
- policy_spec={
- "p0": PolicySpec(policy_class=MockPolicy),
- "p1": PolicySpec(policy_class=MockPolicy),
- },
- policy_mapping_fn=(lambda aid, **kwarg: "p{}".format(aid % 2)),
- episode_horizon=10, # test with episode horizon set
- rollout_fragment_length=50)
- batch = ev.sample()
- self.assertEqual(batch.count, 50)
- def test_sample_from_early_done_env(self):
- ev = RolloutWorker(
- env_creator=lambda _: EarlyDoneMultiAgent(),
- policy_spec={
- "p0": PolicySpec(policy_class=MockPolicy),
- "p1": PolicySpec(policy_class=MockPolicy),
- },
- policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
- batch_mode="complete_episodes",
- rollout_fragment_length=1)
- # This used to raise an Error due to the EarlyDoneMultiAgent
- # terminating at e.g. agent0 w/o publishing the observation for
- # agent1 anymore. This limitation is fixed and an env may
- # terminate at any time (as well as return rewards for any agent
- # at any time, even when that agent doesn't have an obs returned
- # in the same call to `step()`).
- ma_batch = ev.sample()
- # Make sure that agents took the correct (alternating timesteps)
- # path. Except for the last timestep, where both agents got
- # terminated.
- ag0_ts = ma_batch.policy_batches["p0"]["t"]
- ag1_ts = ma_batch.policy_batches["p1"]["t"]
- self.assertTrue(np.all(np.abs(ag0_ts[:-1] - ag1_ts[:-1]) == 1.0))
- self.assertTrue(ag0_ts[-1] == ag1_ts[-1])
- def test_multi_agent_with_flex_agents(self):
- register_env("flex_agents_multi_agent_cartpole",
- lambda _: FlexAgentsMultiAgent())
- pg = PGTrainer(
- env="flex_agents_multi_agent_cartpole",
- config={
- "num_workers": 0,
- "framework": "tf",
- })
- for i in range(10):
- result = pg.train()
- print("Iteration {}, reward {}, timesteps {}".format(
- i, result["episode_reward_mean"], result["timesteps_total"]))
- def test_multi_agent_sample_round_robin(self):
- ev = RolloutWorker(
- env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
- policy_spec={
- "p0": PolicySpec(policy_class=MockPolicy),
- },
- policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
- rollout_fragment_length=50)
- batch = ev.sample()
- self.assertEqual(batch.count, 50)
- # since we round robin introduce agents into the env, some of the env
- # steps don't count as proper transitions
- self.assertEqual(batch.policy_batches["p0"].count, 42)
- check(batch.policy_batches["p0"]["obs"][:10],
- one_hot(np.array([0, 1, 2, 3, 4] * 2), 10))
- check(batch.policy_batches["p0"]["new_obs"][:10],
- one_hot(np.array([1, 2, 3, 4, 5] * 2), 10))
- self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10],
- [100, 100, 100, 100, 0] * 2)
- self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10],
- [False, False, False, False, True] * 2)
- self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10],
- [4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
- def test_custom_rnn_state_values(self):
- h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}}
- class StatefulPolicy(RandomPolicy):
- def compute_actions(self,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None,
- episodes=None,
- explore=True,
- timestep=None,
- **kwargs):
- return [0] * len(obs_batch), [[h] * len(obs_batch)], {}
- def get_initial_state(self):
- return [{}] # empty dict
- ev = RolloutWorker(
- env_creator=lambda _: gym.make("CartPole-v0"),
- policy_spec=StatefulPolicy,
- rollout_fragment_length=5)
- batch = ev.sample()
- self.assertEqual(batch.count, 5)
- self.assertEqual(batch["state_in_0"][0], {})
- self.assertEqual(batch["state_out_0"][0], h)
- self.assertEqual(batch["state_in_0"][1], h)
- self.assertEqual(batch["state_out_0"][1], h)
- def test_returning_model_based_rollouts_data(self):
- class ModelBasedPolicy(DQNTFPolicy):
- def compute_actions_from_input_dict(self,
- input_dict,
- explore=None,
- timestep=None,
- episodes=None,
- **kwargs):
- obs_batch = input_dict["obs"]
- # In policy loss initialization phase, no episodes are passed
- # in.
- if episodes is not None:
- # Pretend we did a model-based rollout and want to return
- # the extra trajectory.
- env_id = episodes[0].env_id
- fake_eps = Episode(episodes[0].policy_map,
- episodes[0].policy_mapping_fn,
- lambda: None, lambda x: None, env_id)
- builder = get_global_worker().sampler.sample_collector
- agent_id = "extra_0"
- policy_id = "p1" # use p1 so we can easily check it
- builder.add_init_obs(fake_eps, agent_id, env_id, policy_id,
- -1, obs_batch[0])
- for t in range(4):
- builder.add_action_reward_next_obs(
- episode_id=fake_eps.episode_id,
- agent_id=agent_id,
- env_id=env_id,
- policy_id=policy_id,
- agent_done=t == 3,
- values=dict(
- t=t,
- actions=0,
- rewards=0,
- dones=t == 3,
- infos={},
- new_obs=obs_batch[0]))
- batch = builder.postprocess_episode(
- episode=fake_eps, build=True)
- episodes[0].add_extra_batch(batch)
- # Just return zeros for actions
- return [0] * len(obs_batch), [], {}
- ev = RolloutWorker(
- env_creator=lambda _: MultiAgentCartPole({"num_agents": 2}),
- policy_spec={
- "p0": PolicySpec(policy_class=ModelBasedPolicy),
- "p1": PolicySpec(policy_class=ModelBasedPolicy),
- },
- policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
- rollout_fragment_length=5)
- batch = ev.sample()
- # 5 environment steps (rollout_fragment_length).
- self.assertEqual(batch.count, 5)
- # 10 agent steps for p0: 2 agents, both using p0 as their policy.
- self.assertEqual(batch.policy_batches["p0"].count, 10)
- # 20 agent steps for p1: Each time both(!) agents takes 1 step,
- # p1 takes 4: 5 (rollout-fragment length) * 4 = 20
- self.assertEqual(batch.policy_batches["p1"].count, 20)
- def test_train_multi_agent_cartpole_single_policy(self):
- n = 10
- register_env("multi_agent_cartpole",
- lambda _: MultiAgentCartPole({"num_agents": n}))
- pg = PGTrainer(
- env="multi_agent_cartpole",
- config={
- "num_workers": 0,
- "framework": "tf",
- })
- for i in range(50):
- result = pg.train()
- print("Iteration {}, reward {}, timesteps {}".format(
- i, result["episode_reward_mean"], result["timesteps_total"]))
- if result["episode_reward_mean"] >= 50 * n:
- return
- raise Exception("failed to improve reward")
- def test_train_multi_agent_cartpole_multi_policy(self):
- n = 10
- register_env("multi_agent_cartpole",
- lambda _: MultiAgentCartPole({"num_agents": n}))
- def gen_policy():
- config = {
- "gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
- "n_step": random.choice([1, 2, 3, 4, 5]),
- }
- return PolicySpec(config=config)
- pg = PGTrainer(
- env="multi_agent_cartpole",
- config={
- "num_workers": 0,
- "multiagent": {
- "policies": {
- "policy_1": gen_policy(),
- "policy_2": gen_policy(),
- },
- "policy_mapping_fn": lambda aid, **kwargs: "policy_1",
- },
- "framework": "tf",
- })
- # Just check that it runs without crashing
- for i in range(10):
- result = pg.train()
- print("Iteration {}, reward {}, timesteps {}".format(
- i, result["episode_reward_mean"], result["timesteps_total"]))
- self.assertTrue(
- pg.compute_single_action([0, 0, 0, 0], policy_id="policy_1") in
- [0, 1])
- self.assertTrue(
- pg.compute_single_action([0, 0, 0, 0], policy_id="policy_2") in
- [0, 1])
- self.assertRaisesRegex(
- KeyError,
- "not found in PolicyMap",
- lambda: pg.compute_single_action(
- [0, 0, 0, 0], policy_id="policy_3"))
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|