test_rllib_train_and_evaluate.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import os
  2. from pathlib import Path
  3. import re
  4. import sys
  5. import unittest
  6. import ray
  7. from ray import tune
  8. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
  9. from ray.rllib.utils.test_utils import framework_iterator
  10. def evaluate_test(algo, env="CartPole-v0", test_episode_rollout=False):
  11. extra_config = ""
  12. if algo == "ARS":
  13. extra_config = ",\"train_batch_size\": 10, \"noise_size\": 250000"
  14. elif algo == "ES":
  15. extra_config = ",\"episodes_per_batch\": 1,\"train_batch_size\": 10, "\
  16. "\"noise_size\": 250000"
  17. for fw in framework_iterator(frameworks=("tf", "torch")):
  18. fw_ = ", \"framework\": \"{}\"".format(fw)
  19. tmp_dir = os.popen("mktemp -d").read()[:-1]
  20. if not os.path.exists(tmp_dir):
  21. sys.exit(1)
  22. print("Saving results to {}".format(tmp_dir))
  23. rllib_dir = str(Path(__file__).parent.parent.absolute())
  24. print("RLlib dir = {}\nexists={}".format(rllib_dir,
  25. os.path.exists(rllib_dir)))
  26. os.system(
  27. "python {}/train.py --local-dir={} --run={} "
  28. "--checkpoint-freq=1 ".format(rllib_dir, tmp_dir, algo) +
  29. "--config='{" + "\"num_workers\": 1, \"num_gpus\": 0{}{}".format(
  30. fw_, extra_config) +
  31. ", \"timesteps_per_iteration\": 5,\"min_time_s_per_reporting\": 0.1, "
  32. "\"model\": {\"fcnet_hiddens\": [10]}"
  33. "}' --stop='{\"training_iteration\": 1}'" +
  34. " --env={} --no-ray-ui".format(env))
  35. checkpoint_path = os.popen("ls {}/default/*/checkpoint_000001/"
  36. "checkpoint-1".format(tmp_dir)).read()[:-1]
  37. if not os.path.exists(checkpoint_path):
  38. sys.exit(1)
  39. print("Checkpoint path {} (exists)".format(checkpoint_path))
  40. # Test rolling out n steps.
  41. os.popen("python {}/evaluate.py --run={} \"{}\" --steps=10 "
  42. "--out=\"{}/rollouts_10steps.pkl\" --no-render".format(
  43. rllib_dir, algo, checkpoint_path, tmp_dir)).read()
  44. if not os.path.exists(tmp_dir + "/rollouts_10steps.pkl"):
  45. sys.exit(1)
  46. print("evaluate output (10 steps) exists!")
  47. # Test rolling out 1 episode.
  48. if test_episode_rollout:
  49. os.popen("python {}/evaluate.py --run={} \"{}\" --episodes=1 "
  50. "--out=\"{}/rollouts_1episode.pkl\" --no-render".format(
  51. rllib_dir, algo, checkpoint_path, tmp_dir)).read()
  52. if not os.path.exists(tmp_dir + "/rollouts_1episode.pkl"):
  53. sys.exit(1)
  54. print("evaluate output (1 ep) exists!")
  55. # Cleanup.
  56. os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
  57. def learn_test_plus_evaluate(algo, env="CartPole-v0"):
  58. for fw in framework_iterator(frameworks=("tf", "torch")):
  59. fw_ = ", \\\"framework\\\": \\\"{}\\\"".format(fw)
  60. tmp_dir = os.popen("mktemp -d").read()[:-1]
  61. if not os.path.exists(tmp_dir):
  62. # Last resort: Resolve via underlying tempdir (and cut tmp_.
  63. tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
  64. if not os.path.exists(tmp_dir):
  65. sys.exit(1)
  66. print("Saving results to {}".format(tmp_dir))
  67. rllib_dir = str(Path(__file__).parent.parent.absolute())
  68. print("RLlib dir = {}\nexists={}".format(rllib_dir,
  69. os.path.exists(rllib_dir)))
  70. os.system("python {}/train.py --local-dir={} --run={} "
  71. "--checkpoint-freq=1 --checkpoint-at-end ".format(
  72. rllib_dir, tmp_dir, algo) +
  73. "--config=\"{\\\"num_gpus\\\": 0, \\\"num_workers\\\": 1, "
  74. "\\\"evaluation_config\\\": {\\\"explore\\\": false}" + fw_ +
  75. "}\" " + "--stop=\"{\\\"episode_reward_mean\\\": 100.0}\"" +
  76. " --env={}".format(env))
  77. # Find last checkpoint and use that for the rollout.
  78. checkpoint_path = os.popen("ls {}/default/*/checkpoint_*/"
  79. "checkpoint-*".format(tmp_dir)).read()[:-1]
  80. checkpoints = [
  81. cp for cp in checkpoint_path.split("\n")
  82. if re.match(r"^.+checkpoint-\d+$", cp)
  83. ]
  84. # Sort by number and pick last (which should be the best checkpoint).
  85. last_checkpoint = sorted(
  86. checkpoints,
  87. key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1)))[-1]
  88. assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint)
  89. if not os.path.exists(last_checkpoint):
  90. sys.exit(1)
  91. print("Best checkpoint={} (exists)".format(last_checkpoint))
  92. # Test rolling out n steps.
  93. result = os.popen(
  94. "python {}/evaluate.py --run={} "
  95. "--steps=400 "
  96. "--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format(
  97. rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
  98. if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
  99. sys.exit(1)
  100. print("Rollout output exists -> Checking reward ...")
  101. episodes = result.split("\n")
  102. mean_reward = 0.0
  103. num_episodes = 0
  104. for ep in episodes:
  105. mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep)
  106. if mo:
  107. mean_reward += float(mo.group(1))
  108. num_episodes += 1
  109. mean_reward /= num_episodes
  110. print("Rollout's mean episode reward={}".format(mean_reward))
  111. assert mean_reward >= 100.0
  112. # Cleanup.
  113. os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
  114. def learn_test_multi_agent_plus_evaluate(algo):
  115. for fw in framework_iterator(frameworks=("tf", "torch")):
  116. tmp_dir = os.popen("mktemp -d").read()[:-1]
  117. if not os.path.exists(tmp_dir):
  118. # Last resort: Resolve via underlying tempdir (and cut tmp_.
  119. tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
  120. if not os.path.exists(tmp_dir):
  121. sys.exit(1)
  122. print("Saving results to {}".format(tmp_dir))
  123. rllib_dir = str(Path(__file__).parent.parent.absolute())
  124. print("RLlib dir = {}\nexists={}".format(rllib_dir,
  125. os.path.exists(rllib_dir)))
  126. def policy_fn(agent_id, episode, **kwargs):
  127. return "pol{}".format(agent_id)
  128. config = {
  129. "num_gpus": 0,
  130. "num_workers": 1,
  131. "evaluation_config": {
  132. "explore": False
  133. },
  134. "framework": fw,
  135. "env": MultiAgentCartPole,
  136. "multiagent": {
  137. "policies": {"pol0", "pol1"},
  138. "policy_mapping_fn": policy_fn,
  139. },
  140. }
  141. stop = {"episode_reward_mean": 100.0}
  142. tune.run(
  143. algo,
  144. config=config,
  145. stop=stop,
  146. checkpoint_freq=1,
  147. checkpoint_at_end=True,
  148. local_dir=tmp_dir,
  149. verbose=1)
  150. # Find last checkpoint and use that for the rollout.
  151. checkpoint_path = os.popen("ls {}/PPO/*/checkpoint_*/"
  152. "checkpoint-*".format(tmp_dir)).read()[:-1]
  153. checkpoint_paths = checkpoint_path.split("\n")
  154. assert len(checkpoint_paths) > 0
  155. checkpoints = [
  156. cp for cp in checkpoint_paths
  157. if re.match(r"^.+checkpoint-\d+$", cp)
  158. ]
  159. # Sort by number and pick last (which should be the best checkpoint).
  160. last_checkpoint = sorted(
  161. checkpoints,
  162. key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1)))[-1]
  163. assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint)
  164. if not os.path.exists(last_checkpoint):
  165. sys.exit(1)
  166. print("Best checkpoint={} (exists)".format(last_checkpoint))
  167. ray.shutdown()
  168. # Test rolling out n steps.
  169. result = os.popen(
  170. "python {}/evaluate.py --run={} "
  171. "--steps=400 "
  172. "--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format(
  173. rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
  174. if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
  175. sys.exit(1)
  176. print("Rollout output exists -> Checking reward ...")
  177. episodes = result.split("\n")
  178. mean_reward = 0.0
  179. num_episodes = 0
  180. for ep in episodes:
  181. mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep)
  182. if mo:
  183. mean_reward += float(mo.group(1))
  184. num_episodes += 1
  185. mean_reward /= num_episodes
  186. print("Rollout's mean episode reward={}".format(mean_reward))
  187. assert mean_reward >= 100.0
  188. # Cleanup.
  189. os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
  190. class TestEvaluate1(unittest.TestCase):
  191. def test_a3c(self):
  192. evaluate_test("A3C")
  193. def test_ddpg(self):
  194. evaluate_test("DDPG", env="Pendulum-v1")
  195. class TestEvaluate2(unittest.TestCase):
  196. def test_dqn(self):
  197. evaluate_test("DQN")
  198. def test_es(self):
  199. evaluate_test("ES")
  200. class TestEvaluate3(unittest.TestCase):
  201. def test_impala(self):
  202. evaluate_test("IMPALA", env="CartPole-v0")
  203. def test_ppo(self):
  204. evaluate_test("PPO", env="CartPole-v0", test_episode_rollout=True)
  205. class TestEvaluate4(unittest.TestCase):
  206. def test_sac(self):
  207. evaluate_test("SAC", env="Pendulum-v1")
  208. class TestTrainAndEvaluate(unittest.TestCase):
  209. def test_ppo_train_then_rollout(self):
  210. learn_test_plus_evaluate("PPO")
  211. def test_ppo_multi_agent_train_then_rollout(self):
  212. learn_test_multi_agent_plus_evaluate("PPO")
  213. if __name__ == "__main__":
  214. import pytest
  215. # One can specify the specific TestCase class to run.
  216. # None for all unittest.TestCase classes in this file.
  217. class_ = sys.argv[1] if len(sys.argv) > 1 else None
  218. sys.exit(
  219. pytest.main(
  220. ["-v", __file__ + ("" if class_ is None else "::" + class_)]))