test_direct_transfer.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import pickle
  2. import threading
  3. import time
  4. import ray
  5. import ray.streaming._streaming as _streaming
  6. import ray.streaming.runtime.transfer as transfer
  7. from ray._raylet import PythonFunctionDescriptor
  8. from ray.streaming.config import Config
  9. @ray.remote
  10. class Worker:
  11. def __init__(self):
  12. self.writer_client = _streaming.WriterClient()
  13. self.reader_client = _streaming.ReaderClient()
  14. self.writer = None
  15. self.output_channel_id = None
  16. self.reader = None
  17. def init_writer(self, output_channel, reader_actor):
  18. conf = {Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL}
  19. reader_async_func = PythonFunctionDescriptor(
  20. __name__, self.on_reader_message.__name__, self.__class__.__name__)
  21. reader_sync_func = PythonFunctionDescriptor(
  22. __name__, self.on_reader_message_sync.__name__,
  23. self.__class__.__name__)
  24. transfer.ChannelCreationParametersBuilder.\
  25. set_python_reader_function_descriptor(
  26. reader_async_func, reader_sync_func)
  27. self.writer = transfer.DataWriter([output_channel],
  28. [pickle.loads(reader_actor)], conf)
  29. self.output_channel_id = transfer.ChannelID(output_channel)
  30. def init_reader(self, input_channel, writer_actor):
  31. conf = {Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL}
  32. writer_async_func = PythonFunctionDescriptor(
  33. __name__, self.on_writer_message.__name__, self.__class__.__name__)
  34. writer_sync_func = PythonFunctionDescriptor(
  35. __name__, self.on_writer_message_sync.__name__,
  36. self.__class__.__name__)
  37. transfer.ChannelCreationParametersBuilder.\
  38. set_python_writer_function_descriptor(
  39. writer_async_func, writer_sync_func)
  40. self.reader = transfer.DataReader([input_channel],
  41. [pickle.loads(writer_actor)], conf)
  42. def start_write(self, msg_nums):
  43. self.t = threading.Thread(
  44. target=self.run_writer, args=[msg_nums], daemon=True)
  45. self.t.start()
  46. def run_writer(self, msg_nums):
  47. for i in range(msg_nums):
  48. self.writer.write(self.output_channel_id, pickle.dumps(i))
  49. print("WriterWorker done.")
  50. def start_read(self, msg_nums):
  51. self.t = threading.Thread(
  52. target=self.run_reader, args=[msg_nums], daemon=True)
  53. self.t.start()
  54. def run_reader(self, msg_nums):
  55. count = 0
  56. msg = None
  57. while count != msg_nums:
  58. item = self.reader.read(100)
  59. if item is None:
  60. time.sleep(0.01)
  61. else:
  62. msg = pickle.loads(item.body)
  63. count += 1
  64. assert msg == msg_nums - 1
  65. print("ReaderWorker done.")
  66. def is_finished(self):
  67. return not self.t.is_alive()
  68. def on_reader_message(self, buffer: bytes):
  69. """used in direct call mode"""
  70. self.reader_client.on_reader_message(buffer)
  71. def on_reader_message_sync(self, buffer: bytes):
  72. """used in direct call mode"""
  73. if self.reader_client is None:
  74. return b" " * 4 # special flag to indicate this actor not ready
  75. result = self.reader_client.on_reader_message_sync(buffer)
  76. return result.to_pybytes()
  77. def on_writer_message(self, buffer: bytes):
  78. """used in direct call mode"""
  79. self.writer_client.on_writer_message(buffer)
  80. def on_writer_message_sync(self, buffer: bytes):
  81. """used in direct call mode"""
  82. if self.writer_client is None:
  83. return b" " * 4 # special flag to indicate this actor not ready
  84. result = self.writer_client.on_writer_message_sync(buffer)
  85. return result.to_pybytes()
  86. def test_queue():
  87. ray.init()
  88. writer = Worker._remote()
  89. reader = Worker._remote()
  90. channel_id_str = transfer.ChannelID.gen_random_id()
  91. inits = [
  92. writer.init_writer.remote(channel_id_str, pickle.dumps(reader)),
  93. reader.init_reader.remote(channel_id_str, pickle.dumps(writer))
  94. ]
  95. ray.get(inits)
  96. msg_nums = 1000
  97. print("start read/write")
  98. reader.start_read.remote(msg_nums)
  99. writer.start_write.remote(msg_nums)
  100. while not ray.get(reader.is_finished.remote()):
  101. time.sleep(0.1)
  102. ray.shutdown()
  103. if __name__ == "__main__":
  104. test_queue()