run.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import json
  2. import logging
  3. import os
  4. import re
  5. import traceback
  6. from typing import Any, Dict, Optional
  7. import yaml
  8. from dataclasses import dataclass
  9. from getpass import getuser
  10. from pathlib import Path
  11. from rich.logging import RichHandler
  12. from simple_parsing import parse
  13. from simple_parsing.helpers.serialization.serializable import FrozenSerializable
  14. from simple_parsing.helpers.flatten import FlattenedAccess
  15. from sweagent import (
  16. Agent,
  17. AgentArguments,
  18. EnvironmentArguments,
  19. ModelArguments,
  20. SWEEnv,
  21. get_data_path_name,
  22. )
  23. from swebench import KEY_INSTANCE_ID, KEY_MODEL, KEY_PREDICTION
  24. from unidiff import PatchSet
  25. from sweagent.environment.utils import InvalidGithubURL, get_associated_commit_urls, get_gh_issue_data, parse_gh_issue_url
  26. handler = RichHandler(show_time=False, show_path=False)
  27. handler.setLevel(logging.DEBUG)
  28. logger = logging.getLogger("run_dev")
  29. logger.setLevel(logging.DEBUG)
  30. logger.addHandler(handler)
  31. logger.propagate = False
  32. logging.getLogger("simple_parsing").setLevel(logging.WARNING)
  33. @dataclass(frozen=True)
  34. class ActionsArguments(FlattenedAccess, FrozenSerializable):
  35. """Run real-life actions (opening PRs, etc.) if we can solve the issue."""
  36. open_pr: bool = False # Open a PR with the patch if we can solve the issue
  37. # Skip action if there are already commits claiming to fix the issue. Please only
  38. # set this to False if you are sure the commits are not fixes or if this is your
  39. # own repository!
  40. skip_if_commits_reference_issue: bool = True
  41. # For PRs: If you want to push the branch to a fork (e.g., because you lack
  42. # permissions to push to the main repo), set this to the URL of the fork.
  43. push_gh_repo_url: str = ""
  44. def __post_init__(self):
  45. if not self.skip_if_commits_reference_issue and self.push_gh_repo_url:
  46. raise ValueError(
  47. "Overriding `skip_if_commits_reference_issue` when you are "
  48. "pushing to a fork is not supported. You should manually "
  49. "apply the patch to the forked repository."
  50. )
  51. @dataclass(frozen=True)
  52. class ScriptArguments(FlattenedAccess, FrozenSerializable):
  53. """Configure the control flow of the run.py script"""
  54. environment: EnvironmentArguments
  55. agent: AgentArguments
  56. actions: ActionsArguments
  57. instance_filter: str = ".*" # Only run instances that completely match this regex
  58. skip_existing: bool = True # Skip instances with existing trajectories
  59. suffix: str = ""
  60. # Raise unhandled exceptions during the run (useful for debugging)
  61. raise_exceptions: bool = False
  62. @property
  63. def run_name(self):
  64. """Generate a unique name for this run based on the arguments."""
  65. model_name = self.agent.model.model_name.replace(":", "-")
  66. data_stem = get_data_path_name(self.environment.data_path)
  67. config_stem = Path(self.agent.config_file).stem
  68. temp = self.agent.model.temperature
  69. top_p = self.agent.model.top_p
  70. per_instance_cost_limit = self.agent.model.per_instance_cost_limit
  71. install_env = self.environment.install_environment
  72. return (
  73. f"{model_name}__{data_stem}__{config_stem}__t-{temp:.2f}__p-{top_p:.2f}"
  74. + f"__c-{per_instance_cost_limit:.2f}__install-{int(install_env)}"
  75. + (f"__{self.suffix}" if self.suffix else "")
  76. )
  77. def main(args: ScriptArguments):
  78. logger.info(f"📙 Arguments: {args.dumps_yaml()}")
  79. agent = Agent("primary", args.agent)
  80. env = SWEEnv(args.environment)
  81. traj_dir = Path("trajectories") / Path(getuser()) / args.run_name
  82. traj_dir.mkdir(parents=True, exist_ok=True)
  83. save_arguments(traj_dir, args)
  84. for index in range(len(env.data)):
  85. try:
  86. # Reset environment
  87. instance_id = env.data[index]["instance_id"]
  88. assert isinstance(instance_id, str) # mypy
  89. if should_skip(args, traj_dir, instance_id):
  90. continue
  91. logger.info("▶️ Beginning task " + str(index))
  92. observation, info = env.reset(index)
  93. if info is None:
  94. continue
  95. # Get info, patch information
  96. issue = getattr(env, "query", None)
  97. files = []
  98. if "patch" in env.record:
  99. files = "\n".join(
  100. [f"- {x.path}" for x in PatchSet(env.record["patch"]).modified_files]
  101. )
  102. # Get test files, F2P tests information
  103. test_files = []
  104. if "test_patch" in env.record:
  105. test_patch_obj = PatchSet(env.record["test_patch"])
  106. test_files = "\n".join(
  107. [f"- {x.path}" for x in test_patch_obj.modified_files + test_patch_obj.added_files]
  108. )
  109. tests = ""
  110. if "FAIL_TO_PASS" in env.record:
  111. tests = "\n".join([f"- {x}" for x in env.record["FAIL_TO_PASS"]])
  112. setup_args = {
  113. "issue": issue,
  114. "files": files,
  115. "test_files": test_files,
  116. "tests": tests
  117. }
  118. info, trajectory = agent.run(
  119. setup_args=setup_args,
  120. env=env,
  121. observation=observation,
  122. traj_dir=traj_dir,
  123. return_type="info_trajectory",
  124. )
  125. save_predictions(traj_dir, instance_id, info)
  126. save_patch(traj_dir, instance_id, info)
  127. if args.actions.open_pr and should_open_pr(args, info, token=env._github_token):
  128. env.open_pr(trajectory=trajectory, push_gh_repo_url=args.actions.push_gh_repo_url)
  129. except KeyboardInterrupt:
  130. logger.info("Exiting InterCode environment...")
  131. env.close()
  132. break
  133. except Exception as e:
  134. traceback.print_exc()
  135. logger.warning(f"❌ Failed on {env.record['instance_id']}: {e}")
  136. if args.raise_exceptions:
  137. raise e
  138. env.reset_container()
  139. continue
  140. def should_open_pr(args: ScriptArguments, info: Dict[str, Any], *, token: str="") -> bool:
  141. """Does opening a PR make sense?"""
  142. if not info.get("submission"):
  143. logger.info("Not opening PR because submission was made.")
  144. return False
  145. if info["exit_status"] != "submitted":
  146. logger.info("Not opening PR because exit status was %s and not submitted.", info["exit_status"])
  147. return False
  148. try:
  149. issue = get_gh_issue_data(args.environment.data_path, token=token)
  150. except InvalidGithubURL:
  151. logger.info("Currently only GitHub is supported to open PRs to. Skipping PR creation.")
  152. return False
  153. if issue.state != "open":
  154. logger.info(f"Issue is not open (state={issue.state}. Skipping PR creation.")
  155. return False
  156. if issue.assignee:
  157. logger.info("Issue is already assigned. Skipping PR creation. Be nice :)")
  158. return False
  159. if issue.locked:
  160. logger.info("Issue is locked. Skipping PR creation.")
  161. return False
  162. org, repo, issue_number = parse_gh_issue_url(args.environment.data_path)
  163. associated_commits = get_associated_commit_urls(org, repo, issue_number, token=token)
  164. if associated_commits:
  165. commit_url_strs = ", ".join(associated_commits)
  166. if args.actions.skip_if_commits_reference_issue:
  167. logger.info(f"Issue already has associated commits (see {commit_url_strs}). Skipping PR creation.")
  168. return False
  169. else:
  170. logger.warning(
  171. "Proceeding with PR creation even though there are already commits "
  172. f"({commit_url_strs}) associated with the issue. Please only do this for your own repositories "
  173. "or after verifying that the existing commits do not fix the issue."
  174. )
  175. return True
  176. def save_arguments(traj_dir: Path, args: ScriptArguments) -> None:
  177. """Save the arguments to a yaml file to the run's trajectory directory."""
  178. log_path = traj_dir / "args.yaml"
  179. if log_path.exists():
  180. try:
  181. other_args = args.load_yaml(log_path)
  182. if (args.dumps_yaml() != other_args.dumps_yaml()): # check yaml equality instead of object equality
  183. logger.warning("**************************************************")
  184. logger.warning("Found existing args.yaml with different arguments!")
  185. logger.warning("**************************************************")
  186. except Exception as e:
  187. logger.warning(f"Failed to load existing args.yaml: {e}")
  188. with log_path.open("w") as f:
  189. args.dump_yaml(f)
  190. def should_skip(args: ScriptArguments, traj_dir: Path, instance_id: str) -> bool:
  191. """Check if we should skip this instance based on the instance filter and skip_existing flag."""
  192. # Skip instances that don't match the instance filter
  193. if re.match(args.instance_filter, instance_id) is None:
  194. logger.info(f"Instance filter not matched. Skipping instance {instance_id}")
  195. return True
  196. # If flag is set to False, don't skip
  197. if not args.skip_existing:
  198. return False
  199. # Check if there's an existing trajectory for this instance
  200. log_path = traj_dir / (instance_id + ".traj")
  201. if log_path.exists():
  202. with log_path.open("r") as f:
  203. data = json.load(f)
  204. # If the trajectory has no exit status, it's incomplete and we will redo it
  205. exit_status = data["info"].get("exit_status", None)
  206. if exit_status == "early_exit" or exit_status is None:
  207. logger.info(f"Found existing trajectory with no exit status: {log_path}")
  208. logger.info("Removing incomplete trajectory...")
  209. os.remove(log_path)
  210. else:
  211. logger.info(f"⏭️ Skipping existing trajectory: {log_path}")
  212. return True
  213. return False
  214. def save_predictions(traj_dir: Path, instance_id: str, info):
  215. output_file = traj_dir / "all_preds.jsonl"
  216. model_patch = info["submission"] if "submission" in info else None
  217. datum = {
  218. KEY_MODEL: Path(traj_dir).name,
  219. KEY_INSTANCE_ID: instance_id,
  220. KEY_PREDICTION: model_patch,
  221. }
  222. with open(output_file, "a+") as fp:
  223. print(json.dumps(datum), file=fp, flush=True)
  224. logger.info(f"Saved predictions to {output_file}")
  225. def save_patch(traj_dir: Path, instance_id: str, info) -> Optional[Path]:
  226. """Create patch files that can be applied with `git am`.
  227. Returns:
  228. The path to the patch file, if it was saved. Otherwise, returns None.
  229. """
  230. patch_output_dir = traj_dir / "patches"
  231. patch_output_dir.mkdir(exist_ok=True, parents=True)
  232. patch_output_file = patch_output_dir / f"{instance_id}.patch"
  233. if not "submission" in info:
  234. logger.info("No patch to save.")
  235. return
  236. model_patch = info["submission"]
  237. patch_output_file.write_text(model_patch)
  238. logger.info(f"Saved patch to {patch_output_file}")
  239. return patch_output_file
  240. def get_args(args=None) -> ScriptArguments:
  241. """Parse command line arguments and return a ScriptArguments object.
  242. Args:
  243. args: Optional list of arguments to parse. If not provided, uses sys.argv.
  244. """
  245. defaults = ScriptArguments(
  246. suffix="",
  247. environment=EnvironmentArguments(
  248. image_name="sweagent/swe-agent:latest",
  249. data_path="princeton-nlp/SWE-bench_Lite",
  250. split="dev",
  251. verbose=True,
  252. install_environment=True,
  253. ),
  254. skip_existing=True,
  255. agent=AgentArguments(
  256. model=ModelArguments(
  257. model_name="gpt4",
  258. total_cost_limit=0.0,
  259. per_instance_cost_limit=3.0,
  260. temperature=0.0,
  261. top_p=0.95,
  262. ),
  263. config_file="config/default.yaml",
  264. ),
  265. actions=ActionsArguments(open_pr=False, skip_if_commits_reference_issue=True),
  266. )
  267. # Nicer yaml dumping of multiline strings
  268. def multiline_representer(dumper, data):
  269. """configures yaml for dumping multiline strings
  270. Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data
  271. """
  272. if data.count("\n") > 0: # check for multiline string
  273. return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
  274. return dumper.represent_scalar("tag:yaml.org,2002:str", data)
  275. yaml.add_representer(str, multiline_representer)
  276. return parse(ScriptArguments, default=defaults, add_config_path_arg=False, args=args)
  277. if __name__ == "__main__":
  278. args = get_args()
  279. main(args)