summarizer.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. from __future__ import annotations
  2. import tempfile
  3. import textwrap
  4. import traceback
  5. from abc import abstractmethod
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Any
  9. from simple_parsing.helpers.serialization.serializable import FrozenSerializable
  10. from sweagent.agent.models import APIStats, BaseModel, ContextWindowExceededError, ModelArguments
  11. from sweagent.environment.swe_env import SWEEnv
  12. from sweagent.environment.utils import copy_anything_to_container
  13. from sweagent.utils.log import get_logger
  14. @dataclass(frozen=True)
  15. class SummarizerConfig(FrozenSerializable):
  16. """The configuration for the summarizer"""
  17. function: str = "Identity"
  18. window_length: int = 105
  19. template: str | None = None
  20. model: ModelArguments | None = None
  21. system_template: str | None = None
  22. instance_template: str | None = None
  23. def __post_init__(self):
  24. object.__setattr__(self, "function", SummarizeFunction.get(self.function, self.window_length)) # type: ignore
  25. if isinstance(self.model, dict):
  26. object.__setattr__(self, "model", ModelArguments.from_dict(self.summarizer_model)) # type: ignore
  27. # ABSTRACT BASE CLASSES
  28. class SummarizeFunctionMeta(type):
  29. """
  30. Registry maps all inherited classes to their names.
  31. """
  32. _warning_message = None
  33. _registry = {}
  34. def __new__(cls, name, bases, attrs):
  35. new_cls = super().__new__(cls, name, bases, attrs)
  36. if name != "SummarizeFunction":
  37. cls._registry[name] = new_cls
  38. return new_cls
  39. @dataclass
  40. class SummarizeFunction(metaclass=SummarizeFunctionMeta):
  41. """
  42. Abstract class for summarizing functions.
  43. We use get to generate the right summarizer based on the name of the summarizer.
  44. """
  45. def __init__(self, window_length: int):
  46. self._window_length = window_length
  47. self.logger = get_logger("summarizer")
  48. def setup(self, instance_args: dict[str, Any], config):
  49. """
  50. Additional setup function for the summarizer.
  51. """
  52. pass
  53. @staticmethod
  54. def _slugify_action(action: str) -> str:
  55. return "".join(c if c.isalnum() else "_" for c in action)[:50]
  56. @staticmethod
  57. def _upload_file_to_container(file_content: str, file_path_on_container: str, env: SWEEnv):
  58. assert env.container_obj is not None
  59. env.communicate(f'mkdir -p "{Path(file_path_on_container).parent}"')
  60. with tempfile.NamedTemporaryFile() as fp:
  61. fp.write(file_content.encode("utf-8"))
  62. fp.flush()
  63. copy_anything_to_container(env.container_obj, fp.name, file_path_on_container)
  64. @abstractmethod
  65. def __call__(self, input: str, observation, env: SWEEnv, model: type[BaseModel]) -> tuple[str, APIStats]:
  66. """
  67. Abstract method for getting an observation and summarize it.
  68. The returned value should be a summation of the given observation.
  69. """
  70. raise NotImplementedError
  71. @classmethod
  72. def get(cls, name: str, window_length: int):
  73. try:
  74. return cls._registry[name](window_length)
  75. except KeyError:
  76. msg = f"Model output summarizer ({name}) not found."
  77. raise ValueError(msg)
  78. # DEFINE NEW SUMMARIZE FUNCTIONS BELOW THIS LINE
  79. class SimpleSummarizer(SummarizeFunction):
  80. """
  81. Saves the output of the command to a file and uses the open command to show the output.
  82. """
  83. _warning_message = """\
  84. Warning: Command output exceeded window, saved command to a file {command_file_name} and opened the file at line 1.
  85. """
  86. block_list_input = [
  87. "create",
  88. "open",
  89. "edit",
  90. "scroll_up",
  91. "scroll_down",
  92. "goto",
  93. "search_file",
  94. "search_dir",
  95. ]
  96. def __call__(self, input: str, observation: str, env: SWEEnv, model: BaseModel) -> tuple[str, APIStats]:
  97. try:
  98. if (
  99. any(input.startswith(s) for s in self.block_list_input)
  100. or len(observation.splitlines()) <= self._window_length
  101. ):
  102. return observation, APIStats()
  103. self.logger.debug(f"Summarizing current observation for input {input}")
  104. command_file_name = "/output/" + self._slugify_action(input)
  105. self._upload_file_to_container(observation, command_file_name, env)
  106. return textwrap.dedent(self._warning_message.format(command_file_name=command_file_name)) + env.communicate(
  107. f"open {command_file_name}"
  108. ), APIStats()
  109. except Exception:
  110. self.logger.warning(
  111. f"Unhandled exception occurred when trying to summarize observation for input {input}: {traceback.format_exc()}"
  112. )
  113. return observation, APIStats()
  114. class Identity(SummarizeFunction):
  115. """
  116. This summarizer does not do any summation. It returns the environment observation as is.
  117. """
  118. def __call__(self, input: str, observation: str, env: SWEEnv, model: type[BaseModel]) -> tuple[str, APIStats]:
  119. """
  120. This doesn't do any summarization. It just returns the environment observation.
  121. """
  122. return observation, APIStats()
  123. class LMSummarizer(SummarizeFunction):
  124. _warning_message = """\
  125. Warning: Command output exceeded window size, saved command to a file {command_file_name} and summarized the command output for you.
  126. If you still want to view the output of the command, use the following command `open {command_file_name}`.
  127. SUMMARY:
  128. """
  129. _warning_message_summarization_failed = """\
  130. Warning: Command output exceeded window size, saved command to a file {command_file_name}.
  131. If you still want to view the output of the command, use the following command `open {command_file_name}`.
  132. """
  133. block_list_input = [
  134. "create",
  135. "open",
  136. "edit",
  137. "scroll_up",
  138. "scroll_down",
  139. "goto",
  140. "search_file",
  141. "search_dir",
  142. ]
  143. fail_back_to_simple_summarizer_input = [
  144. "xxd",
  145. "hexdump",
  146. "strings",
  147. ]
  148. lm_summarizer_char_limit = 200000
  149. def __init__(self, window_length: int):
  150. super().__init__(window_length)
  151. self.history = []
  152. self._simple_summarizer = SimpleSummarizer(window_length=window_length)
  153. def setup(self, instance_args: dict[str, Any], config):
  154. self.name = "ctf_summarizer"
  155. self.system_args = config.__dict__
  156. self.system_args.update({f"summarizer_{k}": v for k, v in config.summarizer_config.__dict__.items()})
  157. system_msg = config.summarizer_config.system_template.format(**self.system_args)
  158. self.history.append({"role": "system", "content": system_msg, "agent": self.name})
  159. self.logger.info(f"SYSTEM ({self.name})\n{system_msg}")
  160. self.instance_template = config.summarizer_config.instance_template
  161. self.instance_args = instance_args
  162. def __call__(self, input: str, observation: str, env: SWEEnv, model: BaseModel) -> tuple[str, APIStats]:
  163. try:
  164. if (
  165. any(input.startswith(s) for s in self.block_list_input)
  166. or len(observation.splitlines()) <= self._window_length
  167. ):
  168. return observation, APIStats()
  169. if len(observation) > self.lm_summarizer_char_limit or any(
  170. input.startswith(s) for s in self.fail_back_to_simple_summarizer_input
  171. ):
  172. self.logger.warning("Observation is too long for LMSummarizer, using SimpleSummarizer instead")
  173. return self._simple_summarizer(input, observation, env, model)
  174. self.logger.debug(f"Summarizing current observation for input {input}")
  175. command_file_name = "/output/" + self._slugify_action(input)
  176. self._upload_file_to_container(observation, command_file_name, env)
  177. self.history.append(
  178. {
  179. "role": "user",
  180. "content": self.instance_template.format(
  181. **self.instance_args, **self.system_args, command=input, observation=observation
  182. ),
  183. "agent": self.name,
  184. }
  185. )
  186. response = model.query(history=self.history)
  187. stats = model.stats
  188. model.reset_stats(APIStats())
  189. self.history.pop()
  190. return textwrap.dedent(self._warning_message.format(command_file_name=command_file_name)) + response, stats
  191. except ContextWindowExceededError:
  192. return textwrap.dedent(
  193. self._warning_message_summarization_failed.format(command_file_name=command_file_name)
  194. ), APIStats()
  195. except Exception:
  196. self.logger.warning(
  197. f"Unhandled exception occurred when trying to summarize observation for input {input}: {traceback.format_exc()}"
  198. )
  199. return observation, APIStats()