test_mbmpo.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from gymnasium.wrappers import TimeLimit
  2. import unittest
  3. import ray
  4. import ray.rllib.algorithms.mbmpo as mbmpo
  5. from ray.rllib.examples.env.mbmpo_env import CartPoleWrapper
  6. from ray.rllib.utils.test_utils import (
  7. check_compute_single_action,
  8. check_train_results,
  9. framework_iterator,
  10. )
  11. from ray.tune.registry import register_env
  12. class TestMBMPO(unittest.TestCase):
  13. @classmethod
  14. def setUpClass(cls):
  15. ray.init()
  16. register_env(
  17. "cartpole-mbmpo",
  18. lambda env_ctx: TimeLimit(CartPoleWrapper(), max_episode_steps=200),
  19. )
  20. @classmethod
  21. def tearDownClass(cls):
  22. ray.shutdown()
  23. def test_mbmpo_compilation(self):
  24. """Test whether MBMPO can be built with all frameworks."""
  25. config = (
  26. mbmpo.MBMPOConfig()
  27. .environment("cartpole-mbmpo")
  28. .rollouts(num_rollout_workers=2)
  29. .training(dynamics_model={"ensemble_size": 2})
  30. )
  31. num_iterations = 1
  32. # Test for torch framework (tf not implemented yet).
  33. for _ in framework_iterator(config, frameworks="torch"):
  34. algo = config.build()
  35. for i in range(num_iterations):
  36. results = algo.train()
  37. check_train_results(results)
  38. print(results)
  39. check_compute_single_action(algo, include_prev_action_reward=False)
  40. algo.stop()
  41. if __name__ == "__main__":
  42. import pytest
  43. import sys
  44. sys.exit(pytest.main(["-v", __file__]))