test_policy.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import unittest
  2. import ray
  3. from ray.rllib.agents.dqn import DQNTrainer, DEFAULT_CONFIG
  4. from ray.rllib.utils.test_utils import check, framework_iterator
  5. class TestPolicy(unittest.TestCase):
  6. @classmethod
  7. def setUpClass(cls) -> None:
  8. ray.init()
  9. @classmethod
  10. def tearDownClass(cls) -> None:
  11. ray.shutdown()
  12. def test_policy_save_restore(self):
  13. config = DEFAULT_CONFIG.copy()
  14. for _ in framework_iterator(config):
  15. trainer = DQNTrainer(config=config, env="CartPole-v0")
  16. policy = trainer.get_policy()
  17. state1 = policy.get_state()
  18. trainer.train()
  19. state2 = policy.get_state()
  20. check(
  21. state1["_exploration_state"]["last_timestep"],
  22. state2["_exploration_state"]["last_timestep"],
  23. false=True)
  24. check(
  25. state1["global_timestep"],
  26. state2["global_timestep"],
  27. false=True)
  28. # Reset policy to its original state and compare.
  29. policy.set_state(state1)
  30. state3 = policy.get_state()
  31. # Make sure everything is the same.
  32. check(state1, state3)
  33. if __name__ == "__main__":
  34. import pytest
  35. import sys
  36. sys.exit(pytest.main(["-v", __file__]))