test_algorithm_checkpoint_restore.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. #!/usr/bin/env python
  2. import unittest
  3. import ray
  4. from ray.rllib.algorithms.apex_ddpg import ApexDDPGConfig
  5. from ray.rllib.algorithms.sac import SACConfig
  6. from ray.rllib.algorithms.simple_q import SimpleQConfig
  7. from ray.rllib.algorithms.ppo import PPOConfig
  8. from ray.rllib.algorithms.es import ESConfig
  9. from ray.rllib.algorithms.dqn import DQNConfig
  10. from ray.rllib.algorithms.ddpg import DDPGConfig
  11. from ray.rllib.algorithms.ars import ARSConfig
  12. from ray.rllib.algorithms.a3c import A3CConfig
  13. from ray.rllib.utils.test_utils import test_ckpt_restore
  14. import os
  15. # As we transition things to RLModule API the explore=False will get
  16. # deprecated. For now, we will just not set it. The reason is that the RLModule
  17. # API has forward_exploration() method that can be overriden if user needs to
  18. # really turn of the stochasticity. This test in particular is robust to
  19. # explore=None if we compare the mean of the distribution of actions for the
  20. # same observation to be the same.
  21. algorithms_and_configs = {
  22. "A3C": (
  23. A3CConfig()
  24. .exploration(explore=False)
  25. .rollouts(num_rollout_workers=1)
  26. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  27. ),
  28. "APEX_DDPG": (
  29. ApexDDPGConfig()
  30. .exploration(explore=False)
  31. .rollouts(observation_filter="MeanStdFilter", num_rollout_workers=2)
  32. .reporting(min_time_s_per_iteration=1)
  33. .training(
  34. optimizer={"num_replay_buffer_shards": 1},
  35. num_steps_sampled_before_learning_starts=0,
  36. )
  37. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  38. ),
  39. "ARS": (
  40. ARSConfig()
  41. .exploration(explore=False)
  42. .rollouts(num_rollout_workers=2, observation_filter="MeanStdFilter")
  43. .training(num_rollouts=10, noise_size=2500000)
  44. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  45. ),
  46. "DDPG": (
  47. DDPGConfig()
  48. .exploration(explore=False)
  49. .reporting(min_sample_timesteps_per_iteration=100)
  50. .training(num_steps_sampled_before_learning_starts=0)
  51. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  52. ),
  53. "DQN": (
  54. DQNConfig()
  55. .exploration(explore=False)
  56. .training(num_steps_sampled_before_learning_starts=0)
  57. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  58. ),
  59. "ES": (
  60. ESConfig()
  61. .exploration(explore=False)
  62. .training(episodes_per_batch=10, train_batch_size=100, noise_size=2500000)
  63. .rollouts(observation_filter="MeanStdFilter", num_rollout_workers=2)
  64. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  65. ),
  66. "PPO": (
  67. # See the comment before the `algorithms_and_configs` dict.
  68. # explore is set to None for PPO in favor of RLModule API support.
  69. PPOConfig()
  70. .training(num_sgd_iter=5, train_batch_size=1000)
  71. .rollouts(num_rollout_workers=2)
  72. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  73. ),
  74. "SimpleQ": (
  75. SimpleQConfig()
  76. .exploration(explore=False)
  77. .training(num_steps_sampled_before_learning_starts=0)
  78. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  79. ),
  80. "SAC": (
  81. SACConfig()
  82. .exploration(explore=False)
  83. .training(num_steps_sampled_before_learning_starts=0)
  84. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  85. ),
  86. }
  87. class TestCheckpointRestorePG(unittest.TestCase):
  88. @classmethod
  89. def setUpClass(cls):
  90. ray.init()
  91. @classmethod
  92. def tearDownClass(cls):
  93. ray.shutdown()
  94. def test_a3c_checkpoint_restore(self):
  95. # TODO(Kourosh) A3C cannot run a restored algorithm for some reason.
  96. test_ckpt_restore(
  97. algorithms_and_configs["A3C"], "CartPole-v1", run_restored_algorithm=False
  98. )
  99. def test_ppo_checkpoint_restore(self):
  100. test_ckpt_restore(algorithms_and_configs["PPO"], "CartPole-v1")
  101. class TestCheckpointRestoreOffPolicy(unittest.TestCase):
  102. @classmethod
  103. def setUpClass(cls):
  104. ray.init()
  105. @classmethod
  106. def tearDownClass(cls):
  107. ray.shutdown()
  108. def test_apex_ddpg_checkpoint_restore(self):
  109. test_ckpt_restore(algorithms_and_configs["APEX_DDPG"], "Pendulum-v1")
  110. def test_ddpg_checkpoint_restore(self):
  111. test_ckpt_restore(
  112. algorithms_and_configs["DDPG"], "Pendulum-v1", replay_buffer=True
  113. )
  114. def test_dqn_checkpoint_restore(self):
  115. test_ckpt_restore(
  116. algorithms_and_configs["DQN"],
  117. "CartPole-v1",
  118. replay_buffer=True,
  119. )
  120. def test_sac_checkpoint_restore(self):
  121. test_ckpt_restore(
  122. algorithms_and_configs["SAC"], "Pendulum-v1", replay_buffer=True
  123. )
  124. def test_simpleq_checkpoint_restore(self):
  125. test_ckpt_restore(
  126. algorithms_and_configs["SimpleQ"], "CartPole-v1", replay_buffer=True
  127. )
  128. class TestCheckpointRestoreEvolutionAlgos(unittest.TestCase):
  129. @classmethod
  130. def setUpClass(cls):
  131. ray.init()
  132. @classmethod
  133. def tearDownClass(cls):
  134. ray.shutdown()
  135. def test_ars_checkpoint_restore(self):
  136. test_ckpt_restore(algorithms_and_configs["ARS"], "CartPole-v1")
  137. def test_es_checkpoint_restore(self):
  138. test_ckpt_restore(algorithms_and_configs["ES"], "CartPole-v1")
  139. if __name__ == "__main__":
  140. import pytest
  141. import sys
  142. # One can specify the specific TestCase class to run.
  143. # None for all unittest.TestCase classes in this file.
  144. class_ = sys.argv[1] if len(sys.argv) > 1 else None
  145. sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))