123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- import unittest
- import ray
- from ray import air
- from ray import tune
- from ray.rllib.utils.framework import try_import_tf
- from ray.tune.registry import get_trainable_cls
- tf1, tf, tfv = try_import_tf()
- def check_support(alg, config, test_eager=False, test_trace=True):
- config["framework"] = "tf2"
- config["log_level"] = "ERROR"
- # Test both continuous and discrete actions.
- for cont in [True, False]:
- if cont and alg in ["DQN", "APEX", "SimpleQ"]:
- continue
- elif not cont and alg in ["DDPG", "APEX_DDPG", "TD3"]:
- continue
- if cont:
- config["env"] = "Pendulum-v1"
- else:
- config["env"] = "CartPole-v1"
- a = get_trainable_cls(alg)
- if test_eager:
- print("tf-eager: alg={} cont.act={}".format(alg, cont))
- config["eager_tracing"] = False
- tune.Tuner(
- a,
- param_space=config,
- run_config=air.RunConfig(stop={"training_iteration": 1}, verbose=1),
- ).fit()
- if test_trace:
- config["eager_tracing"] = True
- print("tf-eager-tracing: alg={} cont.act={}".format(alg, cont))
- tune.Tuner(
- a,
- param_space=config,
- run_config=air.RunConfig(stop={"training_iteration": 1}, verbose=1),
- ).fit()
- class TestEagerSupportPG(unittest.TestCase):
- def setUp(self):
- ray.init(num_cpus=4)
- def tearDown(self):
- ray.shutdown()
- def test_simple_q(self):
- check_support(
- "SimpleQ",
- {
- "num_workers": 0,
- "num_steps_sampled_before_learning_starts": 0,
- },
- )
- def test_dqn(self):
- check_support(
- "DQN",
- {
- "num_workers": 0,
- "num_steps_sampled_before_learning_starts": 0,
- },
- )
- def test_ddpg(self):
- check_support("DDPG", {"num_workers": 0})
- # TODO(sven): Add these once APEX_DDPG supports eager.
- # def test_apex_ddpg(self):
- # check_support("APEX_DDPG", {"num_workers": 1})
- def test_td3(self):
- check_support("TD3", {"num_workers": 0})
- def test_a2c(self):
- check_support("A2C", {"num_workers": 0})
- def test_a3c(self):
- check_support("A3C", {"num_workers": 1})
- def test_pg(self):
- check_support("PG", {"num_workers": 0})
- def test_ppo(self):
- check_support("PPO", {"num_workers": 0})
- def test_appo(self):
- check_support("APPO", {"num_workers": 1, "num_gpus": 0})
- def test_impala(self):
- check_support("IMPALA", {"num_workers": 1, "num_gpus": 0}, test_eager=True)
- class TestEagerSupportOffPolicy(unittest.TestCase):
- def setUp(self):
- ray.init(num_cpus=4)
- def tearDown(self):
- ray.shutdown()
- def test_simple_q(self):
- check_support(
- "SimpleQ",
- {
- "num_workers": 0,
- "replay_buffer_config": {"num_steps_sampled_before_learning_starts": 0},
- },
- )
- def test_dqn(self):
- check_support(
- "DQN",
- {
- "num_workers": 0,
- "num_steps_sampled_before_learning_starts": 0,
- },
- )
- def test_ddpg(self):
- check_support("DDPG", {"num_workers": 0})
- # def test_apex_ddpg(self):
- # check_support("APEX_DDPG", {"num_workers": 1})
- def test_td3(self):
- check_support("TD3", {"num_workers": 0})
- def test_apex_dqn(self):
- check_support(
- "APEX",
- {
- "num_workers": 2,
- "replay_buffer_config": {"num_steps_sampled_before_learning_starts": 0},
- "num_gpus": 0,
- "min_time_s_per_iteration": 1,
- "min_sample_timesteps_per_iteration": 100,
- "optimizer": {
- "num_replay_buffer_shards": 1,
- },
- },
- )
- def test_sac(self):
- check_support(
- "SAC",
- {
- "num_workers": 0,
- "num_steps_sampled_before_learning_starts": 0,
- },
- )
- if __name__ == "__main__":
- import sys
- # Don't test anything for version 2.x (all tests are eager anyways).
- # TODO: (sven) remove entire file in the future.
- if tfv == 2:
- print("\tskip due to tf==2.x")
- sys.exit(0)
- # One can specify the specific TestCase class to run.
- # None for all unittest.TestCase classes in this file.
- import pytest
- class_ = sys.argv[1] if len(sys.argv) > 1 else None
- sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|