test_export.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #!/usr/bin/env python
  2. import os
  3. import shutil
  4. import unittest
  5. import ray
  6. from ray.rllib.agents.registry import get_trainer_class
  7. from ray.rllib.utils.framework import try_import_tf
  8. from ray.tune.trial import ExportFormat
  9. tf1, tf, tfv = try_import_tf()
  10. CONFIGS = {
  11. "A3C": {
  12. "explore": False,
  13. "num_workers": 1,
  14. },
  15. "APEX_DDPG": {
  16. "explore": False,
  17. "observation_filter": "MeanStdFilter",
  18. "num_workers": 2,
  19. "min_iter_time_s": 1,
  20. "optimizer": {
  21. "num_replay_buffer_shards": 1,
  22. },
  23. },
  24. "ARS": {
  25. "explore": False,
  26. "num_rollouts": 10,
  27. "num_workers": 2,
  28. "noise_size": 2500000,
  29. "observation_filter": "MeanStdFilter",
  30. },
  31. "DDPG": {
  32. "explore": False,
  33. "timesteps_per_iteration": 100,
  34. },
  35. "DQN": {
  36. "explore": False,
  37. },
  38. "ES": {
  39. "explore": False,
  40. "episodes_per_batch": 10,
  41. "train_batch_size": 100,
  42. "num_workers": 2,
  43. "noise_size": 2500000,
  44. "observation_filter": "MeanStdFilter",
  45. },
  46. "PPO": {
  47. "explore": False,
  48. "num_sgd_iter": 5,
  49. "train_batch_size": 1000,
  50. "num_workers": 2,
  51. },
  52. "SAC": {
  53. "explore": False,
  54. },
  55. }
  56. def export_test(alg_name, failures, framework="tf"):
  57. def valid_tf_model(model_dir):
  58. return os.path.exists(os.path.join(model_dir, "saved_model.pb")) \
  59. and os.listdir(os.path.join(model_dir, "variables"))
  60. def valid_tf_checkpoint(checkpoint_dir):
  61. return os.path.exists(os.path.join(checkpoint_dir, "model.meta")) \
  62. and os.path.exists(os.path.join(checkpoint_dir, "model.index")) \
  63. and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
  64. cls = get_trainer_class(alg_name)
  65. config = CONFIGS[alg_name].copy()
  66. config["framework"] = framework
  67. if "DDPG" in alg_name or "SAC" in alg_name:
  68. algo = cls(config=config, env="Pendulum-v1")
  69. else:
  70. algo = cls(config=config, env="CartPole-v0")
  71. for _ in range(1):
  72. res = algo.train()
  73. print("current status: " + str(res))
  74. export_dir = os.path.join(ray._private.utils.get_user_temp_dir(),
  75. "export_dir_%s" % alg_name)
  76. print("Exporting model ", alg_name, export_dir)
  77. algo.export_policy_model(export_dir)
  78. if framework == "tf" and not valid_tf_model(export_dir):
  79. failures.append(alg_name)
  80. shutil.rmtree(export_dir)
  81. if framework == "tf":
  82. print("Exporting checkpoint", alg_name, export_dir)
  83. algo.export_policy_checkpoint(export_dir)
  84. if framework == "tf" and not valid_tf_checkpoint(export_dir):
  85. failures.append(alg_name)
  86. shutil.rmtree(export_dir)
  87. print("Exporting default policy", alg_name, export_dir)
  88. algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL],
  89. export_dir)
  90. if not valid_tf_model(os.path.join(export_dir, ExportFormat.MODEL)) \
  91. or not valid_tf_checkpoint(
  92. os.path.join(export_dir, ExportFormat.CHECKPOINT)):
  93. failures.append(alg_name)
  94. # Test loading the exported model.
  95. model = tf.saved_model.load(
  96. os.path.join(export_dir, ExportFormat.MODEL))
  97. assert model
  98. shutil.rmtree(export_dir)
  99. algo.stop()
  100. class TestExport(unittest.TestCase):
  101. @classmethod
  102. def setUpClass(cls) -> None:
  103. ray.init(num_cpus=4)
  104. @classmethod
  105. def tearDownClass(cls) -> None:
  106. ray.shutdown()
  107. def test_export_a3c(self):
  108. failures = []
  109. export_test("A3C", failures, "tf")
  110. assert not failures, failures
  111. def test_export_ddpg(self):
  112. failures = []
  113. export_test("DDPG", failures, "tf")
  114. assert not failures, failures
  115. def test_export_dqn(self):
  116. failures = []
  117. export_test("DQN", failures, "tf")
  118. assert not failures, failures
  119. def test_export_ppo(self):
  120. failures = []
  121. export_test("PPO", failures, "torch")
  122. export_test("PPO", failures, "tf")
  123. assert not failures, failures
  124. def test_export_sac(self):
  125. failures = []
  126. export_test("SAC", failures, "tf")
  127. assert not failures, failures
  128. print("All export tests passed!")
  129. if __name__ == "__main__":
  130. import pytest
  131. import sys
  132. sys.exit(pytest.main(["-v", __file__]))