123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710 |
- import config
- import json
- import logging
- import os
- import together
- from collections import defaultdict
- from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
- from dataclasses import dataclass, fields
- from openai import BadRequestError, OpenAI, AzureOpenAI
- from simple_parsing.helpers.serialization.serializable import FrozenSerializable, Serializable
- from sweagent.agent.commands import Command
- from tenacity import (
- retry,
- stop_after_attempt,
- wait_random_exponential,
- retry_if_not_exception_type,
- )
- from typing import Optional, Union
- logger = logging.getLogger("api_models")
- @dataclass(frozen=True)
- class ModelArguments(FrozenSerializable):
- """Arguments configuring the model and its behavior."""
- model_name: str
- per_instance_cost_limit: float = 0.0
- total_cost_limit: float = 0.0
- temperature: float = 1.0
- top_p: float = 1.0
- replay_path: str = None
- host_url: str = "localhost:11434"
- @dataclass
- class APIStats(Serializable):
- total_cost: float = 0
- instance_cost: float = 0
- tokens_sent: int = 0
- tokens_received: int = 0
- api_calls: int = 0
- def __add__(self, other):
- if not isinstance(other, APIStats):
- raise TypeError("Can only add APIStats with APIStats")
- return APIStats(**{
- field.name: getattr(self, field.name) + getattr(other, field.name)
- for field in fields(self)
- })
- def replace(self, other):
- if not isinstance(other, APIStats):
- raise TypeError("Can only replace APIStats with APIStats")
- return APIStats(**{
- field.name: getattr(other, field.name)
- for field in fields(self)
- })
- class ContextWindowExceededError(Exception):
- pass
- class CostLimitExceededError(Exception):
- pass
- class BaseModel:
- MODELS = {}
- SHORTCUTS = {}
- def __init__(self, args: ModelArguments, commands: list[Command]):
- self.args = args
- self.commands = commands
- self.model_metadata = {}
- self.stats = APIStats()
- # Map `model_name` to API-compatible name `api_model`
- self.api_model = (
- self.SHORTCUTS[self.args.model_name]
- if self.args.model_name in self.SHORTCUTS
- else self.args.model_name
- )
- # Map model name to metadata (cost, context info)
- MODELS = {
- **{dest: self.MODELS[src] for dest, src in self.SHORTCUTS.items()},
- **self.MODELS,
- }
- if args.model_name in MODELS:
- self.model_metadata = MODELS[args.model_name]
- elif args.model_name.startswith("ft:"):
- ft_model = args.model_name.split(":")[1]
- self.model_metadata = MODELS[ft_model]
- elif args.model_name.startswith("ollama:"):
- self.api_model = args.model_name.split('ollama:', 1)[1]
- self.model_metadata = self.MODELS[self.api_model]
- elif args.model_name.startswith("azure:"):
- azure_model = args.model_name.split("azure:", 1)[1]
- self.model_metadata = MODELS[azure_model]
- else:
- raise ValueError(f"Unregistered model ({args.model_name}). Add model name to MODELS metadata to {self.__class__}")
- def reset_stats(self, other: Optional[APIStats] = None):
- if other is None:
- self.stats = APIStats(total_cost=self.stats.total_cost)
- logger.info("Resetting model stats")
- else:
- self.stats = other
- def update_stats(self, input_tokens: int, output_tokens: int) -> float:
- """
- Calculates the cost of a response from the openai API.
- Args:
- input_tokens (int): The number of tokens in the prompt.
- output_tokens (int): The number of tokens in the response.
- Returns:
- float: The cost of the response.
- """
- # Calculate cost and update cost related fields
- cost = (
- self.model_metadata["cost_per_input_token"] * input_tokens
- + self.model_metadata["cost_per_output_token"] * output_tokens
- )
- self.stats.total_cost += cost
- self.stats.instance_cost += cost
- self.stats.tokens_sent += input_tokens
- self.stats.tokens_received += output_tokens
- self.stats.api_calls += 1
- # Log updated cost values to std. out.
- logger.info(
- f"input_tokens={input_tokens:_}, "
- f"output_tokens={output_tokens:_}, "
- f"instance_cost={self.stats.instance_cost:.2f}, "
- f"cost={cost:.2f}"
- )
- logger.info(
- f"total_tokens_sent={self.stats.tokens_sent:_}, "
- f"total_tokens_received={self.stats.tokens_received:_}, "
- f"total_cost={self.stats.total_cost:.2f}, "
- f"total_api_calls={self.stats.api_calls:_}"
- )
- # Check whether total cost or instance cost limits have been exceeded
- if 0 < self.args.total_cost_limit <= self.stats.total_cost:
- logger.warning(
- f"Cost {self.stats.total_cost:.2f} exceeds limit {self.args.total_cost_limit:.2f}"
- )
- raise CostLimitExceededError("Total cost limit exceeded")
- if 0 < self.args.per_instance_cost_limit <= self.stats.instance_cost:
- logger.warning(
- f"Cost {self.stats.instance_cost:.2f} exceeds limit {self.args.per_instance_cost_limit:.2f}"
- )
- raise CostLimitExceededError("Instance cost limit exceeded")
- return cost
- def query(self, history: list[dict[str, str]]) -> str:
- raise NotImplementedError("Use a subclass of BaseModel")
- class OpenAIModel(BaseModel):
- MODELS = {
- "gpt-3.5-turbo-0125": {
- "max_context": 16_385,
- "cost_per_input_token": 5e-07,
- "cost_per_output_token": 1.5e-06,
- },
- "gpt-3.5-turbo-1106": {
- "max_context": 16_385,
- "cost_per_input_token": 1.5e-06,
- "cost_per_output_token": 2e-06,
- },
- "gpt-3.5-turbo-16k-0613": {
- "max_context": 16_385,
- "cost_per_input_token": 1.5e-06,
- "cost_per_output_token": 2e-06,
- },
- "gpt-4-32k-0613": {
- "max_context": 32_768,
- "cost_per_input_token": 6e-05,
- "cost_per_output_token": 0.00012,
- },
- "gpt-4-0613": {
- "max_context": 8_192,
- "cost_per_input_token": 3e-05,
- "cost_per_output_token": 6e-05,
- },
- "gpt-4-1106-preview": {
- "max_context": 128_000,
- "cost_per_input_token": 1e-05,
- "cost_per_output_token": 3e-05,
- },
- "gpt-4-0125-preview": {
- "max_context": 128_000,
- "cost_per_input_token": 1e-05,
- "cost_per_output_token": 3e-05,
- },
- }
- SHORTCUTS = {
- "gpt3": "gpt-3.5-turbo-1106",
- "gpt3-legacy": "gpt-3.5-turbo-16k-0613",
- "gpt4": "gpt-4-1106-preview",
- "gpt4-legacy": "gpt-4-0613",
- "gpt4-0125": "gpt-4-0125-preview",
- "gpt3-0125": "gpt-3.5-turbo-0125",
- }
- def __init__(self, args: ModelArguments, commands: list[Command]):
- super().__init__(args, commands)
- # Set OpenAI key
- cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
- if self.args.model_name.startswith("azure"):
- self.api_model = cfg["AZURE_OPENAI_DEPLOYMENT"]
- self.client = AzureOpenAI(api_key=cfg["AZURE_OPENAI_API_KEY"], azure_endpoint=cfg["AZURE_OPENAI_ENDPOINT"], api_version=cfg.get("AZURE_OPENAI_API_VERSION", "2024-02-01"))
- else:
- api_base_url: Optional[str] = cfg.get("OPENAI_API_BASE_URL", None)
- self.client = OpenAI(api_key=cfg["OPENAI_API_KEY"], base_url=api_base_url)
- def history_to_messages(
- self, history: list[dict[str, str]], is_demonstration: bool = False
- ) -> Union[str, list[dict[str, str]]]:
- """
- Create `messages` by filtering out all keys except for role/content per `history` turn
- """
- # Remove system messages if it is a demonstration
- if is_demonstration:
- history = [entry for entry in history if entry["role"] != "system"]
- return '\n'.join([entry["content"] for entry in history])
- # Return history components with just role, content fields
- return [
- {k: v for k, v in entry.items() if k in ["role", "content"]}
- for entry in history
- ]
- @retry(
- wait=wait_random_exponential(min=1, max=15),
- reraise=True,
- stop=stop_after_attempt(3),
- retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
- )
- def query(self, history: list[dict[str, str]]) -> str:
- """
- Query the OpenAI API with the given `history` and return the response.
- """
- try:
- # Perform OpenAI API call
- response = self.client.chat.completions.create(
- messages=self.history_to_messages(history),
- model=self.api_model,
- temperature=self.args.temperature,
- top_p=self.args.top_p,
- )
- except BadRequestError:
- raise CostLimitExceededError(f"Context window ({self.model_metadata['max_context']} tokens) exceeded")
- # Calculate + update costs, return response
- input_tokens = response.usage.prompt_tokens
- output_tokens = response.usage.completion_tokens
- self.update_stats(input_tokens, output_tokens)
- return response.choices[0].message.content
- class AnthropicModel(BaseModel):
- MODELS = {
- "claude-instant": {
- "max_context": 100_000,
- "cost_per_input_token": 1.63e-06,
- "cost_per_output_token": 5.51e-06,
- },
- "claude-2": {
- "max_context": 100_000,
- "cost_per_input_token": 1.102e-05,
- "cost_per_output_token": 3.268e-05,
- },
- "claude-2.1": {
- "max_context": 100_000,
- "cost_per_input_token": 1.102e-05,
- "cost_per_output_token": 3.268e-05,
- },
- "claude-3-opus-20240229": {
- "max_context": 200_000,
- "max_tokens": 4096, # Max tokens to generate for Claude 3 models
- "cost_per_input_token": 1.5e-05,
- "cost_per_output_token": 7.5e-05,
- },
- "claude-3-sonnet-20240229": {
- "max_context": 200_000,
- "max_tokens": 4096,
- "cost_per_input_token": 3e-06,
- "cost_per_output_token": 1.5e-05,
- },
- "claude-3-haiku-20240307": {
- "max_context": 200_000,
- "max_tokens": 4096,
- "cost_per_input_token": 2.5e-07,
- "cost_per_output_token": 1.25e-06,
- },
- }
- SHORTCUTS = {
- "claude": "claude-2",
- "claude-opus": "claude-3-opus-20240229",
- "claude-sonnet": "claude-3-sonnet-20240229",
- "claude-haiku": "claude-3-haiku-20240307",
- }
- def __init__(self, args: ModelArguments, commands: list[Command]):
- super().__init__(args, commands)
- # Set Anthropic key
- cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
- self.api = Anthropic(api_key=cfg["ANTHROPIC_API_KEY"])
- def history_to_messages(
- self, history: list[dict[str, str]], is_demonstration: bool = False
- ) -> Union[str, list[dict[str, str]]]:
- """
- Create `prompt` by filtering out all keys except for role/content per `history` turn
- Reference: https://docs.anthropic.com/claude/reference/complete_post
- """
- # Preserve behavior for older models
- if self.api_model in ["claude-instant", "claude-2"]:
- # Remove system messages if it is a demonstration
- if is_demonstration:
- history = [entry for entry in history if entry["role"] != "system"]
- # Map history to Claude format
- prompt = "\n\n"
- for entry in history:
- if entry["role"] in {"user", "system"}:
- prompt += f'{HUMAN_PROMPT} {entry["content"]}\n\n'
- elif entry["role"] == "assistant":
- prompt += f'{AI_PROMPT} {entry["content"]}\n\n'
- prompt += AI_PROMPT
- return prompt
- # Remove system messages if it is a demonstration
- if is_demonstration:
- history = [entry for entry in history if entry["role"] != "system"]
- return '\n'.join([entry["content"] for entry in history])
- # Return history components with just role, content fields (no system message)
- messages = [
- {
- k: v for k, v in entry.items()
- if k in ["role", "content"]
- }
- for entry in history if entry["role"] != "system"
- ]
- compiled_messages = [] # Combine messages from the same role
- last_role = None
- for message in reversed(messages):
- if last_role == message["role"]:
- compiled_messages[-1]["content"] = message["content"] + "\n" + compiled_messages[-1]["content"]
- else:
- compiled_messages.append(message)
- last_role = message["role"]
- compiled_messages = list(reversed(compiled_messages))
- # Replace any empty content values with a "(No output)"
- for message in compiled_messages:
- if message["content"].strip() == "":
- message["content"] = "(No output)"
- return compiled_messages
- @retry(
- wait=wait_random_exponential(min=1, max=15),
- reraise=True,
- stop=stop_after_attempt(3),
- retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
- )
- def query(self, history: list[dict[str, str]]) -> str:
- """
- Query the Anthropic API with the given `history` and return the response.
- """
- # Preserve behavior for older models
- if self.api_model in ["claude-instant", "claude-2"]:
- # Perform Anthropic API call
- prompt = self.history_to_messages(history)
- input_tokens = self.api.count_tokens(prompt)
- completion = self.api.completions.create(
- model=self.api_model,
- prompt=prompt,
- max_tokens_to_sample=self.model_metadata["max_context"] - input_tokens,
- temperature=self.args.temperature,
- top_p=self.args.top_p,
- )
- # Calculate + update costs, return response
- response = completion.completion
- output_tokens = self.api.count_tokens(response)
- self.update_stats(input_tokens, output_tokens)
- return response
- # Get system message(s)
- system_message = "\n".join([
- entry["content"] for entry in history if entry["role"] == "system"
- ])
- messages = self.history_to_messages(history)
- # Perform Anthropic API call
- response = self.api.messages.create(
- messages=messages,
- max_tokens=self.model_metadata["max_tokens"],
- model=self.api_model,
- temperature=self.args.temperature,
- top_p=self.args.top_p,
- system=system_message,
- )
- # Calculate + update costs, return response
- self.update_stats(
- response.usage.input_tokens,
- response.usage.output_tokens
- )
- response = "\n".join([x.text for x in response.content])
- return response
- class OllamaModel(BaseModel):
- MODELS = defaultdict(lambda: {
- "max_context": 128_000,
- "cost_per_input_token": 0,
- "cost_per_output_token": 0,
- })
- def __init__(self, args: ModelArguments, commands: list[Command]):
- super().__init__(args, commands)
- from ollama import Client
- self.client = Client(host=args.host_url)
- def history_to_messages(
- self, history: list[dict[str, str]], is_demonstration: bool = False
- ) -> Union[str, list[dict[str, str]]]:
- """
- Create `messages` by filtering out all keys except for role/content per `history` turn
- """
- # Remove system messages if it is a demonstration
- if is_demonstration:
- history = [entry for entry in history if entry["role"] != "system"]
- return '\n'.join([entry["content"] for entry in history])
- # Return history components with just role, content fields
- return [
- {k: v for k, v in entry.items() if k in ["role", "content"]}
- for entry in history
- ]
- @retry(
- wait=wait_random_exponential(min=1, max=15),
- reraise=True,
- stop=stop_after_attempt(3),
- retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
- )
- def query(self, history: list[dict[str, str]]) -> str:
- """
- Query the Ollama API with the given `history` and return the response.
- """
- response = self.client.chat(
- model=self.api_model,
- messages=self.history_to_messages(history),
- options={
- "temperature": self.args.temperature,
- "top_p": self.args.top_p,
- }
- )
- # Calculate + update costs, return response
- if "prompt_eval_count" in response:
- input_tokens = response["prompt_eval_count"]
- else:
- logger.warning(
- "Prompt eval count not found in response. Using 0. "
- "This might be because the prompt has been cached. "
- "See https://github.com/princeton-nlp/SWE-agent/issues/44 "
- "and https://github.com/ollama/ollama/issues/3427."
- )
- input_tokens = 0
- output_tokens = response["eval_count"]
- self.update_stats(input_tokens, output_tokens)
- return response["message"]["content"]
- class TogetherModel(BaseModel):
- # Check https://docs.together.ai/docs/inference-models for model names, context
- # Check https://www.together.ai/pricing for pricing
- MODELS = {
- "meta-llama/Llama-2-13b-chat-hf": {
- "max_context": 4096,
- "cost_per_input_token": 2.25e-07,
- "cost_per_output_token": 2.25e-07,
- },
- "meta-llama/Llama-2-70b-chat-hf": {
- "max_context": 4096,
- "cost_per_input_token": 9e-07,
- "cost_per_output_token": 9e-07,
- },
- "mistralai/Mistral-7B-Instruct-v0.2": {
- "max_context": 32768,
- "cost_per_input_token": 2e-07,
- "cost_per_output_token": 2e-07,
- },
- "togethercomputer/RedPajama-INCITE-7B-Chat": {
- "max_context": 2048,
- "cost_per_input_token": 2e-07,
- "cost_per_output_token": 2e-07,
- },
- "mistralai/Mixtral-8x7B-Instruct-v0.1": {
- "max_context": 32768,
- "cost_per_input_token": 6e-07,
- "cost_per_output_token": 6e-07,
- },
- }
- SHORTCUTS = {
- "llama13b": "meta-llama/Llama-2-13b-chat-hf",
- "llama70b": "meta-llama/Llama-2-70b-chat-hf",
- "mistral7b": "mistralai/Mistral-7B-Instruct-v0.2",
- "mixtral8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
- "redpajama7b": "togethercomputer/RedPajama-INCITE-7B-Chat",
- }
- def __init__(self, args: ModelArguments, commands: list[Command]):
- super().__init__(args, commands)
- # Set Together key
- cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
- together.api_key = cfg.TOGETHER_API_KEY
- def history_to_messages(
- self, history: list[dict[str, str]], is_demonstration: bool = False
- ) -> str:
- """
- Create `prompt` by filtering out all keys except for role/content per `history` turn
- """
- # Remove system messages if it is a demonstration
- if is_demonstration:
- history = [entry for entry in history if entry["role"] != "system"]
- # Map history to TogetherAI format
- mapping = {"user": "human", "assistant": "bot", "system": "bot"}
- prompt = [f'<{mapping[d["role"]]}>: {d["content"]}' for d in history]
- prompt = "\n".join(prompt)
- prompt = f"{prompt}\n<bot>:"
- return prompt
- @retry(
- wait=wait_random_exponential(min=1, max=15),
- reraise=True,
- stop=stop_after_attempt(3),
- retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
- )
- def query(self, history: list[dict[str, str]]) -> str:
- """
- Query the Together API with the given `history` and return the response.
- """
- # Perform Together API call
- prompt = self.history_to_messages(history)
- completion = together.Complete.create(
- model=self.api_model,
- prompt=prompt,
- max_tokens=self.model_metadata["max_context"],
- stop="<human>",
- temperature=self.args.temperature,
- top_p=self.args.top_p,
- )
- # Calculate + update costs, return response
- response = completion["output"]["choices"][0]["text"].split("<human>")[0]
- input_tokens = completion["output"]["usage"]["prompt_tokens"]
- output_tokens = completion["output"]["usage"]["completion_tokens"]
- self.update_stats(input_tokens, output_tokens)
- return response
- class HumanModel(BaseModel):
- MODELS = {"human": {}}
- def __init__(self, args: ModelArguments, commands: list[Command]):
- super().__init__(args, commands)
- # Determine which commands require multi-line input
- self.multi_line_command_endings = {
- command.name: command.end_name
- for command in commands
- if command.end_name is not None
- }
- def history_to_messages(
- self, history: list[dict[str, str]], is_demonstration: bool = False
- ) -> Union[str, list[dict[str, str]]]:
- """
- Create `messages` by filtering out all keys except for role/content per `history` turn
- """
- # Remove system messages if it is a demonstration
- if is_demonstration:
- history = [entry for entry in history if entry["role"] != "system"]
- return '\n'.join([entry["content"] for entry in history])
- # Return history components with just role, content fields
- return [
- {k: v for k, v in entry.items() if k in ["role", "content"]}
- for entry in history
- ]
- def query(self, history: list[dict[str, str]], action_prompt: str = "> ") -> str:
- """
- Logic for handling user input to pass to SWEEnv
- """
- action = input(action_prompt)
- command_name = action.split()[0] if action else ""
- # Special handling for multi-line input actions (i.e. edit)
- if command_name in self.multi_line_command_endings:
- buffer = [action]
- end_keyword = self.multi_line_command_endings[command_name]
- while True:
- action = input("... ")
- buffer.append(action)
- if action.rstrip() == end_keyword:
- # Continue reading input until terminating keyword inputted
- break
- action = "\n".join(buffer)
- elif action.strip() == "start_multiline_command": # do arbitrary multi-line input
- buffer = []
- while True:
- action = input("... ")
- if action.rstrip() == "end_multiline_command":
- break
- buffer.append(action)
- action = "\n".join(buffer)
- return action
- class HumanThoughtModel(HumanModel):
- MODELS = {"human_thought": {}}
- def query(self, history: list[dict[str, str]]) -> str:
- """
- Logic for handling user input (both thought + action) to pass to SWEEnv
- """
- thought_all = ""
- thought = input("Thought (end w/ END_THOUGHT): ")
- while True:
- if "END_THOUGHT" in thought:
- thought = thought.split("END_THOUGHT")[0]
- thought_all += thought
- break
- thought_all += thought
- thought = input("... ")
- action = super().query(history, action_prompt="Action: ")
- return f"{thought_all}\n```\n{action}\n```"
- class ReplayModel(BaseModel):
- MODELS = {"replay": {}}
- def __init__(self, args: ModelArguments, commands: list[Command]):
- super().__init__(args, commands)
- if self.args.replay_path is None or not os.path.exists(self.args.replay_path):
- raise ValueError(
- "--replay_path must point to a file that exists to run a replay policy"
- )
- self.replays = [
- list(json.loads(x).values())[0]
- for x in open(self.args.replay_path, "r").readlines()
- ]
- self.replay_idx = 0
- self.action_idx = 0
- def query(self, history: list[dict[str, str]]) -> str:
- """
- Logic for tracking which replay action to pass to SWEEnv
- """
- action = self.replays[self.replay_idx][self.action_idx]
- self.action_idx += 1
- # Assuming `submit` is always last action of replay trajectory
- if action == "submit":
- self.replay_idx += 1
- self.action_idx = 0
- return action
- def get_model(args: ModelArguments, commands: Optional[list[Command]] = None):
- """
- Returns correct model object given arguments and commands
- """
- if commands is None:
- commands = []
- if args.model_name == "human":
- return HumanModel(args, commands)
- if args.model_name == "human_thought":
- return HumanThoughtModel(args, commands)
- if args.model_name == "replay":
- return ReplayModel(args, commands)
- elif args.model_name.startswith("gpt") or args.model_name.startswith("ft:gpt") or args.model_name.startswith("azure:gpt"):
- return OpenAIModel(args, commands)
- elif args.model_name.startswith("claude"):
- return AnthropicModel(args, commands)
- elif args.model_name.startswith("ollama"):
- return OllamaModel(args, commands)
- elif args.model_name in TogetherModel.SHORTCUTS:
- return TogetherModel(args, commands)
- else:
- raise ValueError(f"Invalid model name: {args.model_name}")
|