mlx_language_model.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import logging
  2. from LLM.chat import Chat
  3. from baseHandler import BaseHandler
  4. from mlx_lm import load, stream_generate, generate
  5. from rich.console import Console
  6. import torch
  7. logger = logging.getLogger(__name__)
  8. console = Console()
  9. WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
  10. "en": "english",
  11. "fr": "french",
  12. "es": "spanish",
  13. "zh": "chinese",
  14. "ja": "japanese",
  15. "ko": "korean",
  16. }
  17. class MLXLanguageModelHandler(BaseHandler):
  18. """
  19. Handles the language model part.
  20. """
  21. def setup(
  22. self,
  23. model_name="microsoft/Phi-3-mini-4k-instruct",
  24. device="mps",
  25. torch_dtype="float16",
  26. gen_kwargs={},
  27. user_role="user",
  28. chat_size=1,
  29. init_chat_role=None,
  30. init_chat_prompt="You are a helpful AI assistant.",
  31. ):
  32. self.model_name = model_name
  33. self.model, self.tokenizer = load(self.model_name)
  34. self.gen_kwargs = gen_kwargs
  35. self.chat = Chat(chat_size)
  36. if init_chat_role:
  37. if not init_chat_prompt:
  38. raise ValueError(
  39. "An initial promt needs to be specified when setting init_chat_role."
  40. )
  41. self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
  42. self.user_role = user_role
  43. self.warmup()
  44. def warmup(self):
  45. logger.info(f"Warming up {self.__class__.__name__}")
  46. dummy_input_text = "Repeat the word 'home'."
  47. dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
  48. n_steps = 2
  49. for _ in range(n_steps):
  50. prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
  51. generate(
  52. self.model,
  53. self.tokenizer,
  54. prompt=prompt,
  55. max_tokens=self.gen_kwargs["max_new_tokens"],
  56. verbose=False,
  57. )
  58. def process(self, prompt):
  59. logger.debug("infering language model...")
  60. language_code = None
  61. if isinstance(prompt, tuple):
  62. prompt, language_code = prompt
  63. if language_code[-5:] == "-auto":
  64. language_code = language_code[:-5]
  65. prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
  66. self.chat.append({"role": self.user_role, "content": prompt})
  67. # Remove system messages if using a Gemma model
  68. if "gemma" in self.model_name.lower():
  69. chat_messages = [
  70. msg for msg in self.chat.to_list() if msg["role"] != "system"
  71. ]
  72. else:
  73. chat_messages = self.chat.to_list()
  74. prompt = self.tokenizer.apply_chat_template(
  75. chat_messages, tokenize=False, add_generation_prompt=True
  76. )
  77. output = ""
  78. curr_output = ""
  79. for t in stream_generate(
  80. self.model,
  81. self.tokenizer,
  82. prompt,
  83. max_tokens=self.gen_kwargs["max_new_tokens"],
  84. ):
  85. output += t
  86. curr_output += t
  87. if curr_output.endswith((".", "?", "!", "<|end|>")):
  88. yield (curr_output.replace("<|end|>", ""), language_code)
  89. curr_output = ""
  90. generated_text = output.replace("<|end|>", "")
  91. torch.mps.empty_cache()
  92. self.chat.append({"role": "assistant", "content": generated_text})