language_model.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from threading import Thread
  2. from transformers import (
  3. AutoModelForCausalLM,
  4. AutoTokenizer,
  5. pipeline,
  6. TextIteratorStreamer,
  7. )
  8. import torch
  9. from LLM.chat import Chat
  10. from baseHandler import BaseHandler
  11. from rich.console import Console
  12. import logging
  13. from nltk import sent_tokenize
  14. logger = logging.getLogger(__name__)
  15. console = Console()
  16. WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
  17. "en": "english",
  18. "fr": "french",
  19. "es": "spanish",
  20. "zh": "chinese",
  21. "ja": "japanese",
  22. "ko": "korean",
  23. "hi": "hindi",
  24. }
  25. class LanguageModelHandler(BaseHandler):
  26. """
  27. Handles the language model part.
  28. """
  29. def setup(
  30. self,
  31. model_name="microsoft/Phi-3-mini-4k-instruct",
  32. device="cuda",
  33. torch_dtype="float16",
  34. gen_kwargs={},
  35. user_role="user",
  36. chat_size=1,
  37. init_chat_role=None,
  38. init_chat_prompt="You are a helpful AI assistant.",
  39. ):
  40. self.device = device
  41. self.torch_dtype = getattr(torch, torch_dtype)
  42. self.tokenizer = AutoTokenizer.from_pretrained(model_name)
  43. self.model = AutoModelForCausalLM.from_pretrained(
  44. model_name, torch_dtype=torch_dtype, trust_remote_code=True
  45. ).to(device)
  46. self.pipe = pipeline(
  47. "text-generation", model=self.model, tokenizer=self.tokenizer, device=device
  48. )
  49. self.streamer = TextIteratorStreamer(
  50. self.tokenizer,
  51. skip_prompt=True,
  52. skip_special_tokens=True,
  53. )
  54. self.gen_kwargs = {
  55. "streamer": self.streamer,
  56. "return_full_text": False,
  57. **gen_kwargs,
  58. }
  59. self.chat = Chat(chat_size)
  60. if init_chat_role:
  61. if not init_chat_prompt:
  62. raise ValueError(
  63. "An initial promt needs to be specified when setting init_chat_role."
  64. )
  65. self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
  66. self.user_role = user_role
  67. self.warmup()
  68. def warmup(self):
  69. logger.info(f"Warming up {self.__class__.__name__}")
  70. dummy_input_text = "Repeat the word 'home'."
  71. dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
  72. warmup_gen_kwargs = {
  73. "min_new_tokens": self.gen_kwargs["min_new_tokens"],
  74. "max_new_tokens": self.gen_kwargs["max_new_tokens"],
  75. **self.gen_kwargs,
  76. }
  77. n_steps = 2
  78. if self.device == "cuda":
  79. start_event = torch.cuda.Event(enable_timing=True)
  80. end_event = torch.cuda.Event(enable_timing=True)
  81. torch.cuda.synchronize()
  82. start_event.record()
  83. for _ in range(n_steps):
  84. thread = Thread(
  85. target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
  86. )
  87. thread.start()
  88. for _ in self.streamer:
  89. pass
  90. if self.device == "cuda":
  91. end_event.record()
  92. torch.cuda.synchronize()
  93. logger.info(
  94. f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
  95. )
  96. def process(self, prompt):
  97. logger.debug("infering language model...")
  98. language_code = None
  99. if isinstance(prompt, tuple):
  100. prompt, language_code = prompt
  101. if language_code[-5:] == "-auto":
  102. language_code = language_code[:-5]
  103. prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
  104. self.chat.append({"role": self.user_role, "content": prompt})
  105. thread = Thread(
  106. target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
  107. )
  108. thread.start()
  109. if self.device == "mps":
  110. generated_text = ""
  111. for new_text in self.streamer:
  112. generated_text += new_text
  113. printable_text = generated_text
  114. torch.mps.empty_cache()
  115. else:
  116. generated_text, printable_text = "", ""
  117. for new_text in self.streamer:
  118. generated_text += new_text
  119. printable_text += new_text
  120. sentences = sent_tokenize(printable_text)
  121. if len(sentences) > 1:
  122. yield (sentences[0], language_code)
  123. printable_text = new_text
  124. self.chat.append({"role": "assistant", "content": generated_text})
  125. # don't forget last sentence
  126. yield (printable_text, language_code)