123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- #!/usr/bin/env python
- import unittest
- import ray
- from ray.rllib.algorithms.apex_ddpg import ApexDDPGConfig
- from ray.rllib.algorithms.sac import SACConfig
- from ray.rllib.algorithms.simple_q import SimpleQConfig
- from ray.rllib.algorithms.ppo import PPOConfig
- from ray.rllib.algorithms.es import ESConfig
- from ray.rllib.algorithms.dqn import DQNConfig
- from ray.rllib.algorithms.ddpg import DDPGConfig
- from ray.rllib.algorithms.ars import ARSConfig
- from ray.rllib.algorithms.a3c import A3CConfig
- from ray.rllib.utils.test_utils import test_ckpt_restore
- import os
- # As we transition things to RLModule API the explore=False will get
- # deprecated. For now, we will just not set it. The reason is that the RLModule
- # API has forward_exploration() method that can be overriden if user needs to
- # really turn of the stochasticity. This test in particular is robust to
- # explore=None if we compare the mean of the distribution of actions for the
- # same observation to be the same.
- algorithms_and_configs = {
- "A3C": (
- A3CConfig()
- .exploration(explore=False)
- .rollouts(num_rollout_workers=1)
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- ),
- "APEX_DDPG": (
- ApexDDPGConfig()
- .exploration(explore=False)
- .rollouts(observation_filter="MeanStdFilter", num_rollout_workers=2)
- .reporting(min_time_s_per_iteration=1)
- .training(
- optimizer={"num_replay_buffer_shards": 1},
- num_steps_sampled_before_learning_starts=0,
- )
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- ),
- "ARS": (
- ARSConfig()
- .exploration(explore=False)
- .rollouts(num_rollout_workers=2, observation_filter="MeanStdFilter")
- .training(num_rollouts=10, noise_size=2500000)
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- ),
- "DDPG": (
- DDPGConfig()
- .exploration(explore=False)
- .reporting(min_sample_timesteps_per_iteration=100)
- .training(num_steps_sampled_before_learning_starts=0)
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- ),
- "DQN": (
- DQNConfig()
- .exploration(explore=False)
- .training(num_steps_sampled_before_learning_starts=0)
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- ),
- "ES": (
- ESConfig()
- .exploration(explore=False)
- .training(episodes_per_batch=10, train_batch_size=100, noise_size=2500000)
- .rollouts(observation_filter="MeanStdFilter", num_rollout_workers=2)
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- ),
- "PPO": (
- # See the comment before the `algorithms_and_configs` dict.
- # explore is set to None for PPO in favor of RLModule API support.
- PPOConfig()
- .training(num_sgd_iter=5, train_batch_size=1000)
- .rollouts(num_rollout_workers=2)
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- ),
- "SimpleQ": (
- SimpleQConfig()
- .exploration(explore=False)
- .training(num_steps_sampled_before_learning_starts=0)
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- ),
- "SAC": (
- SACConfig()
- .exploration(explore=False)
- .training(num_steps_sampled_before_learning_starts=0)
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- ),
- }
- class TestCheckpointRestorePG(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- ray.init()
- @classmethod
- def tearDownClass(cls):
- ray.shutdown()
- def test_a3c_checkpoint_restore(self):
- # TODO(Kourosh) A3C cannot run a restored algorithm for some reason.
- test_ckpt_restore(
- algorithms_and_configs["A3C"], "CartPole-v1", run_restored_algorithm=False
- )
- def test_ppo_checkpoint_restore(self):
- test_ckpt_restore(algorithms_and_configs["PPO"], "CartPole-v1")
- class TestCheckpointRestoreOffPolicy(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- ray.init()
- @classmethod
- def tearDownClass(cls):
- ray.shutdown()
- def test_apex_ddpg_checkpoint_restore(self):
- test_ckpt_restore(algorithms_and_configs["APEX_DDPG"], "Pendulum-v1")
- def test_ddpg_checkpoint_restore(self):
- test_ckpt_restore(
- algorithms_and_configs["DDPG"], "Pendulum-v1", replay_buffer=True
- )
- def test_dqn_checkpoint_restore(self):
- test_ckpt_restore(
- algorithms_and_configs["DQN"],
- "CartPole-v1",
- replay_buffer=True,
- )
- def test_sac_checkpoint_restore(self):
- test_ckpt_restore(
- algorithms_and_configs["SAC"], "Pendulum-v1", replay_buffer=True
- )
- def test_simpleq_checkpoint_restore(self):
- test_ckpt_restore(
- algorithms_and_configs["SimpleQ"], "CartPole-v1", replay_buffer=True
- )
- class TestCheckpointRestoreEvolutionAlgos(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- ray.init()
- @classmethod
- def tearDownClass(cls):
- ray.shutdown()
- def test_ars_checkpoint_restore(self):
- test_ckpt_restore(algorithms_and_configs["ARS"], "CartPole-v1")
- def test_es_checkpoint_restore(self):
- test_ckpt_restore(algorithms_and_configs["ES"], "CartPole-v1")
- if __name__ == "__main__":
- import pytest
- import sys
- # One can specify the specific TestCase class to run.
- # None for all unittest.TestCase classes in this file.
- class_ = sys.argv[1] if len(sys.argv) > 1 else None
- sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|