test_timesteps.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import numpy as np
  2. import unittest
  3. import ray
  4. import ray.rllib.agents.pg as pg
  5. from ray.rllib.examples.env.random_env import RandomEnv
  6. from ray.rllib.utils.test_utils import framework_iterator
  7. class TestTimeSteps(unittest.TestCase):
  8. @classmethod
  9. def setUpClass(cls):
  10. ray.init()
  11. @classmethod
  12. def tearDownClass(cls):
  13. ray.shutdown()
  14. def test_timesteps(self):
  15. """Test whether a PGTrainer can be built with both frameworks."""
  16. config = pg.DEFAULT_CONFIG.copy()
  17. config["num_workers"] = 0 # Run locally.
  18. config["model"]["fcnet_hiddens"] = [1]
  19. config["model"]["fcnet_activation"] = None
  20. obs = np.array(1)
  21. obs_one_hot = np.array([[0.0, 1.0]])
  22. for _ in framework_iterator(config):
  23. trainer = pg.PGTrainer(config=config, env=RandomEnv)
  24. policy = trainer.get_policy()
  25. for i in range(1, 21):
  26. trainer.compute_single_action(obs)
  27. self.assertEqual(policy.global_timestep, i)
  28. for i in range(1, 21):
  29. policy.compute_actions(obs_one_hot)
  30. self.assertEqual(policy.global_timestep, i + 20)
  31. # Artificially set ts to 100Bio, then keep computing actions and
  32. # train.
  33. crazy_timesteps = int(1e11)
  34. policy.global_timestep = crazy_timesteps
  35. # Run for 10 more ts.
  36. for i in range(1, 11):
  37. policy.compute_actions(obs_one_hot)
  38. self.assertEqual(policy.global_timestep, i + crazy_timesteps)
  39. trainer.train()
  40. if __name__ == "__main__":
  41. import pytest
  42. import sys
  43. sys.exit(pytest.main(["-v", __file__]))