run.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import json
  2. import logging
  3. import os
  4. import re
  5. import traceback
  6. import yaml
  7. from dataclasses import dataclass
  8. from getpass import getuser
  9. from pathlib import Path
  10. from rich.logging import RichHandler
  11. from simple_parsing import parse
  12. from simple_parsing.helpers import FrozenSerializable, FlattenedAccess
  13. from sweagent import (
  14. Agent,
  15. AgentArguments,
  16. EnvironmentArguments,
  17. ModelArguments,
  18. SWEEnv,
  19. get_data_path_name,
  20. )
  21. from swebench import KEY_INSTANCE_ID, KEY_MODEL, KEY_PREDICTION
  22. from unidiff import PatchSet
  23. handler = RichHandler(show_time=False, show_path=False)
  24. handler.setLevel(logging.DEBUG)
  25. logger = logging.getLogger("run_dev")
  26. logger.setLevel(logging.DEBUG)
  27. logger.addHandler(handler)
  28. logger.propagate = False
  29. logging.getLogger("simple_parsing").setLevel(logging.WARNING)
  30. @dataclass(frozen=True)
  31. class ScriptArguments(FlattenedAccess, FrozenSerializable):
  32. environment: EnvironmentArguments
  33. agent: AgentArguments
  34. instance_filter: str = ".*" # Only run instances that completely match this regex
  35. skip_existing: bool = True # Skip instances with existing trajectories
  36. suffix: str = ""
  37. @property
  38. def run_name(self):
  39. """Generate a unique name for this run based on the arguments."""
  40. model_name = args.agent.model.model_name.replace(":", "-")
  41. data_stem = get_data_path_name(args.environment.data_path)
  42. config_stem = Path(args.agent.config_file).stem
  43. temp = args.agent.model.temperature
  44. top_p = args.agent.model.top_p
  45. per_instance_cost_limit = args.agent.model.per_instance_cost_limit
  46. install_env = args.environment.install_environment
  47. return (
  48. f"{model_name}__{data_stem}__{config_stem}__t-{temp:.2f}__p-{top_p:.2f}"
  49. + f"__c-{per_instance_cost_limit:.2f}__install-{int(install_env)}"
  50. + (f"__{self.suffix}" if self.suffix else "")
  51. )
  52. def main(args: ScriptArguments):
  53. logger.info(f"📙 Arguments: {args.dumps_yaml()}")
  54. agent = Agent("primary", args.agent)
  55. env = SWEEnv(args.environment)
  56. traj_dir = Path("trajectories") / Path(getuser()) / args.run_name
  57. os.makedirs(traj_dir, exist_ok=True)
  58. save_arguments(traj_dir, args)
  59. for index in range(len(env.data)):
  60. try:
  61. # Reset environment
  62. instance_id = env.data[index]["instance_id"]
  63. if should_skip(args, traj_dir, instance_id):
  64. continue
  65. logger.info("▶️ Beginning task " + str(index))
  66. observation, info = env.reset(index)
  67. if info is None:
  68. continue
  69. # Get info, patch information
  70. issue = getattr(env, "query", None)
  71. files = []
  72. if "patch" in env.record:
  73. files = "\n".join(
  74. [f"- {x.path}" for x in PatchSet(env.record["patch"]).modified_files]
  75. )
  76. # Get test files, F2P tests information
  77. test_files = []
  78. if "test_patch" in env.record:
  79. test_patch_obj = PatchSet(env.record["test_patch"])
  80. test_files = "\n".join(
  81. [f"- {x.path}" for x in test_patch_obj.modified_files + test_patch_obj.added_files]
  82. )
  83. tests = ""
  84. if "FAIL_TO_PASS" in env.record:
  85. tests = "\n".join([f"- {x}" for x in env.record["FAIL_TO_PASS"]])
  86. setup_args = {
  87. "issue": issue,
  88. "files": files,
  89. "test_files": test_files,
  90. "tests": tests
  91. }
  92. info = agent.run(
  93. setup_args=setup_args,
  94. env=env,
  95. observation=observation,
  96. traj_dir=traj_dir,
  97. return_type="info",
  98. )
  99. save_predictions(traj_dir, instance_id, info)
  100. except KeyboardInterrupt:
  101. logger.info("Exiting InterCode environment...")
  102. env.close()
  103. break
  104. except Exception as e:
  105. traceback.print_exc()
  106. logger.warning(f"❌ Failed on {env.record['instance_id']}: {e}")
  107. env.reset_container()
  108. continue
  109. def save_arguments(traj_dir, args):
  110. """Save the arguments to a yaml file to the run's trajectory directory."""
  111. log_path = traj_dir / "args.yaml"
  112. if log_path.exists():
  113. try:
  114. other_args = args.load_yaml(log_path)
  115. if (args.dumps_yaml() != other_args.dumps_yaml()): # check yaml equality instead of object equality
  116. logger.warning("**************************************************")
  117. logger.warning("Found existing args.yaml with different arguments!")
  118. logger.warning("**************************************************")
  119. except Exception as e:
  120. logger.warning(f"Failed to load existing args.yaml: {e}")
  121. with log_path.open("w") as f:
  122. args.dump_yaml(f)
  123. def should_skip(args, traj_dir, instance_id):
  124. """Check if we should skip this instance based on the instance filter and skip_existing flag."""
  125. # Skip instances that don't match the instance filter
  126. if re.match(args.instance_filter, instance_id) is None:
  127. logger.info(f"Instance filter not matched. Skipping instance {instance_id}")
  128. return True
  129. # If flag is set to False, don't skip
  130. if not args.skip_existing:
  131. return False
  132. # Check if there's an existing trajectory for this instance
  133. log_path = traj_dir / (instance_id + ".traj")
  134. if log_path.exists():
  135. with log_path.open("r") as f:
  136. data = json.load(f)
  137. # If the trajectory has no exit status, it's incomplete and we will redo it
  138. exit_status = data["info"].get("exit_status", None)
  139. if exit_status == "early_exit" or exit_status is None:
  140. logger.info(f"Found existing trajectory with no exit status: {log_path}")
  141. logger.info("Removing incomplete trajectory...")
  142. os.remove(log_path)
  143. else:
  144. logger.info(f"⏭️ Skipping existing trajectory: {log_path}")
  145. return True
  146. return False
  147. def save_predictions(traj_dir, instance_id, info):
  148. output_file = Path(traj_dir) / "all_preds.jsonl"
  149. model_patch = info["submission"] if "submission" in info else None
  150. datum = {
  151. KEY_MODEL: Path(traj_dir).name,
  152. KEY_INSTANCE_ID: instance_id,
  153. KEY_PREDICTION: model_patch,
  154. }
  155. with open(output_file, "a+") as fp:
  156. print(json.dumps(datum), file=fp, flush=True)
  157. logger.info(f"Saved predictions to {output_file}")
  158. if __name__ == "__main__":
  159. defaults = ScriptArguments(
  160. suffix="",
  161. environment=EnvironmentArguments(
  162. image_name="swe-agent",
  163. data_path="princeton-nlp/SWE-bench_Lite",
  164. split="dev",
  165. verbose=True,
  166. install_environment=True,
  167. ),
  168. skip_existing=True,
  169. agent=AgentArguments(
  170. model=ModelArguments(
  171. model_name="gpt4",
  172. total_cost_limit=0.0,
  173. per_instance_cost_limit=2.0,
  174. temperature=0.2,
  175. top_p=0.95,
  176. ),
  177. config_file="config/default.yaml",
  178. ),
  179. )
  180. # Nicer yaml dumping of multiline strings
  181. def multiline_representer(dumper, data):
  182. """configures yaml for dumping multiline strings
  183. Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data
  184. """
  185. if data.count("\n") > 0: # check for multiline string
  186. return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
  187. return dumper.represent_scalar("tag:yaml.org,2002:str", data)
  188. yaml.add_representer(str, multiline_representer)
  189. args = parse(ScriptArguments, default=defaults, add_config_path_arg=False)
  190. main(args)