run.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. from __future__ import annotations
  2. import logging
  3. from sweagent import CONFIG_DIR
  4. from sweagent.utils.log import get_logger
  5. try:
  6. import rich
  7. except ModuleNotFoundError as e:
  8. msg = (
  9. "You probably either forgot to install the dependencies "
  10. "or forgot to activate your conda or virtual environment."
  11. )
  12. raise RuntimeError(msg) from e
  13. import json
  14. import os
  15. import re
  16. import subprocess
  17. import traceback
  18. from typing import Any
  19. import rich.console
  20. import rich.markdown
  21. import rich.panel
  22. try:
  23. from rich_argparse import RichHelpFormatter
  24. except ImportError:
  25. msg = "Please install the rich_argparse package with `pip install rich_argparse`."
  26. raise ImportError(msg)
  27. from dataclasses import dataclass
  28. from getpass import getuser
  29. from pathlib import Path
  30. import yaml
  31. from rich.markdown import Markdown
  32. from simple_parsing import parse
  33. from simple_parsing.helpers.flatten import FlattenedAccess
  34. from simple_parsing.helpers.serialization.serializable import FrozenSerializable
  35. from swebench import KEY_INSTANCE_ID, KEY_MODEL, KEY_PREDICTION
  36. from unidiff import PatchSet
  37. from sweagent.agent.agents import Agent, AgentArguments
  38. from sweagent.agent.models import ModelArguments
  39. from sweagent.environment.swe_env import EnvironmentArguments, SWEEnv
  40. from sweagent.environment.utils import (
  41. InvalidGithubURL,
  42. get_associated_commit_urls,
  43. get_data_path_name,
  44. get_gh_issue_data,
  45. parse_gh_issue_url,
  46. )
  47. __doc__: str = """ Run inference. Usage examples:
  48. ```bash
  49. # Run over a github issue:
  50. python run.py --model_name "gpt4" --data_path "https://github.com/pvlib/pvlib-python/issues/1603" --config_file "config/default_from_url.yaml"
  51. # Apply a patch in a local repository to an issue specified as Markdown file and run a custom installer script in the container
  52. python run.py --model_name "gpt4" --data_path "/path/to/my_issue.md" --repo_path "/path/to/my/local/repo" --environment_setup "/path/to/setup.sh" --config_file "config/default_from_url.yaml" --apply_patch_locally
  53. ```
  54. **For more information**: https://princeton-nlp.github.io/SWE-agent/usage/cl_tutorial/
  55. """
  56. logger = get_logger("swe-agent-run")
  57. logging.getLogger("simple_parsing").setLevel(logging.WARNING)
  58. @dataclass(frozen=True)
  59. class ActionsArguments(FlattenedAccess, FrozenSerializable):
  60. """Run real-life actions (opening PRs, etc.) if we can solve the issue."""
  61. # Open a PR with the patch if we can solve the issue
  62. open_pr: bool = False
  63. # When working with local repository: Apply patch
  64. apply_patch_locally: bool = False
  65. # Option to be used with open_pr: Skip action if there are already commits claiming
  66. # to fix the issue. Please only set this to False if you are sure the commits are
  67. # not fixes or if this is your own repository!
  68. skip_if_commits_reference_issue: bool = True
  69. # OBSOLETE. Do not use, will raise error. Please specify --repo_path instead.
  70. push_gh_repo_url: str = ""
  71. def __post_init__(self):
  72. if self.push_gh_repo_url:
  73. msg = "push_gh_repo_url is obsolete. Use repo_path instead"
  74. raise ValueError(msg)
  75. @dataclass(frozen=True)
  76. class ScriptArguments(FlattenedAccess, FrozenSerializable):
  77. """Configure the control flow of the run.py script"""
  78. environment: EnvironmentArguments
  79. agent: AgentArguments
  80. actions: ActionsArguments
  81. # Only run instances that completely match this regex
  82. instance_filter: str = ".*"
  83. # Skip instances with existing trajectories
  84. skip_existing: bool = True
  85. # Suffix for the run name (used for example in trajectory directory naming)
  86. suffix: str = ""
  87. # Raise unhandled exceptions during the run (useful for debugging)
  88. raise_exceptions: bool = False
  89. # Dump the entire config to the log
  90. print_config: bool = True
  91. @property
  92. def run_name(self):
  93. """Generate a unique name for this run based on the arguments."""
  94. model_name = self.agent.model.model_name.replace(":", "-")
  95. data_stem = get_data_path_name(self.environment.data_path)
  96. assert self.agent.config_file is not None # mypy
  97. config_stem = Path(self.agent.config_file).stem
  98. temp = self.agent.model.temperature
  99. top_p = self.agent.model.top_p
  100. per_instance_cost_limit = self.agent.model.per_instance_cost_limit
  101. install_env = self.environment.install_environment
  102. return (
  103. f"{model_name}__{data_stem}__{config_stem}__t-{temp:.2f}__p-{top_p:.2f}"
  104. + f"__c-{per_instance_cost_limit:.2f}__install-{int(install_env)}"
  105. + (f"__{self.suffix}" if self.suffix else "")
  106. )
  107. class _ContinueLoop(Exception):
  108. """Used for internal control flow"""
  109. class MainHook:
  110. """Hook structure for the web server or other addons to interface with"""
  111. @staticmethod
  112. def _is_promising_patch(info: dict[str, Any]) -> bool:
  113. """Do we actually believe that the patch will solve the issue?
  114. Or are we just submitting the last patch we generated before hitting an error?
  115. """
  116. # The exit status can also be `submitted (exit_cost)` etc.
  117. return info["exit_status"] == "submitted" and info.get("submission") is not None
  118. def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
  119. """Called when hook is initialized"""
  120. def on_start(self):
  121. """Called at the beginning of `Main.main`"""
  122. def on_end(self):
  123. """Called at the end of `Main.main`"""
  124. def on_instance_start(self, *, index: int, instance: dict[str, Any]):
  125. """Called at the beginning of each instance loop in `Main.run`"""
  126. def on_instance_skipped(
  127. self,
  128. ):
  129. """Called when an instance is skipped in `Main.run`"""
  130. def on_instance_completed(self, *, info, trajectory):
  131. """Called when an instance is completed in `Main.run`"""
  132. class SaveApplyPatchHook(MainHook):
  133. """This hook saves patches to a separate directory and optionally applies them to a local repository."""
  134. def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
  135. self._traj_dir = traj_dir
  136. self._apply_patch_locally = args.actions.apply_patch_locally
  137. self._instance = None
  138. def on_instance_start(self, *, index: int, instance: dict[str, Any]):
  139. self._instance = instance
  140. def on_instance_completed(self, *, info, trajectory):
  141. assert self._instance is not None # mypy
  142. instance_id = self._instance["instance_id"]
  143. patch_path = self._save_patch(instance_id, info)
  144. if patch_path:
  145. if not self._apply_patch_locally:
  146. return
  147. if not self._is_promising_patch(info):
  148. return
  149. assert self._instance # mypy
  150. if self._instance["repo_type"] != "local":
  151. return
  152. local_dir = Path(self._instance["repo"])
  153. self._apply_patch(patch_path, local_dir)
  154. @staticmethod
  155. def _print_patch_message(patch_output_file: Path):
  156. console = rich.console.Console()
  157. msg = [
  158. "SWE-agent has produced a patch that it believes will solve the issue you submitted!",
  159. "Use the code snippet below to inspect or apply it!",
  160. ]
  161. panel = rich.panel.Panel.fit(
  162. "\n".join(msg),
  163. title="🎉 Submission successful 🎉",
  164. )
  165. console.print(panel)
  166. content = [
  167. "```bash",
  168. "# The patch has been saved to your local filesystem at:",
  169. f"PATCH_FILE_PATH='{patch_output_file.resolve()}'",
  170. "# Inspect it:",
  171. 'cat "${PATCH_FILE_PATH}"',
  172. "# Apply it to a local repository:",
  173. "cd <your local repo root>",
  174. 'git apply "${PATCH_FILE_PATH}"',
  175. "```",
  176. ]
  177. console.print(rich.markdown.Markdown("\n".join(content)))
  178. def _save_patch(self, instance_id: str, info) -> Path | None:
  179. """Create patch files that can be applied with `git am`.
  180. Returns:
  181. The path to the patch file, if it was saved. Otherwise, returns None.
  182. """
  183. patch_output_dir = self._traj_dir / "patches"
  184. patch_output_dir.mkdir(exist_ok=True, parents=True)
  185. patch_output_file = patch_output_dir / f"{instance_id}.patch"
  186. if not info.get("submission"):
  187. logger.info("No patch to save.")
  188. return None
  189. model_patch = info["submission"]
  190. patch_output_file.write_text(model_patch)
  191. if self._is_promising_patch(info):
  192. # Only print big congratulations if we actually believe
  193. # the patch will solve the issue
  194. self._print_patch_message(patch_output_file)
  195. return patch_output_file
  196. def _apply_patch(self, patch_file: Path, local_dir: Path) -> None:
  197. """Apply a patch to a local directory."""
  198. assert local_dir.is_dir()
  199. assert patch_file.exists()
  200. # The resolve() is important, because we're gonna run the cmd
  201. # somewhere else
  202. cmd = ["git", "apply", str(patch_file.resolve())]
  203. try:
  204. subprocess.run(cmd, cwd=local_dir, check=True)
  205. except subprocess.CalledProcessError as e:
  206. logger.error(f"Failed to apply patch {patch_file} to {local_dir}: {e}")
  207. return
  208. logger.info(f"Applied patch {patch_file} to {local_dir}")
  209. class OpenPRHook(MainHook):
  210. """This hook opens a PR if the issue is solved and the user has enabled the option."""
  211. def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
  212. self._env = env
  213. self._token: str = env._github_token
  214. self._data_path = args.environment.data_path
  215. self._open_pr = args.actions.open_pr
  216. self._skip_if_commits_reference_issue = args.actions.skip_if_commits_reference_issue
  217. def on_instance_completed(self, *, info, trajectory):
  218. if self._open_pr and self.should_open_pr(info):
  219. self._env.open_pr(trajectory=trajectory)
  220. def should_open_pr(self, info: dict[str, Any]) -> bool:
  221. """Does opening a PR make sense?"""
  222. if not info.get("submission"):
  223. logger.info("Not opening PR because no submission was made.")
  224. return False
  225. if info["exit_status"] != "submitted":
  226. logger.info("Not opening PR because exit status was %s and not submitted.", info["exit_status"])
  227. return False
  228. try:
  229. issue = get_gh_issue_data(self._data_path, token=self._token)
  230. except InvalidGithubURL:
  231. logger.info("Currently only GitHub is supported to open PRs to. Skipping PR creation.")
  232. return False
  233. if issue.state != "open":
  234. logger.info(f"Issue is not open (state={issue.state}. Skipping PR creation.")
  235. return False
  236. if issue.assignee:
  237. logger.info("Issue is already assigned. Skipping PR creation. Be nice :)")
  238. return False
  239. if issue.locked:
  240. logger.info("Issue is locked. Skipping PR creation.")
  241. return False
  242. org, repo, issue_number = parse_gh_issue_url(self._data_path)
  243. associated_commits = get_associated_commit_urls(org, repo, issue_number, token=self._token)
  244. if associated_commits:
  245. commit_url_strs = ", ".join(associated_commits)
  246. if self._skip_if_commits_reference_issue:
  247. logger.info(f"Issue already has associated commits (see {commit_url_strs}). Skipping PR creation.")
  248. return False
  249. else:
  250. logger.warning(
  251. "Proceeding with PR creation even though there are already commits "
  252. f"({commit_url_strs}) associated with the issue. Please only do this for your own repositories "
  253. "or after verifying that the existing commits do not fix the issue.",
  254. )
  255. return True
  256. class Main:
  257. def __init__(self, args: ScriptArguments):
  258. if args.print_config:
  259. logger.info(f"📙 Arguments: {args.dumps_yaml()}")
  260. self.args = args
  261. self.agent = Agent("primary", args.agent)
  262. self.env = SWEEnv(args.environment)
  263. self.traj_dir = Path("trajectories") / Path(getuser()) / args.run_name
  264. self.traj_dir.mkdir(parents=True, exist_ok=True)
  265. self._save_arguments()
  266. default_hooks = [
  267. SaveApplyPatchHook(),
  268. OpenPRHook(),
  269. ]
  270. self.hooks: list[MainHook] = []
  271. for hook in default_hooks:
  272. self.add_hook(hook)
  273. def add_hook(self, hook: MainHook):
  274. hook.on_init(args=self.args, agent=self.agent, env=self.env, traj_dir=self.traj_dir)
  275. self.hooks.append(hook)
  276. def run(self, index):
  277. # Reset environment
  278. instance_id = self.env.data[index]["instance_id"]
  279. for hook in self.hooks:
  280. hook.on_instance_start(index=index, instance=self.env.data[index])
  281. assert isinstance(instance_id, str) # mypy
  282. if self.should_skip(instance_id):
  283. for hook in self.hooks:
  284. hook.on_instance_skipped()
  285. raise _ContinueLoop
  286. logger.info("▶️ Beginning task " + str(index))
  287. observation, info = self.env.reset(index)
  288. if info is None:
  289. raise _ContinueLoop
  290. # Get info, patch information
  291. issue = getattr(self.env, "query", None)
  292. files = []
  293. assert self.env.record is not None # mypy
  294. if "patch" in self.env.record:
  295. files = "\n".join([f"- {x.path}" for x in PatchSet(self.env.record["patch"]).modified_files])
  296. # Get test files, F2P tests information
  297. test_files = []
  298. if "test_patch" in self.env.record:
  299. test_patch_obj = PatchSet(self.env.record["test_patch"])
  300. test_files = "\n".join([f"- {x.path}" for x in test_patch_obj.modified_files + test_patch_obj.added_files])
  301. tests = ""
  302. if "FAIL_endTO_PASS" in self.env.record:
  303. tests = "\n".join([f"- {x}" for x in self.env.record["FAIL_TO_PASS"]])
  304. setup_args = {"issue": issue, "files": files, "test_files": test_files, "tests": tests}
  305. info, trajectory = self.agent.run(
  306. setup_args=setup_args,
  307. env=self.env,
  308. observation=observation,
  309. traj_dir=self.traj_dir,
  310. return_type="info_trajectory",
  311. )
  312. self._save_predictions(instance_id, info)
  313. for hook in self.hooks:
  314. hook.on_instance_completed(info=info, trajectory=trajectory)
  315. def main(self):
  316. for hook in self.hooks:
  317. hook.on_start()
  318. for index in range(len(self.env.data)):
  319. try:
  320. self.run(index)
  321. except _ContinueLoop:
  322. continue
  323. except KeyboardInterrupt:
  324. logger.info("Exiting InterCode environment...")
  325. self.env.close()
  326. break
  327. except SystemExit:
  328. logger.critical("❌ Exiting because SystemExit was called")
  329. self.env.close()
  330. logger.info("Container closed")
  331. raise
  332. except Exception as e:
  333. traceback.print_exc()
  334. if self.args.raise_exceptions:
  335. self.env.close()
  336. raise e
  337. if self.env.record:
  338. logger.warning(f"❌ Failed on {self.env.record['instance_id']}: {e}")
  339. else:
  340. logger.warning("❌ Failed on unknown instance")
  341. self.env.reset_container()
  342. continue
  343. for hook in self.hooks:
  344. hook.on_end()
  345. def _save_arguments(self) -> None:
  346. """Save the arguments to a yaml file to the run's trajectory directory."""
  347. log_path = self.traj_dir / "args.yaml"
  348. if log_path.exists():
  349. try:
  350. other_args = self.args.load_yaml(log_path)
  351. if self.args.dumps_yaml() != other_args.dumps_yaml(): # check yaml equality instead of object equality
  352. logger.warning("**************************************************")
  353. logger.warning("Found existing args.yaml with different arguments!")
  354. logger.warning("**************************************************")
  355. except Exception as e:
  356. logger.warning(f"Failed to load existing args.yaml: {e}")
  357. with log_path.open("w") as f:
  358. self.args.dump_yaml(f)
  359. def should_skip(self, instance_id: str) -> bool:
  360. """Check if we should skip this instance based on the instance filter and skip_existing flag."""
  361. # Skip instances that don't match the instance filter
  362. if re.match(self.args.instance_filter, instance_id) is None:
  363. logger.info(f"Instance filter not matched. Skipping instance {instance_id}")
  364. return True
  365. # If flag is set to False, don't skip
  366. if not self.args.skip_existing:
  367. return False
  368. # Check if there's an existing trajectory for this instance
  369. log_path = self.traj_dir / (instance_id + ".traj")
  370. if log_path.exists():
  371. with log_path.open("r") as f:
  372. data = json.load(f)
  373. # If the trajectory has no exit status, it's incomplete and we will redo it
  374. exit_status = data["info"].get("exit_status", None)
  375. if exit_status == "early_exit" or exit_status is None:
  376. logger.info(f"Found existing trajectory with no exit status: {log_path}")
  377. logger.info("Removing incomplete trajectory...")
  378. os.remove(log_path)
  379. else:
  380. logger.info(f"⏭️ Skipping existing trajectory: {log_path}")
  381. return True
  382. return False
  383. def _save_predictions(self, instance_id: str, info):
  384. output_file = self.traj_dir / "all_preds.jsonl"
  385. model_patch = info["submission"] if "submission" in info else None
  386. datum = {
  387. KEY_MODEL: Path(self.traj_dir).name,
  388. KEY_INSTANCE_ID: instance_id,
  389. KEY_PREDICTION: model_patch,
  390. }
  391. with open(output_file, "a+") as fp:
  392. print(json.dumps(datum), file=fp, flush=True)
  393. logger.info(f"Saved predictions to {output_file}")
  394. def get_args(args=None) -> ScriptArguments:
  395. """Parse command line arguments and return a ScriptArguments object.
  396. Args:
  397. args: Optional list of arguments to parse. If not provided, uses sys.argv.
  398. """
  399. defaults = ScriptArguments(
  400. suffix="",
  401. environment=EnvironmentArguments(
  402. image_name="sweagent/swe-agent:latest",
  403. data_path="princeton-nlp/SWE-bench_Lite",
  404. split="dev",
  405. verbose=True,
  406. install_environment=True,
  407. cache_task_images=False,
  408. ),
  409. skip_existing=True,
  410. agent=AgentArguments(
  411. model=ModelArguments(
  412. model_name="gpt4",
  413. total_cost_limit=0.0,
  414. per_instance_cost_limit=3.0,
  415. temperature=0.0,
  416. top_p=0.95,
  417. ),
  418. config_file=CONFIG_DIR / "default.yaml",
  419. ),
  420. actions=ActionsArguments(open_pr=False, skip_if_commits_reference_issue=True),
  421. )
  422. # Nicer yaml dumping of multiline strings
  423. def multiline_representer(dumper, data):
  424. """configures yaml for dumping multiline strings
  425. Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data
  426. """
  427. if data.count("\n") > 0: # check for multiline string
  428. return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
  429. return dumper.represent_scalar("tag:yaml.org,2002:str", data)
  430. yaml.add_representer(str, multiline_representer)
  431. return parse(
  432. ScriptArguments,
  433. default=defaults,
  434. add_config_path_arg=False,
  435. args=args,
  436. formatter_class=RichHelpFormatter,
  437. description=Markdown(__doc__),
  438. )
  439. def main(args: ScriptArguments):
  440. Main(args).main()
  441. if __name__ == "__main__":
  442. main(get_args())