123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # This file contains the abstract base class for each llm kernel, providing a
- # common interface for all LLMs to implement.
- import json
- import re
- from aios.context.simple_context import SimpleContextManager
- # abc allows to make abstract classes
- from abc import ABC, abstractmethod
- from aios.utils.logger import LLMKernelLogger
- from aios.utils.id_generator import generator_tool_call_id
- class BaseLLM(ABC):
- def __init__(self,
- llm_name: str,
- max_gpu_memory: dict = None,
- eval_device: str = None,
- max_new_tokens: int = 256,
- log_mode: str = "console"
- ):
- self.max_gpu_memory = max_gpu_memory
- self.eval_device = eval_device
- self.max_new_tokens = max_new_tokens
- self.log_mode = log_mode
- self.model_name = llm_name
- self.context_manager = SimpleContextManager()
- self.load_llm_and_tokenizer()
- self.logger = self.setup_logger()
- self.logger.log(
- "AIOS has been successfully initialized.\n",
- level="info"
- )
- def convert_map(self, map: dict) -> dict:
- """ helper utility to convert the keys of a map to int """
- new_map = {}
- for k, v in map.items():
- new_map[int(k)] = v
- return new_map
- def check_model_type(self, model_name):
- # TODO add more model types
- return "causal_lm"
- def setup_logger(self):
- logger = LLMKernelLogger(self.model_name, self.log_mode)
- return logger
- @abstractmethod
- def load_llm_and_tokenizer(self) -> None: # load model from config
- # raise NotImplementedError
- """Load model and tokenizers for each type of LLMs
- """
- return
- # only use for open-sourced LLM
- def tool_calling_input_format(self, messages: list, tools: list) -> list:
- """Integrate tool information into the messages for open-sourced LLMs
- Args:
- messages (list): messages with different roles
- tools (list): tool information
- """
- prefix_prompt = "In and only in current step, you need to call tools. Available tools are: "
- tool_prompt = json.dumps(tools)
- suffix_prompt = "".join(
- [
- 'Must call functions that are available. To call a function, respond '
- 'immediately and only with a list of JSON object of the following format:'
- '{[{"name":"function_name_value","parameters":{"parameter_name1":"parameter_value1",'
- '"parameter_name2":"parameter_value2"}}]}'
- ]
- )
- # translate tool call message for models don't support tool call
- for message in messages:
- if "tool_calls" in message:
- message["content"] = json.dumps(message.pop("tool_calls"))
- elif message["role"] == "tool":
- message["role"] = "user"
- tool_call_id = message.pop("tool_call_id")
- content = message.pop("content")
- message["content"] = f"The result of the execution of function(id :{tool_call_id}) is: {content}. "
- messages[-1]["content"] += (prefix_prompt + tool_prompt + suffix_prompt)
- return messages
- def parse_json_format(self, message: str) -> str:
- json_array_pattern = r'\[\s*\{.*?\}\s*\]'
- json_object_pattern = r'\{\s*.*?\s*\}'
- match_array = re.search(json_array_pattern, message)
- if match_array:
- json_array_substring = match_array.group(0)
- try:
- json_array_data = json.loads(json_array_substring)
- return json.dumps(json_array_data)
- except json.JSONDecodeError:
- pass
- match_object = re.search(json_object_pattern, message)
- if match_object:
- json_object_substring = match_object.group(0)
- try:
- json_object_data = json.loads(json_object_substring)
- return json.dumps(json_object_data)
- except json.JSONDecodeError:
- pass
- return '[]'
- def parse_tool_calls(self, message):
- # add tool call id and type for models don't support tool call
- tool_calls = json.loads(self.parse_json_format(message))
- for tool_call in tool_calls:
- tool_call["id"] = generator_tool_call_id()
- tool_call["type"] = "function"
- return tool_calls
- def address_request(self,
- agent_request,
- temperature=0.0
- ):
- return self.process(agent_request)
- @abstractmethod
- def process(self,
- agent_request,
- temperature=0.0) -> None:
- raise NotImplementedError
|