123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812 |
- import json
- import re
- import logging
- from dataclasses import dataclass
- from pathlib import Path
- from simple_parsing.helpers.fields import field
- from simple_parsing.helpers.serialization.serializable import FrozenSerializable
- from simple_parsing.helpers.flatten import FlattenedAccess
- from sweagent.utils import debug_time
- from sweagent.agent.commands import Command, ParseCommand
- from sweagent.agent.history_processors import HistoryProcessor
- from sweagent.agent.models import (
- APIStats,
- ContextWindowExceededError,
- CostLimitExceededError,
- ModelArguments,
- get_model,
- )
- from sweagent.agent.parsing import ParseFunction, FormatError
- from sweagent.environment.utils import LOGGER_NAME
- from sweagent.environment.swe_env import SWEEnv
- from tenacity import RetryError
- from typing import Dict, List, Optional, Tuple, Any
- from typing import Optional, Tuple, Any
- logger = logging.getLogger(LOGGER_NAME)
- @dataclass(frozen=True)
- class Subroutine(FrozenSerializable):
- name: str
- agent_file: str
- # one of "action", "observation", "response", "state", "thought"
- return_type: str = None # type: ignore
- init_observation: Optional[str] = None
- end_name: Optional[str] = None
- signature: Optional[str] = None
- docstring: Optional[str] = None
- model: Optional[ModelArguments] = None
- agent_args: Optional[Any] = None
- @dataclass(frozen=True)
- class AgentConfig(FrozenSerializable):
- system_template: str
- instance_template: str
- next_step_template: Optional[str] = None # defaults to instance_template
- next_step_no_output_template: Optional[str] = None # defaults to next_step_template
- strategy_template: Optional[str] = None
- demonstration_template: Optional[str] = None
- demonstrations: list[str] = field(default_factory=list)
- put_demos_in_history: bool = (
- False # if True, add demonstration to history instead of as a single message
- )
- # defaults to format_error_template in ParseFunction
- format_error_template: str = None # type: ignore
- command_files: list[str] = field(default_factory=list)
- env_variables: dict[str, str] = field(default_factory=dict)
- util_functions: list[str] = field(default_factory=list)
- submit_command: str = "submit"
- parse_function: str = "ThoughtActionParser"
- parse_command: str = "ParseCommandBash"
- history_processor: str = "DefaultHistoryProcessor"
- history_processor_args: dict[str, Any] = field(default_factory=dict)
- command_docs: str = None # type: ignore
- blocklist_error_template: str = (
- "Interactive operation '{name}' is not supported by this environment"
- )
- blocklist: Tuple[str, ...] = (
- "vim",
- "vi",
- "emacs",
- "nano",
- "nohup",
- "git",
- )
- blocklist_standalone: Tuple[str, ...] = (
- "python",
- "python3",
- "ipython",
- "bash",
- "sh",
- "exit",
- "/bin/bash",
- "/bin/sh",
- "nohup",
- "vi",
- "vim",
- "emacs",
- "nano",
- )
- # Should extract environment state in a json readable form
- state_command: Command = Command(
- name="state",
- code="""state() {
- echo '{"working_dir": "'$(realpath --relative-to=$ROOT/.. $PWD)'"}';
- };""",
- )
- _commands: list[Command] = field(default_factory=list)
- _subroutines: dict[str, Subroutine] = field(default_factory=dict)
- subroutine_types: list[Subroutine] = field(default_factory=list)
- def __post_init__(self):
- if self.next_step_template is None:
- object.__setattr__(self, "next_step_template", self.instance_template)
- if self.next_step_no_output_template is None:
- object.__setattr__(
- self, "next_step_no_output_template", self.next_step_template
- )
- object.__setattr__(self, "parse_command", ParseCommand.get(self.parse_command))
- for file in self.command_files:
- commands = self.parse_command.parse_command_file(file)
- util_functions = [
- command for command in commands if command.name.startswith("_")
- ]
- commands = [
- command for command in commands if not command.name.startswith("_")
- ]
- object.__setattr__(
- self, "util_functions", self.util_functions + util_functions
- )
- object.__setattr__(self, "_commands", self._commands + commands)
- for subroutine in self.subroutine_types:
- if subroutine.name == "submit":
- raise ValueError("Cannot use 'submit' as a subroutine name")
- agent_args = AgentArguments(
- model=subroutine.model,
- config_file=subroutine.agent_file,
- )
- object.__setattr__(subroutine, "agent_args", agent_args)
- object.__setattr__(
- self, "_subroutines", {**self._subroutines, subroutine.name: subroutine}
- )
- multi_line_command_endings = {
- command.name: command.end_name
- for command in [*self._commands, *self._subroutines.values()]
- if command.end_name is not None
- }
- object.__setattr__(
- self, "multi_line_command_endings", multi_line_command_endings
- )
- object.__setattr__(
- self,
- "command_docs",
- self.parse_command.generate_command_docs(
- self._commands,
- self.subroutine_types,
- **self.env_variables,
- ),
- )
- object.__setattr__(
- self, "parse_function", ParseFunction.get(self.parse_function)
- )
- if self.format_error_template is None:
- object.__setattr__(
- self,
- "format_error_template",
- self.parse_function.format_error_template,
- )
- object.__setattr__(
- self,
- "format_error_template",
- self.format_error_template.format(**self.__dict__),
- )
- for command in self._commands:
- if command.name == self.submit_command:
- object.__setattr__(self, "submit_command_end_name", command.end_name)
- break
- object.__setattr__(
- self,
- "history_processor",
- HistoryProcessor.get(self.history_processor, **self.history_processor_args),
- )
- @dataclass(frozen=True)
- class AgentArguments(FlattenedAccess, FrozenSerializable):
- """Configure the agent's behaviour (templates, parse functions, blocklists, ...)."""
- model: ModelArguments = None
- # Policy can only be set via config yaml file from command line
- config_file: Optional[Path] = None
- config: Optional[AgentConfig] = field(default=None, cmd=False)
- def __post_init__(self):
- if self.config is None and self.config_file is not None:
- # If unassigned, we load the config from the file to store its contents with the overall arguments
- config = AgentConfig.load_yaml(self.config_file)
- object.__setattr__(self, "config", config)
- assert self.config is not None # mypy
- for subroutine in getattr(self.config, "subroutines", {}).values():
- model_args = getattr(subroutine, "model")
- object.__setattr__(
- model_args,
- "per_instance_cost_limit",
- self.model.per_instance_cost_limit,
- )
- object.__setattr__(
- model_args, "total_cost_limit", self.model.total_cost_limit
- )
- class Agent:
- """Agent handles the behaviour of the model and how it interacts with the environment."""
- def __init__(self, name: str, args: AgentArguments):
- self.name = name
- self.model = get_model(
- args.model, args.config._commands + args.config.subroutine_types
- )
- self.config = args.config
- assert self.config is not None # mypy
- self.system_args = {
- "command_docs": self.config.command_docs,
- **self.config.env_variables,
- }
- self.instance_args = None
- self._parse_command_patterns()
- self.history = []
- self.last_container_id = None
- def setup(self, instance_args, init_model_stats=None) -> None:
- """Setup the agent for a new instance."""
- assert self.config is not None # mypy
- self.model.reset_stats(init_model_stats)
- self.instance_args = instance_args
- system_msg = self.config.system_template.format(**self.system_args)
- logger.info(f"SYSTEM ({self.name})\n{system_msg}")
- self.history: List[Dict[str, Any]] = [
- {"role": "system", "content": system_msg, "agent": self.name},
- ]
- if len(self.config.demonstrations) > 0 and "history_to_messages" in dir(
- self.model
- ):
- for demonstration_path in self.config.demonstrations:
- if (
- self.config.demonstration_template is None
- and not self.config.put_demos_in_history
- ):
- raise ValueError(
- "Cannot use demonstrations without a demonstration template or put_demos_in_history=True"
- )
- # Load history
- logger.info(f"DEMONSTRATION: {demonstration_path}")
- demo_history = json.load(open(demonstration_path, "r"))["history"]
- demo_history = [
- entry
- for entry in demo_history
- if ("agent" not in entry)
- or ("agent" in entry and entry["agent"] == self.name)
- ]
- if self.config.put_demos_in_history:
- if self.config.demonstration_template is not None:
- logger.warning(
- "Demonstration template is ignored for put_demos_in_history=True"
- )
- # Add demonstration to history directly as separate messages
- for entry in demo_history:
- if entry["role"] != "system":
- entry["is_demo"] = True
- self.history.append(entry)
- else:
- # Add demonstration as single message to history
- demo_message = self.model.history_to_messages(
- demo_history,
- is_demonstration=True,
- )
- demonstration = self.config.demonstration_template.format(
- **{"demonstration": demo_message}
- )
- self.history.append(
- {
- "agent": self.name,
- "content": demonstration,
- "is_demo": True,
- "role": "user",
- }
- )
- @property
- def state_command(self) -> str:
- """Return the bash command that will be used to extract the environment state."""
- return self.config.state_command.name
- @property
- def local_history(self) -> list[dict[str, str]]:
- """Return the history of the agent since the last reset."""
- return self.config.history_processor(
- [entry for entry in self.history if entry["agent"] == self.name]
- )
- def save_trajectory(self, trajectory, traj_dir, env, info):
- log_path = traj_dir / (env.record["instance_id"] + ".traj")
- log_dict = {
- "environment": env.name,
- "trajectory": trajectory,
- "history": self.history,
- "info": info,
- }
- with log_path.open("w") as f:
- json.dump(log_dict, f, indent=2)
- logger.info(f"Saved trajectory to {log_path}")
- def _get_first_match(self, action: str, pattern_type: str) -> Optional[re.Match]:
- """Return the first match of a command pattern in the action string."""
- assert self.config is not None # mypy
- if pattern_type == "subroutine":
- patterns = {k: v for k, v in self.subroutine_patterns.items()}
- elif pattern_type == "multi_line":
- patterns = {
- k: v
- for k, v in self.command_patterns.items()
- if k in self.config.multi_line_command_endings
- or k == self.config.submit_command
- }
- patterns += {
- k: v
- for k, v in self.subroutine_patterns.items()
- if k in self.config.multi_line_command_endings
- }
- elif pattern_type == "multi_line_no_subroutines":
- patterns = {
- k: v
- for k, v in self.command_patterns.items()
- if k in self.config.multi_line_command_endings
- }
- else:
- raise ValueError(f"Unknown pattern type: {pattern_type}")
- matches = list()
- for name, pat in patterns.items():
- match = pat.search(action)
- if match:
- matches.append(match)
- if len(matches) == 0:
- return None
- matches = sorted(matches, key=lambda x: x.start())
- return matches[0]
- def _guard_multiline_input(self, action: str) -> str:
- """Split action by multiline commands, then append the first line in each multiline command with "<< '{end_name}'".
- Multiline commands (which are specified by an end_name) are commands that span multiple lines and are terminated by a specific end_name.
- Their multi-line argument is sent using a heredoc, which is a way to send a multi-line string to a command in bash.
- """
- parsed_action = list()
- rem_action = action
- while rem_action.strip():
- first_match = self._get_first_match(rem_action, "multi_line_no_subroutines")
- if first_match:
- pre_action = rem_action[: first_match.start()]
- match_action = rem_action[first_match.start() : first_match.end()]
- rem_action = rem_action[first_match.end() :]
- if pre_action.strip():
- parsed_action.append(pre_action)
- if match_action.strip():
- eof = first_match.group(3).strip()
- if not match_action.split("\n")[0].strip().endswith(f"<< '{eof}'"):
- guarded_command = match_action[first_match.start() :]
- first_line = guarded_command.split("\n")[0]
- guarded_command = guarded_command.replace(
- first_line, first_line + f" << '{eof}'", 1
- )
- parsed_action.append(guarded_command)
- else:
- parsed_action.append(match_action)
- else:
- parsed_action.append(rem_action)
- rem_action = ""
- return "\n".join(parsed_action)
- def split_actions(self, action: str, pattern_type="subroutine") -> List[Dict[str, Any]]:
- """Split an action into a list of actions in a greedy manner, each of which is a subroutine call or a single command."""
- parsed_action = list()
- rem_action = action
- while rem_action.strip():
- first_match = self._get_first_match(rem_action, pattern_type)
- if first_match:
- pre_action = rem_action[: first_match.start()]
- match_action = rem_action[first_match.start() : first_match.end()]
- rem_action = rem_action[first_match.end() :]
- if pre_action.strip():
- parsed_action.append(
- {"agent": self.name, "action": pre_action, "cmd_name": None}
- )
- if match_action.strip():
- if match_action.split()[0] == self.config.submit_command:
- parsed_action.append(
- {
- "agent": self.name,
- "action": match_action,
- "cmd_name": first_match.group(1),
- }
- ) # submit command is not a subroutine
- else:
- parsed_action.append(
- {
- "agent": first_match.group(1),
- "args": first_match.group(2),
- "action": match_action,
- "cmd_name": first_match.group(1),
- }
- )
- else:
- parsed_action.append(
- {"agent": self.name, "action": rem_action, "cmd_name": None}
- )
- rem_action = ""
- return parsed_action
- def _parse_command_patterns(self):
- assert self.config is not None # mypy
- self.command_patterns = dict()
- for command in self.config._commands:
- if command.end_name is not None:
- pat = re.compile(
- rf"^\s*({command.name})\s*(.*?)^({command.end_name})\s*$",
- re.DOTALL | re.MULTILINE,
- )
- self.command_patterns[command.name] = pat
- else:
- pat = re.compile(rf"^\s*({command.name})\s*(.*?)$", re.MULTILINE)
- self.command_patterns[command.name] = pat
- self.subroutine_patterns = dict()
- for _, subroutine in self.config._subroutines.items():
- if subroutine.end_name is None:
- pat = re.compile(rf"^\s*({subroutine.name})\s*(.*?)$", re.MULTILINE)
- self.subroutine_patterns[subroutine.name,] = pat
- else:
- pat = re.compile(
- rf"^\s*({subroutine.name})\s*(.*?)^({subroutine.end_name})\s*$",
- re.DOTALL | re.MULTILINE,
- )
- self.subroutine_patterns[subroutine.name] = pat
- if hasattr(self.config, "submit_command_end_name"):
- submit_pat = re.compile(
- rf"^\s*({self.config.submit_command})\s*(.*?)^({self.config.submit_command_end_name})\s*$",
- re.DOTALL | re.MULTILINE,
- )
- else:
- submit_pat = re.compile(
- rf"^\s*({self.config.submit_command})(\s*)$", re.MULTILINE
- ) # group 2 is nothing
- self.subroutine_patterns[self.config.submit_command] = submit_pat
- self.command_patterns[self.config.submit_command] = submit_pat
- def forward(
- self, observation: str, available_actions: list[str], state: str
- ) -> Tuple[str, str, str]:
- debug_time()
- thought, action, output = self.forward_with_error_check(observation, state)
- self.history.append(
- {
- "role": "assistant",
- "content": output,
- "thought": thought,
- "action": action,
- "agent": self.name,
- }
- )
- logger.info(f"💭 THOUGHT ({self.name})\n{thought}")
- logger.info(f"🎬 ACTION ({self.name})\n{action}")
- return thought, action, output
- def forward_model(self, observation: str, state: str) -> str:
- """Query the model with the current state and observation with the appropriate template.
- Returns the model output."""
- assert self.config is not None # mypy
- debug_time()
- state_vars = json.loads(state)
- templates: List[str] = []
- # Determine observation template based on what prior observation was
- if self.history[-1]["role"] == "system" or self.history[-1].get(
- "is_demo", False
- ):
- # Show instance template if prev. obs. was initial system message
- templates = [self.config.instance_template]
- if self.config.strategy_template is not None:
- templates.append(self.config.strategy_template)
- elif observation is None or observation.strip() == "":
- # Show no output template if observation content was empty
- templates = [self.config.next_step_no_output_template]
- else:
- # Show standard output template if there is observation content
- templates = [self.config.next_step_template]
- # Populate selected template(s) with information (e.g., issue, arguments, state)
- messages = []
- for template in templates:
- messages.append(
- template.format(
- **self.instance_args,
- **self.system_args,
- **state_vars,
- observation=(observation if observation is not None else ""),
- )
- )
- message = "\n".join(messages)
- logger.info(f"🤖 MODEL INPUT\n{message}")
- self.history.append({"role": "user", "content": message, "agent": self.name})
- debug_time()
- return self.model.query(self.local_history)
- def retry_after_format_fail(self, output):
- """Ask the model to correct (without committing to persistent history) after a malformatted model output"""
- format_error_template = self.config.format_error_template
- logger.warning(f"MALFORMED OUTPUT\n{output}")
- logger.warning(f"FORMAT ERROR\n{format_error_template}")
- temp_history = self.local_history + [
- {"role": "assistant", "content": output, "agent": self.name},
- {"role": "user", "content": format_error_template, "agent": self.name},
- ]
- return self.model.query(temp_history)
- def retry_after_blocklist_fail(self, output, action):
- """Ask the model to correct (without committing to persistent history) after a disallowed command"""
- name = action.strip().split()[0]
- blocklist_error_message = self.config.blocklist_error_template.format(name=name)
- logger.warning(f"BLOCKLISTED OUTPUT\n{output}")
- logger.warning(f"BLOCKLIST ERROR\n{blocklist_error_message}")
- temp_history = self.local_history + [
- {"role": "assistant", "content": output, "agent": self.name},
- {"role": "user", "content": blocklist_error_message, "agent": self.name},
- ]
- return self.model.query(temp_history)
- def should_block_action(self, action):
- """Check if the command should be blocked."""
- names = action.strip().split()
- if len(names) == 0:
- return False
- name = names[0]
- if name in self.config.blocklist:
- return True
- if name in self.config.blocklist_standalone and name == action.strip():
- return True
- return False
- def check_format_and_requery(
- self,
- output: str,
- ) -> Tuple[str, str, str]:
- """Query the model with the current state and observation with the appropriate template.
- Try to parse the output into a thought and action. Retry if the output is malformatted or the action is blocked.
- Returns the thought, action, and raw model output.
- """
- debug_time()
- # Condition for handling outputs with no thought (just action)
- if self.model.args.model_name == "human":
- return "", output, output
- elif self.model.args.model_name == "human_thought":
- thought, action = ParseFunction.get("ThoughtActionParser")(
- output,
- self.config._commands + self.config.subroutine_types,
- strict=False,
- )
- return thought, action, output
- format_fails = blocklist_fails = 0
- while format_fails + blocklist_fails <= 2:
- try:
- thought, action = self.config.parse_function(
- output,
- self.config._commands + self.config.subroutine_types,
- strict=False,
- )
- except KeyboardInterrupt:
- raise
- except FormatError:
- format_fails += 1
- output = self.retry_after_format_fail(output)
- continue
- if self.should_block_action(action):
- blocklist_fails += 1
- output = self.retry_after_blocklist_fail(output, action)
- else:
- return thought, action, output
- logger.debug("format fails %d, blocklist fails %d", format_fails, blocklist_fails)
- logger.warning(f"Malformat limit reached: \n{output}")
- debug_time()
- return "Exit due to format error", "exit_format", output
- def forward_with_error_check(
- self, observation: str, state: str
- ) -> Tuple[str, str, str]:
- """Wrapper around `self.forward_model` that handles errors and retries
- due to format errors or blocked actions.
- """
- try:
- output = self.forward_model(observation, state)
- except KeyboardInterrupt:
- raise
- except RuntimeError as e:
- logger.warning(f"Runtime error: {e}")
- return (
- f"Exit due to runtime error: {e}",
- "exit_error",
- f"exit due to runtime error: {e}",
- )
- except ContextWindowExceededError:
- logger.warning(f"Context window exceeded")
- return "Exit due to context window", "exit_context", "Exit due to context window"
- except CostLimitExceededError:
- logger.warning(f"Cost limit exceeded")
- return "Exit due to cost limit", "exit_cost", "Exit due to cost limit"
- except RetryError as e:
- logger.warning(f"Retry error: {e}")
- return (
- f"Exit due to retry error: {e}",
- "exit_api",
- f"exit due to retry error: {e}",
- )
- return self.check_format_and_requery(output)
- def init_environment_vars(self, env):
- self.set_environment_vars(env, self.config.env_variables)
- def set_environment_vars(self, env, env_variables):
- assert self.config is not None # mypy
- commands_to_execute = (
- [self.config.state_command.code]
- +
- # [code for code in self.config.util_functions] +
- # [command.code for command in self.config._commands] +
- [f"{k}={v}" for k, v in env_variables.items()]
- )
- commands = "\n".join(commands_to_execute)
- try:
- output = env.communicate(commands)
- if env.returncode != 0:
- raise RuntimeError(
- f"Nonzero return code: {env.returncode}\nOutput: {output}"
- )
- except KeyboardInterrupt:
- raise
- except Exception as e:
- logger.warning("Failed to set environment variables")
- raise e
- command_files = list()
- for file in self.config.command_files:
- datum = dict()
- contents = open(file, "r").read()
- datum["contents"] = contents
- filename = Path(file).name
- if not contents.strip().startswith("#!"):
- if filename.endswith(".sh"):
- # files are sourced, so they are not executable
- datum["name"] = Path(file).name
- datum["type"] = "source_file"
- elif filename.startswith("_"):
- # files are sourced, so they are not executable
- datum["name"] = Path(file).name
- datum["type"] = "utility"
- else:
- raise ValueError(
- (
- f"Non-shell script file {file} does not start with shebang.\n"
- "Either add a shebang (#!) or change the file extension to .sh if you want to source it.\n"
- "You can override this behavior by adding an underscore to the file name (e.g. _utils.py)."
- )
- )
- else:
- # scripts are made executable
- datum["name"] = Path(file).name.rsplit(".", 1)[0]
- datum["type"] = "script"
- command_files.append(datum)
- env.add_commands(command_files)
- def get_environment_vars(self, env):
- assert self.config is not None # mypy
- env_vars = dict()
- for var in self.config.env_variables:
- env_vars[var] = env.communicate(f"echo ${var}").strip()
- return env_vars
- def call_subroutine(self, agent_name, sub_action, env):
- assert self.config is not None # mypy
- env_vars = self.get_environment_vars(env)
- cwd = env.communicate("pwd -P").strip()
- init_observation = self.config._subroutines[agent_name].init_observation
- if init_observation is not None:
- obs, _, _, _ = env.step(init_observation.format(args=sub_action["args"]))
- else:
- obs = None
- if env.returncode != 0:
- self.history.append({"role": "user", "content": obs, "agent": agent_name})
- raise RuntimeError(
- f"Nonzero return code: {env.returncode} for init_observation in {agent_name}.\n{obs}"
- )
- return_type = self.config._subroutines[agent_name].return_type
- sub_agent = Agent(agent_name, self.config._subroutines[agent_name].agent_args)
- sub_agent_output = sub_agent.run(
- {"issue": sub_action["args"]},
- env,
- observation=obs,
- return_type=return_type,
- init_model_stats=self.model.stats,
- )
- self.history += sub_agent.history
- self.set_environment_vars(env, env_vars)
- env.communicate(f"cd {cwd}")
- self.model.stats.replace(sub_agent.model.stats)
- return sub_agent_output
- def run(
- self,
- setup_args: Dict[str, Any],
- env: SWEEnv,
- observation: Optional[str] = None,
- traj_dir: Optional[Path] = None,
- return_type: Optional[str] = "info",
- init_model_stats: Optional[APIStats] = None,
- ):
- """
- Run the agent on an environment.
- Return the final value of the specified return type.
- """
- done = False
- assert env.container_obj is not None
- assert self.config is not None # mypy
- if env.container_obj.id != self.last_container_id:
- logger.info(
- f"Initializing agent settings for container {env.container_obj.id}"
- )
- self.init_environment_vars(env)
- self.last_container_id = env.container_obj.id
- # Re-initialize primary
- self.setup(setup_args, init_model_stats)
- # Run action/observation loop
- trajectory = []
- info = {}
- while not done:
- debug_time()
- state = env.communicate(self.state_command) if self.state_command else None
- debug_time("state cmd")
- thought, action, output = self.forward(
- observation,
- env.get_available_actions(),
- state)
- debug_time("reset before postproc")
- observations = list()
- run_action = self._guard_multiline_input(action)
- debug_time("guard multiline")
- for sub_action in self.split_actions(run_action):
- if (
- sub_action["agent"] == self.name
- or sub_action["cmd_name"] == self.config.submit_command
- ):
- obs, _, done, info = env.step(sub_action["action"])
- observations.append(obs)
- if sub_action["cmd_name"] == self.config.submit_command:
- done = True
- if done:
- break
- else:
- agent_name = sub_action["agent"]
- sub_agent_output = self.call_subroutine(agent_name, sub_action, env)
- observations.append(sub_agent_output)
- observation = "\n".join([obs for obs in observations if obs is not None])
- debug_time("got observation")
- trajectory.append(
- {
- "action": action,
- "observation": observation,
- "response": output,
- "state": state,
- "thought": thought,
- }
- )
- info["model_stats"] = self.model.stats.to_dict()
- if traj_dir:
- self.save_trajectory(trajectory, traj_dir, env, info)
- debug_time("traj saved")
- if return_type == "info":
- return info
- if return_type == "info_trajectory":
- return info, trajectory
- return trajectory[-1][return_type]
|