test_eager_support.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import unittest
  2. import ray
  3. from ray import air
  4. from ray import tune
  5. from ray.rllib.utils.framework import try_import_tf
  6. from ray.tune.registry import get_trainable_cls
  7. tf1, tf, tfv = try_import_tf()
  8. def check_support(alg, config, test_eager=False, test_trace=True):
  9. config["framework"] = "tf2"
  10. config["log_level"] = "ERROR"
  11. # Test both continuous and discrete actions.
  12. for cont in [True, False]:
  13. if cont and alg in ["DQN", "APEX", "SimpleQ"]:
  14. continue
  15. elif not cont and alg in ["DDPG", "APEX_DDPG", "TD3"]:
  16. continue
  17. if cont:
  18. config["env"] = "Pendulum-v1"
  19. else:
  20. config["env"] = "CartPole-v1"
  21. a = get_trainable_cls(alg)
  22. if test_eager:
  23. print("tf-eager: alg={} cont.act={}".format(alg, cont))
  24. config["eager_tracing"] = False
  25. tune.Tuner(
  26. a,
  27. param_space=config,
  28. run_config=air.RunConfig(stop={"training_iteration": 1}, verbose=1),
  29. ).fit()
  30. if test_trace:
  31. config["eager_tracing"] = True
  32. print("tf-eager-tracing: alg={} cont.act={}".format(alg, cont))
  33. tune.Tuner(
  34. a,
  35. param_space=config,
  36. run_config=air.RunConfig(stop={"training_iteration": 1}, verbose=1),
  37. ).fit()
  38. class TestEagerSupportPG(unittest.TestCase):
  39. def setUp(self):
  40. ray.init(num_cpus=4)
  41. def tearDown(self):
  42. ray.shutdown()
  43. def test_simple_q(self):
  44. check_support(
  45. "SimpleQ",
  46. {
  47. "num_workers": 0,
  48. "num_steps_sampled_before_learning_starts": 0,
  49. },
  50. )
  51. def test_dqn(self):
  52. check_support(
  53. "DQN",
  54. {
  55. "num_workers": 0,
  56. "num_steps_sampled_before_learning_starts": 0,
  57. },
  58. )
  59. def test_ddpg(self):
  60. check_support("DDPG", {"num_workers": 0})
  61. # TODO(sven): Add these once APEX_DDPG supports eager.
  62. # def test_apex_ddpg(self):
  63. # check_support("APEX_DDPG", {"num_workers": 1})
  64. def test_td3(self):
  65. check_support("TD3", {"num_workers": 0})
  66. def test_a2c(self):
  67. check_support("A2C", {"num_workers": 0})
  68. def test_a3c(self):
  69. check_support("A3C", {"num_workers": 1})
  70. def test_pg(self):
  71. check_support("PG", {"num_workers": 0})
  72. def test_ppo(self):
  73. check_support("PPO", {"num_workers": 0})
  74. def test_appo(self):
  75. check_support("APPO", {"num_workers": 1, "num_gpus": 0})
  76. def test_impala(self):
  77. check_support("IMPALA", {"num_workers": 1, "num_gpus": 0}, test_eager=True)
  78. class TestEagerSupportOffPolicy(unittest.TestCase):
  79. def setUp(self):
  80. ray.init(num_cpus=4)
  81. def tearDown(self):
  82. ray.shutdown()
  83. def test_simple_q(self):
  84. check_support(
  85. "SimpleQ",
  86. {
  87. "num_workers": 0,
  88. "replay_buffer_config": {"num_steps_sampled_before_learning_starts": 0},
  89. },
  90. )
  91. def test_dqn(self):
  92. check_support(
  93. "DQN",
  94. {
  95. "num_workers": 0,
  96. "num_steps_sampled_before_learning_starts": 0,
  97. },
  98. )
  99. def test_ddpg(self):
  100. check_support("DDPG", {"num_workers": 0})
  101. # def test_apex_ddpg(self):
  102. # check_support("APEX_DDPG", {"num_workers": 1})
  103. def test_td3(self):
  104. check_support("TD3", {"num_workers": 0})
  105. def test_apex_dqn(self):
  106. check_support(
  107. "APEX",
  108. {
  109. "num_workers": 2,
  110. "replay_buffer_config": {"num_steps_sampled_before_learning_starts": 0},
  111. "num_gpus": 0,
  112. "min_time_s_per_iteration": 1,
  113. "min_sample_timesteps_per_iteration": 100,
  114. "optimizer": {
  115. "num_replay_buffer_shards": 1,
  116. },
  117. },
  118. )
  119. def test_sac(self):
  120. check_support(
  121. "SAC",
  122. {
  123. "num_workers": 0,
  124. "num_steps_sampled_before_learning_starts": 0,
  125. },
  126. )
  127. if __name__ == "__main__":
  128. import sys
  129. # Don't test anything for version 2.x (all tests are eager anyways).
  130. # TODO: (sven) remove entire file in the future.
  131. if tfv == 2:
  132. print("\tskip due to tf==2.x")
  133. sys.exit(0)
  134. # One can specify the specific TestCase class to run.
  135. # None for all unittest.TestCase classes in this file.
  136. import pytest
  137. class_ = sys.argv[1] if len(sys.argv) > 1 else None
  138. sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))