run.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. from __future__ import annotations
  2. import logging
  3. from sweagent import CONFIG_DIR
  4. from sweagent.utils.log import add_file_handler, get_logger
  5. try:
  6. import rich
  7. except ModuleNotFoundError as e:
  8. msg = (
  9. "You probably either forgot to install the dependencies "
  10. "or forgot to activate your conda or virtual environment."
  11. )
  12. raise RuntimeError(msg) from e
  13. import json
  14. import re
  15. import subprocess
  16. import traceback
  17. from typing import Any
  18. import rich.console
  19. import rich.markdown
  20. import rich.panel
  21. try:
  22. from rich_argparse import RichHelpFormatter
  23. except ImportError:
  24. msg = "Please install the rich_argparse package with `pip install rich_argparse`."
  25. raise ImportError(msg)
  26. import datetime
  27. from dataclasses import dataclass
  28. from getpass import getuser
  29. from pathlib import Path
  30. import yaml
  31. from rich.markdown import Markdown
  32. from simple_parsing import parse
  33. from simple_parsing.helpers.flatten import FlattenedAccess
  34. from simple_parsing.helpers.serialization.serializable import FrozenSerializable
  35. from swebench.harness.constants import KEY_INSTANCE_ID, KEY_MODEL, KEY_PREDICTION
  36. from unidiff import PatchSet
  37. from sweagent.agent.agents import Agent, AgentArguments
  38. from sweagent.agent.models import ModelArguments
  39. from sweagent.environment.swe_env import EnvironmentArguments, SWEEnv
  40. from sweagent.environment.utils import (
  41. InvalidGithubURL,
  42. extract_flag_format,
  43. get_associated_commit_urls,
  44. get_data_path_name,
  45. get_gh_issue_data,
  46. parse_gh_issue_url,
  47. )
  48. __doc__: str = """ Run inference. Usage examples:
  49. ```bash
  50. # Run over a github issue:
  51. python run.py --model_name "gpt4" --data_path "https://github.com/pvlib/pvlib-python/issues/1603" --config_file "config/default_from_url.yaml"
  52. # Apply a patch in a local repository to an issue specified as Markdown file and run a custom installer script in the container
  53. python run.py --model_name "gpt4" --data_path "/path/to/my_issue.md" --repo_path "/path/to/my/local/repo" --environment_setup "/path/to/setup.sh" --config_file "config/default_from_url.yaml" --apply_patch_locally
  54. ```
  55. **For more information**: https://princeton-nlp.github.io/SWE-agent/usage/cl_tutorial/
  56. """
  57. logger = get_logger("swe-agent-run")
  58. logging.getLogger("simple_parsing").setLevel(logging.WARNING)
  59. @dataclass(frozen=True)
  60. class ActionsArguments(FlattenedAccess, FrozenSerializable):
  61. """Run real-life actions (opening PRs, etc.) if we can solve the issue."""
  62. # Open a PR with the patch if we can solve the issue
  63. open_pr: bool = False
  64. # When working with local repository: Apply patch
  65. apply_patch_locally: bool = False
  66. # Option to be used with open_pr: Skip action if there are already commits claiming
  67. # to fix the issue. Please only set this to False if you are sure the commits are
  68. # not fixes or if this is your own repository!
  69. skip_if_commits_reference_issue: bool = True
  70. # OBSOLETE. Do not use, will raise error. Please specify --repo_path instead.
  71. push_gh_repo_url: str = ""
  72. def __post_init__(self):
  73. if self.push_gh_repo_url:
  74. msg = "push_gh_repo_url is obsolete. Use repo_path instead"
  75. raise ValueError(msg)
  76. @dataclass(frozen=True)
  77. class ScriptArguments(FlattenedAccess, FrozenSerializable):
  78. """Configure the control flow of the run.py script"""
  79. environment: EnvironmentArguments
  80. agent: AgentArguments
  81. actions: ActionsArguments
  82. # Only run instances that completely match this regex
  83. instance_filter: str = ".*"
  84. # Skip instances with existing trajectories
  85. skip_existing: bool = True
  86. # Suffix for the run name (used for example in trajectory directory naming)
  87. suffix: str = ""
  88. # Raise unhandled exceptions during the run (useful for debugging)
  89. raise_exceptions: bool = False
  90. # Dump the entire config to the log
  91. print_config: bool = True
  92. # Run the agent in CTF mode (SWE-agent: EnIGMA)
  93. ctf: bool = False
  94. @property
  95. def run_name(self) -> str:
  96. """Generate a unique name for this run based on the arguments."""
  97. model_name = self.agent.model.model_name.replace(":", "-")
  98. data_stem = get_data_path_name(self.environment.data_path)
  99. assert self.agent.config_file is not None # mypy
  100. config_stem = Path(self.agent.config_file).stem
  101. temp = self.agent.model.temperature
  102. top_p = self.agent.model.top_p
  103. per_instance_cost_limit = self.agent.model.per_instance_cost_limit
  104. install_env = self.environment.install_environment
  105. return (
  106. f"{model_name}__{data_stem}__{config_stem}__t-{temp:.2f}__p-{top_p:.2f}"
  107. + f"__c-{per_instance_cost_limit:.2f}__install-{int(install_env)}"
  108. + (f"__{self.suffix}" if self.suffix else "")
  109. )
  110. class _ContinueLoop(Exception):
  111. """Used for internal control flow"""
  112. class MainHook:
  113. """Hook structure for the web server or other addons to interface with"""
  114. @staticmethod
  115. def _is_promising_patch(info: dict[str, Any]) -> bool:
  116. """Do we actually believe that the patch will solve the issue?
  117. Or are we just submitting the last patch we generated before hitting an error?
  118. """
  119. # The exit status can also be `submitted (exit_cost)` etc.
  120. return info["exit_status"] == "submitted" and info.get("submission") is not None
  121. def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
  122. """Called when hook is initialized"""
  123. def on_start(self):
  124. """Called at the beginning of `Main.main`"""
  125. def on_end(self):
  126. """Called at the end of `Main.main`"""
  127. def on_instance_start(self, *, index: int, instance: dict[str, Any]):
  128. """Called at the beginning of each instance loop in `Main.run`"""
  129. def on_instance_skipped(
  130. self,
  131. ):
  132. """Called when an instance is skipped in `Main.run`"""
  133. def on_instance_completed(self, *, info, trajectory):
  134. """Called when an instance is completed in `Main.run`"""
  135. class SaveApplyPatchHook(MainHook):
  136. """This hook saves patches to a separate directory and optionally applies them to a local repository."""
  137. def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
  138. self._traj_dir = traj_dir
  139. self._apply_patch_locally = args.actions.apply_patch_locally
  140. self._instance = None
  141. def on_instance_start(self, *, index: int, instance: dict[str, Any]):
  142. self._instance = instance
  143. def on_instance_completed(self, *, info, trajectory):
  144. assert self._instance is not None # mypy
  145. instance_id = self._instance["instance_id"]
  146. patch_path = self._save_patch(instance_id, info)
  147. if patch_path:
  148. if not self._apply_patch_locally:
  149. return
  150. if not self._is_promising_patch(info):
  151. return
  152. assert self._instance # mypy
  153. if self._instance["repo_type"] != "local":
  154. return
  155. local_dir = Path(self._instance["repo"])
  156. self._apply_patch(patch_path, local_dir)
  157. @staticmethod
  158. def _print_patch_message(patch_output_file: Path):
  159. console = rich.console.Console()
  160. msg = [
  161. "SWE-agent has produced a patch that it believes will solve the issue you submitted!",
  162. "Use the code snippet below to inspect or apply it!",
  163. ]
  164. panel = rich.panel.Panel.fit(
  165. "\n".join(msg),
  166. title="🎉 Submission successful 🎉",
  167. )
  168. console.print(panel)
  169. content = [
  170. "```bash",
  171. "# The patch has been saved to your local filesystem at:",
  172. f"PATCH_FILE_PATH='{patch_output_file.resolve()}'",
  173. "# Inspect it:",
  174. 'cat "${PATCH_FILE_PATH}"',
  175. "# Apply it to a local repository:",
  176. "cd <your local repo root>",
  177. 'git apply "${PATCH_FILE_PATH}"',
  178. "```",
  179. ]
  180. console.print(rich.markdown.Markdown("\n".join(content)))
  181. def _save_patch(self, instance_id: str, info) -> Path | None:
  182. """Create patch files that can be applied with `git am`.
  183. Returns:
  184. The path to the patch file, if it was saved. Otherwise, returns None.
  185. """
  186. patch_output_dir = self._traj_dir / "patches"
  187. patch_output_dir.mkdir(exist_ok=True, parents=True)
  188. patch_output_file = patch_output_dir / f"{instance_id}.patch"
  189. if info.get("submission") is None:
  190. logger.info("No patch to save.")
  191. return None
  192. model_patch = info["submission"]
  193. patch_output_file.write_text(model_patch)
  194. if self._is_promising_patch(info):
  195. # Only print big congratulations if we actually believe
  196. # the patch will solve the issue
  197. self._print_patch_message(patch_output_file)
  198. return patch_output_file
  199. def _apply_patch(self, patch_file: Path, local_dir: Path) -> None:
  200. """Apply a patch to a local directory."""
  201. assert local_dir.is_dir()
  202. assert patch_file.exists()
  203. # The resolve() is important, because we're gonna run the cmd
  204. # somewhere else
  205. cmd = ["git", "apply", str(patch_file.resolve())]
  206. try:
  207. subprocess.run(cmd, cwd=local_dir, check=True)
  208. except subprocess.CalledProcessError as e:
  209. logger.error(f"Failed to apply patch {patch_file} to {local_dir}: {e}")
  210. return
  211. logger.info(f"Applied patch {patch_file} to {local_dir}")
  212. class OpenPRHook(MainHook):
  213. """This hook opens a PR if the issue is solved and the user has enabled the option."""
  214. def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
  215. self._env = env
  216. self._token: str = env._github_token
  217. self._data_path = args.environment.data_path
  218. self._open_pr = args.actions.open_pr
  219. self._skip_if_commits_reference_issue = args.actions.skip_if_commits_reference_issue
  220. def on_instance_completed(self, *, info, trajectory):
  221. if self._open_pr and self.should_open_pr(info):
  222. self._env.open_pr(trajectory=trajectory)
  223. def should_open_pr(self, info: dict[str, Any]) -> bool:
  224. """Does opening a PR make sense?"""
  225. if not info.get("submission"):
  226. logger.info("Not opening PR because no submission was made.")
  227. return False
  228. if info["exit_status"] != "submitted":
  229. logger.info("Not opening PR because exit status was %s and not submitted.", info["exit_status"])
  230. return False
  231. try:
  232. issue = get_gh_issue_data(self._data_path, token=self._token)
  233. except InvalidGithubURL:
  234. logger.info("Currently only GitHub is supported to open PRs to. Skipping PR creation.")
  235. return False
  236. if issue.state != "open":
  237. logger.info(f"Issue is not open (state={issue.state}. Skipping PR creation.")
  238. return False
  239. if issue.assignee:
  240. logger.info("Issue is already assigned. Skipping PR creation. Be nice :)")
  241. return False
  242. if issue.locked:
  243. logger.info("Issue is locked. Skipping PR creation.")
  244. return False
  245. org, repo, issue_number = parse_gh_issue_url(self._data_path)
  246. associated_commits = get_associated_commit_urls(org, repo, issue_number, token=self._token)
  247. if associated_commits:
  248. commit_url_strs = ", ".join(associated_commits)
  249. if self._skip_if_commits_reference_issue:
  250. logger.info(f"Issue already has associated commits (see {commit_url_strs}). Skipping PR creation.")
  251. return False
  252. else:
  253. logger.warning(
  254. "Proceeding with PR creation even though there are already commits "
  255. f"({commit_url_strs}) associated with the issue. Please only do this for your own repositories "
  256. "or after verifying that the existing commits do not fix the issue.",
  257. )
  258. return True
  259. class Main:
  260. def __init__(self, args: ScriptArguments):
  261. self.traj_dir = Path("trajectories") / Path(getuser()) / args.run_name
  262. self.traj_dir.mkdir(parents=True, exist_ok=True)
  263. timestamp = datetime.datetime.now().strftime("%y%m%d%H%M%S")
  264. log_path = self.traj_dir / f"run-{timestamp}.log"
  265. logger.info("Logging to %s", log_path)
  266. add_file_handler(log_path)
  267. if args.print_config:
  268. logger.info(f"📙 Arguments: {args.dumps_yaml()}")
  269. self.args = args
  270. self.agent = Agent("primary", args.agent)
  271. self.env = SWEEnv(args.environment)
  272. self._save_arguments()
  273. default_hooks = [
  274. SaveApplyPatchHook(),
  275. OpenPRHook(),
  276. ]
  277. self.hooks: list[MainHook] = []
  278. for hook in default_hooks:
  279. self.add_hook(hook)
  280. def add_hook(self, hook: MainHook):
  281. hook.on_init(args=self.args, agent=self.agent, env=self.env, traj_dir=self.traj_dir)
  282. self.hooks.append(hook)
  283. def run(self, index: int) -> None:
  284. # Reset environment
  285. instance_id = self.env.data[index]["instance_id"]
  286. for hook in self.hooks:
  287. hook.on_instance_start(index=index, instance=self.env.data[index])
  288. assert isinstance(instance_id, str) # mypy
  289. if self.should_skip(instance_id):
  290. for hook in self.hooks:
  291. hook.on_instance_skipped()
  292. raise _ContinueLoop
  293. logger.info("▶️ Beginning task " + str(index))
  294. observation, info = self.env.reset(index)
  295. if info is None:
  296. raise _ContinueLoop
  297. # Get info, patch information
  298. issue = getattr(self.env, "query", None)
  299. files = []
  300. assert self.env.record is not None # mypy
  301. if "patch" in self.env.record:
  302. files = "\n".join([f"- {x.path}" for x in PatchSet(self.env.record["patch"]).modified_files])
  303. # Get test files, F2P tests information
  304. test_files = []
  305. if "test_patch" in self.env.record:
  306. test_patch_obj = PatchSet(self.env.record["test_patch"])
  307. test_files = "\n".join([f"- {x.path}" for x in test_patch_obj.modified_files + test_patch_obj.added_files])
  308. tests = ""
  309. if "FAIL_endTO_PASS" in self.env.record:
  310. tests = "\n".join([f"- {x}" for x in self.env.record["FAIL_TO_PASS"]])
  311. setup_args = {"issue": issue, "files": files, "test_files": test_files, "tests": tests}
  312. challenge = self.env.challenge
  313. if challenge is not None:
  314. setup_args["flag_format"] = extract_flag_format(challenge["flag"])
  315. setup_args["name"] = challenge["name"]
  316. setup_args["description"] = challenge["description"]
  317. setup_args["category_friendly"] = challenge["category_friendly"]
  318. setup_args["points"] = challenge["points"]
  319. setup_args["files"] = challenge["files"] or "No files included in this challenge."
  320. setup_args["box"] = challenge.get("server_name")
  321. setup_args["port"] = challenge.get("port")
  322. setup_args["server_description"] = challenge.get("server_description")
  323. info, trajectory = self.agent.run(
  324. setup_args=setup_args,
  325. env=self.env,
  326. observation=observation,
  327. traj_dir=self.traj_dir,
  328. return_type="info_trajectory",
  329. )
  330. self._save_predictions(instance_id, info, challenge)
  331. for hook in self.hooks:
  332. hook.on_instance_completed(info=info, trajectory=trajectory)
  333. def main(self):
  334. for hook in self.hooks:
  335. hook.on_start()
  336. for index in range(len(self.env.data)):
  337. try:
  338. self.run(index)
  339. except _ContinueLoop:
  340. continue
  341. except KeyboardInterrupt:
  342. logger.info("Exiting InterCode environment...")
  343. self.env.close()
  344. break
  345. except SystemExit:
  346. logger.critical("❌ Exiting because SystemExit was called")
  347. self.env.close()
  348. logger.info("Container closed")
  349. raise
  350. except Exception as e:
  351. logger.warning(traceback.format_exc())
  352. if self.args.raise_exceptions:
  353. self.env.close()
  354. raise e
  355. if self.env.record:
  356. logger.warning(f"❌ Failed on {self.env.record['instance_id']}: {e}")
  357. else:
  358. logger.warning("❌ Failed on unknown instance")
  359. self.env.reset_container()
  360. continue
  361. self.env.close()
  362. for hook in self.hooks:
  363. hook.on_end()
  364. def _save_arguments(self) -> None:
  365. """Save the arguments to a yaml file to the run's trajectory directory."""
  366. log_path = self.traj_dir / "args.yaml"
  367. if log_path.exists():
  368. try:
  369. other_args = self.args.load_yaml(log_path)
  370. if self.args.dumps_yaml() != other_args.dumps_yaml(): # check yaml equality instead of object equality
  371. logger.warning("**************************************************")
  372. logger.warning("Found existing args.yaml with different arguments!")
  373. logger.warning("**************************************************")
  374. except Exception:
  375. logger.warning(f"Failed to load existing args.yaml: {traceback.format_exc()}")
  376. with log_path.open("w") as f:
  377. self.args.dump_yaml(f)
  378. def should_skip(self, instance_id: str) -> bool:
  379. """Check if we should skip this instance based on the instance filter and skip_existing flag."""
  380. # Skip instances that don't match the instance filter
  381. if re.match(self.args.instance_filter, instance_id) is None:
  382. logger.info(f"⏭️ Instance filter not matched. Skipping instance {instance_id}")
  383. return True
  384. # If flag is set to False, don't skip
  385. if not self.args.skip_existing:
  386. return False
  387. # Check if there's an existing trajectory for this instance
  388. log_path = self.traj_dir / (instance_id + ".traj")
  389. if not log_path.exists():
  390. return False
  391. content = log_path.read_text()
  392. if not content.strip():
  393. logger.warning("Found empty trajectory: %s. Removing.", log_path)
  394. log_path.unlink()
  395. return False
  396. data = json.loads(content)
  397. # If the trajectory has no exit status, it's incomplete and we will redo it
  398. exit_status = data["info"].get("exit_status", None)
  399. if exit_status == "early_exit" or exit_status is None:
  400. logger.warning(f"Found existing trajectory with no exit status: {log_path}. Removing.")
  401. log_path.unlink()
  402. return False
  403. logger.info(f"⏭️ Skipping existing trajectory: {log_path}")
  404. return True
  405. def _save_predictions(self, instance_id: str, info, challenge: dict[str, str] | None):
  406. output_file = self.traj_dir / "all_preds.jsonl"
  407. model_patch = info["submission"] if "submission" in info else None
  408. datum = {
  409. KEY_MODEL: Path(self.traj_dir).name,
  410. KEY_INSTANCE_ID: instance_id,
  411. KEY_PREDICTION: model_patch,
  412. }
  413. if challenge is not None:
  414. challenge_datum = {
  415. "challenge_name": challenge["name"],
  416. "challenge_category": challenge["category"],
  417. "challenge_path": challenge["file_path"],
  418. }
  419. datum.update(challenge_datum)
  420. with open(output_file, "a+") as fp:
  421. print(json.dumps(datum), file=fp, flush=True)
  422. logger.info(f"Saved predictions to {output_file}")
  423. def get_args(args=None) -> ScriptArguments:
  424. """Parse command line arguments and return a ScriptArguments object.
  425. Args:
  426. args: Optional list of arguments to parse. If not provided, uses sys.argv.
  427. """
  428. defaults = ScriptArguments(
  429. suffix="",
  430. environment=EnvironmentArguments(
  431. image_name="sweagent/swe-agent:latest",
  432. data_path="princeton-nlp/SWE-bench_Lite",
  433. split="dev",
  434. verbose=True,
  435. install_environment=True,
  436. cache_task_images=False,
  437. ),
  438. skip_existing=True,
  439. agent=AgentArguments(
  440. model=ModelArguments(
  441. model_name="gpt4",
  442. total_cost_limit=0.0,
  443. per_instance_cost_limit=3.0,
  444. temperature=0.0,
  445. top_p=0.95,
  446. ),
  447. config_file=CONFIG_DIR / "default.yaml",
  448. ),
  449. actions=ActionsArguments(open_pr=False, skip_if_commits_reference_issue=True),
  450. ctf=False,
  451. )
  452. # Nicer yaml dumping of multiline strings
  453. def multiline_representer(dumper, data):
  454. """configures yaml for dumping multiline strings
  455. Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data
  456. """
  457. if data.count("\n") > 0: # check for multiline string
  458. return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
  459. return dumper.represent_scalar("tag:yaml.org,2002:str", data)
  460. yaml.add_representer(str, multiline_representer)
  461. return parse(
  462. ScriptArguments,
  463. default=defaults,
  464. add_config_path_arg=False,
  465. args=args,
  466. formatter_class=RichHelpFormatter,
  467. description=Markdown(__doc__),
  468. )
  469. def main(args: ScriptArguments):
  470. Main(args).main()
  471. if __name__ == "__main__":
  472. main(get_args())