models.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. import config
  2. import json
  3. import logging
  4. import os
  5. import together
  6. from collections import defaultdict
  7. from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
  8. from dataclasses import dataclass, fields
  9. from openai import BadRequestError, OpenAI
  10. from simple_parsing.helpers import FrozenSerializable, Serializable
  11. from sweagent.agent.commands import Command
  12. from tenacity import (
  13. retry,
  14. stop_after_attempt,
  15. wait_random_exponential,
  16. retry_if_not_exception_type,
  17. )
  18. from typing import Optional
  19. logger = logging.getLogger("api_models")
  20. @dataclass(frozen=True)
  21. class ModelArguments(FrozenSerializable):
  22. model_name: str
  23. per_instance_cost_limit: float = 0.0
  24. total_cost_limit: float = 0.0
  25. temperature: float = 1.0
  26. top_p: float = 1.0
  27. replay_path: str = None
  28. host_url: str = "localhost:11434"
  29. @dataclass
  30. class APIStats(Serializable):
  31. total_cost: float = 0
  32. instance_cost: float = 0
  33. tokens_sent: int = 0
  34. tokens_received: int = 0
  35. api_calls: int = 0
  36. def __add__(self, other):
  37. if not isinstance(other, APIStats):
  38. raise TypeError("Can only add APIStats with APIStats")
  39. return APIStats(**{
  40. field.name: getattr(self, field.name) + getattr(other, field.name)
  41. for field in fields(self)
  42. })
  43. def replace(self, other):
  44. if not isinstance(other, APIStats):
  45. raise TypeError("Can only replace APIStats with APIStats")
  46. return APIStats(**{
  47. field.name: getattr(other, field.name)
  48. for field in fields(self)
  49. })
  50. class ContextWindowExceededError(Exception):
  51. pass
  52. class CostLimitExceededError(Exception):
  53. pass
  54. class BaseModel:
  55. MODELS = {}
  56. SHORTCUTS = {}
  57. def __init__(self, args: ModelArguments, commands: list[Command]):
  58. self.args = args
  59. self.commands = commands
  60. self.model_metadata = {}
  61. self.stats = APIStats()
  62. # Map `model_name` to API-compatible name `api_model`
  63. self.api_model = (
  64. self.SHORTCUTS[self.args.model_name]
  65. if self.args.model_name in self.SHORTCUTS
  66. else self.args.model_name
  67. )
  68. # Map model name to metadata (cost, context info)
  69. MODELS = {
  70. **{dest: self.MODELS[src] for dest, src in self.SHORTCUTS.items()},
  71. **self.MODELS,
  72. }
  73. if args.model_name in MODELS:
  74. self.model_metadata = MODELS[args.model_name]
  75. elif args.model_name.startswith("ft:"):
  76. ft_model = args.model_name.split(":")[1]
  77. self.model_metadata = MODELS[ft_model]
  78. elif args.model_name.startswith("ollama:"):
  79. self.api_model = args.model_name.split('ollama:', 1)[1]
  80. self.model_metadata = self.MODELS[self.api_model]
  81. else:
  82. raise ValueError(f"Unregistered model ({args.model_name}). Add model name to MODELS metadata to {self.__class__}")
  83. def reset_stats(self, other: APIStats = None):
  84. if other is None:
  85. self.stats = APIStats(total_cost=self.stats.total_cost)
  86. logger.info("Resetting model stats")
  87. else:
  88. self.stats = other
  89. def update_stats(self, input_tokens, output_tokens):
  90. """
  91. Calculates the cost of a response from the openai API.
  92. Args:
  93. input_tokens (int): The number of tokens in the prompt.
  94. output_tokens (int): The number of tokens in the response.
  95. Returns:
  96. float: The cost of the response.
  97. """
  98. # Calculate cost and update cost related fields
  99. cost = (
  100. self.model_metadata["cost_per_input_token"] * input_tokens
  101. + self.model_metadata["cost_per_output_token"] * output_tokens
  102. )
  103. self.stats.total_cost += cost
  104. self.stats.instance_cost += cost
  105. self.stats.tokens_sent += input_tokens
  106. self.stats.tokens_received += output_tokens
  107. self.stats.api_calls += 1
  108. # Log updated cost values to std. out.
  109. logger.info(
  110. f"input_tokens={input_tokens:_}, "
  111. f"output_tokens={output_tokens:_}, "
  112. f"instance_cost={self.stats.instance_cost:.2f}, "
  113. f"cost={cost:.2f}"
  114. )
  115. logger.info(
  116. f"total_tokens_sent={self.stats.tokens_sent:_}, "
  117. f"total_tokens_received={self.stats.tokens_received:_}, "
  118. f"total_cost={self.stats.total_cost:.2f}, "
  119. f"total_api_calls={self.stats.api_calls:_}"
  120. )
  121. # Check whether total cost or instance cost limits have been exceeded
  122. if (
  123. self.args.total_cost_limit > 0
  124. and self.stats.total_cost >= self.args.total_cost_limit
  125. ):
  126. logger.warning(
  127. f"Cost {self.stats.total_cost:.2f} exceeds limit {self.args.total_cost_limit:.2f}"
  128. )
  129. raise CostLimitExceededError("Total cost limit exceeded")
  130. if (
  131. self.args.per_instance_cost_limit > 0
  132. and self.stats.instance_cost >= self.args.per_instance_cost_limit
  133. ):
  134. logger.warning(
  135. f"Cost {self.stats.instance_cost:.2f} exceeds limit {self.args.per_instance_cost_limit:.2f}"
  136. )
  137. raise CostLimitExceededError("Instance cost limit exceeded")
  138. return cost
  139. def query(self, history: list[dict[str, str]]) -> str:
  140. raise NotImplementedError("Use a subclass of BaseModel")
  141. class OpenAIModel(BaseModel):
  142. MODELS = {
  143. "gpt-3.5-turbo-0125": {
  144. "max_context": 16_385,
  145. "cost_per_input_token": 5e-07,
  146. "cost_per_output_token": 1.5e-06,
  147. },
  148. "gpt-3.5-turbo-1106": {
  149. "max_context": 16_385,
  150. "cost_per_input_token": 1.5e-06,
  151. "cost_per_output_token": 2e-06,
  152. },
  153. "gpt-3.5-turbo-16k-0613": {
  154. "max_context": 16_385,
  155. "cost_per_input_token": 1.5e-06,
  156. "cost_per_output_token": 2e-06,
  157. },
  158. "gpt-4-32k-0613": {
  159. "max_context": 32_768,
  160. "cost_per_input_token": 6e-05,
  161. "cost_per_output_token": 0.00012,
  162. },
  163. "gpt-4-0613": {
  164. "max_context": 8_192,
  165. "cost_per_input_token": 3e-05,
  166. "cost_per_output_token": 6e-05,
  167. },
  168. "gpt-4-1106-preview": {
  169. "max_context": 128_000,
  170. "cost_per_input_token": 1e-05,
  171. "cost_per_output_token": 3e-05,
  172. },
  173. "gpt-4-0125-preview": {
  174. "max_context": 128_000,
  175. "cost_per_input_token": 1e-05,
  176. "cost_per_output_token": 3e-05,
  177. },
  178. }
  179. SHORTCUTS = {
  180. "gpt3": "gpt-3.5-turbo-1106",
  181. "gpt3-legacy": "gpt-3.5-turbo-16k-0613",
  182. "gpt4": "gpt-4-1106-preview",
  183. "gpt4-legacy": "gpt-4-0613",
  184. "gpt4-0125": "gpt-4-0125-preview",
  185. "gpt3-0125": "gpt-3.5-turbo-0125",
  186. }
  187. def __init__(self, args: ModelArguments, commands: list[Command]):
  188. super().__init__(args, commands)
  189. # Set OpenAI key
  190. cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
  191. self.client = OpenAI(api_key=cfg["OPENAI_API_KEY"])
  192. def history_to_messages(
  193. self, history: list[dict[str, str]], is_demonstration: bool = False
  194. ) -> list[dict[str, str]]:
  195. """
  196. Create `messages` by filtering out all keys except for role/content per `history` turn
  197. """
  198. # Remove system messages if it is a demonstration
  199. if is_demonstration:
  200. history = [entry for entry in history if entry["role"] != "system"]
  201. return '\n'.join([entry["content"] for entry in history])
  202. # Return history components with just role, content fields
  203. return [
  204. {k: v for k, v in entry.items() if k in ["role", "content"]}
  205. for entry in history
  206. ]
  207. @retry(
  208. wait=wait_random_exponential(min=1, max=15),
  209. reraise=True,
  210. stop=stop_after_attempt(3),
  211. retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
  212. )
  213. def query(self, history: list[dict[str, str]]) -> str:
  214. """
  215. Query the OpenAI API with the given `history` and return the response.
  216. """
  217. try:
  218. # Perform OpenAI API call
  219. response = self.client.chat.completions.create(
  220. messages=self.history_to_messages(history),
  221. model=self.api_model,
  222. temperature=self.args.temperature,
  223. top_p=self.args.top_p,
  224. )
  225. except BadRequestError as e:
  226. raise CostLimitExceededError(f"Context window ({self.model_metadata['max_context']} tokens) exceeded")
  227. # Calculate + update costs, return response
  228. input_tokens = response.usage.prompt_tokens
  229. output_tokens = response.usage.completion_tokens
  230. self.update_stats(input_tokens, output_tokens)
  231. return response.choices[0].message.content
  232. class AnthropicModel(BaseModel):
  233. MODELS = {
  234. "claude-instant": {
  235. "max_context": 100_000,
  236. "cost_per_input_token": 1.63e-06,
  237. "cost_per_output_token": 5.51e-06,
  238. },
  239. "claude-2": {
  240. "max_context": 100_000,
  241. "cost_per_input_token": 1.102e-05,
  242. "cost_per_output_token": 3.268e-05,
  243. },
  244. "claude-2.1": {
  245. "max_context": 100_000,
  246. "cost_per_input_token": 1.102e-05,
  247. "cost_per_output_token": 3.268e-05,
  248. },
  249. "claude-3-opus-20240229": {
  250. "max_context": 200_000,
  251. "max_tokens": 4096, # Max tokens to generate for Claude 3 models
  252. "cost_per_input_token": 1.5e-05,
  253. "cost_per_output_token": 7.5e-05,
  254. },
  255. "claude-3-sonnet-20240229": {
  256. "max_context": 200_000,
  257. "max_tokens": 4096,
  258. "cost_per_input_token": 3e-06,
  259. "cost_per_output_token": 1.5e-05,
  260. },
  261. }
  262. SHORTCUTS = {
  263. "claude": "claude-2",
  264. "claude-opus": "claude-3-opus-20240229",
  265. "claude-sonnet": "claude-3-sonnet-20240229",
  266. }
  267. def __init__(self, args: ModelArguments, commands: list[Command]):
  268. super().__init__(args, commands)
  269. # Set Anthropic key
  270. cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
  271. self.api = Anthropic(api_key=cfg["ANTHROPIC_API_KEY"])
  272. def history_to_messages(
  273. self, history: list[dict[str, str]], is_demonstration: bool = False
  274. ) -> list[dict[str, str]]:
  275. """
  276. Create `prompt` by filtering out all keys except for role/content per `history` turn
  277. Reference: https://docs.anthropic.com/claude/reference/complete_post
  278. """
  279. # Preserve behavior for older models
  280. if self.api_model in ["claude-instant", "claude-2"]:
  281. # Remove system messages if it is a demonstration
  282. if is_demonstration:
  283. history = [entry for entry in history if entry["role"] != "system"]
  284. # Map history to Claude format
  285. prompt = "\n\n"
  286. for entry in history:
  287. if entry["role"] in {"user", "system"}:
  288. prompt += f'{HUMAN_PROMPT} {entry["content"]}\n\n'
  289. elif entry["role"] == "assistant":
  290. prompt += f'{AI_PROMPT} {entry["content"]}\n\n'
  291. prompt += AI_PROMPT
  292. return prompt
  293. # Remove system messages if it is a demonstration
  294. if is_demonstration:
  295. history = [entry for entry in history if entry["role"] != "system"]
  296. return '\n'.join([entry["content"] for entry in history])
  297. # Return history components with just role, content fields (no system message)
  298. messages = [
  299. {
  300. k: v for k, v in entry.items()
  301. if k in ["role", "content"]
  302. }
  303. for entry in history if entry["role"] != "system"
  304. ]
  305. compiled_messages = [] # Combine messages from the same role
  306. last_role = None
  307. for message in reversed(messages):
  308. if last_role == message["role"]:
  309. compiled_messages[-1]["content"] = message["content"] + "\n" + compiled_messages[-1]["content"]
  310. else:
  311. compiled_messages.append(message)
  312. last_role = message["role"]
  313. compiled_messages = list(reversed(compiled_messages))
  314. # Replace any empty content values with a "(No output)"
  315. for message in compiled_messages:
  316. if message["content"].strip() == "":
  317. message["content"] = "(No output)"
  318. return compiled_messages
  319. @retry(
  320. wait=wait_random_exponential(min=1, max=15),
  321. reraise=True,
  322. stop=stop_after_attempt(3),
  323. retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
  324. )
  325. def query(self, history: list[dict[str, str]]) -> str:
  326. """
  327. Query the Anthropic API with the given `history` and return the response.
  328. """
  329. # Preserve behavior for older models
  330. if self.api_model in ["claude-instant", "claude-2"]:
  331. # Perform Anthropic API call
  332. prompt = self.history_to_messages(history)
  333. input_tokens = self.api.count_tokens(prompt)
  334. completion = self.api.completions.create(
  335. model=self.api_model,
  336. prompt=prompt,
  337. max_tokens_to_sample=self.model_metadata["max_context"] - input_tokens,
  338. temperature=self.args.temperature,
  339. top_p=self.args.top_p,
  340. )
  341. # Calculate + update costs, return response
  342. response = completion.completion
  343. output_tokens = self.api.count_tokens(response)
  344. self.update_stats(input_tokens, output_tokens)
  345. return response
  346. # Get system message(s)
  347. system_message = "\n".join([
  348. entry["content"] for entry in history if entry["role"] == "system"
  349. ])
  350. messages = self.history_to_messages(history)
  351. # Perform Anthropic API call
  352. response = self.api.messages.create(
  353. messages=messages,
  354. max_tokens=self.model_metadata["max_tokens"],
  355. model=self.api_model,
  356. temperature=self.args.temperature,
  357. top_p=self.args.top_p,
  358. system=system_message,
  359. )
  360. # Calculate + update costs, return response
  361. self.update_stats(
  362. response.usage.input_tokens,
  363. response.usage.output_tokens
  364. )
  365. response = "\n".join([x.text for x in response.content])
  366. return response
  367. class OllamaModel(BaseModel):
  368. MODELS = defaultdict(lambda: {
  369. "max_context": 128_000,
  370. "cost_per_input_token": 0,
  371. "cost_per_output_token": 0,
  372. })
  373. def __init__(self, args: ModelArguments, commands: list[Command]):
  374. super().__init__(args, commands)
  375. from ollama import Client
  376. self.client = Client(host=args.host_url)
  377. def history_to_messages(
  378. self, history: list[dict[str, str]], is_demonstration: bool = False
  379. ) -> list[dict[str, str]]:
  380. """
  381. Create `messages` by filtering out all keys except for role/content per `history` turn
  382. """
  383. # Remove system messages if it is a demonstration
  384. if is_demonstration:
  385. history = [entry for entry in history if entry["role"] != "system"]
  386. return '\n'.join([entry["content"] for entry in history])
  387. # Return history components with just role, content fields
  388. return [
  389. {k: v for k, v in entry.items() if k in ["role", "content"]}
  390. for entry in history
  391. ]
  392. @retry(
  393. wait=wait_random_exponential(min=1, max=15),
  394. reraise=True,
  395. stop=stop_after_attempt(3),
  396. retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
  397. )
  398. def query(self, history: list[dict[str, str]]) -> str:
  399. """
  400. Query the Ollama API with the given `history` and return the response.
  401. """
  402. response = self.client.chat(
  403. model=self.api_model,
  404. messages=self.history_to_messages(history),
  405. options={
  406. "temperature": self.args.temperature,
  407. "top_p": self.args.top_p,
  408. }
  409. )
  410. # Calculate + update costs, return response
  411. input_tokens = response["prompt_eval_count"]
  412. output_tokens = response["eval_count"]
  413. self.update_stats(input_tokens, output_tokens)
  414. return response["message"]["content"]
  415. class TogetherModel(BaseModel):
  416. # Check https://docs.together.ai/docs/inference-models for model names, context
  417. # Check https://www.together.ai/pricing for pricing
  418. MODELS = {
  419. "meta-llama/Llama-2-13b-chat-hf": {
  420. "max_context": 4096,
  421. "cost_per_input_token": 2.25e-07,
  422. "cost_per_output_token": 2.25e-07,
  423. },
  424. "meta-llama/Llama-2-70b-chat-hf": {
  425. "max_context": 4096,
  426. "cost_per_input_token": 9e-07,
  427. "cost_per_output_token": 9e-07,
  428. },
  429. "mistralai/Mistral-7B-Instruct-v0.2": {
  430. "max_context": 32768,
  431. "cost_per_input_token": 2e-07,
  432. "cost_per_output_token": 2e-07,
  433. },
  434. "togethercomputer/RedPajama-INCITE-7B-Chat": {
  435. "max_context": 2048,
  436. "cost_per_input_token": 2e-07,
  437. "cost_per_output_token": 2e-07,
  438. },
  439. "mistralai/Mixtral-8x7B-Instruct-v0.1": {
  440. "max_context": 32768,
  441. "cost_per_input_token": 6e-07,
  442. "cost_per_output_token": 6e-07,
  443. },
  444. }
  445. SHORTCUTS = {
  446. "llama13b": "meta-llama/Llama-2-13b-chat-hf",
  447. "llama70b": "meta-llama/Llama-2-70b-chat-hf",
  448. "mistral7b": "mistralai/Mistral-7B-Instruct-v0.2",
  449. "mixtral8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
  450. "redpajama7b": "togethercomputer/RedPajama-INCITE-7B-Chat",
  451. }
  452. def __init__(self, args: ModelArguments, commands: list[Command]):
  453. super().__init__(args, commands)
  454. # Set Together key
  455. cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
  456. together.api_key = cfg.TOGETHER_API_KEY
  457. def history_to_messages(
  458. self, history: list[dict[str, str]], is_demonstration: bool = False
  459. ) -> str:
  460. """
  461. Create `prompt` by filtering out all keys except for role/content per `history` turn
  462. """
  463. # Remove system messages if it is a demonstration
  464. if is_demonstration:
  465. history = [entry for entry in history if entry["role"] != "system"]
  466. # Map history to TogetherAI format
  467. mapping = {"user": "human", "assistant": "bot", "system": "bot"}
  468. prompt = [f'<{mapping[d["role"]]}>: {d["content"]}' for d in history]
  469. prompt = "\n".join(prompt)
  470. prompt = f"{prompt}\n<bot>:"
  471. return prompt
  472. @retry(
  473. wait=wait_random_exponential(min=1, max=15),
  474. reraise=True,
  475. stop=stop_after_attempt(3),
  476. retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
  477. )
  478. def query(self, history: list[dict[str, str]]) -> str:
  479. """
  480. Query the Together API with the given `history` and return the response.
  481. """
  482. # Perform Together API call
  483. prompt = self.history_to_messages(history)
  484. completion = together.Complete.create(
  485. model=self.api_model,
  486. prompt=prompt,
  487. max_tokens=self.model_metadata["max_context"],
  488. stop="<human>",
  489. temperature=self.args.temperature,
  490. top_p=self.args.top_p,
  491. )
  492. # Calculate + update costs, return response
  493. response = completion["output"]["choices"][0]["text"].split("<human>")[0]
  494. input_tokens = completion["output"]["usage"]["prompt_tokens"]
  495. output_tokens = completion["output"]["usage"]["completion_tokens"]
  496. self.update_stats(input_tokens, output_tokens)
  497. return response
  498. class HumanModel(BaseModel):
  499. MODELS = {"human": {}}
  500. def __init__(self, args: ModelArguments, commands: list[Command]):
  501. super().__init__(args, commands)
  502. # Determine which commands require multi-line input
  503. self.multi_line_command_endings = {
  504. command.name: command.end_name
  505. for command in commands
  506. if command.end_name is not None
  507. }
  508. def history_to_messages(
  509. self, history: list[dict[str, str]], is_demonstration: bool = False
  510. ) -> list[dict[str, str]]:
  511. """
  512. Create `messages` by filtering out all keys except for role/content per `history` turn
  513. """
  514. # Remove system messages if it is a demonstration
  515. if is_demonstration:
  516. history = [entry for entry in history if entry["role"] != "system"]
  517. return '\n'.join([entry["content"] for entry in history])
  518. # Return history components with just role, content fields
  519. return [
  520. {k: v for k, v in entry.items() if k in ["role", "content"]}
  521. for entry in history
  522. ]
  523. def query(self, history: list[dict[str, str]], action_prompt: str = "> ") -> str:
  524. """
  525. Logic for handling user input to pass to SWEEnv
  526. """
  527. action = input(action_prompt)
  528. command_name = action.split()[0] if action else ""
  529. # Special handling for multi-line input actions (i.e. edit)
  530. if command_name in self.multi_line_command_endings:
  531. buffer = [action]
  532. end_keyword = self.multi_line_command_endings[command_name]
  533. while True:
  534. action = input("... ")
  535. buffer.append(action)
  536. if action.rstrip() == end_keyword:
  537. # Continue reading input until terminating keyword inputted
  538. break
  539. action = "\n".join(buffer)
  540. elif action.strip() == "start_multiline_command": # do arbitrary multi-line input
  541. buffer = []
  542. while True:
  543. action = input("... ")
  544. if action.rstrip() == "end_multiline_command":
  545. break
  546. buffer.append(action)
  547. action = "\n".join(buffer)
  548. return action
  549. class HumanThoughtModel(HumanModel):
  550. MODELS = {"human_thought": {}}
  551. def query(self, history: list[dict[str, str]]) -> str:
  552. """
  553. Logic for handling user input (both thought + action) to pass to SWEEnv
  554. """
  555. thought_all = ""
  556. thought = input("Thought (end w/ END_THOUGHT): ")
  557. while True:
  558. if "END_THOUGHT" in thought:
  559. thought = thought.split("END_THOUGHT")[0]
  560. thought_all += thought
  561. break
  562. thought_all += thought
  563. thought = input("... ")
  564. action = super().query(history, action_prompt="Action: ")
  565. return f"{thought_all}\n```\n{action}\n```"
  566. class ReplayModel(BaseModel):
  567. MODELS = {"replay": {}}
  568. def __init__(self, args: ModelArguments, commands: list[Command]):
  569. super().__init__(args, commands)
  570. if self.args.replay_path == None or not os.path.exists(self.args.replay_path):
  571. raise ValueError(
  572. "--replay_path must point to a file that exists to run a replay policy"
  573. )
  574. self.replays = [
  575. list(json.loads(x).values())[0]
  576. for x in open(self.args.replay_path, "r").readlines()
  577. ]
  578. self.replay_idx = 0
  579. self.action_idx = 0
  580. def query(self, history: list[dict[str, str]]) -> str:
  581. """
  582. Logic for tracking which replay action to pass to SWEEnv
  583. """
  584. action = self.replays[self.replay_idx][self.action_idx]
  585. self.action_idx += 1
  586. # Assuming `submit` is always last action of replay trajectory
  587. if action == "submit":
  588. self.replay_idx += 1
  589. self.action_idx = 0
  590. return action
  591. def get_model(args: ModelArguments, commands: Optional[list[Command]] = None):
  592. """
  593. Returns correct model object given arguments and commands
  594. """
  595. if commands is None:
  596. commands = []
  597. if args.model_name == "human":
  598. return HumanModel(args, commands)
  599. if args.model_name == "human_thought":
  600. return HumanThoughtModel(args, commands)
  601. if args.model_name == "replay":
  602. return ReplayModel(args, commands)
  603. elif args.model_name.startswith("gpt") or args.model_name.startswith("ft:gpt"):
  604. return OpenAIModel(args, commands)
  605. elif args.model_name.startswith("claude"):
  606. return AnthropicModel(args, commands)
  607. elif args.model_name.startswith("ollama"):
  608. return OllamaModel(args, commands)
  609. else:
  610. raise ValueError(f"Invalid model name: {args.model_name}")