models.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710
  1. import config
  2. import json
  3. import logging
  4. import os
  5. import together
  6. from collections import defaultdict
  7. from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
  8. from dataclasses import dataclass, fields
  9. from openai import BadRequestError, OpenAI, AzureOpenAI
  10. from simple_parsing.helpers.serialization.serializable import FrozenSerializable, Serializable
  11. from sweagent.agent.commands import Command
  12. from tenacity import (
  13. retry,
  14. stop_after_attempt,
  15. wait_random_exponential,
  16. retry_if_not_exception_type,
  17. )
  18. from typing import Optional, Union
  19. logger = logging.getLogger("api_models")
  20. @dataclass(frozen=True)
  21. class ModelArguments(FrozenSerializable):
  22. """Arguments configuring the model and its behavior."""
  23. model_name: str
  24. per_instance_cost_limit: float = 0.0
  25. total_cost_limit: float = 0.0
  26. temperature: float = 1.0
  27. top_p: float = 1.0
  28. replay_path: str = None
  29. host_url: str = "localhost:11434"
  30. @dataclass
  31. class APIStats(Serializable):
  32. total_cost: float = 0
  33. instance_cost: float = 0
  34. tokens_sent: int = 0
  35. tokens_received: int = 0
  36. api_calls: int = 0
  37. def __add__(self, other):
  38. if not isinstance(other, APIStats):
  39. raise TypeError("Can only add APIStats with APIStats")
  40. return APIStats(**{
  41. field.name: getattr(self, field.name) + getattr(other, field.name)
  42. for field in fields(self)
  43. })
  44. def replace(self, other):
  45. if not isinstance(other, APIStats):
  46. raise TypeError("Can only replace APIStats with APIStats")
  47. return APIStats(**{
  48. field.name: getattr(other, field.name)
  49. for field in fields(self)
  50. })
  51. class ContextWindowExceededError(Exception):
  52. pass
  53. class CostLimitExceededError(Exception):
  54. pass
  55. class BaseModel:
  56. MODELS = {}
  57. SHORTCUTS = {}
  58. def __init__(self, args: ModelArguments, commands: list[Command]):
  59. self.args = args
  60. self.commands = commands
  61. self.model_metadata = {}
  62. self.stats = APIStats()
  63. # Map `model_name` to API-compatible name `api_model`
  64. self.api_model = (
  65. self.SHORTCUTS[self.args.model_name]
  66. if self.args.model_name in self.SHORTCUTS
  67. else self.args.model_name
  68. )
  69. # Map model name to metadata (cost, context info)
  70. MODELS = {
  71. **{dest: self.MODELS[src] for dest, src in self.SHORTCUTS.items()},
  72. **self.MODELS,
  73. }
  74. if args.model_name in MODELS:
  75. self.model_metadata = MODELS[args.model_name]
  76. elif args.model_name.startswith("ft:"):
  77. ft_model = args.model_name.split(":")[1]
  78. self.model_metadata = MODELS[ft_model]
  79. elif args.model_name.startswith("ollama:"):
  80. self.api_model = args.model_name.split('ollama:', 1)[1]
  81. self.model_metadata = self.MODELS[self.api_model]
  82. elif args.model_name.startswith("azure:"):
  83. azure_model = args.model_name.split("azure:", 1)[1]
  84. self.model_metadata = MODELS[azure_model]
  85. else:
  86. raise ValueError(f"Unregistered model ({args.model_name}). Add model name to MODELS metadata to {self.__class__}")
  87. def reset_stats(self, other: Optional[APIStats] = None):
  88. if other is None:
  89. self.stats = APIStats(total_cost=self.stats.total_cost)
  90. logger.info("Resetting model stats")
  91. else:
  92. self.stats = other
  93. def update_stats(self, input_tokens: int, output_tokens: int) -> float:
  94. """
  95. Calculates the cost of a response from the openai API.
  96. Args:
  97. input_tokens (int): The number of tokens in the prompt.
  98. output_tokens (int): The number of tokens in the response.
  99. Returns:
  100. float: The cost of the response.
  101. """
  102. # Calculate cost and update cost related fields
  103. cost = (
  104. self.model_metadata["cost_per_input_token"] * input_tokens
  105. + self.model_metadata["cost_per_output_token"] * output_tokens
  106. )
  107. self.stats.total_cost += cost
  108. self.stats.instance_cost += cost
  109. self.stats.tokens_sent += input_tokens
  110. self.stats.tokens_received += output_tokens
  111. self.stats.api_calls += 1
  112. # Log updated cost values to std. out.
  113. logger.info(
  114. f"input_tokens={input_tokens:_}, "
  115. f"output_tokens={output_tokens:_}, "
  116. f"instance_cost={self.stats.instance_cost:.2f}, "
  117. f"cost={cost:.2f}"
  118. )
  119. logger.info(
  120. f"total_tokens_sent={self.stats.tokens_sent:_}, "
  121. f"total_tokens_received={self.stats.tokens_received:_}, "
  122. f"total_cost={self.stats.total_cost:.2f}, "
  123. f"total_api_calls={self.stats.api_calls:_}"
  124. )
  125. # Check whether total cost or instance cost limits have been exceeded
  126. if 0 < self.args.total_cost_limit <= self.stats.total_cost:
  127. logger.warning(
  128. f"Cost {self.stats.total_cost:.2f} exceeds limit {self.args.total_cost_limit:.2f}"
  129. )
  130. raise CostLimitExceededError("Total cost limit exceeded")
  131. if 0 < self.args.per_instance_cost_limit <= self.stats.instance_cost:
  132. logger.warning(
  133. f"Cost {self.stats.instance_cost:.2f} exceeds limit {self.args.per_instance_cost_limit:.2f}"
  134. )
  135. raise CostLimitExceededError("Instance cost limit exceeded")
  136. return cost
  137. def query(self, history: list[dict[str, str]]) -> str:
  138. raise NotImplementedError("Use a subclass of BaseModel")
  139. class OpenAIModel(BaseModel):
  140. MODELS = {
  141. "gpt-3.5-turbo-0125": {
  142. "max_context": 16_385,
  143. "cost_per_input_token": 5e-07,
  144. "cost_per_output_token": 1.5e-06,
  145. },
  146. "gpt-3.5-turbo-1106": {
  147. "max_context": 16_385,
  148. "cost_per_input_token": 1.5e-06,
  149. "cost_per_output_token": 2e-06,
  150. },
  151. "gpt-3.5-turbo-16k-0613": {
  152. "max_context": 16_385,
  153. "cost_per_input_token": 1.5e-06,
  154. "cost_per_output_token": 2e-06,
  155. },
  156. "gpt-4-32k-0613": {
  157. "max_context": 32_768,
  158. "cost_per_input_token": 6e-05,
  159. "cost_per_output_token": 0.00012,
  160. },
  161. "gpt-4-0613": {
  162. "max_context": 8_192,
  163. "cost_per_input_token": 3e-05,
  164. "cost_per_output_token": 6e-05,
  165. },
  166. "gpt-4-1106-preview": {
  167. "max_context": 128_000,
  168. "cost_per_input_token": 1e-05,
  169. "cost_per_output_token": 3e-05,
  170. },
  171. "gpt-4-0125-preview": {
  172. "max_context": 128_000,
  173. "cost_per_input_token": 1e-05,
  174. "cost_per_output_token": 3e-05,
  175. },
  176. }
  177. SHORTCUTS = {
  178. "gpt3": "gpt-3.5-turbo-1106",
  179. "gpt3-legacy": "gpt-3.5-turbo-16k-0613",
  180. "gpt4": "gpt-4-1106-preview",
  181. "gpt4-legacy": "gpt-4-0613",
  182. "gpt4-0125": "gpt-4-0125-preview",
  183. "gpt3-0125": "gpt-3.5-turbo-0125",
  184. }
  185. def __init__(self, args: ModelArguments, commands: list[Command]):
  186. super().__init__(args, commands)
  187. # Set OpenAI key
  188. cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
  189. if self.args.model_name.startswith("azure"):
  190. self.api_model = cfg["AZURE_OPENAI_DEPLOYMENT"]
  191. self.client = AzureOpenAI(api_key=cfg["AZURE_OPENAI_API_KEY"], azure_endpoint=cfg["AZURE_OPENAI_ENDPOINT"], api_version=cfg.get("AZURE_OPENAI_API_VERSION", "2024-02-01"))
  192. else:
  193. api_base_url: Optional[str] = cfg.get("OPENAI_API_BASE_URL", None)
  194. self.client = OpenAI(api_key=cfg["OPENAI_API_KEY"], base_url=api_base_url)
  195. def history_to_messages(
  196. self, history: list[dict[str, str]], is_demonstration: bool = False
  197. ) -> Union[str, list[dict[str, str]]]:
  198. """
  199. Create `messages` by filtering out all keys except for role/content per `history` turn
  200. """
  201. # Remove system messages if it is a demonstration
  202. if is_demonstration:
  203. history = [entry for entry in history if entry["role"] != "system"]
  204. return '\n'.join([entry["content"] for entry in history])
  205. # Return history components with just role, content fields
  206. return [
  207. {k: v for k, v in entry.items() if k in ["role", "content"]}
  208. for entry in history
  209. ]
  210. @retry(
  211. wait=wait_random_exponential(min=1, max=15),
  212. reraise=True,
  213. stop=stop_after_attempt(3),
  214. retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
  215. )
  216. def query(self, history: list[dict[str, str]]) -> str:
  217. """
  218. Query the OpenAI API with the given `history` and return the response.
  219. """
  220. try:
  221. # Perform OpenAI API call
  222. response = self.client.chat.completions.create(
  223. messages=self.history_to_messages(history),
  224. model=self.api_model,
  225. temperature=self.args.temperature,
  226. top_p=self.args.top_p,
  227. )
  228. except BadRequestError:
  229. raise CostLimitExceededError(f"Context window ({self.model_metadata['max_context']} tokens) exceeded")
  230. # Calculate + update costs, return response
  231. input_tokens = response.usage.prompt_tokens
  232. output_tokens = response.usage.completion_tokens
  233. self.update_stats(input_tokens, output_tokens)
  234. return response.choices[0].message.content
  235. class AnthropicModel(BaseModel):
  236. MODELS = {
  237. "claude-instant": {
  238. "max_context": 100_000,
  239. "cost_per_input_token": 1.63e-06,
  240. "cost_per_output_token": 5.51e-06,
  241. },
  242. "claude-2": {
  243. "max_context": 100_000,
  244. "cost_per_input_token": 1.102e-05,
  245. "cost_per_output_token": 3.268e-05,
  246. },
  247. "claude-2.1": {
  248. "max_context": 100_000,
  249. "cost_per_input_token": 1.102e-05,
  250. "cost_per_output_token": 3.268e-05,
  251. },
  252. "claude-3-opus-20240229": {
  253. "max_context": 200_000,
  254. "max_tokens": 4096, # Max tokens to generate for Claude 3 models
  255. "cost_per_input_token": 1.5e-05,
  256. "cost_per_output_token": 7.5e-05,
  257. },
  258. "claude-3-sonnet-20240229": {
  259. "max_context": 200_000,
  260. "max_tokens": 4096,
  261. "cost_per_input_token": 3e-06,
  262. "cost_per_output_token": 1.5e-05,
  263. },
  264. "claude-3-haiku-20240307": {
  265. "max_context": 200_000,
  266. "max_tokens": 4096,
  267. "cost_per_input_token": 2.5e-07,
  268. "cost_per_output_token": 1.25e-06,
  269. },
  270. }
  271. SHORTCUTS = {
  272. "claude": "claude-2",
  273. "claude-opus": "claude-3-opus-20240229",
  274. "claude-sonnet": "claude-3-sonnet-20240229",
  275. "claude-haiku": "claude-3-haiku-20240307",
  276. }
  277. def __init__(self, args: ModelArguments, commands: list[Command]):
  278. super().__init__(args, commands)
  279. # Set Anthropic key
  280. cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
  281. self.api = Anthropic(api_key=cfg["ANTHROPIC_API_KEY"])
  282. def history_to_messages(
  283. self, history: list[dict[str, str]], is_demonstration: bool = False
  284. ) -> Union[str, list[dict[str, str]]]:
  285. """
  286. Create `prompt` by filtering out all keys except for role/content per `history` turn
  287. Reference: https://docs.anthropic.com/claude/reference/complete_post
  288. """
  289. # Preserve behavior for older models
  290. if self.api_model in ["claude-instant", "claude-2"]:
  291. # Remove system messages if it is a demonstration
  292. if is_demonstration:
  293. history = [entry for entry in history if entry["role"] != "system"]
  294. # Map history to Claude format
  295. prompt = "\n\n"
  296. for entry in history:
  297. if entry["role"] in {"user", "system"}:
  298. prompt += f'{HUMAN_PROMPT} {entry["content"]}\n\n'
  299. elif entry["role"] == "assistant":
  300. prompt += f'{AI_PROMPT} {entry["content"]}\n\n'
  301. prompt += AI_PROMPT
  302. return prompt
  303. # Remove system messages if it is a demonstration
  304. if is_demonstration:
  305. history = [entry for entry in history if entry["role"] != "system"]
  306. return '\n'.join([entry["content"] for entry in history])
  307. # Return history components with just role, content fields (no system message)
  308. messages = [
  309. {
  310. k: v for k, v in entry.items()
  311. if k in ["role", "content"]
  312. }
  313. for entry in history if entry["role"] != "system"
  314. ]
  315. compiled_messages = [] # Combine messages from the same role
  316. last_role = None
  317. for message in reversed(messages):
  318. if last_role == message["role"]:
  319. compiled_messages[-1]["content"] = message["content"] + "\n" + compiled_messages[-1]["content"]
  320. else:
  321. compiled_messages.append(message)
  322. last_role = message["role"]
  323. compiled_messages = list(reversed(compiled_messages))
  324. # Replace any empty content values with a "(No output)"
  325. for message in compiled_messages:
  326. if message["content"].strip() == "":
  327. message["content"] = "(No output)"
  328. return compiled_messages
  329. @retry(
  330. wait=wait_random_exponential(min=1, max=15),
  331. reraise=True,
  332. stop=stop_after_attempt(3),
  333. retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
  334. )
  335. def query(self, history: list[dict[str, str]]) -> str:
  336. """
  337. Query the Anthropic API with the given `history` and return the response.
  338. """
  339. # Preserve behavior for older models
  340. if self.api_model in ["claude-instant", "claude-2"]:
  341. # Perform Anthropic API call
  342. prompt = self.history_to_messages(history)
  343. input_tokens = self.api.count_tokens(prompt)
  344. completion = self.api.completions.create(
  345. model=self.api_model,
  346. prompt=prompt,
  347. max_tokens_to_sample=self.model_metadata["max_context"] - input_tokens,
  348. temperature=self.args.temperature,
  349. top_p=self.args.top_p,
  350. )
  351. # Calculate + update costs, return response
  352. response = completion.completion
  353. output_tokens = self.api.count_tokens(response)
  354. self.update_stats(input_tokens, output_tokens)
  355. return response
  356. # Get system message(s)
  357. system_message = "\n".join([
  358. entry["content"] for entry in history if entry["role"] == "system"
  359. ])
  360. messages = self.history_to_messages(history)
  361. # Perform Anthropic API call
  362. response = self.api.messages.create(
  363. messages=messages,
  364. max_tokens=self.model_metadata["max_tokens"],
  365. model=self.api_model,
  366. temperature=self.args.temperature,
  367. top_p=self.args.top_p,
  368. system=system_message,
  369. )
  370. # Calculate + update costs, return response
  371. self.update_stats(
  372. response.usage.input_tokens,
  373. response.usage.output_tokens
  374. )
  375. response = "\n".join([x.text for x in response.content])
  376. return response
  377. class OllamaModel(BaseModel):
  378. MODELS = defaultdict(lambda: {
  379. "max_context": 128_000,
  380. "cost_per_input_token": 0,
  381. "cost_per_output_token": 0,
  382. })
  383. def __init__(self, args: ModelArguments, commands: list[Command]):
  384. super().__init__(args, commands)
  385. from ollama import Client
  386. self.client = Client(host=args.host_url)
  387. def history_to_messages(
  388. self, history: list[dict[str, str]], is_demonstration: bool = False
  389. ) -> Union[str, list[dict[str, str]]]:
  390. """
  391. Create `messages` by filtering out all keys except for role/content per `history` turn
  392. """
  393. # Remove system messages if it is a demonstration
  394. if is_demonstration:
  395. history = [entry for entry in history if entry["role"] != "system"]
  396. return '\n'.join([entry["content"] for entry in history])
  397. # Return history components with just role, content fields
  398. return [
  399. {k: v for k, v in entry.items() if k in ["role", "content"]}
  400. for entry in history
  401. ]
  402. @retry(
  403. wait=wait_random_exponential(min=1, max=15),
  404. reraise=True,
  405. stop=stop_after_attempt(3),
  406. retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
  407. )
  408. def query(self, history: list[dict[str, str]]) -> str:
  409. """
  410. Query the Ollama API with the given `history` and return the response.
  411. """
  412. response = self.client.chat(
  413. model=self.api_model,
  414. messages=self.history_to_messages(history),
  415. options={
  416. "temperature": self.args.temperature,
  417. "top_p": self.args.top_p,
  418. }
  419. )
  420. # Calculate + update costs, return response
  421. if "prompt_eval_count" in response:
  422. input_tokens = response["prompt_eval_count"]
  423. else:
  424. logger.warning(
  425. "Prompt eval count not found in response. Using 0. "
  426. "This might be because the prompt has been cached. "
  427. "See https://github.com/princeton-nlp/SWE-agent/issues/44 "
  428. "and https://github.com/ollama/ollama/issues/3427."
  429. )
  430. input_tokens = 0
  431. output_tokens = response["eval_count"]
  432. self.update_stats(input_tokens, output_tokens)
  433. return response["message"]["content"]
  434. class TogetherModel(BaseModel):
  435. # Check https://docs.together.ai/docs/inference-models for model names, context
  436. # Check https://www.together.ai/pricing for pricing
  437. MODELS = {
  438. "meta-llama/Llama-2-13b-chat-hf": {
  439. "max_context": 4096,
  440. "cost_per_input_token": 2.25e-07,
  441. "cost_per_output_token": 2.25e-07,
  442. },
  443. "meta-llama/Llama-2-70b-chat-hf": {
  444. "max_context": 4096,
  445. "cost_per_input_token": 9e-07,
  446. "cost_per_output_token": 9e-07,
  447. },
  448. "mistralai/Mistral-7B-Instruct-v0.2": {
  449. "max_context": 32768,
  450. "cost_per_input_token": 2e-07,
  451. "cost_per_output_token": 2e-07,
  452. },
  453. "togethercomputer/RedPajama-INCITE-7B-Chat": {
  454. "max_context": 2048,
  455. "cost_per_input_token": 2e-07,
  456. "cost_per_output_token": 2e-07,
  457. },
  458. "mistralai/Mixtral-8x7B-Instruct-v0.1": {
  459. "max_context": 32768,
  460. "cost_per_input_token": 6e-07,
  461. "cost_per_output_token": 6e-07,
  462. },
  463. }
  464. SHORTCUTS = {
  465. "llama13b": "meta-llama/Llama-2-13b-chat-hf",
  466. "llama70b": "meta-llama/Llama-2-70b-chat-hf",
  467. "mistral7b": "mistralai/Mistral-7B-Instruct-v0.2",
  468. "mixtral8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
  469. "redpajama7b": "togethercomputer/RedPajama-INCITE-7B-Chat",
  470. }
  471. def __init__(self, args: ModelArguments, commands: list[Command]):
  472. super().__init__(args, commands)
  473. # Set Together key
  474. cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
  475. together.api_key = cfg.TOGETHER_API_KEY
  476. def history_to_messages(
  477. self, history: list[dict[str, str]], is_demonstration: bool = False
  478. ) -> str:
  479. """
  480. Create `prompt` by filtering out all keys except for role/content per `history` turn
  481. """
  482. # Remove system messages if it is a demonstration
  483. if is_demonstration:
  484. history = [entry for entry in history if entry["role"] != "system"]
  485. # Map history to TogetherAI format
  486. mapping = {"user": "human", "assistant": "bot", "system": "bot"}
  487. prompt = [f'<{mapping[d["role"]]}>: {d["content"]}' for d in history]
  488. prompt = "\n".join(prompt)
  489. prompt = f"{prompt}\n<bot>:"
  490. return prompt
  491. @retry(
  492. wait=wait_random_exponential(min=1, max=15),
  493. reraise=True,
  494. stop=stop_after_attempt(3),
  495. retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
  496. )
  497. def query(self, history: list[dict[str, str]]) -> str:
  498. """
  499. Query the Together API with the given `history` and return the response.
  500. """
  501. # Perform Together API call
  502. prompt = self.history_to_messages(history)
  503. completion = together.Complete.create(
  504. model=self.api_model,
  505. prompt=prompt,
  506. max_tokens=self.model_metadata["max_context"],
  507. stop="<human>",
  508. temperature=self.args.temperature,
  509. top_p=self.args.top_p,
  510. )
  511. # Calculate + update costs, return response
  512. response = completion["output"]["choices"][0]["text"].split("<human>")[0]
  513. input_tokens = completion["output"]["usage"]["prompt_tokens"]
  514. output_tokens = completion["output"]["usage"]["completion_tokens"]
  515. self.update_stats(input_tokens, output_tokens)
  516. return response
  517. class HumanModel(BaseModel):
  518. MODELS = {"human": {}}
  519. def __init__(self, args: ModelArguments, commands: list[Command]):
  520. super().__init__(args, commands)
  521. # Determine which commands require multi-line input
  522. self.multi_line_command_endings = {
  523. command.name: command.end_name
  524. for command in commands
  525. if command.end_name is not None
  526. }
  527. def history_to_messages(
  528. self, history: list[dict[str, str]], is_demonstration: bool = False
  529. ) -> Union[str, list[dict[str, str]]]:
  530. """
  531. Create `messages` by filtering out all keys except for role/content per `history` turn
  532. """
  533. # Remove system messages if it is a demonstration
  534. if is_demonstration:
  535. history = [entry for entry in history if entry["role"] != "system"]
  536. return '\n'.join([entry["content"] for entry in history])
  537. # Return history components with just role, content fields
  538. return [
  539. {k: v for k, v in entry.items() if k in ["role", "content"]}
  540. for entry in history
  541. ]
  542. def query(self, history: list[dict[str, str]], action_prompt: str = "> ") -> str:
  543. """
  544. Logic for handling user input to pass to SWEEnv
  545. """
  546. action = input(action_prompt)
  547. command_name = action.split()[0] if action else ""
  548. # Special handling for multi-line input actions (i.e. edit)
  549. if command_name in self.multi_line_command_endings:
  550. buffer = [action]
  551. end_keyword = self.multi_line_command_endings[command_name]
  552. while True:
  553. action = input("... ")
  554. buffer.append(action)
  555. if action.rstrip() == end_keyword:
  556. # Continue reading input until terminating keyword inputted
  557. break
  558. action = "\n".join(buffer)
  559. elif action.strip() == "start_multiline_command": # do arbitrary multi-line input
  560. buffer = []
  561. while True:
  562. action = input("... ")
  563. if action.rstrip() == "end_multiline_command":
  564. break
  565. buffer.append(action)
  566. action = "\n".join(buffer)
  567. return action
  568. class HumanThoughtModel(HumanModel):
  569. MODELS = {"human_thought": {}}
  570. def query(self, history: list[dict[str, str]]) -> str:
  571. """
  572. Logic for handling user input (both thought + action) to pass to SWEEnv
  573. """
  574. thought_all = ""
  575. thought = input("Thought (end w/ END_THOUGHT): ")
  576. while True:
  577. if "END_THOUGHT" in thought:
  578. thought = thought.split("END_THOUGHT")[0]
  579. thought_all += thought
  580. break
  581. thought_all += thought
  582. thought = input("... ")
  583. action = super().query(history, action_prompt="Action: ")
  584. return f"{thought_all}\n```\n{action}\n```"
  585. class ReplayModel(BaseModel):
  586. MODELS = {"replay": {}}
  587. def __init__(self, args: ModelArguments, commands: list[Command]):
  588. super().__init__(args, commands)
  589. if self.args.replay_path is None or not os.path.exists(self.args.replay_path):
  590. raise ValueError(
  591. "--replay_path must point to a file that exists to run a replay policy"
  592. )
  593. self.replays = [
  594. list(json.loads(x).values())[0]
  595. for x in open(self.args.replay_path, "r").readlines()
  596. ]
  597. self.replay_idx = 0
  598. self.action_idx = 0
  599. def query(self, history: list[dict[str, str]]) -> str:
  600. """
  601. Logic for tracking which replay action to pass to SWEEnv
  602. """
  603. action = self.replays[self.replay_idx][self.action_idx]
  604. self.action_idx += 1
  605. # Assuming `submit` is always last action of replay trajectory
  606. if action == "submit":
  607. self.replay_idx += 1
  608. self.action_idx = 0
  609. return action
  610. def get_model(args: ModelArguments, commands: Optional[list[Command]] = None):
  611. """
  612. Returns correct model object given arguments and commands
  613. """
  614. if commands is None:
  615. commands = []
  616. if args.model_name == "human":
  617. return HumanModel(args, commands)
  618. if args.model_name == "human_thought":
  619. return HumanThoughtModel(args, commands)
  620. if args.model_name == "replay":
  621. return ReplayModel(args, commands)
  622. elif args.model_name.startswith("gpt") or args.model_name.startswith("ft:gpt") or args.model_name.startswith("azure:gpt"):
  623. return OpenAIModel(args, commands)
  624. elif args.model_name.startswith("claude"):
  625. return AnthropicModel(args, commands)
  626. elif args.model_name.startswith("ollama"):
  627. return OllamaModel(args, commands)
  628. elif args.model_name in TogetherModel.SHORTCUTS:
  629. return TogetherModel(args, commands)
  630. else:
  631. raise ValueError(f"Invalid model name: {args.model_name}")