12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import logging
- from time import perf_counter
- from baseHandler import BaseHandler
- from funasr import AutoModel
- import numpy as np
- from rich.console import Console
- import torch
- logging.basicConfig(
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
- )
- logger = logging.getLogger(__name__)
- console = Console()
- class ParaformerSTTHandler(BaseHandler):
- """
- Handles the Speech To Text generation using a Paraformer model.
- The default for this model is set to Chinese.
- This model was contributed by @wuhongsheng.
- """
- def setup(
- self,
- model_name="paraformer-zh",
- device="cuda",
- gen_kwargs={},
- ):
- print(model_name)
- if len(model_name.split("/")) > 1:
- model_name = model_name.split("/")[-1]
- self.device = device
- self.model = AutoModel(model=model_name, device=device)
- self.warmup()
- def warmup(self):
- logger.info(f"Warming up {self.__class__.__name__}")
- # 2 warmup steps for no compile or compile mode with CUDA graphs capture
- n_steps = 1
- dummy_input = np.array([0] * 512, dtype=np.float32)
- for _ in range(n_steps):
- _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "")
- def process(self, spoken_prompt):
- logger.debug("infering paraformer...")
- global pipeline_start
- pipeline_start = perf_counter()
- pred_text = (
- self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "")
- )
- torch.mps.empty_cache()
- logger.debug("finished paraformer inference")
- console.print(f"[yellow]USER: {pred_text}")
- yield pred_text
|