paraformer_handler.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import logging
  2. from time import perf_counter
  3. from baseHandler import BaseHandler
  4. from funasr import AutoModel
  5. import numpy as np
  6. from rich.console import Console
  7. import torch
  8. logging.basicConfig(
  9. format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  10. )
  11. logger = logging.getLogger(__name__)
  12. console = Console()
  13. class ParaformerSTTHandler(BaseHandler):
  14. """
  15. Handles the Speech To Text generation using a Paraformer model.
  16. The default for this model is set to Chinese.
  17. This model was contributed by @wuhongsheng.
  18. """
  19. def setup(
  20. self,
  21. model_name="paraformer-zh",
  22. device="cuda",
  23. gen_kwargs={},
  24. ):
  25. print(model_name)
  26. if len(model_name.split("/")) > 1:
  27. model_name = model_name.split("/")[-1]
  28. self.device = device
  29. self.model = AutoModel(model=model_name, device=device)
  30. self.warmup()
  31. def warmup(self):
  32. logger.info(f"Warming up {self.__class__.__name__}")
  33. # 2 warmup steps for no compile or compile mode with CUDA graphs capture
  34. n_steps = 1
  35. dummy_input = np.array([0] * 512, dtype=np.float32)
  36. for _ in range(n_steps):
  37. _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "")
  38. def process(self, spoken_prompt):
  39. logger.debug("infering paraformer...")
  40. global pipeline_start
  41. pipeline_start = perf_counter()
  42. pred_text = (
  43. self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "")
  44. )
  45. torch.mps.empty_cache()
  46. logger.debug("finished paraformer inference")
  47. console.print(f"[yellow]USER: {pred_text}")
  48. yield pred_text