123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- from threading import Thread
- from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
- pipeline,
- TextIteratorStreamer,
- )
- import torch
- from LLM.chat import Chat
- from baseHandler import BaseHandler
- from rich.console import Console
- import logging
- from nltk import sent_tokenize
- logger = logging.getLogger(__name__)
- console = Console()
- WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
- "en": "english",
- "fr": "french",
- "es": "spanish",
- "zh": "chinese",
- "ja": "japanese",
- "ko": "korean",
- "hi": "hindi",
- }
- class LanguageModelHandler(BaseHandler):
- """
- Handles the language model part.
- """
- def setup(
- self,
- model_name="microsoft/Phi-3-mini-4k-instruct",
- device="cuda",
- torch_dtype="float16",
- gen_kwargs={},
- user_role="user",
- chat_size=1,
- init_chat_role=None,
- init_chat_prompt="You are a helpful AI assistant.",
- ):
- self.device = device
- self.torch_dtype = getattr(torch, torch_dtype)
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
- self.model = AutoModelForCausalLM.from_pretrained(
- model_name, torch_dtype=torch_dtype, trust_remote_code=True
- ).to(device)
- self.pipe = pipeline(
- "text-generation", model=self.model, tokenizer=self.tokenizer, device=device
- )
- self.streamer = TextIteratorStreamer(
- self.tokenizer,
- skip_prompt=True,
- skip_special_tokens=True,
- )
- self.gen_kwargs = {
- "streamer": self.streamer,
- "return_full_text": False,
- **gen_kwargs,
- }
- self.chat = Chat(chat_size)
- if init_chat_role:
- if not init_chat_prompt:
- raise ValueError(
- "An initial promt needs to be specified when setting init_chat_role."
- )
- self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
- self.user_role = user_role
- self.warmup()
- def warmup(self):
- logger.info(f"Warming up {self.__class__.__name__}")
- dummy_input_text = "Repeat the word 'home'."
- dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
- warmup_gen_kwargs = {
- "min_new_tokens": self.gen_kwargs["min_new_tokens"],
- "max_new_tokens": self.gen_kwargs["max_new_tokens"],
- **self.gen_kwargs,
- }
- n_steps = 2
- if self.device == "cuda":
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
- torch.cuda.synchronize()
- start_event.record()
- for _ in range(n_steps):
- thread = Thread(
- target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
- )
- thread.start()
- for _ in self.streamer:
- pass
- if self.device == "cuda":
- end_event.record()
- torch.cuda.synchronize()
- logger.info(
- f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
- )
- def process(self, prompt):
- logger.debug("infering language model...")
- language_code = None
- if isinstance(prompt, tuple):
- prompt, language_code = prompt
- if language_code[-5:] == "-auto":
- language_code = language_code[:-5]
- prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
- self.chat.append({"role": self.user_role, "content": prompt})
- thread = Thread(
- target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
- )
- thread.start()
- if self.device == "mps":
- generated_text = ""
- for new_text in self.streamer:
- generated_text += new_text
- printable_text = generated_text
- torch.mps.empty_cache()
- else:
- generated_text, printable_text = "", ""
- for new_text in self.streamer:
- generated_text += new_text
- printable_text += new_text
- sentences = sent_tokenize(printable_text)
- if len(sentences) > 1:
- yield (sentences[0], language_code)
- printable_text = new_text
- self.chat.append({"role": "assistant", "content": generated_text})
- # don't forget last sentence
- yield (printable_text, language_code)
|