facebookmms_handler.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from transformers import VitsModel, AutoTokenizer
  2. import torch
  3. import numpy as np
  4. import librosa
  5. from rich.console import Console
  6. from baseHandler import BaseHandler
  7. import logging
  8. logging.basicConfig(
  9. format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  10. level=logging.DEBUG
  11. )
  12. logger = logging.getLogger(__name__)
  13. console = Console()
  14. WHISPER_LANGUAGE_TO_FACEBOOK_LANGUAGE = {
  15. "en": "eng", # English
  16. "fr": "fra", # French
  17. "es": "spa", # Spanish
  18. "ko": "kor", # Korean
  19. "hi": "hin", # Hindi
  20. "ar": "ara", # Arabic
  21. "ar": "hyw", # Armenian
  22. "az": "azb", # Azerbaijani
  23. "bu": "bul", # Bulgarian
  24. "ca": "cat", # Catalan
  25. "nl": "nld", # Dutch
  26. "fi": "fin", # Finnish
  27. "fr": "fra", # French
  28. "de": "deu", # German
  29. "el": "ell", # Greek
  30. "he": "heb", # Hebrew
  31. "hu": "hun", # Hungarian
  32. "is": "isl", # Icelandic
  33. "id": "ind", # Indonesian
  34. "ka": "kan", # Kannada
  35. "kk": "kaz", # Kazakh
  36. "lv": "lav", # Latvian
  37. "zl": "zlm", # Malay
  38. "ma": "mar", # Marathi
  39. "fa": "fas", # Persian
  40. "po": "pol", # Polish
  41. "pt": "por", # Portuguese
  42. "ro": "ron", # Romanian
  43. "ru": "rus", # Russian
  44. "sw": "swh", # Swahili
  45. "sv": "swe", # Swedish
  46. "tg": "tgl", # Tagalog
  47. "ta": "tam", # Tamil
  48. "th": "tha", # Thai
  49. "tu": "tur", # Turkish
  50. "uk": "ukr", # Ukrainian
  51. "ur": "urd", # Urdu
  52. "vi": "vie", # Vietnamese
  53. "cy": "cym", # Welsh
  54. }
  55. class FacebookMMSTTSHandler(BaseHandler):
  56. def setup(
  57. self,
  58. should_listen,
  59. device="cuda",
  60. torch_dtype="float32",
  61. language="en",
  62. stream=True,
  63. chunk_size=512,
  64. **kwargs
  65. ):
  66. self.should_listen = should_listen
  67. self.device = device
  68. self.torch_dtype = getattr(torch, torch_dtype)
  69. self.stream = stream
  70. self.chunk_size = chunk_size
  71. self.language = language
  72. self.load_model(self.language)
  73. self.warmup()
  74. def load_model(self, language_code):
  75. try:
  76. model_name = f"facebook/mms-tts-{WHISPER_LANGUAGE_TO_FACEBOOK_LANGUAGE[language_code]}"
  77. logger.info(f"Loading model: {model_name}")
  78. self.model = VitsModel.from_pretrained(model_name).to(self.device)
  79. self.tokenizer = AutoTokenizer.from_pretrained(model_name)
  80. self.language = language_code
  81. except KeyError:
  82. logger.warning(f"Unsupported language: {language_code}. Falling back to English.")
  83. self.load_model("en")
  84. def warmup(self):
  85. logger.info(f"Warming up {self.__class__.__name__}")
  86. output = self.generate_audio("Hello, this is a test")
  87. def generate_audio(self, text):
  88. if not text:
  89. logger.warning("Received empty text input")
  90. return None
  91. try:
  92. logger.debug(f"Tokenizing text: {text}")
  93. logger.debug(f"Current language: {self.language}")
  94. logger.debug(f"Tokenizer: {self.tokenizer}")
  95. inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
  96. input_ids = inputs.input_ids.to(self.device).long()
  97. attention_mask = inputs.attention_mask.to(self.device)
  98. logger.debug(f"Input IDs shape: {input_ids.shape}, dtype: {input_ids.dtype}")
  99. logger.debug(f"Input IDs: {input_ids}")
  100. if input_ids.numel() == 0:
  101. logger.error("Input IDs tensor is empty")
  102. return None
  103. with torch.no_grad():
  104. output = self.model(input_ids=input_ids, attention_mask=attention_mask)
  105. logger.debug(f"Output waveform shape: {output.waveform.shape}")
  106. return output.waveform
  107. except Exception as e:
  108. logger.error(f"Error in generate_audio: {str(e)}")
  109. logger.exception("Full traceback:")
  110. return None
  111. def process(self, llm_sentence):
  112. language_code = None
  113. if isinstance(llm_sentence, tuple):
  114. llm_sentence, language_code = llm_sentence
  115. console.print(f"[green]ASSISTANT: {llm_sentence}")
  116. logger.debug(f"Processing text: {llm_sentence}")
  117. logger.debug(f"Language code: {language_code}")
  118. if language_code is not None and self.language != language_code:
  119. try:
  120. logger.info(f"Switching language from {self.language} to {language_code}")
  121. self.load_model(language_code)
  122. except KeyError:
  123. console.print(f"[red]Language {language_code} not supported by Facebook MMS. Using {self.language} instead.")
  124. logger.warning(f"Unsupported language: {language_code}")
  125. audio_output = self.generate_audio(llm_sentence)
  126. if audio_output is None or audio_output.numel() == 0:
  127. logger.warning("No audio output generated")
  128. self.should_listen.set()
  129. return
  130. audio_numpy = audio_output.cpu().numpy().squeeze()
  131. logger.debug(f"Raw audio shape: {audio_numpy.shape}, dtype: {audio_numpy.dtype}")
  132. audio_resampled = librosa.resample(audio_numpy, orig_sr=self.model.config.sampling_rate, target_sr=16000)
  133. logger.debug(f"Resampled audio shape: {audio_resampled.shape}, dtype: {audio_resampled.dtype}")
  134. audio_int16 = (audio_resampled * 32768).astype(np.int16)
  135. logger.debug(f"Final audio shape: {audio_int16.shape}, dtype: {audio_int16.dtype}")
  136. if self.stream:
  137. for i in range(0, len(audio_int16), self.chunk_size):
  138. chunk = audio_int16[i:i + self.chunk_size]
  139. yield np.pad(chunk, (0, self.chunk_size - len(chunk)))
  140. else:
  141. for i in range(0, len(audio_int16), self.chunk_size):
  142. yield np.pad(
  143. audio_int16[i : i + self.chunk_size],
  144. (0, self.chunk_size - len(audio_int16[i : i + self.chunk_size])),
  145. )
  146. self.should_listen.set()