melo_handler.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from melo.api import TTS
  2. import logging
  3. from baseHandler import BaseHandler
  4. import librosa
  5. import numpy as np
  6. from rich.console import Console
  7. import torch
  8. logger = logging.getLogger(__name__)
  9. console = Console()
  10. WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
  11. "en": "EN",
  12. "fr": "FR",
  13. "es": "ES",
  14. "zh": "ZH",
  15. "ja": "JP",
  16. "ko": "KR",
  17. }
  18. WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
  19. "en": "EN-BR",
  20. "fr": "FR",
  21. "es": "ES",
  22. "zh": "ZH",
  23. "ja": "JP",
  24. "ko": "KR",
  25. }
  26. class MeloTTSHandler(BaseHandler):
  27. def setup(
  28. self,
  29. should_listen,
  30. device="mps",
  31. language="en",
  32. speaker_to_id="en",
  33. gen_kwargs={}, # Unused
  34. blocksize=512,
  35. ):
  36. self.should_listen = should_listen
  37. self.device = device
  38. self.language = language
  39. self.model = TTS(
  40. language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
  41. )
  42. self.speaker_id = self.model.hps.data.spk2id[
  43. WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
  44. ]
  45. self.blocksize = blocksize
  46. self.warmup()
  47. def warmup(self):
  48. logger.info(f"Warming up {self.__class__.__name__}")
  49. _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
  50. def process(self, llm_sentence):
  51. if llm_sentence == b"DONE":
  52. self.should_listen.set()
  53. yield b"DONE"
  54. return
  55. language_code = None
  56. if isinstance(llm_sentence, tuple):
  57. llm_sentence, language_code = llm_sentence
  58. console.print(f"[green]ASSISTANT: {llm_sentence}")
  59. if language_code is not None and self.language != language_code:
  60. try:
  61. self.model = TTS(
  62. language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
  63. device=self.device,
  64. )
  65. self.speaker_id = self.model.hps.data.spk2id[
  66. WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
  67. ]
  68. self.language = language_code
  69. except KeyError:
  70. console.print(
  71. f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
  72. )
  73. if self.device == "mps":
  74. import time
  75. start = time.time()
  76. torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
  77. torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
  78. _ = (
  79. time.time() - start
  80. ) # Removing this line makes it fail more often. I'm looking into it.
  81. try:
  82. audio_chunk = self.model.tts_to_file(
  83. llm_sentence, self.speaker_id, quiet=True
  84. )
  85. except (AssertionError, RuntimeError) as e:
  86. logger.error(f"Error in MeloTTSHandler: {e}")
  87. audio_chunk = np.zeros([self.blocksize])
  88. except Exception as e:
  89. logger.error(f"Unknown error in MeloTTSHandler: {e}")
  90. audio_chunk = np.zeros([self.blocksize])
  91. audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
  92. audio_chunk = (audio_chunk * 32768).astype(np.int16)
  93. for i in range(0, len(audio_chunk), self.blocksize):
  94. yield np.pad(
  95. audio_chunk[i : i + self.blocksize],
  96. (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
  97. )
  98. self.should_listen.set()