123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import unittest
- import ray
- from ray.rllib.agents.registry import get_trainer_class
- from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
- MultiAgentMountainCar
- from ray.rllib.policy.policy import PolicySpec
- from ray.rllib.utils.test_utils import check_train_results, \
- framework_iterator
- from ray.tune import register_env
- def check_support_multiagent(alg, config):
- register_env("multi_agent_mountaincar",
- lambda _: MultiAgentMountainCar({"num_agents": 2}))
- register_env("multi_agent_cartpole",
- lambda _: MultiAgentCartPole({"num_agents": 2}))
- # Simulate a simple multi-agent setup.
- policies = {
- "policy_0": PolicySpec(config={"gamma": 0.99}),
- "policy_1": PolicySpec(config={"gamma": 0.95}),
- }
- policy_ids = list(policies.keys())
- def policy_mapping_fn(agent_id, episode, worker, **kwargs):
- pol_id = policy_ids[agent_id]
- return pol_id
- config["multiagent"] = {
- "policies": policies,
- "policy_mapping_fn": policy_mapping_fn,
- }
- for fw in framework_iterator(config):
- if fw in ["tf2", "tfe"] and \
- alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]:
- continue
- if alg in ["DDPG", "APEX_DDPG", "SAC"]:
- a = get_trainer_class(alg)(
- config=config, env="multi_agent_mountaincar")
- else:
- a = get_trainer_class(alg)(
- config=config, env="multi_agent_cartpole")
- results = a.train()
- check_train_results(results)
- print(results)
- a.stop()
- class TestSupportedMultiAgentPG(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- ray.init(num_cpus=4)
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_a3c_multiagent(self):
- check_support_multiagent("A3C", {
- "num_workers": 1,
- "optimizer": {
- "grads_per_step": 1
- }
- })
- def test_impala_multiagent(self):
- check_support_multiagent("IMPALA", {"num_gpus": 0})
- def test_pg_multiagent(self):
- check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}})
- def test_ppo_multiagent(self):
- check_support_multiagent(
- "PPO", {
- "num_workers": 1,
- "num_sgd_iter": 1,
- "train_batch_size": 10,
- "rollout_fragment_length": 10,
- "sgd_minibatch_size": 1,
- })
- class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- ray.init(num_cpus=6)
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_apex_multiagent(self):
- check_support_multiagent(
- "APEX", {
- "num_workers": 2,
- "timesteps_per_iteration": 100,
- "num_gpus": 0,
- "buffer_size": 1000,
- "min_time_s_per_reporting": 1,
- "learning_starts": 10,
- "target_network_update_freq": 100,
- "optimizer": {
- "num_replay_buffer_shards": 1,
- },
- })
- def test_apex_ddpg_multiagent(self):
- check_support_multiagent(
- "APEX_DDPG", {
- "num_workers": 2,
- "timesteps_per_iteration": 100,
- "buffer_size": 1000,
- "num_gpus": 0,
- "min_time_s_per_reporting": 1,
- "learning_starts": 10,
- "target_network_update_freq": 100,
- "use_state_preprocessor": True,
- })
- def test_ddpg_multiagent(self):
- check_support_multiagent(
- "DDPG", {
- "timesteps_per_iteration": 1,
- "buffer_size": 1000,
- "use_state_preprocessor": True,
- "learning_starts": 500,
- })
- def test_dqn_multiagent(self):
- check_support_multiagent("DQN", {
- "timesteps_per_iteration": 1,
- "buffer_size": 1000,
- })
- def test_sac_multiagent(self):
- check_support_multiagent("SAC", {
- "num_workers": 0,
- "buffer_size": 1000,
- "normalize_actions": False,
- })
- if __name__ == "__main__":
- import pytest
- import sys
- # One can specify the specific TestCase class to run.
- # None for all unittest.TestCase classes in this file.
- class_ = sys.argv[1] if len(sys.argv) > 1 else None
- sys.exit(
- pytest.main(
- ["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|