12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- import logging
- from time import perf_counter
- from baseHandler import BaseHandler
- from lightning_whisper_mlx import LightningWhisperMLX
- import numpy as np
- from rich.console import Console
- from copy import copy
- import torch
- logger = logging.getLogger(__name__)
- console = Console()
- SUPPORTED_LANGUAGES = [
- "en",
- "fr",
- "es",
- "zh",
- "ja",
- "ko",
- ]
- class LightningWhisperSTTHandler(BaseHandler):
- """
- Handles the Speech To Text generation using a Whisper model.
- """
- def setup(
- self,
- model_name="distil-large-v3",
- device="mps",
- torch_dtype="float16",
- compile_mode=None,
- language=None,
- gen_kwargs={},
- ):
- if len(model_name.split("/")) > 1:
- model_name = model_name.split("/")[-1]
- self.device = device
- self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
- self.start_language = language
- self.last_language = language
- self.warmup()
- def warmup(self):
- logger.info(f"Warming up {self.__class__.__name__}")
- # 2 warmup steps for no compile or compile mode with CUDA graphs capture
- n_steps = 1
- dummy_input = np.array([0] * 512)
- for _ in range(n_steps):
- _ = self.model.transcribe(dummy_input)["text"].strip()
- def process(self, spoken_prompt):
- logger.debug("infering whisper...")
- global pipeline_start
- pipeline_start = perf_counter()
- if self.start_language != 'auto':
- transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
- else:
- transcription_dict = self.model.transcribe(spoken_prompt)
- language_code = transcription_dict["language"]
- if language_code not in SUPPORTED_LANGUAGES:
- logger.warning(f"Whisper detected unsupported language: {language_code}")
- if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
- transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
- else:
- transcription_dict = {"text": "", "language": "en"}
- else:
- self.last_language = language_code
- pred_text = transcription_dict["text"].strip()
- language_code = transcription_dict["language"]
- torch.mps.empty_cache()
- logger.debug("finished whisper inference")
- console.print(f"[yellow]USER: {pred_text}")
- logger.debug(f"Language Code Whisper: {language_code}")
- if self.start_language == "auto":
- language_code += "-auto"
-
- yield (pred_text, language_code)
|