test_repro_dqn.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import unittest
  2. import os
  3. import ray
  4. from ray.tune import register_env
  5. import ray.rllib.algorithms.dqn as dqn
  6. from ray.rllib.examples.env.deterministic_envs import create_cartpole_deterministic
  7. from ray.rllib.utils.test_utils import check_reproducibilty
  8. class TestReproDQN(unittest.TestCase):
  9. @classmethod
  10. def setUpClass(cls):
  11. ray.init()
  12. @classmethod
  13. def tearDownClass(cls):
  14. ray.shutdown()
  15. def test_reproducibility_dqn_cartpole(self):
  16. """Tests whether the algorithm is reproducible within 3 iterations
  17. on discrete env cartpole."""
  18. register_env("DeterministicCartPole-v1", create_cartpole_deterministic)
  19. config = dqn.DQNConfig().environment(
  20. env="DeterministicCartPole-v1", env_config={"seed": 42}
  21. )
  22. # tf-gpu is excluded for determinism
  23. # reason: https://github.com/tensorflow/tensorflow/issues/2732
  24. # https://github.com/tensorflow/tensorflow/issues/2652
  25. frameworks = ["torch"]
  26. if int(os.environ.get("RLLIB_NUM_GPUS", 0)) == 0:
  27. frameworks.append("tf")
  28. check_reproducibilty(
  29. algo_class=dqn.DQN,
  30. algo_config=config,
  31. fw_kwargs={"frameworks": frameworks},
  32. training_iteration=3,
  33. )
  34. if __name__ == "__main__":
  35. import pytest
  36. import sys
  37. sys.exit(pytest.main(["-v", __file__]))