123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526 |
- from __future__ import annotations
- import logging
- from sweagent import CONFIG_DIR
- from sweagent.utils.log import add_file_handler, get_logger
- try:
- import rich
- except ModuleNotFoundError as e:
- msg = (
- "You probably either forgot to install the dependencies "
- "or forgot to activate your conda or virtual environment."
- )
- raise RuntimeError(msg) from e
- import json
- import re
- import subprocess
- import traceback
- from typing import Any
- import rich.console
- import rich.markdown
- import rich.panel
- try:
- from rich_argparse import RichHelpFormatter
- except ImportError:
- msg = "Please install the rich_argparse package with `pip install rich_argparse`."
- raise ImportError(msg)
- import datetime
- from dataclasses import dataclass
- from getpass import getuser
- from pathlib import Path
- import yaml
- from rich.markdown import Markdown
- from simple_parsing import parse
- from simple_parsing.helpers.flatten import FlattenedAccess
- from simple_parsing.helpers.serialization.serializable import FrozenSerializable
- from swebench import KEY_INSTANCE_ID, KEY_MODEL, KEY_PREDICTION
- from unidiff import PatchSet
- from sweagent.agent.agents import Agent, AgentArguments
- from sweagent.agent.models import ModelArguments
- from sweagent.environment.swe_env import EnvironmentArguments, SWEEnv
- from sweagent.environment.utils import (
- InvalidGithubURL,
- get_associated_commit_urls,
- get_data_path_name,
- get_gh_issue_data,
- parse_gh_issue_url,
- )
- __doc__: str = """ Run inference. Usage examples:
- ```bash
- # Run over a github issue:
- python run.py --model_name "gpt4" --data_path "https://github.com/pvlib/pvlib-python/issues/1603" --config_file "config/default_from_url.yaml"
- # Apply a patch in a local repository to an issue specified as Markdown file and run a custom installer script in the container
- python run.py --model_name "gpt4" --data_path "/path/to/my_issue.md" --repo_path "/path/to/my/local/repo" --environment_setup "/path/to/setup.sh" --config_file "config/default_from_url.yaml" --apply_patch_locally
- ```
- **For more information**: https://princeton-nlp.github.io/SWE-agent/usage/cl_tutorial/
- """
- logger = get_logger("swe-agent-run")
- logging.getLogger("simple_parsing").setLevel(logging.WARNING)
- @dataclass(frozen=True)
- class ActionsArguments(FlattenedAccess, FrozenSerializable):
- """Run real-life actions (opening PRs, etc.) if we can solve the issue."""
- # Open a PR with the patch if we can solve the issue
- open_pr: bool = False
- # When working with local repository: Apply patch
- apply_patch_locally: bool = False
- # Option to be used with open_pr: Skip action if there are already commits claiming
- # to fix the issue. Please only set this to False if you are sure the commits are
- # not fixes or if this is your own repository!
- skip_if_commits_reference_issue: bool = True
- # OBSOLETE. Do not use, will raise error. Please specify --repo_path instead.
- push_gh_repo_url: str = ""
- def __post_init__(self):
- if self.push_gh_repo_url:
- msg = "push_gh_repo_url is obsolete. Use repo_path instead"
- raise ValueError(msg)
- @dataclass(frozen=True)
- class ScriptArguments(FlattenedAccess, FrozenSerializable):
- """Configure the control flow of the run.py script"""
- environment: EnvironmentArguments
- agent: AgentArguments
- actions: ActionsArguments
- # Only run instances that completely match this regex
- instance_filter: str = ".*"
- # Skip instances with existing trajectories
- skip_existing: bool = True
- # Suffix for the run name (used for example in trajectory directory naming)
- suffix: str = ""
- # Raise unhandled exceptions during the run (useful for debugging)
- raise_exceptions: bool = False
- # Dump the entire config to the log
- print_config: bool = True
- @property
- def run_name(self) -> str:
- """Generate a unique name for this run based on the arguments."""
- model_name = self.agent.model.model_name.replace(":", "-")
- data_stem = get_data_path_name(self.environment.data_path)
- assert self.agent.config_file is not None # mypy
- config_stem = Path(self.agent.config_file).stem
- temp = self.agent.model.temperature
- top_p = self.agent.model.top_p
- per_instance_cost_limit = self.agent.model.per_instance_cost_limit
- install_env = self.environment.install_environment
- return (
- f"{model_name}__{data_stem}__{config_stem}__t-{temp:.2f}__p-{top_p:.2f}"
- + f"__c-{per_instance_cost_limit:.2f}__install-{int(install_env)}"
- + (f"__{self.suffix}" if self.suffix else "")
- )
- class _ContinueLoop(Exception):
- """Used for internal control flow"""
- class MainHook:
- """Hook structure for the web server or other addons to interface with"""
- @staticmethod
- def _is_promising_patch(info: dict[str, Any]) -> bool:
- """Do we actually believe that the patch will solve the issue?
- Or are we just submitting the last patch we generated before hitting an error?
- """
- # The exit status can also be `submitted (exit_cost)` etc.
- return info["exit_status"] == "submitted" and info.get("submission") is not None
- def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
- """Called when hook is initialized"""
- def on_start(self):
- """Called at the beginning of `Main.main`"""
- def on_end(self):
- """Called at the end of `Main.main`"""
- def on_instance_start(self, *, index: int, instance: dict[str, Any]):
- """Called at the beginning of each instance loop in `Main.run`"""
- def on_instance_skipped(
- self,
- ):
- """Called when an instance is skipped in `Main.run`"""
- def on_instance_completed(self, *, info, trajectory):
- """Called when an instance is completed in `Main.run`"""
- class SaveApplyPatchHook(MainHook):
- """This hook saves patches to a separate directory and optionally applies them to a local repository."""
- def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
- self._traj_dir = traj_dir
- self._apply_patch_locally = args.actions.apply_patch_locally
- self._instance = None
- def on_instance_start(self, *, index: int, instance: dict[str, Any]):
- self._instance = instance
- def on_instance_completed(self, *, info, trajectory):
- assert self._instance is not None # mypy
- instance_id = self._instance["instance_id"]
- patch_path = self._save_patch(instance_id, info)
- if patch_path:
- if not self._apply_patch_locally:
- return
- if not self._is_promising_patch(info):
- return
- assert self._instance # mypy
- if self._instance["repo_type"] != "local":
- return
- local_dir = Path(self._instance["repo"])
- self._apply_patch(patch_path, local_dir)
- @staticmethod
- def _print_patch_message(patch_output_file: Path):
- console = rich.console.Console()
- msg = [
- "SWE-agent has produced a patch that it believes will solve the issue you submitted!",
- "Use the code snippet below to inspect or apply it!",
- ]
- panel = rich.panel.Panel.fit(
- "\n".join(msg),
- title="🎉 Submission successful 🎉",
- )
- console.print(panel)
- content = [
- "```bash",
- "# The patch has been saved to your local filesystem at:",
- f"PATCH_FILE_PATH='{patch_output_file.resolve()}'",
- "# Inspect it:",
- 'cat "${PATCH_FILE_PATH}"',
- "# Apply it to a local repository:",
- "cd <your local repo root>",
- 'git apply "${PATCH_FILE_PATH}"',
- "```",
- ]
- console.print(rich.markdown.Markdown("\n".join(content)))
- def _save_patch(self, instance_id: str, info) -> Path | None:
- """Create patch files that can be applied with `git am`.
- Returns:
- The path to the patch file, if it was saved. Otherwise, returns None.
- """
- patch_output_dir = self._traj_dir / "patches"
- patch_output_dir.mkdir(exist_ok=True, parents=True)
- patch_output_file = patch_output_dir / f"{instance_id}.patch"
- if not info.get("submission"):
- logger.info("No patch to save.")
- return None
- model_patch = info["submission"]
- patch_output_file.write_text(model_patch)
- if self._is_promising_patch(info):
- # Only print big congratulations if we actually believe
- # the patch will solve the issue
- self._print_patch_message(patch_output_file)
- return patch_output_file
- def _apply_patch(self, patch_file: Path, local_dir: Path) -> None:
- """Apply a patch to a local directory."""
- assert local_dir.is_dir()
- assert patch_file.exists()
- # The resolve() is important, because we're gonna run the cmd
- # somewhere else
- cmd = ["git", "apply", str(patch_file.resolve())]
- try:
- subprocess.run(cmd, cwd=local_dir, check=True)
- except subprocess.CalledProcessError as e:
- logger.error(f"Failed to apply patch {patch_file} to {local_dir}: {e}")
- return
- logger.info(f"Applied patch {patch_file} to {local_dir}")
- class OpenPRHook(MainHook):
- """This hook opens a PR if the issue is solved and the user has enabled the option."""
- def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
- self._env = env
- self._token: str = env._github_token
- self._data_path = args.environment.data_path
- self._open_pr = args.actions.open_pr
- self._skip_if_commits_reference_issue = args.actions.skip_if_commits_reference_issue
- def on_instance_completed(self, *, info, trajectory):
- if self._open_pr and self.should_open_pr(info):
- self._env.open_pr(trajectory=trajectory)
- def should_open_pr(self, info: dict[str, Any]) -> bool:
- """Does opening a PR make sense?"""
- if not info.get("submission"):
- logger.info("Not opening PR because no submission was made.")
- return False
- if info["exit_status"] != "submitted":
- logger.info("Not opening PR because exit status was %s and not submitted.", info["exit_status"])
- return False
- try:
- issue = get_gh_issue_data(self._data_path, token=self._token)
- except InvalidGithubURL:
- logger.info("Currently only GitHub is supported to open PRs to. Skipping PR creation.")
- return False
- if issue.state != "open":
- logger.info(f"Issue is not open (state={issue.state}. Skipping PR creation.")
- return False
- if issue.assignee:
- logger.info("Issue is already assigned. Skipping PR creation. Be nice :)")
- return False
- if issue.locked:
- logger.info("Issue is locked. Skipping PR creation.")
- return False
- org, repo, issue_number = parse_gh_issue_url(self._data_path)
- associated_commits = get_associated_commit_urls(org, repo, issue_number, token=self._token)
- if associated_commits:
- commit_url_strs = ", ".join(associated_commits)
- if self._skip_if_commits_reference_issue:
- logger.info(f"Issue already has associated commits (see {commit_url_strs}). Skipping PR creation.")
- return False
- else:
- logger.warning(
- "Proceeding with PR creation even though there are already commits "
- f"({commit_url_strs}) associated with the issue. Please only do this for your own repositories "
- "or after verifying that the existing commits do not fix the issue.",
- )
- return True
- class Main:
- def __init__(self, args: ScriptArguments):
- self.traj_dir = Path("trajectories") / Path(getuser()) / args.run_name
- self.traj_dir.mkdir(parents=True, exist_ok=True)
- timestamp = datetime.datetime.now().strftime("%y%m%d%H%M%S")
- log_path = self.traj_dir / f"run-{timestamp}.log"
- logger.info("Logging to %s", log_path)
- add_file_handler(log_path)
- if args.print_config:
- logger.info(f"📙 Arguments: {args.dumps_yaml()}")
- self.args = args
- self.agent = Agent("primary", args.agent)
- self.env = SWEEnv(args.environment)
- self._save_arguments()
- default_hooks = [
- SaveApplyPatchHook(),
- OpenPRHook(),
- ]
- self.hooks: list[MainHook] = []
- for hook in default_hooks:
- self.add_hook(hook)
- def add_hook(self, hook: MainHook):
- hook.on_init(args=self.args, agent=self.agent, env=self.env, traj_dir=self.traj_dir)
- self.hooks.append(hook)
- def run(self, index):
- # Reset environment
- instance_id = self.env.data[index]["instance_id"]
- for hook in self.hooks:
- hook.on_instance_start(index=index, instance=self.env.data[index])
- assert isinstance(instance_id, str) # mypy
- if self.should_skip(instance_id):
- for hook in self.hooks:
- hook.on_instance_skipped()
- raise _ContinueLoop
- logger.info("▶️ Beginning task " + str(index))
- observation, info = self.env.reset(index)
- if info is None:
- raise _ContinueLoop
- # Get info, patch information
- issue = getattr(self.env, "query", None)
- files = []
- assert self.env.record is not None # mypy
- if "patch" in self.env.record:
- files = "\n".join([f"- {x.path}" for x in PatchSet(self.env.record["patch"]).modified_files])
- # Get test files, F2P tests information
- test_files = []
- if "test_patch" in self.env.record:
- test_patch_obj = PatchSet(self.env.record["test_patch"])
- test_files = "\n".join([f"- {x.path}" for x in test_patch_obj.modified_files + test_patch_obj.added_files])
- tests = ""
- if "FAIL_endTO_PASS" in self.env.record:
- tests = "\n".join([f"- {x}" for x in self.env.record["FAIL_TO_PASS"]])
- setup_args = {"issue": issue, "files": files, "test_files": test_files, "tests": tests}
- info, trajectory = self.agent.run(
- setup_args=setup_args,
- env=self.env,
- observation=observation,
- traj_dir=self.traj_dir,
- return_type="info_trajectory",
- )
- self._save_predictions(instance_id, info)
- for hook in self.hooks:
- hook.on_instance_completed(info=info, trajectory=trajectory)
- def main(self):
- for hook in self.hooks:
- hook.on_start()
- for index in range(len(self.env.data)):
- try:
- self.run(index)
- except _ContinueLoop:
- continue
- except KeyboardInterrupt:
- logger.info("Exiting InterCode environment...")
- self.env.close()
- break
- except SystemExit:
- logger.critical("❌ Exiting because SystemExit was called")
- self.env.close()
- logger.info("Container closed")
- raise
- except Exception as e:
- traceback.print_exc()
- if self.args.raise_exceptions:
- self.env.close()
- raise e
- if self.env.record:
- logger.warning(f"❌ Failed on {self.env.record['instance_id']}: {e}")
- else:
- logger.warning("❌ Failed on unknown instance")
- self.env.reset_container()
- continue
- for hook in self.hooks:
- hook.on_end()
- def _save_arguments(self) -> None:
- """Save the arguments to a yaml file to the run's trajectory directory."""
- log_path = self.traj_dir / "args.yaml"
- if log_path.exists():
- try:
- other_args = self.args.load_yaml(log_path)
- if self.args.dumps_yaml() != other_args.dumps_yaml(): # check yaml equality instead of object equality
- logger.warning("**************************************************")
- logger.warning("Found existing args.yaml with different arguments!")
- logger.warning("**************************************************")
- except Exception as e:
- logger.warning(f"Failed to load existing args.yaml: {e}")
- with log_path.open("w") as f:
- self.args.dump_yaml(f)
- def should_skip(self, instance_id: str) -> bool:
- """Check if we should skip this instance based on the instance filter and skip_existing flag."""
- # Skip instances that don't match the instance filter
- if re.match(self.args.instance_filter, instance_id) is None:
- logger.info(f"⏭️ Instance filter not matched. Skipping instance {instance_id}")
- return True
- # If flag is set to False, don't skip
- if not self.args.skip_existing:
- return False
- # Check if there's an existing trajectory for this instance
- log_path = self.traj_dir / (instance_id + ".traj")
- if not log_path.exists():
- return False
- content = log_path.read_text()
- if not content.strip():
- logger.warning("Found empty trajectory: %s. Removing.", log_path)
- log_path.unlink()
- return False
- data = json.loads(content)
- # If the trajectory has no exit status, it's incomplete and we will redo it
- exit_status = data["info"].get("exit_status", None)
- if exit_status == "early_exit" or exit_status is None:
- logger.warning(f"Found existing trajectory with no exit status: {log_path}. Removing.")
- log_path.unlink()
- return False
- logger.info(f"⏭️ Skipping existing trajectory: {log_path}")
- return True
- def _save_predictions(self, instance_id: str, info):
- output_file = self.traj_dir / "all_preds.jsonl"
- model_patch = info["submission"] if "submission" in info else None
- datum = {
- KEY_MODEL: Path(self.traj_dir).name,
- KEY_INSTANCE_ID: instance_id,
- KEY_PREDICTION: model_patch,
- }
- with open(output_file, "a+") as fp:
- print(json.dumps(datum), file=fp, flush=True)
- logger.info(f"Saved predictions to {output_file}")
- def get_args(args=None) -> ScriptArguments:
- """Parse command line arguments and return a ScriptArguments object.
- Args:
- args: Optional list of arguments to parse. If not provided, uses sys.argv.
- """
- defaults = ScriptArguments(
- suffix="",
- environment=EnvironmentArguments(
- image_name="sweagent/swe-agent:latest",
- data_path="princeton-nlp/SWE-bench_Lite",
- split="dev",
- verbose=True,
- install_environment=True,
- cache_task_images=False,
- ),
- skip_existing=True,
- agent=AgentArguments(
- model=ModelArguments(
- model_name="gpt4",
- total_cost_limit=0.0,
- per_instance_cost_limit=3.0,
- temperature=0.0,
- top_p=0.95,
- ),
- config_file=CONFIG_DIR / "default.yaml",
- ),
- actions=ActionsArguments(open_pr=False, skip_if_commits_reference_issue=True),
- )
- # Nicer yaml dumping of multiline strings
- def multiline_representer(dumper, data):
- """configures yaml for dumping multiline strings
- Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data
- """
- if data.count("\n") > 0: # check for multiline string
- return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
- return dumper.represent_scalar("tag:yaml.org,2002:str", data)
- yaml.add_representer(str, multiline_representer)
- return parse(
- ScriptArguments,
- default=defaults,
- add_config_path_arg=False,
- args=args,
- formatter_class=RichHelpFormatter,
- description=Markdown(__doc__),
- )
- def main(args: ScriptArguments):
- Main(args).main()
- if __name__ == "__main__":
- main(get_args())
|