evaluation.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import argparse
  2. import json
  3. import os
  4. import traceback
  5. from collections import Counter
  6. from rich import print
  7. from swebench import (
  8. KEY_INSTANCE_ID,
  9. KEY_MODEL,
  10. KEY_PREDICTION,
  11. get_eval_report,
  12. get_logs_eval,
  13. get_model_report,
  14. get_resolution_status,
  15. run_evaluation,
  16. get_eval_refs,
  17. )
  18. from swebench.harness.constants import (
  19. INSTALL_FAIL,
  20. )
  21. from unidiff import PatchSet
  22. def main(
  23. predictions_path,
  24. log_dir,
  25. swe_bench_tasks,
  26. testbed,
  27. skip_existing,
  28. timeout,
  29. verbose,
  30. conda_link,
  31. log_suffix,
  32. num_processes,
  33. ):
  34. # Check if paths exist
  35. if not os.path.exists(predictions_path):
  36. raise FileNotFoundError(f"Predictions path {predictions_path} does not exist")
  37. eval_refs = get_eval_refs(swe_bench_tasks)
  38. for k, v in eval_refs.items():
  39. eval_refs[k] = {key: v[key] for key in [KEY_INSTANCE_ID, "FAIL_TO_PASS", "PASS_TO_PASS"]}
  40. # Change model_name_or_patch field to directory name for all predictions
  41. directory = os.path.dirname(predictions_path)
  42. directory_name = directory.rsplit("/", 1)[-1]
  43. pred_path_orig = predictions_path
  44. pred_path_temp = predictions_path.replace(".jsonl", "_filtered.jsonl")
  45. pred_total, pred_will_eval = 0, 0
  46. with open(pred_path_temp, "w") as f:
  47. for l in open(pred_path_orig, "r").readlines():
  48. pred_total += 1
  49. p = json.loads(l)
  50. # Exclude predictions w/ empty strings
  51. if p[KEY_PREDICTION] is not None and p[KEY_PREDICTION].strip() != "":
  52. p[KEY_MODEL] = directory_name
  53. json.dump(p, f)
  54. f.write("\n")
  55. pred_will_eval += 1
  56. print(
  57. f"Found {pred_total} total predictions, will evaluate {pred_will_eval} ({pred_total-pred_will_eval} are empty)"
  58. )
  59. # Run evaluation
  60. predictions_path = pred_path_temp
  61. try:
  62. print("🏃 Beginning evaluation...")
  63. run_evaluation(
  64. predictions_path=predictions_path,
  65. log_dir=log_dir,
  66. swe_bench_tasks=swe_bench_tasks,
  67. testbed=testbed,
  68. skip_existing=skip_existing,
  69. timeout=timeout,
  70. verbose=verbose,
  71. conda_link=conda_link,
  72. log_suffix=log_suffix,
  73. num_processes=num_processes,
  74. )
  75. print("✅ Finished evaluation")
  76. except Exception as e:
  77. print(f"❌ Evaluation failed: {e}\n{traceback.format_exc()}")
  78. print("==================================")
  79. os.remove(pred_path_temp)
  80. # Get predictions, define log_dir
  81. predictions = [json.loads(l) for l in open(pred_path_orig, "r").readlines()]
  82. log_dir = os.path.join(log_dir, directory_name)
  83. print(f"Log directory for evaluation run: {log_dir}")
  84. # Iterate through predictions
  85. scorecards = []
  86. for p in predictions:
  87. scorecard = {KEY_INSTANCE_ID: p[KEY_INSTANCE_ID], "statuses": [], "stats": {}}
  88. # Add trajectory statistics if traj_path exists
  89. traj_path = os.path.join(directory, f"{p[KEY_INSTANCE_ID]}.traj")
  90. if os.path.exists(traj_path):
  91. traj_data = json.load(open(traj_path, "r"))
  92. scorecard["stats"]["traj_num_steps"] = len(traj_data["trajectory"])
  93. scorecard["stats"]["traj_action_dist"] = dict(
  94. Counter(
  95. [
  96. entry["action"].strip().split()[0]
  97. if entry["role"] == "assistant" and "action" in entry and len(entry["action"]) > 0
  98. else None
  99. for entry in traj_data["history"]
  100. ]
  101. )
  102. )
  103. scorecard["exit_status"] = traj_data["info"]["exit_status"] if "exit_status" in traj_data["info"] else "n/a"
  104. # Check that a prediction was generated
  105. if p[KEY_PREDICTION] is None or p[KEY_PREDICTION].strip() == "":
  106. scorecard["statuses"].append("not_generated")
  107. scorecards.append(scorecard)
  108. continue
  109. scorecard["statuses"].append("generated")
  110. # Get log file
  111. log_path = os.path.join(log_dir, f"{p[KEY_INSTANCE_ID]}.{directory_name}.eval.log")
  112. if not os.path.exists(log_path):
  113. scorecard["statuses"].append("build_failure")
  114. scorecards.append(scorecard)
  115. continue
  116. # Get evaluation logs
  117. eval_sm, found = get_logs_eval(log_path)
  118. # Check that the prediction generated
  119. if not found:
  120. scorecards.append(scorecard)
  121. continue
  122. scorecard["statuses"].append("applied")
  123. with open(log_path, "r") as f:
  124. log_contents = f.read()
  125. if INSTALL_FAIL in log_contents:
  126. scorecard["statuses"].append("install_fail")
  127. # Get resolution status
  128. report = get_eval_report(eval_sm, eval_refs[p[KEY_INSTANCE_ID]])
  129. scorecard["test_results"] = {
  130. "failure": {
  131. "FAIL_TO_PASS": report["FAIL_TO_PASS"]["failure"],
  132. "PASS_TO_PASS": report["PASS_TO_PASS"]["failure"],
  133. },
  134. "success": {
  135. "FAIL_TO_PASS": report["FAIL_TO_PASS"]["success"],
  136. "PASS_TO_PASS": report["PASS_TO_PASS"]["success"],
  137. },
  138. }
  139. resolution_status = get_resolution_status(report)
  140. scorecard["statuses"].append(resolution_status)
  141. try:
  142. diff_obj = PatchSet(p[KEY_PREDICTION])
  143. scorecard["patch_files"] = [
  144. x.path for x in diff_obj.modified_files + diff_obj.added_files + diff_obj.removed_files
  145. ]
  146. scorecard["patch_lines_add"] = sum([f.added for f in diff_obj])
  147. scorecard["patch_lines_del"] = sum([f.removed for f in diff_obj])
  148. except Exception as e:
  149. print(f"[{p[KEY_INSTANCE_ID]}] Error parsing prediction diff: {e}")
  150. scorecard["patch_files"] = []
  151. scorecard["patch_lines_add"] = 0
  152. scorecard["patch_lines_del"] = 0
  153. scorecards.append(scorecard)
  154. # Save to summary, scorecard json
  155. path_scorecards = os.path.join(directory, "scorecards.json")
  156. with open(path_scorecards, "w") as f:
  157. json.dump(scorecards, fp=f, indent=2)
  158. print(f"- Wrote per-instance scorecards to {path_scorecards}")
  159. # Get results and write to file
  160. print(f"Reference Report:")
  161. report = get_model_report(directory_name, pred_path_orig, swe_bench_tasks, log_dir)
  162. for k, v in report.items():
  163. print(f"- {k}: {len(v)}")
  164. path_results = os.path.join(directory, "results.json")
  165. with open(path_results, "w") as f:
  166. json.dump(report, f, indent=2)
  167. print(f"- Wrote summary of run to {path_results}")
  168. if __name__ == "__main__":
  169. # Parse arguments
  170. parser = argparse.ArgumentParser()
  171. parser.add_argument(
  172. "--predictions_path",
  173. type=str,
  174. help="Path to predictions file (.jsonl)",
  175. required=True,
  176. )
  177. parser.add_argument("--log_dir", type=str, help="Path to log directory", required=True)
  178. parser.add_argument(
  179. "--swe_bench_tasks",
  180. type=str,
  181. help="Path to SWE-bench task instances file",
  182. required=True,
  183. )
  184. parser.add_argument("--testbed", type=str, help="Path to testbed directory", required=True)
  185. parser.add_argument("--skip_existing", action="store_true", help="(Optional) Skip existing logs")
  186. parser.add_argument(
  187. "--timeout",
  188. type=int,
  189. help="(Optional) Timeout in seconds (default: 900)",
  190. default=900,
  191. )
  192. parser.add_argument("--verbose", action="store_true", help="(Optional) Verbose mode")
  193. parser.add_argument(
  194. "--conda_link",
  195. default=None,
  196. type=str,
  197. help="(Optional) URL to conda installation to use",
  198. )
  199. parser.add_argument("--log_suffix", default=None, type=str, help="(Optional) Log suffix")
  200. parser.add_argument("--num_processes", default=-1, type=int, help="Num processes")
  201. args = parser.parse_args()
  202. main(**vars(args))