base_llm.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # This file contains the abstract base class for each llm kernel, providing a
  2. # common interface for all LLMs to implement.
  3. import json
  4. import re
  5. from aios.context.simple_context import SimpleContextManager
  6. # abc allows to make abstract classes
  7. from abc import ABC, abstractmethod
  8. from aios.utils.logger import LLMKernelLogger
  9. from aios.utils.id_generator import generator_tool_call_id
  10. class BaseLLM(ABC):
  11. def __init__(self,
  12. llm_name: str,
  13. max_gpu_memory: dict = None,
  14. eval_device: str = None,
  15. max_new_tokens: int = 256,
  16. log_mode: str = "console"
  17. ):
  18. self.max_gpu_memory = max_gpu_memory
  19. self.eval_device = eval_device
  20. self.max_new_tokens = max_new_tokens
  21. self.log_mode = log_mode
  22. self.model_name = llm_name
  23. self.context_manager = SimpleContextManager()
  24. self.load_llm_and_tokenizer()
  25. self.logger = self.setup_logger()
  26. self.logger.log(
  27. "AIOS has been successfully initialized.\n",
  28. level="info"
  29. )
  30. def convert_map(self, map: dict) -> dict:
  31. """ helper utility to convert the keys of a map to int """
  32. new_map = {}
  33. for k, v in map.items():
  34. new_map[int(k)] = v
  35. return new_map
  36. def check_model_type(self, model_name):
  37. # TODO add more model types
  38. return "causal_lm"
  39. def setup_logger(self):
  40. logger = LLMKernelLogger(self.model_name, self.log_mode)
  41. return logger
  42. @abstractmethod
  43. def load_llm_and_tokenizer(self) -> None: # load model from config
  44. # raise NotImplementedError
  45. """Load model and tokenizers for each type of LLMs
  46. """
  47. return
  48. # only use for open-sourced LLM
  49. def tool_calling_input_format(self, messages: list, tools: list) -> list:
  50. """Integrate tool information into the messages for open-sourced LLMs
  51. Args:
  52. messages (list): messages with different roles
  53. tools (list): tool information
  54. """
  55. prefix_prompt = "In and only in current step, you need to call tools. Available tools are: "
  56. tool_prompt = json.dumps(tools)
  57. suffix_prompt = "".join(
  58. [
  59. 'Must call functions that are available. To call a function, respond '
  60. 'immediately and only with a list of JSON object of the following format:'
  61. '{[{"name":"function_name_value","parameters":{"parameter_name1":"parameter_value1",'
  62. '"parameter_name2":"parameter_value2"}}]}'
  63. ]
  64. )
  65. # translate tool call message for models don't support tool call
  66. for message in messages:
  67. if "tool_calls" in message:
  68. message["content"] = json.dumps(message.pop("tool_calls"))
  69. elif message["role"] == "tool":
  70. message["role"] = "user"
  71. tool_call_id = message.pop("tool_call_id")
  72. content = message.pop("content")
  73. message["content"] = f"The result of the execution of function(id :{tool_call_id}) is: {content}. "
  74. messages[-1]["content"] += (prefix_prompt + tool_prompt + suffix_prompt)
  75. return messages
  76. def parse_json_format(self, message: str) -> str:
  77. json_array_pattern = r'\[\s*\{.*?\}\s*\]'
  78. json_object_pattern = r'\{\s*.*?\s*\}'
  79. match_array = re.search(json_array_pattern, message)
  80. if match_array:
  81. json_array_substring = match_array.group(0)
  82. try:
  83. json_array_data = json.loads(json_array_substring)
  84. return json.dumps(json_array_data)
  85. except json.JSONDecodeError:
  86. pass
  87. match_object = re.search(json_object_pattern, message)
  88. if match_object:
  89. json_object_substring = match_object.group(0)
  90. try:
  91. json_object_data = json.loads(json_object_substring)
  92. return json.dumps(json_object_data)
  93. except json.JSONDecodeError:
  94. pass
  95. return '[]'
  96. def parse_tool_calls(self, message):
  97. # add tool call id and type for models don't support tool call
  98. tool_calls = json.loads(self.parse_json_format(message))
  99. for tool_call in tool_calls:
  100. tool_call["id"] = generator_tool_call_id()
  101. tool_call["type"] = "function"
  102. return tool_calls
  103. def address_request(self,
  104. agent_request,
  105. temperature=0.0
  106. ):
  107. return self.process(agent_request)
  108. @abstractmethod
  109. def process(self,
  110. agent_request,
  111. temperature=0.0) -> None:
  112. raise NotImplementedError