run.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. import json
  2. import logging
  3. import os
  4. import re
  5. import subprocess
  6. import traceback
  7. from typing import Any, Dict, Optional
  8. import rich.console
  9. import rich.markdown
  10. import rich.panel
  11. import rich.markdown
  12. import yaml
  13. from dataclasses import dataclass
  14. from getpass import getuser
  15. from pathlib import Path
  16. from rich.logging import RichHandler
  17. from simple_parsing import parse
  18. from simple_parsing.helpers.serialization.serializable import FrozenSerializable
  19. from simple_parsing.helpers.flatten import FlattenedAccess
  20. from sweagent import (
  21. Agent,
  22. AgentArguments,
  23. EnvironmentArguments,
  24. ModelArguments,
  25. SWEEnv,
  26. get_data_path_name,
  27. )
  28. from swebench import KEY_INSTANCE_ID, KEY_MODEL, KEY_PREDICTION
  29. from unidiff import PatchSet
  30. from sweagent.environment.utils import InvalidGithubURL, get_associated_commit_urls, get_gh_issue_data, parse_gh_issue_url
  31. handler = RichHandler(show_time=False, show_path=False)
  32. handler.setLevel(logging.DEBUG)
  33. logger = logging.getLogger("run_dev")
  34. logger.setLevel(logging.DEBUG)
  35. logger.addHandler(handler)
  36. logger.propagate = False
  37. logging.getLogger("simple_parsing").setLevel(logging.WARNING)
  38. @dataclass(frozen=True)
  39. class ActionsArguments(FlattenedAccess, FrozenSerializable):
  40. """Run real-life actions (opening PRs, etc.) if we can solve the issue."""
  41. # Open a PR with the patch if we can solve the issue
  42. open_pr: bool = False
  43. # When working with local repository: Apply patch
  44. apply_patch_locally: bool = False
  45. # Option to be used with open_pr: Skip action if there are already commits claiming
  46. # to fix the issue. Please only set this to False if you are sure the commits are
  47. # not fixes or if this is your own repository!
  48. skip_if_commits_reference_issue: bool = True
  49. # OBSOLETE. Do not use, will raise error. Please specify --repo_path instead.
  50. push_gh_repo_url: str = ""
  51. def __post_init__(self):
  52. if self.push_gh_repo_url:
  53. raise ValueError("push_gh_repo_url is obsolete. Use repo_path instead")
  54. @dataclass(frozen=True)
  55. class ScriptArguments(FlattenedAccess, FrozenSerializable):
  56. """Configure the control flow of the run.py script"""
  57. environment: EnvironmentArguments
  58. agent: AgentArguments
  59. actions: ActionsArguments
  60. instance_filter: str = ".*" # Only run instances that completely match this regex
  61. skip_existing: bool = True # Skip instances with existing trajectories
  62. suffix: str = ""
  63. # Raise unhandled exceptions during the run (useful for debugging)
  64. raise_exceptions: bool = False
  65. @property
  66. def run_name(self):
  67. """Generate a unique name for this run based on the arguments."""
  68. model_name = self.agent.model.model_name.replace(":", "-")
  69. data_stem = get_data_path_name(self.environment.data_path)
  70. config_stem = Path(self.agent.config_file).stem
  71. temp = self.agent.model.temperature
  72. top_p = self.agent.model.top_p
  73. per_instance_cost_limit = self.agent.model.per_instance_cost_limit
  74. install_env = self.environment.install_environment
  75. return (
  76. f"{model_name}__{data_stem}__{config_stem}__t-{temp:.2f}__p-{top_p:.2f}"
  77. + f"__c-{per_instance_cost_limit:.2f}__install-{int(install_env)}"
  78. + (f"__{self.suffix}" if self.suffix else "")
  79. )
  80. def main(args: ScriptArguments):
  81. logger.info(f"📙 Arguments: {args.dumps_yaml()}")
  82. agent = Agent("primary", args.agent)
  83. env = SWEEnv(args.environment)
  84. traj_dir = Path("trajectories") / Path(getuser()) / args.run_name
  85. traj_dir.mkdir(parents=True, exist_ok=True)
  86. save_arguments(traj_dir, args)
  87. for index in range(len(env.data)):
  88. try:
  89. # Reset environment
  90. instance_id = env.data[index]["instance_id"]
  91. assert isinstance(instance_id, str) # mypy
  92. if should_skip(args, traj_dir, instance_id):
  93. continue
  94. logger.info("▶️ Beginning task " + str(index))
  95. observation, info = env.reset(index)
  96. if info is None:
  97. continue
  98. # Get info, patch information
  99. issue = getattr(env, "query", None)
  100. files = []
  101. assert env.record is not None # mypy
  102. if "patch" in env.record:
  103. files = "\n".join(
  104. [f"- {x.path}" for x in PatchSet(env.record["patch"]).modified_files]
  105. )
  106. # Get test files, F2P tests information
  107. test_files = []
  108. if "test_patch" in env.record:
  109. test_patch_obj = PatchSet(env.record["test_patch"])
  110. test_files = "\n".join(
  111. [f"- {x.path}" for x in test_patch_obj.modified_files + test_patch_obj.added_files]
  112. )
  113. tests = ""
  114. if "FAIL_TO_PASS" in env.record:
  115. tests = "\n".join([f"- {x}" for x in env.record["FAIL_TO_PASS"]])
  116. setup_args = {
  117. "issue": issue,
  118. "files": files,
  119. "test_files": test_files,
  120. "tests": tests
  121. }
  122. info, trajectory = agent.run(
  123. setup_args=setup_args,
  124. env=env,
  125. observation=observation,
  126. traj_dir=traj_dir,
  127. return_type="info_trajectory",
  128. )
  129. save_predictions(traj_dir, instance_id, info)
  130. patch_path = save_patch(traj_dir, instance_id, info)
  131. if args.actions.open_pr and should_open_pr(args, info, token=env._github_token):
  132. env.open_pr(trajectory=trajectory)
  133. if args.actions.apply_patch_locally and patch_path is not None and env.record["repo_type"] == "local":
  134. apply_patch(Path(args.environment.repo_path), patch_file=patch_path)
  135. except KeyboardInterrupt:
  136. logger.info("Exiting InterCode environment...")
  137. env.close()
  138. break
  139. except Exception as e:
  140. traceback.print_exc()
  141. logger.warning(f"❌ Failed on {env.record['instance_id']}: {e}")
  142. if args.raise_exceptions:
  143. raise e
  144. env.reset_container()
  145. continue
  146. def should_open_pr(args: ScriptArguments, info: Dict[str, Any], *, token: str="") -> bool:
  147. """Does opening a PR make sense?"""
  148. if not info.get("submission"):
  149. logger.info("Not opening PR because submission was made.")
  150. return False
  151. if info["exit_status"] != "submitted":
  152. logger.info("Not opening PR because exit status was %s and not submitted.", info["exit_status"])
  153. return False
  154. try:
  155. issue = get_gh_issue_data(args.environment.data_path, token=token)
  156. except InvalidGithubURL:
  157. logger.info("Currently only GitHub is supported to open PRs to. Skipping PR creation.")
  158. return False
  159. if issue.state != "open":
  160. logger.info(f"Issue is not open (state={issue.state}. Skipping PR creation.")
  161. return False
  162. if issue.assignee:
  163. logger.info("Issue is already assigned. Skipping PR creation. Be nice :)")
  164. return False
  165. if issue.locked:
  166. logger.info("Issue is locked. Skipping PR creation.")
  167. return False
  168. org, repo, issue_number = parse_gh_issue_url(args.environment.data_path)
  169. associated_commits = get_associated_commit_urls(org, repo, issue_number, token=token)
  170. if associated_commits:
  171. commit_url_strs = ", ".join(associated_commits)
  172. if args.actions.skip_if_commits_reference_issue:
  173. logger.info(f"Issue already has associated commits (see {commit_url_strs}). Skipping PR creation.")
  174. return False
  175. else:
  176. logger.warning(
  177. "Proceeding with PR creation even though there are already commits "
  178. f"({commit_url_strs}) associated with the issue. Please only do this for your own repositories "
  179. "or after verifying that the existing commits do not fix the issue."
  180. )
  181. return True
  182. def save_arguments(traj_dir: Path, args: ScriptArguments) -> None:
  183. """Save the arguments to a yaml file to the run's trajectory directory."""
  184. log_path = traj_dir / "args.yaml"
  185. if log_path.exists():
  186. try:
  187. other_args = args.load_yaml(log_path)
  188. if (args.dumps_yaml() != other_args.dumps_yaml()): # check yaml equality instead of object equality
  189. logger.warning("**************************************************")
  190. logger.warning("Found existing args.yaml with different arguments!")
  191. logger.warning("**************************************************")
  192. except Exception as e:
  193. logger.warning(f"Failed to load existing args.yaml: {e}")
  194. with log_path.open("w") as f:
  195. args.dump_yaml(f)
  196. def should_skip(args: ScriptArguments, traj_dir: Path, instance_id: str) -> bool:
  197. """Check if we should skip this instance based on the instance filter and skip_existing flag."""
  198. # Skip instances that don't match the instance filter
  199. if re.match(args.instance_filter, instance_id) is None:
  200. logger.info(f"Instance filter not matched. Skipping instance {instance_id}")
  201. return True
  202. # If flag is set to False, don't skip
  203. if not args.skip_existing:
  204. return False
  205. # Check if there's an existing trajectory for this instance
  206. log_path = traj_dir / (instance_id + ".traj")
  207. if log_path.exists():
  208. with log_path.open("r") as f:
  209. data = json.load(f)
  210. # If the trajectory has no exit status, it's incomplete and we will redo it
  211. exit_status = data["info"].get("exit_status", None)
  212. if exit_status == "early_exit" or exit_status is None:
  213. logger.info(f"Found existing trajectory with no exit status: {log_path}")
  214. logger.info("Removing incomplete trajectory...")
  215. os.remove(log_path)
  216. else:
  217. logger.info(f"⏭️ Skipping existing trajectory: {log_path}")
  218. return True
  219. return False
  220. def save_predictions(traj_dir: Path, instance_id: str, info):
  221. output_file = traj_dir / "all_preds.jsonl"
  222. model_patch = info["submission"] if "submission" in info else None
  223. datum = {
  224. KEY_MODEL: Path(traj_dir).name,
  225. KEY_INSTANCE_ID: instance_id,
  226. KEY_PREDICTION: model_patch,
  227. }
  228. with open(output_file, "a+") as fp:
  229. print(json.dumps(datum), file=fp, flush=True)
  230. logger.info(f"Saved predictions to {output_file}")
  231. def save_patch(traj_dir: Path, instance_id: str, info) -> Optional[Path]:
  232. """Create patch files that can be applied with `git am`.
  233. Returns:
  234. The path to the patch file, if it was saved. Otherwise, returns None.
  235. """
  236. patch_output_dir = traj_dir / "patches"
  237. patch_output_dir.mkdir(exist_ok=True, parents=True)
  238. patch_output_file = patch_output_dir / f"{instance_id}.patch"
  239. if not "submission" in info:
  240. logger.info("No patch to save.")
  241. return
  242. model_patch = info["submission"]
  243. patch_output_file.write_text(model_patch)
  244. _print_patch_message(patch_output_file)
  245. return patch_output_file
  246. def apply_patch(local_dir: Path, patch_file: Path) -> None:
  247. """Apply a patch to a local directory."""
  248. assert local_dir.is_dir()
  249. assert patch_file.exists()
  250. # The resolve() is important, because we're gonna run the cmd
  251. # somewhere else
  252. cmd = ["git", "apply", str(patch_file.resolve())]
  253. try:
  254. subprocess.run(cmd, cwd=local_dir, check=True)
  255. except subprocess.CalledProcessError as e:
  256. logger.error(f"Failed to apply patch {patch_file} to {local_dir}: {e}")
  257. return
  258. logger.info(f"Applied patch {patch_file} to {local_dir}")
  259. def _print_patch_message(patch_output_file: Path):
  260. console = rich.console.Console()
  261. msg = [
  262. "SWE-agent has produced a patch that it believes will solve the issue you submitted!",
  263. "Use the code snippet below to inspect or apply it!"
  264. ]
  265. panel = rich.panel.Panel.fit(
  266. "\n".join(msg),
  267. title="🎉 Submission successful 🎉",
  268. )
  269. console.print(panel)
  270. content = [
  271. "```bash",
  272. f"# The patch has been saved to your local filesystem at:",
  273. f"PATCH_FILE_PATH='{patch_output_file.resolve()}'",
  274. "# Inspect it:",
  275. "cat \"${PATCH_FILE_PATH}\"",
  276. "# Apply it to a local repository:",
  277. f"cd <your local repo root>",
  278. "git apply \"${PATCH_FILE_PATH}\"",
  279. "```",
  280. ]
  281. console.print(rich.markdown.Markdown("\n".join(content)))
  282. def get_args(args=None) -> ScriptArguments:
  283. """Parse command line arguments and return a ScriptArguments object.
  284. Args:
  285. args: Optional list of arguments to parse. If not provided, uses sys.argv.
  286. """
  287. defaults = ScriptArguments(
  288. suffix="",
  289. environment=EnvironmentArguments(
  290. image_name="sweagent/swe-agent:latest",
  291. data_path="princeton-nlp/SWE-bench_Lite",
  292. split="dev",
  293. verbose=True,
  294. install_environment=True,
  295. ),
  296. skip_existing=True,
  297. agent=AgentArguments(
  298. model=ModelArguments(
  299. model_name="gpt4",
  300. total_cost_limit=0.0,
  301. per_instance_cost_limit=3.0,
  302. temperature=0.0,
  303. top_p=0.95,
  304. ),
  305. config_file="config/default.yaml",
  306. ),
  307. actions=ActionsArguments(open_pr=False, skip_if_commits_reference_issue=True),
  308. )
  309. # Nicer yaml dumping of multiline strings
  310. def multiline_representer(dumper, data):
  311. """configures yaml for dumping multiline strings
  312. Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data
  313. """
  314. if data.count("\n") > 0: # check for multiline string
  315. return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
  316. return dumper.represent_scalar("tag:yaml.org,2002:str", data)
  317. yaml.add_representer(str, multiline_representer)
  318. return parse(ScriptArguments, default=defaults, add_config_path_arg=False, args=args)
  319. if __name__ == "__main__":
  320. args = get_args()
  321. main(args)