run.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. """Learning regression tests for RLlib (torch and tf).
  2. Runs Atari/MuJoCo benchmarks for all major algorithms.
  3. """
  4. import json
  5. import os
  6. from pathlib import Path
  7. from ray.rllib.utils.test_utils import run_learning_tests_from_yaml
  8. if __name__ == "__main__":
  9. import argparse
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument(
  12. "--smoke-test",
  13. action="store_true",
  14. default=False,
  15. help="Finish quickly for training.",
  16. )
  17. parser.add_argument(
  18. "--yaml-sub-dir",
  19. type=str,
  20. default="",
  21. help="Sub directory under yaml_files/ to look for test files.",
  22. )
  23. parser.add_argument(
  24. "--framework",
  25. type=str,
  26. default="tf",
  27. help="The framework (tf|tf2|torch) to use.",
  28. )
  29. args = parser.parse_args()
  30. assert args.yaml_sub_dir, "--yaml-sub-dir can't be empty."
  31. # Get path of this very script to look for yaml files.
  32. abs_yaml_path = os.path.join(
  33. str(Path(__file__).parent), "yaml_files", args.yaml_sub_dir
  34. )
  35. print("abs_yaml_path={}".format(abs_yaml_path))
  36. yaml_files = Path(abs_yaml_path).rglob("*.yaml")
  37. yaml_files = sorted(
  38. map(lambda path: str(path.absolute()), yaml_files), reverse=True
  39. )
  40. # Run all tests in the found yaml files.
  41. results = run_learning_tests_from_yaml(
  42. yaml_files=yaml_files,
  43. # Note(jungong) : run learning tests to full desired duration
  44. # for performance regression purpose.
  45. # Talk to jungong@ if you have questions about why we do this.
  46. use_pass_criteria_as_stop=False,
  47. smoke_test=args.smoke_test,
  48. framework=args.framework,
  49. )
  50. test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/learning_test.json")
  51. with open(test_output_json, "wt") as f:
  52. json.dump(results, f)
  53. if len(results["not_passed"]) > 0:
  54. raise ValueError(
  55. "Not all learning tests successfully learned the tasks.\n"
  56. f"Results=\n{results}"
  57. )
  58. else:
  59. print("Ok.")