vad_handler.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import torchaudio
  2. from VAD.vad_iterator import VADIterator
  3. from baseHandler import BaseHandler
  4. import numpy as np
  5. import torch
  6. from rich.console import Console
  7. from utils.utils import int2float
  8. from df.enhance import enhance, init_df
  9. import logging
  10. logger = logging.getLogger(__name__)
  11. console = Console()
  12. class VADHandler(BaseHandler):
  13. """
  14. Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
  15. to the following part.
  16. """
  17. def setup(
  18. self,
  19. should_listen,
  20. thresh=0.3,
  21. sample_rate=16000,
  22. min_silence_ms=1000,
  23. min_speech_ms=500,
  24. max_speech_ms=float("inf"),
  25. speech_pad_ms=30,
  26. audio_enhancement=False,
  27. ):
  28. self.should_listen = should_listen
  29. self.sample_rate = sample_rate
  30. self.min_silence_ms = min_silence_ms
  31. self.min_speech_ms = min_speech_ms
  32. self.max_speech_ms = max_speech_ms
  33. self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
  34. self.iterator = VADIterator(
  35. self.model,
  36. threshold=thresh,
  37. sampling_rate=sample_rate,
  38. min_silence_duration_ms=min_silence_ms,
  39. speech_pad_ms=speech_pad_ms,
  40. )
  41. self.audio_enhancement = audio_enhancement
  42. if audio_enhancement:
  43. self.enhanced_model, self.df_state, _ = init_df()
  44. def process(self, audio_chunk):
  45. audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
  46. audio_float32 = int2float(audio_int16)
  47. vad_output = self.iterator(torch.from_numpy(audio_float32))
  48. if vad_output is not None and len(vad_output) != 0:
  49. logger.debug("VAD: end of speech detected")
  50. array = torch.cat(vad_output).cpu().numpy()
  51. duration_ms = len(array) / self.sample_rate * 1000
  52. if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
  53. logger.debug(
  54. f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
  55. )
  56. else:
  57. self.should_listen.clear()
  58. logger.debug("Stop listening")
  59. if self.audio_enhancement:
  60. if self.sample_rate != self.df_state.sr():
  61. audio_float32 = torchaudio.functional.resample(
  62. torch.from_numpy(array),
  63. orig_freq=self.sample_rate,
  64. new_freq=self.df_state.sr(),
  65. )
  66. enhanced = enhance(
  67. self.enhanced_model,
  68. self.df_state,
  69. audio_float32.unsqueeze(0),
  70. )
  71. enhanced = torchaudio.functional.resample(
  72. enhanced,
  73. orig_freq=self.df_state.sr(),
  74. new_freq=self.sample_rate,
  75. )
  76. else:
  77. enhanced = enhance(
  78. self.enhanced_model, self.df_state, audio_float32
  79. )
  80. array = enhanced.numpy().squeeze()
  81. yield array
  82. @property
  83. def min_time_to_debug(self):
  84. return 0.00001