test_exec_api.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import unittest
  2. import ray
  3. from ray.rllib.agents.a3c import A2CTrainer
  4. from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
  5. STEPS_TRAINED_COUNTER
  6. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
  7. from ray.rllib.utils.test_utils import framework_iterator
  8. class TestDistributedExecution(unittest.TestCase):
  9. """General tests for the distributed execution API."""
  10. @classmethod
  11. def setUpClass(cls):
  12. ray.init(num_cpus=4)
  13. @classmethod
  14. def tearDownClass(cls):
  15. ray.shutdown()
  16. def test_exec_plan_stats(ray_start_regular):
  17. for fw in framework_iterator(frameworks=("torch", "tf")):
  18. trainer = A2CTrainer(
  19. env="CartPole-v0",
  20. config={
  21. "min_time_s_per_reporting": 0,
  22. "framework": fw,
  23. })
  24. result = trainer.train()
  25. assert isinstance(result, dict)
  26. assert "info" in result
  27. assert LEARNER_INFO in result["info"]
  28. assert STEPS_SAMPLED_COUNTER in result["info"]
  29. assert STEPS_TRAINED_COUNTER in result["info"]
  30. assert "timers" in result
  31. assert "learn_time_ms" in result["timers"]
  32. assert "learn_throughput" in result["timers"]
  33. assert "sample_time_ms" in result["timers"]
  34. assert "sample_throughput" in result["timers"]
  35. assert "update_time_ms" in result["timers"]
  36. def test_exec_plan_save_restore(ray_start_regular):
  37. for fw in framework_iterator(frameworks=("torch", "tf")):
  38. trainer = A2CTrainer(
  39. env="CartPole-v0",
  40. config={
  41. "min_time_s_per_reporting": 0,
  42. "framework": fw,
  43. })
  44. res1 = trainer.train()
  45. checkpoint = trainer.save()
  46. for _ in range(2):
  47. res2 = trainer.train()
  48. assert res2["timesteps_total"] > res1["timesteps_total"], \
  49. (res1, res2)
  50. trainer.restore(checkpoint)
  51. # Should restore the timesteps counter to the same as res2.
  52. res3 = trainer.train()
  53. assert res3["timesteps_total"] < res2["timesteps_total"], \
  54. (res2, res3)
  55. if __name__ == "__main__":
  56. import pytest
  57. import sys
  58. sys.exit(pytest.main(["-v", __file__]))