agents.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. import json
  2. import re
  3. import logging
  4. from dataclasses import dataclass
  5. from pathlib import Path
  6. from simple_parsing.helpers import field, FrozenSerializable, FlattenedAccess
  7. from sweagent.agent.commands import Command, ParseCommand
  8. from sweagent.agent.history_processors import HistoryProcessor
  9. from sweagent.agent.models import (
  10. APIStats,
  11. ContextWindowExceededError,
  12. CostLimitExceededError,
  13. ModelArguments,
  14. get_model,
  15. )
  16. from sweagent.agent.parsing import ParseFunction, FormatError
  17. from sweagent.environment.utils import LOGGER_NAME
  18. from sweagent.environment.swe_env import SWEEnv
  19. from tenacity import RetryError
  20. from typing import Optional, Tuple, Any
  21. logger = logging.getLogger(LOGGER_NAME)
  22. @dataclass(frozen=True)
  23. class Subroutine(FrozenSerializable):
  24. name: str
  25. agent_file: str
  26. return_type: str = None # one of "action", "observation", "response", "state", "thought"
  27. init_observation: Optional[str] = None
  28. end_name: Optional[str] = None
  29. signature: Optional[str] = None
  30. docstring: Optional[str] = None
  31. model: Optional[ModelArguments] = None
  32. agent_args: Optional[Any] = None
  33. @dataclass(frozen=True)
  34. class AgentConfig(FrozenSerializable):
  35. system_template: str
  36. instance_template: str
  37. next_step_template: Optional[str] = None # defaults to instance_template
  38. next_step_no_output_template: Optional[str] = None # defaults to next_step_template
  39. strategy_template: Optional[str] = None
  40. demonstration_template: Optional[str] = None
  41. demonstrations: list[str] = field(default_factory=list)
  42. put_demos_in_history: bool = False # if True, add demonstration to history instead of as a single message
  43. format_error_template: str = None # defaults to format_error_template in ParseFunction
  44. command_files: list[str] = field(default_factory=list)
  45. env_variables: dict[str, str] = field(default_factory=dict)
  46. util_functions: list[str] = field(default_factory=list)
  47. submit_command: str = "submit"
  48. parse_function: str = "ThoughtActionParser"
  49. parse_command: str = "ParseCommandBash"
  50. history_processor: str = "DefaultHistoryProcessor"
  51. history_processor_args: dict[str, Any] = field(default_factory=dict)
  52. command_docs: str = None
  53. blocklist_error_template: str = "Interactive operation '{name}' is not supported by this environment"
  54. blocklist: Tuple[str] = (
  55. "vim",
  56. "vi",
  57. "emacs",
  58. "nano",
  59. "nohup",
  60. "git",
  61. )
  62. blocklist_standalone: Tuple[str] = (
  63. "python",
  64. "python3",
  65. "ipython",
  66. "bash",
  67. "sh",
  68. "exit",
  69. "/bin/bash",
  70. "/bin/sh",
  71. "nohup",
  72. "vi",
  73. "vim",
  74. "emacs",
  75. "nano",
  76. )
  77. # Should extract environment state in a json readable form
  78. state_command: Command = Command(
  79. name="state",
  80. code="""state() {
  81. echo '{"working_dir": "'$(realpath --relative-to=$ROOT/.. $PWD)'"}';
  82. };""",
  83. )
  84. _commands: list[Command] = field(default_factory=list)
  85. _subroutines: dict[str, Subroutine] = field(default_factory=dict)
  86. subroutine_types: list[Subroutine] = field(default_factory=list)
  87. def __post_init__(self):
  88. if self.next_step_template is None:
  89. object.__setattr__(self, "next_step_template", self.instance_template)
  90. if self.next_step_no_output_template is None:
  91. object.__setattr__(
  92. self, "next_step_no_output_template", self.next_step_template
  93. )
  94. object.__setattr__(self, "parse_command", ParseCommand.get(self.parse_command))
  95. for file in self.command_files:
  96. commands = self.parse_command.parse_command_file(file)
  97. util_functions = [
  98. command for command in commands if command.name.startswith("_")
  99. ]
  100. commands = [
  101. command for command in commands if not command.name.startswith("_")
  102. ]
  103. object.__setattr__(
  104. self, "util_functions", self.util_functions + util_functions
  105. )
  106. object.__setattr__(self, "_commands", self._commands + commands)
  107. for subroutine in self.subroutine_types:
  108. if subroutine.name == 'submit':
  109. raise ValueError("Cannot use 'submit' as a subroutine name")
  110. agent_args = AgentArguments(
  111. model=subroutine.model,
  112. config_file=subroutine.agent_file,
  113. )
  114. object.__setattr__(subroutine, "agent_args", agent_args)
  115. object.__setattr__(self, "_subroutines", {**self._subroutines, subroutine.name: subroutine})
  116. multi_line_command_endings = {
  117. command.name: command.end_name
  118. for command in [*self._commands, *self._subroutines.values()]
  119. if command.end_name is not None
  120. }
  121. object.__setattr__(self, "multi_line_command_endings", multi_line_command_endings)
  122. object.__setattr__(
  123. self,
  124. "command_docs",
  125. self.parse_command.generate_command_docs(
  126. self._commands,
  127. self.subroutine_types,
  128. **self.env_variables,
  129. ),
  130. )
  131. object.__setattr__(self, "parse_function", ParseFunction.get(self.parse_function))
  132. if self.format_error_template is None:
  133. object.__setattr__(
  134. self,
  135. "format_error_template",
  136. self.parse_function.format_error_template,
  137. )
  138. object.__setattr__(self, "format_error_template", self.format_error_template.format(**self.__dict__))
  139. for command in self._commands:
  140. if command.name == self.submit_command:
  141. object.__setattr__(self, "submit_command_end_name", command.end_name)
  142. break
  143. object.__setattr__(
  144. self, "history_processor",
  145. HistoryProcessor.get(self.history_processor, **self.history_processor_args)
  146. )
  147. @dataclass(frozen=True)
  148. class AgentArguments(FlattenedAccess, FrozenSerializable):
  149. model: ModelArguments = None
  150. # Policy can only be set via config yaml file from command line
  151. config_file: Optional[Path] = None
  152. config: Optional[AgentConfig] = field(default=None, cmd=False)
  153. def __post_init__(self):
  154. if self.config is None and self.config_file is not None:
  155. # If unassigned, we load the config from the file to store its contents with the overall arguments
  156. config = AgentConfig.load_yaml(self.config_file)
  157. object.__setattr__(self, "config", config)
  158. assert self.config is not None
  159. for subroutine in getattr(self.config, "subroutines", {}).values():
  160. model_args = getattr(subroutine, "model")
  161. object.__setattr__(model_args, "per_instance_cost_limit", self.model.per_instance_cost_limit)
  162. object.__setattr__(model_args, "total_cost_limit", self.model.total_cost_limit)
  163. class Agent:
  164. """Agent handles the behaviour of the model and how it interacts with the environment."""
  165. def __init__(self, name: str, args: AgentArguments):
  166. self.name = name
  167. self.model = get_model(args.model, args.config._commands + args.config.subroutine_types)
  168. self.config = args.config
  169. self.system_args = {
  170. "command_docs": self.config.command_docs,
  171. **self.config.env_variables,
  172. }
  173. self.instance_args = None
  174. self._parse_command_patterns()
  175. self.history = []
  176. self.last_container_id = None
  177. def setup(self, instance_args, init_model_stats=None) -> None:
  178. """Setup the agent for a new instance."""
  179. self.model.reset_stats(init_model_stats)
  180. self.instance_args = instance_args
  181. system_msg = self.config.system_template.format(**self.system_args)
  182. logger.info(f"SYSTEM ({self.name})\n{system_msg}")
  183. self.history = [
  184. {"role": "system", "content": system_msg, "agent": self.name},
  185. ]
  186. if len(self.config.demonstrations) > 0 and "history_to_messages" in dir(
  187. self.model
  188. ):
  189. for demonstration_path in self.config.demonstrations:
  190. if self.config.demonstration_template is None and not self.config.put_demos_in_history:
  191. raise ValueError("Cannot use demonstrations without a demonstration template or put_demos_in_history=True")
  192. # Load history
  193. logger.info(f"DEMONSTRATION: {demonstration_path}")
  194. demo_history = json.load(open(demonstration_path, "r"))["history"]
  195. demo_history = [
  196. entry for entry in demo_history
  197. if ("agent" not in entry) or
  198. ("agent" in entry and entry["agent"] == self.name)
  199. ]
  200. if self.config.put_demos_in_history:
  201. if self.config.demonstration_template is not None:
  202. logger.warning("Demonstration template is ignored for put_demos_in_history=True")
  203. # Add demonstration to history directly as separate messages
  204. for entry in demo_history:
  205. if entry["role"] != "system":
  206. entry["is_demo"] = True
  207. self.history.append(entry)
  208. else:
  209. # Add demonstration as single message to history
  210. demo_message = self.model.history_to_messages(
  211. demo_history,
  212. is_demonstration=True,
  213. )
  214. demonstration = self.config.demonstration_template.format(
  215. **{"demonstration": demo_message}
  216. )
  217. self.history.append({
  218. "agent": self.name,
  219. "content": demonstration,
  220. "is_demo": True,
  221. "role": "user",
  222. })
  223. @property
  224. def state_command(self) -> str:
  225. """Return the bash command that will be used to extract the environment state."""
  226. return self.config.state_command.name
  227. @property
  228. def local_history(self) -> list[dict[str, str]]:
  229. """Return the history of the agent since the last reset."""
  230. return self.config.history_processor([entry for entry in self.history if entry["agent"] == self.name])
  231. def save_trajectory(self, trajectory, traj_dir, env, info):
  232. log_path = traj_dir / (env.record['instance_id'] + ".traj")
  233. log_dict = {
  234. "environment": env.name,
  235. "trajectory": trajectory,
  236. "history": self.history,
  237. "info": info,
  238. }
  239. with log_path.open("w") as f:
  240. json.dump(log_dict, f, indent=2)
  241. logger.info(f"Saved trajectory to {log_path}")
  242. def _get_first_match(self, action: str, pattern_type: str) -> Optional[re.Match]:
  243. """Return the first match of a command pattern in the action string."""
  244. if pattern_type == "subroutine":
  245. patterns = {k: v for k, v in self.subroutine_patterns.items()}
  246. elif pattern_type == "multi_line":
  247. patterns = {k: v for k, v in self.command_patterns.items() if k in self.config.multi_line_command_endings or k == self.config.submit_command}
  248. patterns += {k: v for k, v in self.subroutine_patterns.items() if k in self.config.multi_line_command_endings}
  249. elif pattern_type == "multi_line_no_subroutines":
  250. patterns = {k: v for k, v in self.command_patterns.items() if k in self.config.multi_line_command_endings}
  251. else:
  252. raise ValueError(f"Unknown pattern type: {pattern_type}")
  253. matches = list()
  254. for name, pat in patterns.items():
  255. match = pat.search(action)
  256. if match:
  257. matches.append(match)
  258. if len(matches) == 0:
  259. return None
  260. matches = sorted(matches, key=lambda x: x.start())
  261. return matches[0]
  262. def _guard_multiline_input(self, action: str) -> str:
  263. """Split action by multiline commands, then append the first line in each multiline command with "<< '{end_name}'".
  264. Multiline commands (which are specified by an end_name) are commands that span multiple lines and are terminated by a specific end_name.
  265. Their multi-line argument is sent using a heredoc, which is a way to send a multi-line string to a command in bash.
  266. """
  267. parsed_action = list()
  268. rem_action = action
  269. while rem_action.strip():
  270. first_match = self._get_first_match(rem_action, "multi_line_no_subroutines")
  271. if first_match:
  272. pre_action = rem_action[:first_match.start()]
  273. match_action = rem_action[first_match.start():first_match.end()]
  274. rem_action = rem_action[first_match.end():]
  275. if pre_action.strip():
  276. parsed_action.append(pre_action)
  277. if match_action.strip():
  278. eof = first_match.group(3).strip()
  279. if not match_action.split('\n')[0].strip().endswith(f"<< '{eof}'"):
  280. guarded_command = match_action[first_match.start():]
  281. first_line = guarded_command.split('\n')[0]
  282. guarded_command = guarded_command.replace(
  283. first_line,
  284. first_line + f" << '{eof}'",
  285. 1
  286. )
  287. parsed_action.append(guarded_command)
  288. else:
  289. parsed_action.append(match_action)
  290. else:
  291. parsed_action.append(rem_action)
  292. rem_action = ""
  293. return '\n'.join(parsed_action)
  294. def split_actions(self, action: str, pattern_type="subroutine") -> list[str]:
  295. """Split an action into a list of actions in a greedy manner, each of which is a subroutine call or a single command."""
  296. parsed_action = list()
  297. rem_action = action
  298. while rem_action.strip():
  299. first_match = self._get_first_match(rem_action, pattern_type)
  300. if first_match:
  301. pre_action = rem_action[:first_match.start()]
  302. match_action = rem_action[first_match.start():first_match.end()]
  303. rem_action = rem_action[first_match.end():]
  304. if pre_action.strip():
  305. parsed_action.append({'agent': self.name, 'action': pre_action, 'cmd_name': None})
  306. if match_action.strip():
  307. if match_action.split()[0] == self.config.submit_command:
  308. parsed_action.append({'agent': self.name, 'action': match_action, 'cmd_name': first_match.group(1)}) # submit command is not a subroutine
  309. else:
  310. parsed_action.append({'agent': first_match.group(1), 'args': first_match.group(2), 'action': match_action, 'cmd_name': first_match.group(1)})
  311. else:
  312. parsed_action.append({'agent': self.name, 'action': rem_action, 'cmd_name': None})
  313. rem_action = ""
  314. return parsed_action
  315. def _parse_command_patterns(self):
  316. self.command_patterns = dict()
  317. for command in self.config._commands:
  318. if command.end_name is not None:
  319. pat = re.compile(fr'^\s*({command.name})\s*(.*?)^({command.end_name})\s*$', re.DOTALL | re.MULTILINE)
  320. self.command_patterns[command.name] = pat
  321. else:
  322. pat = re.compile(fr'^\s*({command.name})\s*(.*?)$', re.MULTILINE)
  323. self.command_patterns[command.name] = pat
  324. self.subroutine_patterns = dict()
  325. for _, subroutine in self.config._subroutines.items():
  326. if subroutine.end_name is None:
  327. pat = re.compile(fr'^\s*({subroutine.name})\s*(.*?)$', re.MULTILINE)
  328. self.subroutine_patterns[subroutine.name,] = pat
  329. else:
  330. pat = re.compile(fr'^\s*({subroutine.name})\s*(.*?)^({subroutine.end_name})\s*$', re.DOTALL | re.MULTILINE)
  331. self.subroutine_patterns[subroutine.name] = pat
  332. if hasattr(self.config, 'submit_command_end_name'):
  333. submit_pat = re.compile(rf'^\s*({self.config.submit_command})\s*(.*?)^({self.config.submit_command_end_name})\s*$', re.DOTALL | re.MULTILINE)
  334. else:
  335. submit_pat = re.compile(rf'^\s*({self.config.submit_command})(\s*)$', re.MULTILINE) # group 2 is nothing
  336. self.subroutine_patterns[self.config.submit_command] = submit_pat
  337. self.command_patterns[self.config.submit_command] = submit_pat
  338. def forward(self, observation: str, available_actions: list[str], state: str) -> Tuple[str, str, str]:
  339. thought, action, output = self.forward_with_error_check(observation, state)
  340. self.history.append(
  341. {"role": "assistant",
  342. "content": output,
  343. "thought": thought,
  344. "action": action,
  345. "agent": self.name,
  346. }
  347. )
  348. logger.info(f"💭 THOUGHT ({self.name})\n{thought}")
  349. logger.info(f"🎬 ACTION ({self.name})\n{action}")
  350. return thought, action, output
  351. def forward_model(self, observation: str, state: str) -> str:
  352. """Query the model with the current state and observation with the appropriate template.
  353. Returns the model output."""
  354. state_vars = json.loads(state)
  355. templates = []
  356. # Determine observation template based on what prior observation was
  357. if self.history[-1]["role"] == "system" or self.history[-1].get("is_demo", False):
  358. # Show instance template if prev. obs. was initial system message
  359. templates = [self.config.instance_template]
  360. if self.config.strategy_template is not None:
  361. templates.append(self.config.strategy_template)
  362. elif observation is None or observation.strip() == "":
  363. # Show no output template if observation content was empty
  364. templates = [self.config.next_step_no_output_template]
  365. else:
  366. # Show standard output template if there is observation content
  367. templates = [self.config.next_step_template]
  368. # Populate selected template(s) with information (e.g., issue, arguments, state)
  369. messages = []
  370. for template in templates:
  371. messages.append(
  372. template.format(
  373. **self.instance_args,
  374. **self.system_args,
  375. **state_vars,
  376. observation=(observation if observation is not None else ""),
  377. )
  378. )
  379. message = "\n".join(messages)
  380. logger.info(f"🤖 MODEL INPUT\n{message}")
  381. self.history.append({"role": "user", "content": message, "agent": self.name})
  382. return self.model.query(self.local_history)
  383. def retry_after_format_fail(self, output):
  384. """Ask the model to correct (without committing to persistent history) after a malformatted model output"""
  385. format_error_template = self.config.format_error_template
  386. logger.warning(f"MALFORMED OUTPUT\n{output}")
  387. logger.warning(f"FORMAT ERROR\n{format_error_template}")
  388. temp_history = self.local_history + [
  389. {"role": "assistant", "content": output, "agent": self.name},
  390. {"role": "user", "content": format_error_template, "agent": self.name},
  391. ]
  392. return self.model.query(temp_history)
  393. def retry_after_blocklist_fail(self, output, action):
  394. """Ask the model to correct (without committing to persistent history) after a disallowed command"""
  395. name = action.strip().split()[0]
  396. blocklist_error_message = self.config.blocklist_error_template.format(name=name)
  397. logger.warning(f"BLOCKLISTED OUTPUT\n{output}")
  398. logger.warning(f"BLOCKLIST ERROR\n{blocklist_error_message}")
  399. temp_history = self.local_history + [
  400. {"role": "assistant", "content": output, "agent": self.name},
  401. {"role": "user", "content": blocklist_error_message, "agent": self.name},
  402. ]
  403. return self.model.query(temp_history)
  404. def should_block_action(self, action):
  405. """Check if the command should be blocked."""
  406. names = action.strip().split()
  407. if len(names) == 0:
  408. return False
  409. name = names[0]
  410. if name in self.config.blocklist:
  411. return True
  412. if name in self.config.blocklist_standalone and name == action.strip():
  413. return True
  414. return False
  415. def check_format_and_requery(
  416. self, output: str,
  417. ) -> Tuple[str, str, str]:
  418. """Query the model with the current state and observation with the appropriate template.
  419. Try to parse the output into a thought and action. Retry if the output is malformatted or the action is blocked.
  420. Returns the thought, action, and raw model output.
  421. """
  422. # Condition for handling outputs with no thought (just action)
  423. if self.model.args.model_name == "human":
  424. return "", output, output
  425. elif self.model.args.model_name == "human_thought":
  426. thought, action = ParseFunction.get("ThoughtActionParser")(
  427. output,
  428. self.config._commands + self.config.subroutine_types,
  429. strict=False,
  430. )
  431. return thought, action, output
  432. format_fails = blocklist_fails = 0
  433. while format_fails + blocklist_fails <= 2:
  434. try:
  435. thought, action = self.config.parse_function(
  436. output,
  437. self.config._commands + self.config.subroutine_types,
  438. strict=False,
  439. )
  440. except KeyboardInterrupt:
  441. raise
  442. except FormatError as e:
  443. format_fails += 1
  444. output = self.retry_after_format_fail(output)
  445. continue
  446. if self.should_block_action(action):
  447. blocklist_fails += 1
  448. output = self.retry_after_blocklist_fail(output, action)
  449. else:
  450. return thought, action, output
  451. logger.warning(f"Malformat limit reached: \n{output}")
  452. return "Exit due to format error", "exit_format", output
  453. def forward_with_error_check(self, observation: str, state: str) -> Tuple[str, str, str]:
  454. try:
  455. output = self.forward_model(observation, state)
  456. except KeyboardInterrupt:
  457. raise
  458. except RuntimeError as e:
  459. logger.warning(f"Runtime error: {e}")
  460. return (
  461. f"Exit due to runtime error: {e}",
  462. "exit_error",
  463. f"exit due to runtime error: {e}",
  464. )
  465. except ContextWindowExceededError as e:
  466. logger.warning(f"Context window exceeded")
  467. return "Exit due to context window", "exit_context", "Exit due to context window"
  468. except CostLimitExceededError as e:
  469. logger.warning(f"Cost limit exceeded")
  470. return "Exit due to cost limit", "exit_cost", "Exit due to cost limit"
  471. except RetryError as e:
  472. logger.warning(f"Retry error: {e}")
  473. return (
  474. f"Exit due to retry error: {e}",
  475. "exit_api",
  476. f"exit due to retry error: {e}",
  477. )
  478. return self.check_format_and_requery(output)
  479. def init_environment_vars(self, env):
  480. self.set_environment_vars(env, self.config.env_variables)
  481. def set_environment_vars(self, env, env_variables):
  482. commands_to_execute = (
  483. [self.config.state_command.code] +
  484. # [code for code in self.config.util_functions] +
  485. # [command.code for command in self.config._commands] +
  486. [f"{k}={v}" for k,v in env_variables.items()]
  487. )
  488. commands = "\n".join(commands_to_execute)
  489. try:
  490. output = env.communicate(commands)
  491. if env.returncode != 0:
  492. raise RuntimeError(f"Nonzero return code: {env.returncode}\nOutput: {output}")
  493. except KeyboardInterrupt:
  494. raise
  495. except Exception as e:
  496. logger.warning("Failed to set environment variables")
  497. raise e
  498. command_files = list()
  499. for file in self.config.command_files:
  500. datum = dict()
  501. contents = open(file, 'r').read()
  502. datum['contents'] = contents
  503. filename = Path(file).name
  504. if not contents.strip().startswith('#!'):
  505. if filename.endswith('.sh'):
  506. # files are sourced, so they are not executable
  507. datum['name'] = Path(file).name
  508. datum['type'] = 'source_file'
  509. elif filename.startswith('_'):
  510. # files are sourced, so they are not executable
  511. datum['name'] = Path(file).name
  512. datum['type'] = 'utility'
  513. else:
  514. raise ValueError((
  515. f"Non-shell script file {file} does not start with shebang.\n"
  516. "Either add a shebang (#!) or change the file extension to .sh if you want to source it.\n"
  517. "You can override this behavior by adding an underscore to the file name (e.g. _utils.py)."
  518. ))
  519. else:
  520. # scripts are made executable
  521. datum['name'] = Path(file).name.rsplit('.', 1)[0]
  522. datum['type'] = 'script'
  523. command_files.append(datum)
  524. env.add_commands(command_files)
  525. def get_environment_vars(self, env):
  526. env_vars = dict()
  527. for var in self.config.env_variables:
  528. env_vars[var] = env.communicate(f"echo ${var}").strip()
  529. return env_vars
  530. def call_subroutine(self, agent_name, sub_action, env):
  531. env_vars = self.get_environment_vars(env)
  532. cwd = env.communicate("pwd -P").strip()
  533. init_observation = self.config._subroutines[agent_name].init_observation
  534. if init_observation is not None:
  535. obs, _, _, _ = env.step(init_observation.format(args=sub_action['args']))
  536. else:
  537. obs = None
  538. if env.returncode != 0:
  539. self.history.append({"role": "user", "content": obs, "agent": agent_name})
  540. raise RuntimeError(f"Nonzero return code: {env.returncode} for init_observation in {agent_name}.\n{obs}")
  541. return_type = self.config._subroutines[agent_name].return_type
  542. sub_agent = Agent(agent_name, self.config._subroutines[agent_name].agent_args)
  543. sub_agent_output = sub_agent.run(
  544. {"issue": sub_action['args']},
  545. env,
  546. observation=obs,
  547. return_type=return_type,
  548. init_model_stats=self.model.stats,
  549. )
  550. self.history += sub_agent.history
  551. self.set_environment_vars(env, env_vars)
  552. env.communicate(f"cd {cwd}")
  553. self.model.stats.replace(sub_agent.model.stats)
  554. return sub_agent_output
  555. def run(
  556. self,
  557. setup_args,
  558. env: SWEEnv,
  559. observation: str = None,
  560. traj_dir: Optional[Path] = None,
  561. return_type: Optional[str] = "info",
  562. init_model_stats: Optional[APIStats] = None,
  563. ):
  564. """
  565. Run the agent on an environment.
  566. Return the final value of the specified return type.
  567. """
  568. done = False
  569. if env.container_obj.id != self.last_container_id:
  570. logger.info(f"Initializing agent settings for container {env.container_obj.id}")
  571. self.init_environment_vars(env)
  572. self.last_container_id = env.container_obj.id
  573. # Re-initialize primary
  574. self.setup(setup_args, init_model_stats)
  575. # Run action/observation loop
  576. trajectory = []
  577. info = {}
  578. while not done:
  579. state = env.communicate(self.state_command) if self.state_command else None
  580. thought, action, output = self.forward(
  581. observation,
  582. env.get_available_actions(),
  583. state)
  584. observations = list()
  585. run_action = self._guard_multiline_input(action)
  586. for sub_action in self.split_actions(run_action):
  587. if sub_action['agent'] == self.name or sub_action['cmd_name'] == self.config.submit_command:
  588. obs, _, done, info = env.step(sub_action['action'])
  589. observations.append(obs)
  590. if sub_action['cmd_name'] == self.config.submit_command:
  591. done = True
  592. if done:
  593. break
  594. else:
  595. agent_name = sub_action['agent']
  596. sub_agent_output = self.call_subroutine(agent_name, sub_action, env)
  597. observations.append(sub_agent_output)
  598. observation = '\n'.join([obs for obs in observations if obs is not None])
  599. trajectory.append(
  600. {
  601. "action": action,
  602. "observation": observation,
  603. "response": output,
  604. "state": state,
  605. "thought": thought,
  606. }
  607. )
  608. info['model_stats'] = self.model.stats.to_dict()
  609. if traj_dir:
  610. self.save_trajectory(trajectory, traj_dir, env, info)
  611. if return_type != "info":
  612. return trajectory[-1][return_type]
  613. else:
  614. return info