test_a3c.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import unittest
  2. from rllib_a3c.a3c import A3CConfig
  3. import ray
  4. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  5. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
  6. from ray.rllib.utils.test_utils import (
  7. check_compute_single_action,
  8. check_train_results,
  9. framework_iterator,
  10. )
  11. class TestA3C(unittest.TestCase):
  12. """Sanity tests for A2C exec impl."""
  13. def setUp(self):
  14. ray.init(num_cpus=4)
  15. def tearDown(self):
  16. ray.shutdown()
  17. def test_a3c_compilation(self):
  18. """Test whether an A3C can be built with both frameworks."""
  19. config = A3CConfig().rollouts(num_rollout_workers=2, num_envs_per_worker=2)
  20. num_iterations = 2
  21. # Test against all frameworks.
  22. for _ in framework_iterator(config):
  23. config.eager_tracing = False
  24. for env in ["CartPole-v1", "Pendulum-v1"]:
  25. print("env={}".format(env))
  26. config.model["use_lstm"] = env == "CartPole-v1"
  27. algo = config.build(env=env)
  28. for i in range(num_iterations):
  29. results = algo.train()
  30. check_train_results(results)
  31. print(results)
  32. check_compute_single_action(
  33. algo, include_state=config.model["use_lstm"]
  34. )
  35. algo.stop()
  36. def test_a3c_entropy_coeff_schedule(self):
  37. """Test A3C entropy coeff schedule support."""
  38. config = A3CConfig().rollouts(
  39. num_rollout_workers=1,
  40. num_envs_per_worker=1,
  41. batch_mode="truncate_episodes",
  42. rollout_fragment_length=10,
  43. )
  44. # Initial entropy coeff, doesn't really matter because of the schedule below.
  45. config.training(
  46. train_batch_size=20,
  47. entropy_coeff=0.01,
  48. entropy_coeff_schedule=[
  49. [0, 0.01],
  50. [120, 0.0001],
  51. ],
  52. )
  53. # 0 metrics reporting delay, this makes sure timestep,
  54. # which entropy coeff depends on, is updated after each worker rollout.
  55. config.reporting(
  56. min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=20
  57. )
  58. def _step_n_times(trainer, n: int):
  59. """Step trainer n times.
  60. Returns:
  61. learning rate at the end of the execution.
  62. """
  63. for _ in range(n):
  64. results = trainer.train()
  65. return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
  66. "entropy_coeff"
  67. ]
  68. # Test against all frameworks.
  69. for _ in framework_iterator(config, frameworks=("torch", "tf")):
  70. config.eager_tracing = False
  71. algo = config.build(env="CartPole-v1")
  72. coeff = _step_n_times(algo, 1) # 20 timesteps
  73. # Should be close to the starting coeff of 0.01
  74. self.assertGreaterEqual(coeff, 0.005)
  75. coeff = _step_n_times(algo, 10) # 200 timesteps
  76. # Should have annealed to the final coeff of 0.0001.
  77. self.assertLessEqual(coeff, 0.00011)
  78. algo.stop()
  79. if __name__ == "__main__":
  80. import sys
  81. import pytest
  82. sys.exit(pytest.main(["-v", __file__]))