agents.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812
  1. import json
  2. import re
  3. import logging
  4. from dataclasses import dataclass
  5. from pathlib import Path
  6. from simple_parsing.helpers.fields import field
  7. from simple_parsing.helpers.serialization.serializable import FrozenSerializable
  8. from simple_parsing.helpers.flatten import FlattenedAccess
  9. from sweagent.utils import debug_time
  10. from sweagent.agent.commands import Command, ParseCommand
  11. from sweagent.agent.history_processors import HistoryProcessor
  12. from sweagent.agent.models import (
  13. APIStats,
  14. ContextWindowExceededError,
  15. CostLimitExceededError,
  16. ModelArguments,
  17. get_model,
  18. )
  19. from sweagent.agent.parsing import ParseFunction, FormatError
  20. from sweagent.environment.utils import LOGGER_NAME
  21. from sweagent.environment.swe_env import SWEEnv
  22. from tenacity import RetryError
  23. from typing import Dict, List, Optional, Tuple, Any
  24. from typing import Optional, Tuple, Any
  25. logger = logging.getLogger(LOGGER_NAME)
  26. @dataclass(frozen=True)
  27. class Subroutine(FrozenSerializable):
  28. name: str
  29. agent_file: str
  30. # one of "action", "observation", "response", "state", "thought"
  31. return_type: str = None # type: ignore
  32. init_observation: Optional[str] = None
  33. end_name: Optional[str] = None
  34. signature: Optional[str] = None
  35. docstring: Optional[str] = None
  36. model: Optional[ModelArguments] = None
  37. agent_args: Optional[Any] = None
  38. @dataclass(frozen=True)
  39. class AgentConfig(FrozenSerializable):
  40. system_template: str
  41. instance_template: str
  42. next_step_template: Optional[str] = None # defaults to instance_template
  43. next_step_no_output_template: Optional[str] = None # defaults to next_step_template
  44. strategy_template: Optional[str] = None
  45. demonstration_template: Optional[str] = None
  46. demonstrations: list[str] = field(default_factory=list)
  47. put_demos_in_history: bool = (
  48. False # if True, add demonstration to history instead of as a single message
  49. )
  50. # defaults to format_error_template in ParseFunction
  51. format_error_template: str = None # type: ignore
  52. command_files: list[str] = field(default_factory=list)
  53. env_variables: dict[str, str] = field(default_factory=dict)
  54. util_functions: list[str] = field(default_factory=list)
  55. submit_command: str = "submit"
  56. parse_function: str = "ThoughtActionParser"
  57. parse_command: str = "ParseCommandBash"
  58. history_processor: str = "DefaultHistoryProcessor"
  59. history_processor_args: dict[str, Any] = field(default_factory=dict)
  60. command_docs: str = None # type: ignore
  61. blocklist_error_template: str = (
  62. "Interactive operation '{name}' is not supported by this environment"
  63. )
  64. blocklist: Tuple[str, ...] = (
  65. "vim",
  66. "vi",
  67. "emacs",
  68. "nano",
  69. "nohup",
  70. "git",
  71. )
  72. blocklist_standalone: Tuple[str, ...] = (
  73. "python",
  74. "python3",
  75. "ipython",
  76. "bash",
  77. "sh",
  78. "exit",
  79. "/bin/bash",
  80. "/bin/sh",
  81. "nohup",
  82. "vi",
  83. "vim",
  84. "emacs",
  85. "nano",
  86. )
  87. # Should extract environment state in a json readable form
  88. state_command: Command = Command(
  89. name="state",
  90. code="""state() {
  91. echo '{"working_dir": "'$(realpath --relative-to=$ROOT/.. $PWD)'"}';
  92. };""",
  93. )
  94. _commands: list[Command] = field(default_factory=list)
  95. _subroutines: dict[str, Subroutine] = field(default_factory=dict)
  96. subroutine_types: list[Subroutine] = field(default_factory=list)
  97. def __post_init__(self):
  98. if self.next_step_template is None:
  99. object.__setattr__(self, "next_step_template", self.instance_template)
  100. if self.next_step_no_output_template is None:
  101. object.__setattr__(
  102. self, "next_step_no_output_template", self.next_step_template
  103. )
  104. object.__setattr__(self, "parse_command", ParseCommand.get(self.parse_command))
  105. for file in self.command_files:
  106. commands = self.parse_command.parse_command_file(file)
  107. util_functions = [
  108. command for command in commands if command.name.startswith("_")
  109. ]
  110. commands = [
  111. command for command in commands if not command.name.startswith("_")
  112. ]
  113. object.__setattr__(
  114. self, "util_functions", self.util_functions + util_functions
  115. )
  116. object.__setattr__(self, "_commands", self._commands + commands)
  117. for subroutine in self.subroutine_types:
  118. if subroutine.name == "submit":
  119. raise ValueError("Cannot use 'submit' as a subroutine name")
  120. agent_args = AgentArguments(
  121. model=subroutine.model,
  122. config_file=subroutine.agent_file,
  123. )
  124. object.__setattr__(subroutine, "agent_args", agent_args)
  125. object.__setattr__(
  126. self, "_subroutines", {**self._subroutines, subroutine.name: subroutine}
  127. )
  128. multi_line_command_endings = {
  129. command.name: command.end_name
  130. for command in [*self._commands, *self._subroutines.values()]
  131. if command.end_name is not None
  132. }
  133. object.__setattr__(
  134. self, "multi_line_command_endings", multi_line_command_endings
  135. )
  136. object.__setattr__(
  137. self,
  138. "command_docs",
  139. self.parse_command.generate_command_docs(
  140. self._commands,
  141. self.subroutine_types,
  142. **self.env_variables,
  143. ),
  144. )
  145. object.__setattr__(
  146. self, "parse_function", ParseFunction.get(self.parse_function)
  147. )
  148. if self.format_error_template is None:
  149. object.__setattr__(
  150. self,
  151. "format_error_template",
  152. self.parse_function.format_error_template,
  153. )
  154. object.__setattr__(
  155. self,
  156. "format_error_template",
  157. self.format_error_template.format(**self.__dict__),
  158. )
  159. for command in self._commands:
  160. if command.name == self.submit_command:
  161. object.__setattr__(self, "submit_command_end_name", command.end_name)
  162. break
  163. object.__setattr__(
  164. self,
  165. "history_processor",
  166. HistoryProcessor.get(self.history_processor, **self.history_processor_args),
  167. )
  168. @dataclass(frozen=True)
  169. class AgentArguments(FlattenedAccess, FrozenSerializable):
  170. """Configure the agent's behaviour (templates, parse functions, blocklists, ...)."""
  171. model: ModelArguments = None
  172. # Policy can only be set via config yaml file from command line
  173. config_file: Optional[Path] = None
  174. config: Optional[AgentConfig] = field(default=None, cmd=False)
  175. def __post_init__(self):
  176. if self.config is None and self.config_file is not None:
  177. # If unassigned, we load the config from the file to store its contents with the overall arguments
  178. config = AgentConfig.load_yaml(self.config_file)
  179. object.__setattr__(self, "config", config)
  180. assert self.config is not None # mypy
  181. for subroutine in getattr(self.config, "subroutines", {}).values():
  182. model_args = getattr(subroutine, "model")
  183. object.__setattr__(
  184. model_args,
  185. "per_instance_cost_limit",
  186. self.model.per_instance_cost_limit,
  187. )
  188. object.__setattr__(
  189. model_args, "total_cost_limit", self.model.total_cost_limit
  190. )
  191. class Agent:
  192. """Agent handles the behaviour of the model and how it interacts with the environment."""
  193. def __init__(self, name: str, args: AgentArguments):
  194. self.name = name
  195. self.model = get_model(
  196. args.model, args.config._commands + args.config.subroutine_types
  197. )
  198. self.config = args.config
  199. assert self.config is not None # mypy
  200. self.system_args = {
  201. "command_docs": self.config.command_docs,
  202. **self.config.env_variables,
  203. }
  204. self.instance_args = None
  205. self._parse_command_patterns()
  206. self.history = []
  207. self.last_container_id = None
  208. def setup(self, instance_args, init_model_stats=None) -> None:
  209. """Setup the agent for a new instance."""
  210. assert self.config is not None # mypy
  211. self.model.reset_stats(init_model_stats)
  212. self.instance_args = instance_args
  213. system_msg = self.config.system_template.format(**self.system_args)
  214. logger.info(f"SYSTEM ({self.name})\n{system_msg}")
  215. self.history: List[Dict[str, Any]] = [
  216. {"role": "system", "content": system_msg, "agent": self.name},
  217. ]
  218. if len(self.config.demonstrations) > 0 and "history_to_messages" in dir(
  219. self.model
  220. ):
  221. for demonstration_path in self.config.demonstrations:
  222. if (
  223. self.config.demonstration_template is None
  224. and not self.config.put_demos_in_history
  225. ):
  226. raise ValueError(
  227. "Cannot use demonstrations without a demonstration template or put_demos_in_history=True"
  228. )
  229. # Load history
  230. logger.info(f"DEMONSTRATION: {demonstration_path}")
  231. demo_history = json.load(open(demonstration_path, "r"))["history"]
  232. demo_history = [
  233. entry
  234. for entry in demo_history
  235. if ("agent" not in entry)
  236. or ("agent" in entry and entry["agent"] == self.name)
  237. ]
  238. if self.config.put_demos_in_history:
  239. if self.config.demonstration_template is not None:
  240. logger.warning(
  241. "Demonstration template is ignored for put_demos_in_history=True"
  242. )
  243. # Add demonstration to history directly as separate messages
  244. for entry in demo_history:
  245. if entry["role"] != "system":
  246. entry["is_demo"] = True
  247. self.history.append(entry)
  248. else:
  249. # Add demonstration as single message to history
  250. demo_message = self.model.history_to_messages(
  251. demo_history,
  252. is_demonstration=True,
  253. )
  254. demonstration = self.config.demonstration_template.format(
  255. **{"demonstration": demo_message}
  256. )
  257. self.history.append(
  258. {
  259. "agent": self.name,
  260. "content": demonstration,
  261. "is_demo": True,
  262. "role": "user",
  263. }
  264. )
  265. @property
  266. def state_command(self) -> str:
  267. """Return the bash command that will be used to extract the environment state."""
  268. return self.config.state_command.name
  269. @property
  270. def local_history(self) -> list[dict[str, str]]:
  271. """Return the history of the agent since the last reset."""
  272. return self.config.history_processor(
  273. [entry for entry in self.history if entry["agent"] == self.name]
  274. )
  275. def save_trajectory(self, trajectory, traj_dir, env, info):
  276. log_path = traj_dir / (env.record["instance_id"] + ".traj")
  277. log_dict = {
  278. "environment": env.name,
  279. "trajectory": trajectory,
  280. "history": self.history,
  281. "info": info,
  282. }
  283. with log_path.open("w") as f:
  284. json.dump(log_dict, f, indent=2)
  285. logger.info(f"Saved trajectory to {log_path}")
  286. def _get_first_match(self, action: str, pattern_type: str) -> Optional[re.Match]:
  287. """Return the first match of a command pattern in the action string."""
  288. assert self.config is not None # mypy
  289. if pattern_type == "subroutine":
  290. patterns = {k: v for k, v in self.subroutine_patterns.items()}
  291. elif pattern_type == "multi_line":
  292. patterns = {
  293. k: v
  294. for k, v in self.command_patterns.items()
  295. if k in self.config.multi_line_command_endings
  296. or k == self.config.submit_command
  297. }
  298. patterns += {
  299. k: v
  300. for k, v in self.subroutine_patterns.items()
  301. if k in self.config.multi_line_command_endings
  302. }
  303. elif pattern_type == "multi_line_no_subroutines":
  304. patterns = {
  305. k: v
  306. for k, v in self.command_patterns.items()
  307. if k in self.config.multi_line_command_endings
  308. }
  309. else:
  310. raise ValueError(f"Unknown pattern type: {pattern_type}")
  311. matches = list()
  312. for name, pat in patterns.items():
  313. match = pat.search(action)
  314. if match:
  315. matches.append(match)
  316. if len(matches) == 0:
  317. return None
  318. matches = sorted(matches, key=lambda x: x.start())
  319. return matches[0]
  320. def _guard_multiline_input(self, action: str) -> str:
  321. """Split action by multiline commands, then append the first line in each multiline command with "<< '{end_name}'".
  322. Multiline commands (which are specified by an end_name) are commands that span multiple lines and are terminated by a specific end_name.
  323. Their multi-line argument is sent using a heredoc, which is a way to send a multi-line string to a command in bash.
  324. """
  325. parsed_action = list()
  326. rem_action = action
  327. while rem_action.strip():
  328. first_match = self._get_first_match(rem_action, "multi_line_no_subroutines")
  329. if first_match:
  330. pre_action = rem_action[: first_match.start()]
  331. match_action = rem_action[first_match.start() : first_match.end()]
  332. rem_action = rem_action[first_match.end() :]
  333. if pre_action.strip():
  334. parsed_action.append(pre_action)
  335. if match_action.strip():
  336. eof = first_match.group(3).strip()
  337. if not match_action.split("\n")[0].strip().endswith(f"<< '{eof}'"):
  338. guarded_command = match_action[first_match.start() :]
  339. first_line = guarded_command.split("\n")[0]
  340. guarded_command = guarded_command.replace(
  341. first_line, first_line + f" << '{eof}'", 1
  342. )
  343. parsed_action.append(guarded_command)
  344. else:
  345. parsed_action.append(match_action)
  346. else:
  347. parsed_action.append(rem_action)
  348. rem_action = ""
  349. return "\n".join(parsed_action)
  350. def split_actions(self, action: str, pattern_type="subroutine") -> List[Dict[str, Any]]:
  351. """Split an action into a list of actions in a greedy manner, each of which is a subroutine call or a single command."""
  352. parsed_action = list()
  353. rem_action = action
  354. while rem_action.strip():
  355. first_match = self._get_first_match(rem_action, pattern_type)
  356. if first_match:
  357. pre_action = rem_action[: first_match.start()]
  358. match_action = rem_action[first_match.start() : first_match.end()]
  359. rem_action = rem_action[first_match.end() :]
  360. if pre_action.strip():
  361. parsed_action.append(
  362. {"agent": self.name, "action": pre_action, "cmd_name": None}
  363. )
  364. if match_action.strip():
  365. if match_action.split()[0] == self.config.submit_command:
  366. parsed_action.append(
  367. {
  368. "agent": self.name,
  369. "action": match_action,
  370. "cmd_name": first_match.group(1),
  371. }
  372. ) # submit command is not a subroutine
  373. else:
  374. parsed_action.append(
  375. {
  376. "agent": first_match.group(1),
  377. "args": first_match.group(2),
  378. "action": match_action,
  379. "cmd_name": first_match.group(1),
  380. }
  381. )
  382. else:
  383. parsed_action.append(
  384. {"agent": self.name, "action": rem_action, "cmd_name": None}
  385. )
  386. rem_action = ""
  387. return parsed_action
  388. def _parse_command_patterns(self):
  389. assert self.config is not None # mypy
  390. self.command_patterns = dict()
  391. for command in self.config._commands:
  392. if command.end_name is not None:
  393. pat = re.compile(
  394. rf"^\s*({command.name})\s*(.*?)^({command.end_name})\s*$",
  395. re.DOTALL | re.MULTILINE,
  396. )
  397. self.command_patterns[command.name] = pat
  398. else:
  399. pat = re.compile(rf"^\s*({command.name})\s*(.*?)$", re.MULTILINE)
  400. self.command_patterns[command.name] = pat
  401. self.subroutine_patterns = dict()
  402. for _, subroutine in self.config._subroutines.items():
  403. if subroutine.end_name is None:
  404. pat = re.compile(rf"^\s*({subroutine.name})\s*(.*?)$", re.MULTILINE)
  405. self.subroutine_patterns[subroutine.name,] = pat
  406. else:
  407. pat = re.compile(
  408. rf"^\s*({subroutine.name})\s*(.*?)^({subroutine.end_name})\s*$",
  409. re.DOTALL | re.MULTILINE,
  410. )
  411. self.subroutine_patterns[subroutine.name] = pat
  412. if hasattr(self.config, "submit_command_end_name"):
  413. submit_pat = re.compile(
  414. rf"^\s*({self.config.submit_command})\s*(.*?)^({self.config.submit_command_end_name})\s*$",
  415. re.DOTALL | re.MULTILINE,
  416. )
  417. else:
  418. submit_pat = re.compile(
  419. rf"^\s*({self.config.submit_command})(\s*)$", re.MULTILINE
  420. ) # group 2 is nothing
  421. self.subroutine_patterns[self.config.submit_command] = submit_pat
  422. self.command_patterns[self.config.submit_command] = submit_pat
  423. def forward(
  424. self, observation: str, available_actions: list[str], state: str
  425. ) -> Tuple[str, str, str]:
  426. debug_time()
  427. thought, action, output = self.forward_with_error_check(observation, state)
  428. self.history.append(
  429. {
  430. "role": "assistant",
  431. "content": output,
  432. "thought": thought,
  433. "action": action,
  434. "agent": self.name,
  435. }
  436. )
  437. logger.info(f"💭 THOUGHT ({self.name})\n{thought}")
  438. logger.info(f"🎬 ACTION ({self.name})\n{action}")
  439. return thought, action, output
  440. def forward_model(self, observation: str, state: str) -> str:
  441. """Query the model with the current state and observation with the appropriate template.
  442. Returns the model output."""
  443. assert self.config is not None # mypy
  444. debug_time()
  445. state_vars = json.loads(state)
  446. templates: List[str] = []
  447. # Determine observation template based on what prior observation was
  448. if self.history[-1]["role"] == "system" or self.history[-1].get(
  449. "is_demo", False
  450. ):
  451. # Show instance template if prev. obs. was initial system message
  452. templates = [self.config.instance_template]
  453. if self.config.strategy_template is not None:
  454. templates.append(self.config.strategy_template)
  455. elif observation is None or observation.strip() == "":
  456. # Show no output template if observation content was empty
  457. templates = [self.config.next_step_no_output_template]
  458. else:
  459. # Show standard output template if there is observation content
  460. templates = [self.config.next_step_template]
  461. # Populate selected template(s) with information (e.g., issue, arguments, state)
  462. messages = []
  463. for template in templates:
  464. messages.append(
  465. template.format(
  466. **self.instance_args,
  467. **self.system_args,
  468. **state_vars,
  469. observation=(observation if observation is not None else ""),
  470. )
  471. )
  472. message = "\n".join(messages)
  473. logger.info(f"🤖 MODEL INPUT\n{message}")
  474. self.history.append({"role": "user", "content": message, "agent": self.name})
  475. debug_time()
  476. return self.model.query(self.local_history)
  477. def retry_after_format_fail(self, output):
  478. """Ask the model to correct (without committing to persistent history) after a malformatted model output"""
  479. format_error_template = self.config.format_error_template
  480. logger.warning(f"MALFORMED OUTPUT\n{output}")
  481. logger.warning(f"FORMAT ERROR\n{format_error_template}")
  482. temp_history = self.local_history + [
  483. {"role": "assistant", "content": output, "agent": self.name},
  484. {"role": "user", "content": format_error_template, "agent": self.name},
  485. ]
  486. return self.model.query(temp_history)
  487. def retry_after_blocklist_fail(self, output, action):
  488. """Ask the model to correct (without committing to persistent history) after a disallowed command"""
  489. name = action.strip().split()[0]
  490. blocklist_error_message = self.config.blocklist_error_template.format(name=name)
  491. logger.warning(f"BLOCKLISTED OUTPUT\n{output}")
  492. logger.warning(f"BLOCKLIST ERROR\n{blocklist_error_message}")
  493. temp_history = self.local_history + [
  494. {"role": "assistant", "content": output, "agent": self.name},
  495. {"role": "user", "content": blocklist_error_message, "agent": self.name},
  496. ]
  497. return self.model.query(temp_history)
  498. def should_block_action(self, action):
  499. """Check if the command should be blocked."""
  500. names = action.strip().split()
  501. if len(names) == 0:
  502. return False
  503. name = names[0]
  504. if name in self.config.blocklist:
  505. return True
  506. if name in self.config.blocklist_standalone and name == action.strip():
  507. return True
  508. return False
  509. def check_format_and_requery(
  510. self,
  511. output: str,
  512. ) -> Tuple[str, str, str]:
  513. """Query the model with the current state and observation with the appropriate template.
  514. Try to parse the output into a thought and action. Retry if the output is malformatted or the action is blocked.
  515. Returns the thought, action, and raw model output.
  516. """
  517. debug_time()
  518. # Condition for handling outputs with no thought (just action)
  519. if self.model.args.model_name == "human":
  520. return "", output, output
  521. elif self.model.args.model_name == "human_thought":
  522. thought, action = ParseFunction.get("ThoughtActionParser")(
  523. output,
  524. self.config._commands + self.config.subroutine_types,
  525. strict=False,
  526. )
  527. return thought, action, output
  528. format_fails = blocklist_fails = 0
  529. while format_fails + blocklist_fails <= 2:
  530. try:
  531. thought, action = self.config.parse_function(
  532. output,
  533. self.config._commands + self.config.subroutine_types,
  534. strict=False,
  535. )
  536. except KeyboardInterrupt:
  537. raise
  538. except FormatError:
  539. format_fails += 1
  540. output = self.retry_after_format_fail(output)
  541. continue
  542. if self.should_block_action(action):
  543. blocklist_fails += 1
  544. output = self.retry_after_blocklist_fail(output, action)
  545. else:
  546. return thought, action, output
  547. logger.debug("format fails %d, blocklist fails %d", format_fails, blocklist_fails)
  548. logger.warning(f"Malformat limit reached: \n{output}")
  549. debug_time()
  550. return "Exit due to format error", "exit_format", output
  551. def forward_with_error_check(
  552. self, observation: str, state: str
  553. ) -> Tuple[str, str, str]:
  554. """Wrapper around `self.forward_model` that handles errors and retries
  555. due to format errors or blocked actions.
  556. """
  557. try:
  558. output = self.forward_model(observation, state)
  559. except KeyboardInterrupt:
  560. raise
  561. except RuntimeError as e:
  562. logger.warning(f"Runtime error: {e}")
  563. return (
  564. f"Exit due to runtime error: {e}",
  565. "exit_error",
  566. f"exit due to runtime error: {e}",
  567. )
  568. except ContextWindowExceededError:
  569. logger.warning(f"Context window exceeded")
  570. return "Exit due to context window", "exit_context", "Exit due to context window"
  571. except CostLimitExceededError:
  572. logger.warning(f"Cost limit exceeded")
  573. return "Exit due to cost limit", "exit_cost", "Exit due to cost limit"
  574. except RetryError as e:
  575. logger.warning(f"Retry error: {e}")
  576. return (
  577. f"Exit due to retry error: {e}",
  578. "exit_api",
  579. f"exit due to retry error: {e}",
  580. )
  581. return self.check_format_and_requery(output)
  582. def init_environment_vars(self, env):
  583. self.set_environment_vars(env, self.config.env_variables)
  584. def set_environment_vars(self, env, env_variables):
  585. assert self.config is not None # mypy
  586. commands_to_execute = (
  587. [self.config.state_command.code]
  588. +
  589. # [code for code in self.config.util_functions] +
  590. # [command.code for command in self.config._commands] +
  591. [f"{k}={v}" for k, v in env_variables.items()]
  592. )
  593. commands = "\n".join(commands_to_execute)
  594. try:
  595. output = env.communicate(commands)
  596. if env.returncode != 0:
  597. raise RuntimeError(
  598. f"Nonzero return code: {env.returncode}\nOutput: {output}"
  599. )
  600. except KeyboardInterrupt:
  601. raise
  602. except Exception as e:
  603. logger.warning("Failed to set environment variables")
  604. raise e
  605. command_files = list()
  606. for file in self.config.command_files:
  607. datum = dict()
  608. contents = open(file, "r").read()
  609. datum["contents"] = contents
  610. filename = Path(file).name
  611. if not contents.strip().startswith("#!"):
  612. if filename.endswith(".sh"):
  613. # files are sourced, so they are not executable
  614. datum["name"] = Path(file).name
  615. datum["type"] = "source_file"
  616. elif filename.startswith("_"):
  617. # files are sourced, so they are not executable
  618. datum["name"] = Path(file).name
  619. datum["type"] = "utility"
  620. else:
  621. raise ValueError(
  622. (
  623. f"Non-shell script file {file} does not start with shebang.\n"
  624. "Either add a shebang (#!) or change the file extension to .sh if you want to source it.\n"
  625. "You can override this behavior by adding an underscore to the file name (e.g. _utils.py)."
  626. )
  627. )
  628. else:
  629. # scripts are made executable
  630. datum["name"] = Path(file).name.rsplit(".", 1)[0]
  631. datum["type"] = "script"
  632. command_files.append(datum)
  633. env.add_commands(command_files)
  634. def get_environment_vars(self, env):
  635. assert self.config is not None # mypy
  636. env_vars = dict()
  637. for var in self.config.env_variables:
  638. env_vars[var] = env.communicate(f"echo ${var}").strip()
  639. return env_vars
  640. def call_subroutine(self, agent_name, sub_action, env):
  641. assert self.config is not None # mypy
  642. env_vars = self.get_environment_vars(env)
  643. cwd = env.communicate("pwd -P").strip()
  644. init_observation = self.config._subroutines[agent_name].init_observation
  645. if init_observation is not None:
  646. obs, _, _, _ = env.step(init_observation.format(args=sub_action["args"]))
  647. else:
  648. obs = None
  649. if env.returncode != 0:
  650. self.history.append({"role": "user", "content": obs, "agent": agent_name})
  651. raise RuntimeError(
  652. f"Nonzero return code: {env.returncode} for init_observation in {agent_name}.\n{obs}"
  653. )
  654. return_type = self.config._subroutines[agent_name].return_type
  655. sub_agent = Agent(agent_name, self.config._subroutines[agent_name].agent_args)
  656. sub_agent_output = sub_agent.run(
  657. {"issue": sub_action["args"]},
  658. env,
  659. observation=obs,
  660. return_type=return_type,
  661. init_model_stats=self.model.stats,
  662. )
  663. self.history += sub_agent.history
  664. self.set_environment_vars(env, env_vars)
  665. env.communicate(f"cd {cwd}")
  666. self.model.stats.replace(sub_agent.model.stats)
  667. return sub_agent_output
  668. def run(
  669. self,
  670. setup_args: Dict[str, Any],
  671. env: SWEEnv,
  672. observation: Optional[str] = None,
  673. traj_dir: Optional[Path] = None,
  674. return_type: Optional[str] = "info",
  675. init_model_stats: Optional[APIStats] = None,
  676. ):
  677. """
  678. Run the agent on an environment.
  679. Return the final value of the specified return type.
  680. """
  681. done = False
  682. assert env.container_obj is not None
  683. assert self.config is not None # mypy
  684. if env.container_obj.id != self.last_container_id:
  685. logger.info(
  686. f"Initializing agent settings for container {env.container_obj.id}"
  687. )
  688. self.init_environment_vars(env)
  689. self.last_container_id = env.container_obj.id
  690. # Re-initialize primary
  691. self.setup(setup_args, init_model_stats)
  692. # Run action/observation loop
  693. trajectory = []
  694. info = {}
  695. while not done:
  696. debug_time()
  697. state = env.communicate(self.state_command) if self.state_command else None
  698. debug_time("state cmd")
  699. thought, action, output = self.forward(
  700. observation,
  701. env.get_available_actions(),
  702. state)
  703. debug_time("reset before postproc")
  704. observations = list()
  705. run_action = self._guard_multiline_input(action)
  706. debug_time("guard multiline")
  707. for sub_action in self.split_actions(run_action):
  708. if (
  709. sub_action["agent"] == self.name
  710. or sub_action["cmd_name"] == self.config.submit_command
  711. ):
  712. obs, _, done, info = env.step(sub_action["action"])
  713. observations.append(obs)
  714. if sub_action["cmd_name"] == self.config.submit_command:
  715. done = True
  716. if done:
  717. break
  718. else:
  719. agent_name = sub_action["agent"]
  720. sub_agent_output = self.call_subroutine(agent_name, sub_action, env)
  721. observations.append(sub_agent_output)
  722. observation = "\n".join([obs for obs in observations if obs is not None])
  723. debug_time("got observation")
  724. trajectory.append(
  725. {
  726. "action": action,
  727. "observation": observation,
  728. "response": output,
  729. "state": state,
  730. "thought": thought,
  731. }
  732. )
  733. info["model_stats"] = self.model.stats.to_dict()
  734. if traj_dir:
  735. self.save_trajectory(trajectory, traj_dir, env, info)
  736. debug_time("traj saved")
  737. if return_type == "info":
  738. return info
  739. if return_type == "info_trajectory":
  740. return info, trajectory
  741. return trajectory[-1][return_type]