123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- from melo.api import TTS
- import logging
- from baseHandler import BaseHandler
- import librosa
- import numpy as np
- from rich.console import Console
- import torch
- logger = logging.getLogger(__name__)
- console = Console()
- WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
- "en": "EN",
- "fr": "FR",
- "es": "ES",
- "zh": "ZH",
- "ja": "JP",
- "ko": "KR",
- }
- WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
- "en": "EN-BR",
- "fr": "FR",
- "es": "ES",
- "zh": "ZH",
- "ja": "JP",
- "ko": "KR",
- }
- class MeloTTSHandler(BaseHandler):
- def setup(
- self,
- should_listen,
- device="mps",
- language="en",
- speaker_to_id="en",
- gen_kwargs={}, # Unused
- blocksize=512,
- ):
- self.should_listen = should_listen
- self.device = device
- self.language = language
- self.model = TTS(
- language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
- )
- self.speaker_id = self.model.hps.data.spk2id[
- WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
- ]
- self.blocksize = blocksize
- self.warmup()
- def warmup(self):
- logger.info(f"Warming up {self.__class__.__name__}")
- _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
- def process(self, llm_sentence):
- if llm_sentence == b"DONE":
- self.should_listen.set()
- yield b"DONE"
- return
-
- language_code = None
- if isinstance(llm_sentence, tuple):
- llm_sentence, language_code = llm_sentence
- console.print(f"[green]ASSISTANT: {llm_sentence}")
- if language_code is not None and self.language != language_code:
- try:
- self.model = TTS(
- language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
- device=self.device,
- )
- self.speaker_id = self.model.hps.data.spk2id[
- WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
- ]
- self.language = language_code
- except KeyError:
- console.print(
- f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
- )
- if self.device == "mps":
- import time
- start = time.time()
- torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
- torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
- _ = (
- time.time() - start
- ) # Removing this line makes it fail more often. I'm looking into it.
- try:
- audio_chunk = self.model.tts_to_file(
- llm_sentence, self.speaker_id, quiet=True
- )
- except (AssertionError, RuntimeError) as e:
- logger.error(f"Error in MeloTTSHandler: {e}")
- audio_chunk = np.zeros([self.blocksize])
- except Exception as e:
- logger.error(f"Unknown error in MeloTTSHandler: {e}")
- audio_chunk = np.zeros([self.blocksize])
- audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
- audio_chunk = (audio_chunk * 32768).astype(np.int16)
- for i in range(0, len(audio_chunk), self.blocksize):
- yield np.pad(
- audio_chunk[i : i + self.blocksize],
- (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
- )
- self.should_listen.set()
|