melo_handler.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from melo.api import TTS
  2. import logging
  3. from baseHandler import BaseHandler
  4. import librosa
  5. import numpy as np
  6. from rich.console import Console
  7. import torch
  8. logger = logging.getLogger(__name__)
  9. console = Console()
  10. WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
  11. "en": "EN",
  12. "fr": "FR",
  13. "es": "ES",
  14. "zh": "ZH",
  15. "ja": "JP",
  16. "ko": "KR",
  17. }
  18. WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
  19. "en": "EN-BR",
  20. "fr": "FR",
  21. "es": "ES",
  22. "zh": "ZH",
  23. "ja": "JP",
  24. "ko": "KR",
  25. }
  26. class MeloTTSHandler(BaseHandler):
  27. def setup(
  28. self,
  29. should_listen,
  30. device="mps",
  31. language="en",
  32. speaker_to_id="en",
  33. gen_kwargs={}, # Unused
  34. blocksize=512,
  35. ):
  36. self.should_listen = should_listen
  37. self.device = device
  38. self.language = language
  39. self.model = TTS(
  40. language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
  41. )
  42. self.speaker_id = self.model.hps.data.spk2id[
  43. WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
  44. ]
  45. self.blocksize = blocksize
  46. self.warmup()
  47. def warmup(self):
  48. logger.info(f"Warming up {self.__class__.__name__}")
  49. _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
  50. def process(self, llm_sentence):
  51. language_code = None
  52. if isinstance(llm_sentence, tuple):
  53. llm_sentence, language_code = llm_sentence
  54. console.print(f"[green]ASSISTANT: {llm_sentence}")
  55. if language_code is not None and self.language != language_code:
  56. try:
  57. self.model = TTS(
  58. language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
  59. device=self.device,
  60. )
  61. self.speaker_id = self.model.hps.data.spk2id[
  62. WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
  63. ]
  64. self.language = language_code
  65. except KeyError:
  66. console.print(
  67. f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
  68. )
  69. if self.device == "mps":
  70. import time
  71. start = time.time()
  72. torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
  73. torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
  74. _ = (
  75. time.time() - start
  76. ) # Removing this line makes it fail more often. I'm looking into it.
  77. try:
  78. audio_chunk = self.model.tts_to_file(
  79. llm_sentence, self.speaker_id, quiet=True
  80. )
  81. except (AssertionError, RuntimeError) as e:
  82. logger.error(f"Error in MeloTTSHandler: {e}")
  83. audio_chunk = np.array([])
  84. if len(audio_chunk) == 0:
  85. self.should_listen.set()
  86. return
  87. audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
  88. audio_chunk = (audio_chunk * 32768).astype(np.int16)
  89. for i in range(0, len(audio_chunk), self.blocksize):
  90. yield np.pad(
  91. audio_chunk[i : i + self.blocksize],
  92. (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
  93. )
  94. self.should_listen.set()