123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- #!/usr/bin/env python
- import os
- import shutil
- import unittest
- import ray
- from ray.rllib.agents.registry import get_trainer_class
- from ray.rllib.utils.framework import try_import_tf
- from ray.tune.trial import ExportFormat
- tf1, tf, tfv = try_import_tf()
- CONFIGS = {
- "A3C": {
- "explore": False,
- "num_workers": 1,
- },
- "APEX_DDPG": {
- "explore": False,
- "observation_filter": "MeanStdFilter",
- "num_workers": 2,
- "min_iter_time_s": 1,
- "optimizer": {
- "num_replay_buffer_shards": 1,
- },
- },
- "ARS": {
- "explore": False,
- "num_rollouts": 10,
- "num_workers": 2,
- "noise_size": 2500000,
- "observation_filter": "MeanStdFilter",
- },
- "DDPG": {
- "explore": False,
- "timesteps_per_iteration": 100,
- },
- "DQN": {
- "explore": False,
- },
- "ES": {
- "explore": False,
- "episodes_per_batch": 10,
- "train_batch_size": 100,
- "num_workers": 2,
- "noise_size": 2500000,
- "observation_filter": "MeanStdFilter",
- },
- "PPO": {
- "explore": False,
- "num_sgd_iter": 5,
- "train_batch_size": 1000,
- "num_workers": 2,
- },
- "SAC": {
- "explore": False,
- },
- }
- def export_test(alg_name, failures, framework="tf"):
- def valid_tf_model(model_dir):
- return os.path.exists(os.path.join(model_dir, "saved_model.pb")) \
- and os.listdir(os.path.join(model_dir, "variables"))
- def valid_tf_checkpoint(checkpoint_dir):
- return os.path.exists(os.path.join(checkpoint_dir, "model.meta")) \
- and os.path.exists(os.path.join(checkpoint_dir, "model.index")) \
- and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
- cls = get_trainer_class(alg_name)
- config = CONFIGS[alg_name].copy()
- config["framework"] = framework
- if "DDPG" in alg_name or "SAC" in alg_name:
- algo = cls(config=config, env="Pendulum-v1")
- else:
- algo = cls(config=config, env="CartPole-v0")
- for _ in range(1):
- res = algo.train()
- print("current status: " + str(res))
- export_dir = os.path.join(ray._private.utils.get_user_temp_dir(),
- "export_dir_%s" % alg_name)
- print("Exporting model ", alg_name, export_dir)
- algo.export_policy_model(export_dir)
- if framework == "tf" and not valid_tf_model(export_dir):
- failures.append(alg_name)
- shutil.rmtree(export_dir)
- if framework == "tf":
- print("Exporting checkpoint", alg_name, export_dir)
- algo.export_policy_checkpoint(export_dir)
- if framework == "tf" and not valid_tf_checkpoint(export_dir):
- failures.append(alg_name)
- shutil.rmtree(export_dir)
- print("Exporting default policy", alg_name, export_dir)
- algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL],
- export_dir)
- if not valid_tf_model(os.path.join(export_dir, ExportFormat.MODEL)) \
- or not valid_tf_checkpoint(
- os.path.join(export_dir, ExportFormat.CHECKPOINT)):
- failures.append(alg_name)
- # Test loading the exported model.
- model = tf.saved_model.load(
- os.path.join(export_dir, ExportFormat.MODEL))
- assert model
- shutil.rmtree(export_dir)
- algo.stop()
- class TestExport(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- ray.init(num_cpus=4)
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_export_a3c(self):
- failures = []
- export_test("A3C", failures, "tf")
- assert not failures, failures
- def test_export_ddpg(self):
- failures = []
- export_test("DDPG", failures, "tf")
- assert not failures, failures
- def test_export_dqn(self):
- failures = []
- export_test("DQN", failures, "tf")
- assert not failures, failures
- def test_export_ppo(self):
- failures = []
- export_test("PPO", failures, "torch")
- export_test("PPO", failures, "tf")
- assert not failures, failures
- def test_export_sac(self):
- failures = []
- export_test("SAC", failures, "tf")
- assert not failures, failures
- print("All export tests passed!")
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|