1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import torchaudio
- from VAD.vad_iterator import VADIterator
- from baseHandler import BaseHandler
- import numpy as np
- import torch
- from rich.console import Console
- from utils.utils import int2float
- from df.enhance import enhance, init_df
- import logging
- logger = logging.getLogger(__name__)
- console = Console()
- class VADHandler(BaseHandler):
- """
- Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
- to the following part.
- """
- def setup(
- self,
- should_listen,
- thresh=0.3,
- sample_rate=16000,
- min_silence_ms=1000,
- min_speech_ms=500,
- max_speech_ms=float("inf"),
- speech_pad_ms=30,
- audio_enhancement=False,
- ):
- self.should_listen = should_listen
- self.sample_rate = sample_rate
- self.min_silence_ms = min_silence_ms
- self.min_speech_ms = min_speech_ms
- self.max_speech_ms = max_speech_ms
- self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
- self.iterator = VADIterator(
- self.model,
- threshold=thresh,
- sampling_rate=sample_rate,
- min_silence_duration_ms=min_silence_ms,
- speech_pad_ms=speech_pad_ms,
- )
- self.audio_enhancement = audio_enhancement
- if audio_enhancement:
- self.enhanced_model, self.df_state, _ = init_df()
- def process(self, audio_chunk):
- audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
- audio_float32 = int2float(audio_int16)
- vad_output = self.iterator(torch.from_numpy(audio_float32))
- if vad_output is not None and len(vad_output) != 0:
- logger.debug("VAD: end of speech detected")
- array = torch.cat(vad_output).cpu().numpy()
- duration_ms = len(array) / self.sample_rate * 1000
- if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
- logger.debug(
- f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
- )
- else:
- self.should_listen.clear()
- logger.debug("Stop listening")
- if self.audio_enhancement:
- if self.sample_rate != self.df_state.sr():
- audio_float32 = torchaudio.functional.resample(
- torch.from_numpy(array),
- orig_freq=self.sample_rate,
- new_freq=self.df_state.sr(),
- )
- enhanced = enhance(
- self.enhanced_model,
- self.df_state,
- audio_float32.unsqueeze(0),
- )
- enhanced = torchaudio.functional.resample(
- enhanced,
- orig_freq=self.df_state.sr(),
- new_freq=self.sample_rate,
- )
- else:
- enhanced = enhance(
- self.enhanced_model, self.df_state, audio_float32
- )
- array = enhanced.numpy().squeeze()
- yield array
- @property
- def min_time_to_debug(self):
- return 0.00001
|