audio_streaming_client.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import threading
  2. from queue import Queue
  3. import sounddevice as sd
  4. import numpy as np
  5. import time
  6. from dataclasses import dataclass, field
  7. import websocket
  8. import ssl
  9. @dataclass
  10. class AudioStreamingClientArguments:
  11. sample_rate: int = field(
  12. default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."}
  13. )
  14. chunk_size: int = field(
  15. default=512,
  16. metadata={"help": "The size of audio chunks in samples. Default is 512."},
  17. )
  18. api_url: str = field(
  19. default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud",
  20. metadata={"help": "The URL of the API endpoint."},
  21. )
  22. auth_token: str = field(
  23. default="your_auth_token",
  24. metadata={"help": "Authentication token for the API."},
  25. )
  26. class AudioStreamingClient:
  27. def __init__(self, args: AudioStreamingClientArguments):
  28. self.args = args
  29. self.stop_event = threading.Event()
  30. self.send_queue = Queue()
  31. self.recv_queue = Queue()
  32. self.session_id = None
  33. self.headers = {
  34. "Accept": "application/json",
  35. "Authorization": f"Bearer {self.args.auth_token}",
  36. "Content-Type": "application/json",
  37. }
  38. self.session_state = (
  39. "idle" # Possible states: idle, sending, processing, waiting
  40. )
  41. self.ws_ready = threading.Event()
  42. def start(self):
  43. print("Starting audio streaming...")
  44. ws_url = self.args.api_url.replace("http", "ws") + "/ws"
  45. self.ws = websocket.WebSocketApp(
  46. ws_url,
  47. header=[f"{key}: {value}" for key, value in self.headers.items()],
  48. on_open=self.on_open,
  49. on_message=self.on_message,
  50. on_error=self.on_error,
  51. on_close=self.on_close,
  52. )
  53. self.ws_thread = threading.Thread(
  54. target=self.ws.run_forever, kwargs={"sslopt": {"cert_reqs": ssl.CERT_NONE}}
  55. )
  56. self.ws_thread.start()
  57. # Wait for the WebSocket to be ready
  58. self.ws_ready.wait()
  59. self.start_audio_streaming()
  60. def start_audio_streaming(self):
  61. self.send_thread = threading.Thread(target=self.send_audio)
  62. self.play_thread = threading.Thread(target=self.play_audio)
  63. with sd.InputStream(
  64. samplerate=self.args.sample_rate,
  65. channels=1,
  66. dtype="int16",
  67. callback=self.audio_input_callback,
  68. blocksize=self.args.chunk_size,
  69. ):
  70. self.send_thread.start()
  71. self.play_thread.start()
  72. input("Press Enter to stop streaming... \n")
  73. self.on_shutdown()
  74. def on_open(self, ws):
  75. print("WebSocket connection opened.")
  76. self.ws_ready.set() # Signal that the WebSocket is ready
  77. def on_message(self, ws, message):
  78. # message is bytes
  79. if message == b"DONE":
  80. print("listen")
  81. self.session_state = "listen"
  82. else:
  83. if self.session_state != "processing":
  84. print("processing")
  85. self.session_state = "processing"
  86. audio_np = np.frombuffer(message, dtype=np.int16)
  87. self.recv_queue.put(audio_np)
  88. def on_error(self, ws, error):
  89. print(f"WebSocket error: {error}")
  90. def on_close(self, ws, close_status_code, close_msg):
  91. print("WebSocket connection closed.")
  92. def on_shutdown(self):
  93. self.stop_event.set()
  94. self.send_thread.join()
  95. self.play_thread.join()
  96. self.ws.close()
  97. self.ws_thread.join()
  98. print("Service shutdown.")
  99. def send_audio(self):
  100. while not self.stop_event.is_set():
  101. if not self.send_queue.empty():
  102. chunk = self.send_queue.get()
  103. if self.session_state != "processing":
  104. self.ws.send(chunk.tobytes(), opcode=websocket.ABNF.OPCODE_BINARY)
  105. else:
  106. self.ws.send([], opcode=websocket.ABNF.OPCODE_BINARY) # handshake
  107. time.sleep(0.01)
  108. def audio_input_callback(self, indata, frames, time, status):
  109. self.send_queue.put(indata.copy())
  110. def audio_out_callback(self, outdata, frames, time, status):
  111. if not self.recv_queue.empty():
  112. chunk = self.recv_queue.get()
  113. # Ensure chunk is int16 and clip to valid range
  114. chunk_int16 = np.clip(chunk, -32768, 32767).astype(np.int16)
  115. if len(chunk_int16) < len(outdata):
  116. outdata[: len(chunk_int16), 0] = chunk_int16
  117. outdata[len(chunk_int16) :] = 0
  118. else:
  119. outdata[:, 0] = chunk_int16[: len(outdata)]
  120. else:
  121. outdata[:] = 0
  122. def play_audio(self):
  123. with sd.OutputStream(
  124. samplerate=self.args.sample_rate,
  125. channels=1,
  126. dtype="int16",
  127. callback=self.audio_out_callback,
  128. blocksize=self.args.chunk_size,
  129. ):
  130. while not self.stop_event.is_set():
  131. time.sleep(0.1)
  132. if __name__ == "__main__":
  133. import argparse
  134. parser = argparse.ArgumentParser(description="Audio Streaming Client")
  135. parser.add_argument(
  136. "--sample_rate",
  137. type=int,
  138. default=16000,
  139. help="Audio sample rate in Hz. Default is 16000.",
  140. )
  141. parser.add_argument(
  142. "--chunk_size",
  143. type=int,
  144. default=1024,
  145. help="The size of audio chunks in samples. Default is 1024.",
  146. )
  147. parser.add_argument(
  148. "--api_url", type=str, required=True, help="The URL of the API endpoint."
  149. )
  150. parser.add_argument(
  151. "--auth_token",
  152. type=str,
  153. required=True,
  154. help="Authentication token for the API.",
  155. )
  156. args = parser.parse_args()
  157. client_args = AudioStreamingClientArguments(**vars(args))
  158. client = AudioStreamingClient(client_args)
  159. client.start()