evaluation.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. from __future__ import annotations
  2. import argparse
  3. import json
  4. import os
  5. import traceback
  6. from collections import Counter
  7. from pathlib import Path
  8. from rich import print
  9. from swebench import (
  10. KEY_INSTANCE_ID,
  11. KEY_MODEL,
  12. KEY_PREDICTION,
  13. get_eval_refs,
  14. get_eval_report,
  15. get_logs_eval,
  16. get_model_report,
  17. get_resolution_status,
  18. run_evaluation,
  19. )
  20. from swebench.harness.constants import (
  21. INSTALL_FAIL,
  22. )
  23. from unidiff import PatchSet
  24. def main(
  25. predictions_path,
  26. log_dir,
  27. swe_bench_tasks,
  28. testbed,
  29. skip_existing,
  30. timeout,
  31. verbose,
  32. conda_link,
  33. log_suffix,
  34. num_processes,
  35. ):
  36. # Check if paths exist
  37. if not os.path.exists(predictions_path):
  38. msg = f"Predictions path {predictions_path} does not exist"
  39. raise FileNotFoundError(msg)
  40. eval_refs = get_eval_refs(swe_bench_tasks)
  41. for k, v in eval_refs.items():
  42. eval_refs[k] = {key: v[key] for key in [KEY_INSTANCE_ID, "FAIL_TO_PASS", "PASS_TO_PASS"]}
  43. # Change model_name_or_patch field to directory name for all predictions
  44. directory = os.path.dirname(predictions_path)
  45. directory_name = directory.rsplit("/", 1)[-1]
  46. pred_path_orig = predictions_path
  47. pred_path_temp = predictions_path.replace(".jsonl", "_filtered.jsonl")
  48. pred_total, pred_will_eval = 0, 0
  49. with open(pred_path_temp, "w") as f:
  50. for l in Path(pred_path_orig).read_text().splitlines(keepends=True):
  51. pred_total += 1
  52. p = json.loads(l)
  53. # Exclude predictions w/ empty strings
  54. if p[KEY_PREDICTION] is not None and p[KEY_PREDICTION].strip() != "":
  55. p[KEY_MODEL] = directory_name
  56. json.dump(p, f)
  57. f.write("\n")
  58. pred_will_eval += 1
  59. print(
  60. f"Found {pred_total} total predictions, will evaluate {pred_will_eval} ({pred_total-pred_will_eval} are empty)",
  61. )
  62. # Run evaluation
  63. predictions_path = pred_path_temp
  64. try:
  65. print("🏃 Beginning evaluation...")
  66. run_evaluation(
  67. predictions_path=predictions_path,
  68. log_dir=log_dir,
  69. swe_bench_tasks=swe_bench_tasks,
  70. testbed=testbed,
  71. skip_existing=skip_existing,
  72. timeout=timeout,
  73. verbose=verbose,
  74. conda_link=conda_link,
  75. log_suffix=log_suffix,
  76. num_processes=num_processes,
  77. )
  78. print("✅ Finished evaluation")
  79. except Exception as e:
  80. print(f"❌ Evaluation failed: {e}\n{traceback.format_exc()}")
  81. print("==================================")
  82. os.remove(pred_path_temp)
  83. # Get predictions, define log_dir
  84. predictions = [json.loads(l) for l in Path(pred_path_orig).read_text().splitlines()]
  85. log_dir = os.path.join(log_dir, directory_name)
  86. print(f"Log directory for evaluation run: {log_dir}")
  87. # Iterate through predictions
  88. scorecards = []
  89. for p in predictions:
  90. scorecard = {KEY_INSTANCE_ID: p[KEY_INSTANCE_ID], "statuses": [], "stats": {}}
  91. # Add trajectory statistics if traj_path exists
  92. traj_path = os.path.join(directory, f"{p[KEY_INSTANCE_ID]}.traj")
  93. if os.path.exists(traj_path):
  94. with open(traj_path) as f:
  95. traj_data = json.load(f)
  96. scorecard["stats"]["traj_num_steps"] = len(traj_data["trajectory"])
  97. scorecard["stats"]["traj_action_dist"] = dict(
  98. Counter(
  99. [
  100. entry["action"].strip().split()[0]
  101. if entry["role"] == "assistant" and "action" in entry and len(entry["action"]) > 0
  102. else None
  103. for entry in traj_data["history"]
  104. ],
  105. ),
  106. )
  107. scorecard["exit_status"] = traj_data["info"]["exit_status"] if "exit_status" in traj_data["info"] else "n/a"
  108. # Check that a prediction was generated
  109. if p[KEY_PREDICTION] is None or p[KEY_PREDICTION].strip() == "":
  110. scorecard["statuses"].append("not_generated")
  111. scorecards.append(scorecard)
  112. continue
  113. scorecard["statuses"].append("generated")
  114. # Get log file
  115. log_path = os.path.join(log_dir, f"{p[KEY_INSTANCE_ID]}.{directory_name}.eval.log")
  116. if not os.path.exists(log_path):
  117. scorecard["statuses"].append("build_failure")
  118. scorecards.append(scorecard)
  119. continue
  120. # Get evaluation logs
  121. eval_sm, found = get_logs_eval(log_path)
  122. # Check that the prediction generated
  123. if not found:
  124. scorecards.append(scorecard)
  125. continue
  126. scorecard["statuses"].append("applied")
  127. with open(log_path) as f:
  128. log_contents = f.read()
  129. if INSTALL_FAIL in log_contents:
  130. scorecard["statuses"].append("install_fail")
  131. # Get resolution status
  132. report = get_eval_report(eval_sm, eval_refs[p[KEY_INSTANCE_ID]])
  133. scorecard["test_results"] = {
  134. "failure": {
  135. "FAIL_TO_PASS": report["FAIL_TO_PASS"]["failure"],
  136. "PASS_TO_PASS": report["PASS_TO_PASS"]["failure"],
  137. },
  138. "success": {
  139. "FAIL_TO_PASS": report["FAIL_TO_PASS"]["success"],
  140. "PASS_TO_PASS": report["PASS_TO_PASS"]["success"],
  141. },
  142. }
  143. resolution_status = get_resolution_status(report)
  144. scorecard["statuses"].append(resolution_status)
  145. try:
  146. diff_obj = PatchSet(p[KEY_PREDICTION])
  147. scorecard["patch_files"] = [
  148. x.path for x in diff_obj.modified_files + diff_obj.added_files + diff_obj.removed_files
  149. ]
  150. scorecard["patch_lines_add"] = sum(f.added for f in diff_obj)
  151. scorecard["patch_lines_del"] = sum(f.removed for f in diff_obj)
  152. except Exception as e:
  153. print(f"[{p[KEY_INSTANCE_ID]}] Error parsing prediction diff: {e}")
  154. scorecard["patch_files"] = []
  155. scorecard["patch_lines_add"] = 0
  156. scorecard["patch_lines_del"] = 0
  157. scorecards.append(scorecard)
  158. # Save to summary, scorecard json
  159. path_scorecards = os.path.join(directory, "scorecards.json")
  160. with open(path_scorecards, "w") as f:
  161. json.dump(scorecards, fp=f, indent=2)
  162. print(f"- Wrote per-instance scorecards to {path_scorecards}")
  163. # Get results and write to file
  164. print("Reference Report:")
  165. report = get_model_report(directory_name, pred_path_orig, swe_bench_tasks, log_dir)
  166. for k, v in report.items():
  167. print(f"- {k}: {len(v)}")
  168. path_results = os.path.join(directory, "results.json")
  169. with open(path_results, "w") as f:
  170. json.dump(report, f, indent=2)
  171. print(f"- Wrote summary of run to {path_results}")
  172. if __name__ == "__main__":
  173. # Parse arguments
  174. parser = argparse.ArgumentParser()
  175. parser.add_argument(
  176. "--predictions_path",
  177. type=str,
  178. help="Path to predictions file (.jsonl)",
  179. required=True,
  180. )
  181. parser.add_argument("--log_dir", type=str, help="Path to log directory", required=True)
  182. parser.add_argument(
  183. "--swe_bench_tasks",
  184. type=str,
  185. help="Path to SWE-bench task instances file",
  186. required=True,
  187. )
  188. parser.add_argument("--testbed", type=str, help="Path to testbed directory", required=True)
  189. parser.add_argument("--skip_existing", action="store_true", help="(Optional) Skip existing logs")
  190. parser.add_argument(
  191. "--timeout",
  192. type=int,
  193. help="(Optional) Timeout in seconds (default: 900)",
  194. default=900,
  195. )
  196. parser.add_argument("--verbose", action="store_true", help="(Optional) Verbose mode")
  197. parser.add_argument("--conda_link", default=None, type=str, help="(Optional) URL to conda installation to use")
  198. parser.add_argument("--log_suffix", default=None, type=str, help="(Optional) Log suffix")
  199. parser.add_argument("--num_processes", default=-1, type=int, help="Num processes")
  200. args = parser.parse_args()
  201. main(**vars(args))