test_mbmpo.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import unittest
  2. import ray
  3. import ray.rllib.agents.mbmpo as mbmpo
  4. from ray.rllib.utils.test_utils import check_compute_single_action, \
  5. check_train_results, framework_iterator
  6. class TestMBMPO(unittest.TestCase):
  7. @classmethod
  8. def setUpClass(cls):
  9. ray.init()
  10. @classmethod
  11. def tearDownClass(cls):
  12. ray.shutdown()
  13. def test_mbmpo_compilation(self):
  14. """Test whether an MBMPOTrainer can be built with all frameworks."""
  15. config = mbmpo.DEFAULT_CONFIG.copy()
  16. config["num_workers"] = 2
  17. config["horizon"] = 200
  18. config["dynamics_model"]["ensemble_size"] = 2
  19. num_iterations = 1
  20. # Test for torch framework (tf not implemented yet).
  21. for _ in framework_iterator(config, frameworks="torch"):
  22. trainer = mbmpo.MBMPOTrainer(
  23. config=config,
  24. env="ray.rllib.examples.env.mbmpo_env.CartPoleWrapper")
  25. for i in range(num_iterations):
  26. results = trainer.train()
  27. check_train_results(results)
  28. print(results)
  29. check_compute_single_action(
  30. trainer, include_prev_action_reward=False)
  31. trainer.stop()
  32. if __name__ == "__main__":
  33. import pytest
  34. import sys
  35. sys.exit(pytest.main(["-v", __file__]))