listen_and_play.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import socket
  2. import threading
  3. from queue import Queue
  4. from dataclasses import dataclass, field
  5. import sounddevice as sd
  6. from transformers import HfArgumentParser
  7. @dataclass
  8. class ListenAndPlayArguments:
  9. send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
  10. recv_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
  11. list_play_chunk_size: int = field(
  12. default=1024,
  13. metadata={"help": "The size of data chunks (in bytes). Default is 1024."},
  14. )
  15. host: str = field(
  16. default="localhost",
  17. metadata={
  18. "help": "The hostname or IP address for listening and playing. Default is 'localhost'."
  19. },
  20. )
  21. send_port: int = field(
  22. default=12345,
  23. metadata={"help": "The network port for sending data. Default is 12345."},
  24. )
  25. recv_port: int = field(
  26. default=12346,
  27. metadata={"help": "The network port for receiving data. Default is 12346."},
  28. )
  29. def listen_and_play(
  30. send_rate=16000,
  31. recv_rate=44100,
  32. list_play_chunk_size=1024,
  33. host="localhost",
  34. send_port=12345,
  35. recv_port=12346,
  36. ):
  37. send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  38. send_socket.connect((host, send_port))
  39. recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  40. recv_socket.connect((host, recv_port))
  41. print("Recording and streaming...")
  42. stop_event = threading.Event()
  43. recv_queue = Queue()
  44. send_queue = Queue()
  45. def callback_recv(outdata, frames, time, status):
  46. if not recv_queue.empty():
  47. data = recv_queue.get()
  48. outdata[: len(data)] = data
  49. outdata[len(data) :] = b"\x00" * (len(outdata) - len(data))
  50. else:
  51. outdata[:] = b"\x00" * len(outdata)
  52. def callback_send(indata, frames, time, status):
  53. if recv_queue.empty():
  54. data = bytes(indata)
  55. send_queue.put(data)
  56. def send(stop_event, send_queue):
  57. while not stop_event.is_set():
  58. data = send_queue.get()
  59. send_socket.sendall(data)
  60. def recv(stop_event, recv_queue):
  61. def receive_full_chunk(conn, chunk_size):
  62. data = b""
  63. while len(data) < chunk_size:
  64. packet = conn.recv(chunk_size - len(data))
  65. if not packet:
  66. return None # Connection has been closed
  67. data += packet
  68. return data
  69. while not stop_event.is_set():
  70. data = receive_full_chunk(recv_socket, list_play_chunk_size * 2)
  71. if data:
  72. recv_queue.put(data)
  73. try:
  74. send_stream = sd.RawInputStream(
  75. samplerate=send_rate,
  76. channels=1,
  77. dtype="int16",
  78. blocksize=list_play_chunk_size,
  79. callback=callback_send,
  80. )
  81. recv_stream = sd.RawOutputStream(
  82. samplerate=recv_rate,
  83. channels=1,
  84. dtype="int16",
  85. blocksize=list_play_chunk_size,
  86. callback=callback_recv,
  87. )
  88. threading.Thread(target=send_stream.start).start()
  89. threading.Thread(target=recv_stream.start).start()
  90. send_thread = threading.Thread(target=send, args=(stop_event, send_queue))
  91. send_thread.start()
  92. recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue))
  93. recv_thread.start()
  94. input("Press Enter to stop...")
  95. except KeyboardInterrupt:
  96. print("Finished streaming.")
  97. finally:
  98. stop_event.set()
  99. # Given that socket::recv is blocking in receive_data_chunk, shut it down to allow the thread to continue.
  100. recv_socket.shutdown(socket.SHUT_RDWR)
  101. recv_thread.join()
  102. send_thread.join()
  103. send_socket.close()
  104. recv_socket.close()
  105. print("Connection closed.")
  106. if __name__ == "__main__":
  107. parser = HfArgumentParser((ListenAndPlayArguments,))
  108. (listen_and_play_kwargs,) = parser.parse_args_into_dataclasses()
  109. listen_and_play(**vars(listen_and_play_kwargs))