gpt_llm.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import re
  2. from .base_llm import BaseLLM
  3. import time
  4. # could be dynamically imported similar to other models
  5. from openai import OpenAI
  6. import openai
  7. from pyopenagi.utils.chat_template import Response
  8. import json
  9. class GPTLLM(BaseLLM):
  10. def __init__(
  11. self,
  12. llm_name: str,
  13. max_gpu_memory: dict = None,
  14. eval_device: str = None,
  15. max_new_tokens: int = 1024,
  16. log_mode: str = "console",
  17. ):
  18. super().__init__(
  19. llm_name, max_gpu_memory, eval_device, max_new_tokens, log_mode
  20. )
  21. def load_llm_and_tokenizer(self) -> None:
  22. self.model = OpenAI()
  23. self.tokenizer = None
  24. def parse_tool_calls(self, tool_calls):
  25. if tool_calls:
  26. parsed_tool_calls = []
  27. for tool_call in tool_calls:
  28. function_name = tool_call.function.name
  29. function_args = json.loads(tool_call.function.arguments)
  30. parsed_tool_calls.append(
  31. {
  32. "name": function_name,
  33. "parameters": function_args,
  34. "type": tool_call.type,
  35. "id": tool_call.id,
  36. }
  37. )
  38. return parsed_tool_calls
  39. return None
  40. def process(self, agent_request, temperature=0.0):
  41. # ensures the model is the current one
  42. assert re.search(r"gpt", self.model_name, re.IGNORECASE)
  43. """ wrapper around openai api """
  44. agent_request.set_status("executing")
  45. agent_request.set_start_time(time.time())
  46. messages = agent_request.query.messages
  47. try:
  48. response = self.model.chat.completions.create(
  49. model=self.model_name,
  50. messages=messages,
  51. tools=agent_request.query.tools,
  52. # tool_choice = "required" if agent_request.query.tools else None,
  53. max_tokens=self.max_new_tokens,
  54. )
  55. response_message = response.choices[0].message.content
  56. # print(f"[Response] {response}")
  57. tool_calls = self.parse_tool_calls(response.choices[0].message.tool_calls)
  58. # print(tool_calls)
  59. # print(response.choices[0].message)
  60. response = Response(
  61. response_message=response_message, tool_calls=tool_calls
  62. )
  63. except openai.APIConnectionError as e:
  64. response = Response(
  65. response_message=f"Server connection error: {e.__cause__}"
  66. )
  67. except openai.RateLimitError as e:
  68. response = Response(
  69. response_message=f"OpenAI RATE LIMIT error {e.status_code}: (e.response)"
  70. )
  71. except openai.APIStatusError as e:
  72. response = Response(
  73. response_message=f"OpenAI STATUS error {e.status_code}: (e.response)"
  74. )
  75. except openai.BadRequestError as e:
  76. response = Response(
  77. response_message=f"OpenAI BAD REQUEST error {e.status_code}: (e.response)"
  78. )
  79. except Exception as e:
  80. response = Response(response_message=f"An unexpected error occurred: {e}")
  81. return response