faster_whisper_handler.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import logging
  2. import os
  3. from time import perf_counter
  4. from faster_whisper import WhisperModel
  5. from rich.console import Console
  6. from baseHandler import BaseHandler
  7. console = Console()
  8. logger = logging.getLogger(__name__)
  9. class FasterWhisperSTTHandler(BaseHandler):
  10. """
  11. Handles the Speech To Text generation using a Whisper model.
  12. """
  13. def setup(
  14. self,
  15. model_name: str = "tiny.en",
  16. device: str = "auto",
  17. compute_type: str = "auto",
  18. gen_kwargs={},
  19. ):
  20. self.gen_kwargs = self.adapt_gen_kwargs(gen_kwargs)
  21. os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
  22. self.model = WhisperModel(model_name, device=device, compute_type=compute_type)
  23. def process(self, audio):
  24. logger.debug("infering faster whisper...")
  25. global pipeline_start
  26. pipeline_start = perf_counter()
  27. segments, info = self.model.transcribe(audio, **self.gen_kwargs)
  28. output_text = []
  29. for segment in segments:
  30. logger.debug(
  31. "[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)
  32. )
  33. output_text.append(segment.text)
  34. pred_text = " ".join(output_text).strip()
  35. logger.debug("finished whisper inference")
  36. if pred_text:
  37. console.print(f"[yellow]USER: {pred_text}")
  38. yield pred_text
  39. else:
  40. logger.debug("no text detected. skipping...")
  41. def cleanup(self):
  42. print("Stopping FasterWhisperSTTHandler")
  43. del self.model
  44. def adapt_gen_kwargs(self, gen_kwargs: dict):
  45. gen_kwargs["without_timestamps"] = not gen_kwargs.pop("return_timestamps", True)
  46. return gen_kwargs