lightning_whisper_mlx_handler.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import logging
  2. from time import perf_counter
  3. from baseHandler import BaseHandler
  4. from lightning_whisper_mlx import LightningWhisperMLX
  5. import numpy as np
  6. from rich.console import Console
  7. from copy import copy
  8. import torch
  9. logger = logging.getLogger(__name__)
  10. console = Console()
  11. SUPPORTED_LANGUAGES = [
  12. "en",
  13. "fr",
  14. "es",
  15. "zh",
  16. "ja",
  17. "ko",
  18. ]
  19. class LightningWhisperSTTHandler(BaseHandler):
  20. """
  21. Handles the Speech To Text generation using a Whisper model.
  22. """
  23. def setup(
  24. self,
  25. model_name="distil-large-v3",
  26. device="mps",
  27. torch_dtype="float16",
  28. compile_mode=None,
  29. language=None,
  30. gen_kwargs={},
  31. ):
  32. if len(model_name.split("/")) > 1:
  33. model_name = model_name.split("/")[-1]
  34. self.device = device
  35. self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
  36. self.start_language = language
  37. self.last_language = language
  38. self.warmup()
  39. def warmup(self):
  40. logger.info(f"Warming up {self.__class__.__name__}")
  41. # 2 warmup steps for no compile or compile mode with CUDA graphs capture
  42. n_steps = 1
  43. dummy_input = np.array([0] * 512)
  44. for _ in range(n_steps):
  45. _ = self.model.transcribe(dummy_input)["text"].strip()
  46. def process(self, spoken_prompt):
  47. logger.debug("infering whisper...")
  48. global pipeline_start
  49. pipeline_start = perf_counter()
  50. if self.start_language != 'auto':
  51. transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
  52. else:
  53. transcription_dict = self.model.transcribe(spoken_prompt)
  54. language_code = transcription_dict["language"]
  55. if language_code not in SUPPORTED_LANGUAGES:
  56. logger.warning(f"Whisper detected unsupported language: {language_code}")
  57. if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
  58. transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
  59. else:
  60. transcription_dict = {"text": "", "language": "en"}
  61. else:
  62. self.last_language = language_code
  63. pred_text = transcription_dict["text"].strip()
  64. language_code = transcription_dict["language"]
  65. torch.mps.empty_cache()
  66. logger.debug("finished whisper inference")
  67. console.print(f"[yellow]USER: {pred_text}")
  68. logger.debug(f"Language Code Whisper: {language_code}")
  69. if self.start_language == "auto":
  70. language_code += "-auto"
  71. yield (pred_text, language_code)