s2s_handler.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from typing import Dict, Any, List, Generator
  2. import torch
  3. import os
  4. import logging
  5. from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline
  6. import numpy as np
  7. from queue import Queue, Empty
  8. import threading
  9. import base64
  10. import uuid
  11. import torch
  12. class EndpointHandler:
  13. def __init__(self, path=""):
  14. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  15. lm_model_name = os.getenv('LM_MODEL_NAME', 'meta-llama/Meta-Llama-3.1-8B-Instruct')
  16. chat_size = int(os.getenv('CHAT_SIZE', 10))
  17. (
  18. self.module_kwargs,
  19. self.socket_receiver_kwargs,
  20. self.socket_sender_kwargs,
  21. self.vad_handler_kwargs,
  22. self.whisper_stt_handler_kwargs,
  23. self.paraformer_stt_handler_kwargs,
  24. self.faster_whisper_stt_handler_kwargs,
  25. self.language_model_handler_kwargs,
  26. self.open_api_language_model_handler_kwargs,
  27. self.mlx_language_model_handler_kwargs,
  28. self.parler_tts_handler_kwargs,
  29. self.melo_tts_handler_kwargs,
  30. self.chat_tts_handler_kwargs,
  31. self.facebook_mm_stts_handler_kwargs,
  32. ) = get_default_arguments(mode='none', log_level='DEBUG', lm_model_name=lm_model_name,
  33. tts="melo", device=device, chat_size=chat_size)
  34. setup_logger(self.module_kwargs.log_level)
  35. prepare_all_args(
  36. self.module_kwargs,
  37. self.whisper_stt_handler_kwargs,
  38. self.paraformer_stt_handler_kwargs,
  39. self.faster_whisper_stt_handler_kwargs,
  40. self.language_model_handler_kwargs,
  41. self.open_api_language_model_handler_kwargs,
  42. self.mlx_language_model_handler_kwargs,
  43. self.parler_tts_handler_kwargs,
  44. self.melo_tts_handler_kwargs,
  45. self.chat_tts_handler_kwargs,
  46. self.facebook_mm_stts_handler_kwargs,
  47. )
  48. self.queues_and_events = initialize_queues_and_events()
  49. self.pipeline_manager = build_pipeline(
  50. self.module_kwargs,
  51. self.socket_receiver_kwargs,
  52. self.socket_sender_kwargs,
  53. self.vad_handler_kwargs,
  54. self.whisper_stt_handler_kwargs,
  55. self.paraformer_stt_handler_kwargs,
  56. self.faster_whisper_stt_handler_kwargs,
  57. self.language_model_handler_kwargs,
  58. self.open_api_language_model_handler_kwargs,
  59. self.mlx_language_model_handler_kwargs,
  60. self.parler_tts_handler_kwargs,
  61. self.melo_tts_handler_kwargs,
  62. self.chat_tts_handler_kwargs,
  63. self.facebook_mm_stts_handler_kwargs,
  64. self.queues_and_events,
  65. )
  66. self.vad_chunk_size = 512 # Set the chunk size required by the VAD model
  67. self.sample_rate = 16000 # Set the expected sample rate
  68. def process_streaming_data(self, data: bytes) -> bytes:
  69. audio_array = np.frombuffer(data, dtype=np.int16)
  70. # Process the audio data in chunks
  71. chunks = [audio_array[i:i+self.vad_chunk_size] for i in range(0, len(audio_array), self.vad_chunk_size)]
  72. for chunk in chunks:
  73. if len(chunk) == self.vad_chunk_size:
  74. self.queues_and_events['recv_audio_chunks_queue'].put(chunk.tobytes())
  75. elif len(chunk) < self.vad_chunk_size:
  76. # Pad the last chunk if it's smaller than the required size
  77. padded_chunk = np.pad(chunk, (0, self.vad_chunk_size - len(chunk)), 'constant')
  78. self.queues_and_events['recv_audio_chunks_queue'].put(padded_chunk.tobytes())
  79. # Collect the output, if any
  80. try:
  81. output = self.queues_and_events['send_audio_chunks_queue'].get_nowait() # improvement idea, group all available output chunks
  82. if isinstance(output, np.ndarray):
  83. return output.tobytes()
  84. else:
  85. return output
  86. except Empty:
  87. return None
  88. def cleanup(self):
  89. # Stop the pipeline
  90. self.pipeline_manager.stop()
  91. # Stop the output collector thread
  92. self.queues_and_events['send_audio_chunks_queue'].put(b"END")
  93. self.output_collector_thread.join()