test_dqn.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from copy import deepcopy
  2. import numpy as np
  3. import unittest
  4. import ray
  5. import ray.rllib.algorithms.dqn as dqn
  6. from ray.rllib.utils.test_utils import (
  7. check,
  8. check_compute_single_action,
  9. check_train_results,
  10. framework_iterator,
  11. )
  12. class TestDQN(unittest.TestCase):
  13. @classmethod
  14. def setUpClass(cls) -> None:
  15. ray.init()
  16. @classmethod
  17. def tearDownClass(cls) -> None:
  18. ray.shutdown()
  19. def test_dqn_compilation(self):
  20. """Test whether DQN can be built on all frameworks."""
  21. num_iterations = 1
  22. config = (
  23. dqn.dqn.DQNConfig()
  24. .environment("CartPole-v1")
  25. .rollouts(num_rollout_workers=2)
  26. .training(num_steps_sampled_before_learning_starts=0)
  27. )
  28. for _ in framework_iterator(config):
  29. # Double-dueling DQN.
  30. print("Double-dueling")
  31. algo = config.build()
  32. for i in range(num_iterations):
  33. results = algo.train()
  34. check_train_results(results)
  35. print(results)
  36. check_compute_single_action(algo)
  37. algo.stop()
  38. # Rainbow.
  39. print("Rainbow")
  40. rainbow_config = deepcopy(config).training(
  41. num_atoms=10, noisy=True, double_q=True, dueling=True, n_step=5
  42. )
  43. algo = rainbow_config.build()
  44. for i in range(num_iterations):
  45. results = algo.train()
  46. check_train_results(results)
  47. print(results)
  48. check_compute_single_action(algo)
  49. algo.stop()
  50. def test_dqn_compilation_integer_rewards(self):
  51. """Test whether DQN can be built on all frameworks.
  52. Unlike the previous test, this uses an environment with integer rewards
  53. in order to test that type conversions are working correctly."""
  54. num_iterations = 1
  55. config = (
  56. dqn.dqn.DQNConfig()
  57. .environment("Taxi-v3")
  58. .rollouts(num_rollout_workers=2)
  59. .training(num_steps_sampled_before_learning_starts=0)
  60. )
  61. for _ in framework_iterator(config):
  62. # Double-dueling DQN.
  63. print("Double-dueling")
  64. algo = config.build()
  65. for i in range(num_iterations):
  66. results = algo.train()
  67. check_train_results(results)
  68. print(results)
  69. check_compute_single_action(algo)
  70. algo.stop()
  71. # Rainbow.
  72. print("Rainbow")
  73. rainbow_config = deepcopy(config).training(
  74. num_atoms=10, noisy=True, double_q=True, dueling=True, n_step=5
  75. )
  76. algo = rainbow_config.build()
  77. for i in range(num_iterations):
  78. results = algo.train()
  79. check_train_results(results)
  80. print(results)
  81. check_compute_single_action(algo)
  82. algo.stop()
  83. def test_dqn_exploration_and_soft_q_config(self):
  84. """Tests, whether a DQN Agent outputs exploration/softmaxed actions."""
  85. config = (
  86. dqn.dqn.DQNConfig()
  87. .environment("FrozenLake-v1")
  88. .rollouts(num_rollout_workers=0)
  89. .environment(env_config={"is_slippery": False, "map_name": "4x4"})
  90. ).training(num_steps_sampled_before_learning_starts=0)
  91. obs = np.array(0)
  92. # Test against all frameworks.
  93. for _ in framework_iterator(config):
  94. # Default EpsilonGreedy setup.
  95. algo = config.build()
  96. # Setting explore=False should always return the same action.
  97. a_ = algo.compute_single_action(obs, explore=False)
  98. for _ in range(50):
  99. a = algo.compute_single_action(obs, explore=False)
  100. check(a, a_)
  101. # explore=None (default: explore) should return different actions.
  102. actions = []
  103. for _ in range(50):
  104. actions.append(algo.compute_single_action(obs))
  105. check(np.std(actions), 0.0, false=True)
  106. algo.stop()
  107. # Low softmax temperature. Behaves like argmax
  108. # (but no epsilon exploration).
  109. config.exploration(
  110. exploration_config={"type": "SoftQ", "temperature": 0.000001}
  111. )
  112. algo = config.build()
  113. # Due to the low temp, always expect the same action.
  114. actions = [algo.compute_single_action(obs)]
  115. for _ in range(50):
  116. actions.append(algo.compute_single_action(obs))
  117. check(np.std(actions), 0.0, decimals=3)
  118. algo.stop()
  119. # Higher softmax temperature.
  120. config.exploration_config["temperature"] = 1.0
  121. algo = config.build()
  122. # Even with the higher temperature, if we set explore=False, we
  123. # should expect the same actions always.
  124. a_ = algo.compute_single_action(obs, explore=False)
  125. for _ in range(50):
  126. a = algo.compute_single_action(obs, explore=False)
  127. check(a, a_)
  128. # Due to the higher temp, expect different actions avg'ing
  129. # around 1.5.
  130. actions = []
  131. for _ in range(300):
  132. actions.append(algo.compute_single_action(obs))
  133. check(np.std(actions), 0.0, false=True)
  134. algo.stop()
  135. # With Random exploration.
  136. config.exploration(exploration_config={"type": "Random"}, explore=True)
  137. algo = config.build()
  138. actions = []
  139. for _ in range(300):
  140. actions.append(algo.compute_single_action(obs))
  141. check(np.std(actions), 0.0, false=True)
  142. algo.stop()
  143. if __name__ == "__main__":
  144. import pytest
  145. import sys
  146. sys.exit(pytest.main(["-v", __file__]))