openai_api_language_model.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import logging
  2. import time
  3. from nltk import sent_tokenize
  4. from rich.console import Console
  5. from openai import OpenAI
  6. from baseHandler import BaseHandler
  7. from LLM.chat import Chat
  8. logger = logging.getLogger(__name__)
  9. console = Console()
  10. WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
  11. "en": "english",
  12. "fr": "french",
  13. "es": "spanish",
  14. "zh": "chinese",
  15. "ja": "japanese",
  16. "ko": "korean",
  17. }
  18. class OpenApiModelHandler(BaseHandler):
  19. """
  20. Handles the language model part.
  21. """
  22. def setup(
  23. self,
  24. model_name="deepseek-chat",
  25. device="cuda",
  26. gen_kwargs={},
  27. base_url =None,
  28. api_key=None,
  29. stream=False,
  30. user_role="user",
  31. chat_size=1,
  32. init_chat_role="system",
  33. init_chat_prompt="You are a helpful AI assistant.",
  34. ):
  35. self.model_name = model_name
  36. self.stream = stream
  37. self.chat = Chat(chat_size)
  38. if init_chat_role:
  39. if not init_chat_prompt:
  40. raise ValueError(
  41. "An initial promt needs to be specified when setting init_chat_role."
  42. )
  43. self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
  44. self.user_role = user_role
  45. self.client = OpenAI(api_key=api_key, base_url=base_url)
  46. self.warmup()
  47. def warmup(self):
  48. logger.info(f"Warming up {self.__class__.__name__}")
  49. start = time.time()
  50. response = self.client.chat.completions.create(
  51. model=self.model_name,
  52. messages=[
  53. {"role": "system", "content": "You are a helpful assistant"},
  54. {"role": "user", "content": "Hello"},
  55. ],
  56. stream=self.stream
  57. )
  58. end = time.time()
  59. logger.info(
  60. f"{self.__class__.__name__}: warmed up! time: {(end - start):.3f} s"
  61. )
  62. def process(self, prompt):
  63. logger.debug("call api language model...")
  64. self.chat.append({"role": self.user_role, "content": prompt})
  65. language_code = None
  66. if isinstance(prompt, tuple):
  67. prompt, language_code = prompt
  68. if language_code[-5:] == "-auto":
  69. language_code = language_code[:-5]
  70. prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
  71. response = self.client.chat.completions.create(
  72. model=self.model_name,
  73. messages=[
  74. {"role": self.user_role, "content": prompt},
  75. ],
  76. stream=self.stream
  77. )
  78. if self.stream:
  79. generated_text, printable_text = "", ""
  80. for chunk in response:
  81. new_text = chunk.choices[0].delta.content or ""
  82. generated_text += new_text
  83. printable_text += new_text
  84. sentences = sent_tokenize(printable_text)
  85. if len(sentences) > 1:
  86. yield sentences[0], language_code
  87. printable_text = new_text
  88. self.chat.append({"role": "assistant", "content": generated_text})
  89. # don't forget last sentence
  90. yield printable_text, language_code
  91. else:
  92. generated_text = response.choices[0].message.content
  93. self.chat.append({"role": "assistant", "content": generated_text})
  94. yield generated_text, language_code