s2s_pipeline.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. import logging
  2. import os
  3. import sys
  4. from copy import copy
  5. from pathlib import Path
  6. from queue import Queue
  7. from threading import Event
  8. from typing import Optional
  9. from sys import platform
  10. from VAD.vad_handler import VADHandler
  11. from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments
  12. from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
  13. from arguments_classes.mlx_language_model_arguments import (
  14. MLXLanguageModelHandlerArguments,
  15. )
  16. from arguments_classes.module_arguments import ModuleArguments
  17. from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments
  18. from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
  19. from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
  20. from arguments_classes.socket_sender_arguments import SocketSenderArguments
  21. from arguments_classes.vad_arguments import VADHandlerArguments
  22. from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
  23. from arguments_classes.faster_whisper_stt_arguments import (
  24. FasterWhisperSTTHandlerArguments,
  25. )
  26. from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
  27. from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments
  28. from arguments_classes.facebookmms_tts_arguments import FacebookMMSTTSHandlerArguments
  29. import torch
  30. import nltk
  31. from rich.console import Console
  32. from transformers import (
  33. HfArgumentParser,
  34. )
  35. from utils.thread_manager import ThreadManager
  36. # Ensure that the necessary NLTK resources are available
  37. try:
  38. nltk.data.find("tokenizers/punkt_tab")
  39. except (LookupError, OSError):
  40. nltk.download("punkt_tab")
  41. try:
  42. nltk.data.find("tokenizers/averaged_perceptron_tagger_eng")
  43. except (LookupError, OSError):
  44. nltk.download("averaged_perceptron_tagger_eng")
  45. # caching allows ~50% compilation time reduction
  46. # see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma
  47. CURRENT_DIR = Path(__file__).resolve().parent
  48. os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
  49. console = Console()
  50. logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs
  51. def rename_args(args, prefix):
  52. """
  53. Rename arguments by removing the prefix and prepares the gen_kwargs.
  54. """
  55. gen_kwargs = {}
  56. for key in copy(args.__dict__):
  57. if key.startswith(prefix):
  58. value = args.__dict__.pop(key)
  59. new_key = key[len(prefix) + 1 :] # Remove prefix and underscore
  60. if new_key.startswith("gen_"):
  61. gen_kwargs[new_key[4:]] = value # Remove 'gen_' and add to dict
  62. else:
  63. args.__dict__[new_key] = value
  64. args.__dict__["gen_kwargs"] = gen_kwargs
  65. def parse_arguments():
  66. parser = HfArgumentParser(
  67. (
  68. ModuleArguments,
  69. SocketReceiverArguments,
  70. SocketSenderArguments,
  71. VADHandlerArguments,
  72. WhisperSTTHandlerArguments,
  73. ParaformerSTTHandlerArguments,
  74. FasterWhisperSTTHandlerArguments,
  75. LanguageModelHandlerArguments,
  76. OpenApiLanguageModelHandlerArguments,
  77. MLXLanguageModelHandlerArguments,
  78. ParlerTTSHandlerArguments,
  79. MeloTTSHandlerArguments,
  80. ChatTTSHandlerArguments,
  81. FacebookMMSTTSHandlerArguments,
  82. )
  83. )
  84. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
  85. # Parse configurations from a JSON file if specified
  86. return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
  87. else:
  88. # Parse arguments from command line if no JSON file is provided
  89. return parser.parse_args_into_dataclasses()
  90. def setup_logger(log_level):
  91. global logger
  92. logging.basicConfig(
  93. level=log_level.upper(),
  94. format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  95. )
  96. logger = logging.getLogger(__name__)
  97. # torch compile logs
  98. if log_level == "debug":
  99. torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
  100. def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
  101. if mac_optimal_settings:
  102. for kwargs in handler_kwargs:
  103. if hasattr(kwargs, "device"):
  104. kwargs.device = "mps"
  105. if hasattr(kwargs, "mode"):
  106. kwargs.mode = "local"
  107. if hasattr(kwargs, "stt"):
  108. kwargs.stt = "whisper-mlx"
  109. if hasattr(kwargs, "llm"):
  110. kwargs.llm = "mlx-lm"
  111. if hasattr(kwargs, "tts"):
  112. kwargs.tts = "melo"
  113. def check_mac_settings(module_kwargs):
  114. if platform == "darwin":
  115. if module_kwargs.device == "cuda":
  116. raise ValueError(
  117. "Cannot use CUDA on macOS. Please set the device to 'cpu' or 'mps'."
  118. )
  119. if module_kwargs.llm != "mlx-lm":
  120. logger.warning(
  121. "For macOS users, it is recommended to use mlx-lm. You can activate it by passing --llm mlx-lm."
  122. )
  123. if module_kwargs.tts != "melo":
  124. logger.warning(
  125. "If you experiences issues generating the voice, considering setting the tts to melo."
  126. )
  127. def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
  128. if common_device:
  129. for kwargs in handler_kwargs:
  130. if hasattr(kwargs, "lm_device"):
  131. kwargs.lm_device = common_device
  132. if hasattr(kwargs, "tts_device"):
  133. kwargs.tts_device = common_device
  134. if hasattr(kwargs, "stt_device"):
  135. kwargs.stt_device = common_device
  136. if hasattr(kwargs, "paraformer_stt_device"):
  137. kwargs.paraformer_stt_device = common_device
  138. if hasattr(kwargs, "facebook_mms_device"):
  139. kwargs.facebook_mms_device = common_device
  140. def prepare_module_args(module_kwargs, *handler_kwargs):
  141. optimal_mac_settings(module_kwargs.local_mac_optimal_settings, module_kwargs)
  142. if platform == "darwin":
  143. check_mac_settings(module_kwargs)
  144. overwrite_device_argument(module_kwargs.device, *handler_kwargs)
  145. def prepare_all_args(
  146. module_kwargs,
  147. whisper_stt_handler_kwargs,
  148. paraformer_stt_handler_kwargs,
  149. faster_whisper_stt_handler_kwargs,
  150. language_model_handler_kwargs,
  151. open_api_language_model_handler_kwargs,
  152. mlx_language_model_handler_kwargs,
  153. parler_tts_handler_kwargs,
  154. melo_tts_handler_kwargs,
  155. chat_tts_handler_kwargs,
  156. facebook_mms_tts_handler_kwargs,
  157. ):
  158. prepare_module_args(
  159. module_kwargs,
  160. whisper_stt_handler_kwargs,
  161. faster_whisper_stt_handler_kwargs,
  162. paraformer_stt_handler_kwargs,
  163. language_model_handler_kwargs,
  164. open_api_language_model_handler_kwargs,
  165. mlx_language_model_handler_kwargs,
  166. parler_tts_handler_kwargs,
  167. melo_tts_handler_kwargs,
  168. chat_tts_handler_kwargs,
  169. facebook_mms_tts_handler_kwargs,
  170. )
  171. rename_args(whisper_stt_handler_kwargs, "stt")
  172. rename_args(faster_whisper_stt_handler_kwargs, "faster_whisper_stt")
  173. rename_args(paraformer_stt_handler_kwargs, "paraformer_stt")
  174. rename_args(language_model_handler_kwargs, "lm")
  175. rename_args(mlx_language_model_handler_kwargs, "mlx_lm")
  176. rename_args(open_api_language_model_handler_kwargs, "open_api")
  177. rename_args(parler_tts_handler_kwargs, "tts")
  178. rename_args(melo_tts_handler_kwargs, "melo")
  179. rename_args(chat_tts_handler_kwargs, "chat_tts")
  180. rename_args(facebook_mms_tts_handler_kwargs, "facebook_mms")
  181. def initialize_queues_and_events():
  182. return {
  183. "stop_event": Event(),
  184. "should_listen": Event(),
  185. "recv_audio_chunks_queue": Queue(),
  186. "send_audio_chunks_queue": Queue(),
  187. "spoken_prompt_queue": Queue(),
  188. "text_prompt_queue": Queue(),
  189. "lm_response_queue": Queue(),
  190. }
  191. def build_pipeline(
  192. module_kwargs,
  193. socket_receiver_kwargs,
  194. socket_sender_kwargs,
  195. vad_handler_kwargs,
  196. whisper_stt_handler_kwargs,
  197. faster_whisper_stt_handler_kwargs,
  198. paraformer_stt_handler_kwargs,
  199. language_model_handler_kwargs,
  200. open_api_language_model_handler_kwargs,
  201. mlx_language_model_handler_kwargs,
  202. parler_tts_handler_kwargs,
  203. melo_tts_handler_kwargs,
  204. chat_tts_handler_kwargs,
  205. facebook_mms_tts_handler_kwargs,
  206. queues_and_events,
  207. ):
  208. stop_event = queues_and_events["stop_event"]
  209. should_listen = queues_and_events["should_listen"]
  210. recv_audio_chunks_queue = queues_and_events["recv_audio_chunks_queue"]
  211. send_audio_chunks_queue = queues_and_events["send_audio_chunks_queue"]
  212. spoken_prompt_queue = queues_and_events["spoken_prompt_queue"]
  213. text_prompt_queue = queues_and_events["text_prompt_queue"]
  214. lm_response_queue = queues_and_events["lm_response_queue"]
  215. if module_kwargs.mode == "local":
  216. from connections.local_audio_streamer import LocalAudioStreamer
  217. local_audio_streamer = LocalAudioStreamer(
  218. input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue
  219. )
  220. comms_handlers = [local_audio_streamer]
  221. should_listen.set()
  222. else:
  223. from connections.socket_receiver import SocketReceiver
  224. from connections.socket_sender import SocketSender
  225. comms_handlers = [
  226. SocketReceiver(
  227. stop_event,
  228. recv_audio_chunks_queue,
  229. should_listen,
  230. host=socket_receiver_kwargs.recv_host,
  231. port=socket_receiver_kwargs.recv_port,
  232. chunk_size=socket_receiver_kwargs.chunk_size,
  233. ),
  234. SocketSender(
  235. stop_event,
  236. send_audio_chunks_queue,
  237. host=socket_sender_kwargs.send_host,
  238. port=socket_sender_kwargs.send_port,
  239. ),
  240. ]
  241. vad = VADHandler(
  242. stop_event,
  243. queue_in=recv_audio_chunks_queue,
  244. queue_out=spoken_prompt_queue,
  245. setup_args=(should_listen,),
  246. setup_kwargs=vars(vad_handler_kwargs),
  247. )
  248. stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, faster_whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs)
  249. lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs)
  250. tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, facebook_mms_tts_handler_kwargs)
  251. return ThreadManager([*comms_handlers, vad, stt, lm, tts])
  252. def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, faster_whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs):
  253. if module_kwargs.stt == "whisper":
  254. from STT.whisper_stt_handler import WhisperSTTHandler
  255. return WhisperSTTHandler(
  256. stop_event,
  257. queue_in=spoken_prompt_queue,
  258. queue_out=text_prompt_queue,
  259. setup_kwargs=vars(whisper_stt_handler_kwargs),
  260. )
  261. elif module_kwargs.stt == "whisper-mlx":
  262. from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
  263. return LightningWhisperSTTHandler(
  264. stop_event,
  265. queue_in=spoken_prompt_queue,
  266. queue_out=text_prompt_queue,
  267. setup_kwargs=vars(whisper_stt_handler_kwargs),
  268. )
  269. elif module_kwargs.stt == "paraformer":
  270. from STT.paraformer_handler import ParaformerSTTHandler
  271. return ParaformerSTTHandler(
  272. stop_event,
  273. queue_in=spoken_prompt_queue,
  274. queue_out=text_prompt_queue,
  275. setup_kwargs=vars(paraformer_stt_handler_kwargs),
  276. )
  277. elif module_kwargs.stt == "faster-whisper":
  278. from STT.faster_whisper_handler import FasterWhisperSTTHandler
  279. return FasterWhisperSTTHandler(
  280. stop_event,
  281. queue_in=spoken_prompt_queue,
  282. queue_out=text_prompt_queue,
  283. setup_kwargs=vars(faster_whisper_stt_handler_kwargs),
  284. )
  285. else:
  286. raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.")
  287. def get_llm_handler(
  288. module_kwargs,
  289. stop_event,
  290. text_prompt_queue,
  291. lm_response_queue,
  292. language_model_handler_kwargs,
  293. open_api_language_model_handler_kwargs,
  294. mlx_language_model_handler_kwargs
  295. ):
  296. if module_kwargs.llm == "transformers":
  297. from LLM.language_model import LanguageModelHandler
  298. return LanguageModelHandler(
  299. stop_event,
  300. queue_in=text_prompt_queue,
  301. queue_out=lm_response_queue,
  302. setup_kwargs=vars(language_model_handler_kwargs),
  303. )
  304. elif module_kwargs.llm == "open_api":
  305. from LLM.openai_api_language_model import OpenApiModelHandler
  306. return OpenApiModelHandler(
  307. stop_event,
  308. queue_in=text_prompt_queue,
  309. queue_out=lm_response_queue,
  310. setup_kwargs=vars(open_api_language_model_handler_kwargs),
  311. )
  312. elif module_kwargs.llm == "mlx-lm":
  313. from LLM.mlx_language_model import MLXLanguageModelHandler
  314. return MLXLanguageModelHandler(
  315. stop_event,
  316. queue_in=text_prompt_queue,
  317. queue_out=lm_response_queue,
  318. setup_kwargs=vars(mlx_language_model_handler_kwargs),
  319. )
  320. else:
  321. raise ValueError("The LLM should be either transformers or mlx-lm")
  322. def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, facebook_mms_tts_handler_kwargs):
  323. if module_kwargs.tts == "parler":
  324. from TTS.parler_handler import ParlerTTSHandler
  325. return ParlerTTSHandler(
  326. stop_event,
  327. queue_in=lm_response_queue,
  328. queue_out=send_audio_chunks_queue,
  329. setup_args=(should_listen,),
  330. setup_kwargs=vars(parler_tts_handler_kwargs),
  331. )
  332. elif module_kwargs.tts == "melo":
  333. try:
  334. from TTS.melo_handler import MeloTTSHandler
  335. except RuntimeError as e:
  336. logger.error(
  337. "Error importing MeloTTSHandler. You might need to run: python -m unidic download"
  338. )
  339. raise e
  340. return MeloTTSHandler(
  341. stop_event,
  342. queue_in=lm_response_queue,
  343. queue_out=send_audio_chunks_queue,
  344. setup_args=(should_listen,),
  345. setup_kwargs=vars(melo_tts_handler_kwargs),
  346. )
  347. elif module_kwargs.tts == "chatTTS":
  348. try:
  349. from TTS.chatTTS_handler import ChatTTSHandler
  350. except RuntimeError as e:
  351. logger.error("Error importing ChatTTSHandler")
  352. raise e
  353. return ChatTTSHandler(
  354. stop_event,
  355. queue_in=lm_response_queue,
  356. queue_out=send_audio_chunks_queue,
  357. setup_args=(should_listen,),
  358. setup_kwargs=vars(chat_tts_handler_kwargs),
  359. )
  360. elif module_kwargs.tts == "facebookMMS":
  361. from TTS.facebookmms_handler import FacebookMMSTTSHandler
  362. return FacebookMMSTTSHandler(
  363. stop_event,
  364. queue_in=lm_response_queue,
  365. queue_out=send_audio_chunks_queue,
  366. setup_args=(should_listen,),
  367. setup_kwargs=vars(facebook_mms_tts_handler_kwargs),
  368. )
  369. else:
  370. raise ValueError("The TTS should be either parler, melo or chatTTS")
  371. def main():
  372. (
  373. module_kwargs,
  374. socket_receiver_kwargs,
  375. socket_sender_kwargs,
  376. vad_handler_kwargs,
  377. whisper_stt_handler_kwargs,
  378. paraformer_stt_handler_kwargs,
  379. faster_whisper_stt_handler_kwargs, # Add this line
  380. language_model_handler_kwargs,
  381. open_api_language_model_handler_kwargs,
  382. mlx_language_model_handler_kwargs,
  383. parler_tts_handler_kwargs,
  384. melo_tts_handler_kwargs,
  385. chat_tts_handler_kwargs,
  386. facebook_mms_tts_handler_kwargs,
  387. ) = parse_arguments()
  388. setup_logger(module_kwargs.log_level)
  389. prepare_all_args(
  390. module_kwargs,
  391. whisper_stt_handler_kwargs,
  392. paraformer_stt_handler_kwargs,
  393. faster_whisper_stt_handler_kwargs, # Add this line
  394. language_model_handler_kwargs,
  395. open_api_language_model_handler_kwargs,
  396. mlx_language_model_handler_kwargs,
  397. parler_tts_handler_kwargs,
  398. melo_tts_handler_kwargs,
  399. chat_tts_handler_kwargs,
  400. facebook_mms_tts_handler_kwargs,
  401. )
  402. queues_and_events = initialize_queues_and_events()
  403. pipeline_manager = build_pipeline(
  404. module_kwargs,
  405. socket_receiver_kwargs,
  406. socket_sender_kwargs,
  407. vad_handler_kwargs,
  408. whisper_stt_handler_kwargs,
  409. faster_whisper_stt_handler_kwargs, # Add this line
  410. paraformer_stt_handler_kwargs,
  411. language_model_handler_kwargs,
  412. open_api_language_model_handler_kwargs,
  413. mlx_language_model_handler_kwargs,
  414. parler_tts_handler_kwargs,
  415. melo_tts_handler_kwargs,
  416. chat_tts_handler_kwargs,
  417. facebook_mms_tts_handler_kwargs,
  418. queues_and_events,
  419. )
  420. try:
  421. pipeline_manager.start()
  422. except KeyboardInterrupt:
  423. pipeline_manager.stop()
  424. if __name__ == "__main__":
  425. main()