test_perf.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import gym
  2. import time
  3. import unittest
  4. import ray
  5. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  6. from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
  7. class TestPerf(unittest.TestCase):
  8. @classmethod
  9. def setUpClass(cls):
  10. ray.init(num_cpus=5)
  11. @classmethod
  12. def tearDownClass(cls):
  13. ray.shutdown()
  14. # Tested on Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
  15. # 11/23/18: Samples per second 8501.125113727468
  16. # 03/01/19: Samples per second 8610.164353268685
  17. def test_baseline_performance(self):
  18. for _ in range(20):
  19. ev = RolloutWorker(
  20. env_creator=lambda _: gym.make("CartPole-v0"),
  21. policy_spec=MockPolicy,
  22. rollout_fragment_length=100)
  23. start = time.time()
  24. count = 0
  25. while time.time() - start < 1:
  26. count += ev.sample().count
  27. print()
  28. print("Samples per second {}".format(
  29. count / (time.time() - start)))
  30. print()
  31. if __name__ == "__main__":
  32. import pytest
  33. import sys
  34. sys.exit(pytest.main(["-v", __file__]))