collector.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import logging
  2. import typing
  3. from abc import ABC, abstractmethod
  4. from ray import Language
  5. from ray.actor import ActorHandle
  6. from ray.streaming import function
  7. from ray.streaming import message
  8. from ray.streaming import partition
  9. from ray.streaming.runtime import serialization
  10. from ray.streaming.runtime.transfer import ChannelID, DataWriter
  11. logger = logging.getLogger(__name__)
  12. class Collector(ABC):
  13. """
  14. The collector that collects data from an upstream operator,
  15. and emits data to downstream operators.
  16. """
  17. @abstractmethod
  18. def collect(self, record):
  19. pass
  20. class CollectionCollector(Collector):
  21. def __init__(self, collector_list):
  22. self._collector_list = collector_list
  23. def collect(self, value):
  24. for collector in self._collector_list:
  25. collector.collect(message.Record(value))
  26. class OutputCollector(Collector):
  27. def __init__(self, writer: DataWriter, channel_ids: typing.List[str],
  28. target_actors: typing.List[ActorHandle],
  29. partition_func: partition.Partition):
  30. self._writer = writer
  31. self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
  32. self._target_languages = []
  33. for actor in target_actors:
  34. if actor._ray_actor_language == Language.PYTHON:
  35. self._target_languages.append(function.Language.PYTHON)
  36. elif actor._ray_actor_language == Language.JAVA:
  37. self._target_languages.append(function.Language.JAVA)
  38. else:
  39. raise Exception("Unsupported language {}"
  40. .format(actor._ray_actor_language))
  41. self._partition_func = partition_func
  42. self.python_serializer = serialization.PythonSerializer()
  43. self.cross_lang_serializer = serialization.CrossLangSerializer()
  44. logger.info(
  45. "Create OutputCollector, channel_ids {}, partition_func {}".format(
  46. channel_ids, partition_func))
  47. def collect(self, record):
  48. partitions = self._partition_func \
  49. .partition(record, len(self._channel_ids))
  50. python_buffer = None
  51. cross_lang_buffer = None
  52. for partition_index in partitions:
  53. if self._target_languages[partition_index] == \
  54. function.Language.PYTHON:
  55. # avoid repeated serialization
  56. if python_buffer is None:
  57. python_buffer = self.python_serializer.serialize(record)
  58. self._writer.write(
  59. self._channel_ids[partition_index],
  60. bytes([serialization.PYTHON_TYPE_ID]) + python_buffer)
  61. else:
  62. # avoid repeated serialization
  63. if cross_lang_buffer is None:
  64. cross_lang_buffer = self.cross_lang_serializer.serialize(
  65. record)
  66. self._writer.write(
  67. self._channel_ids[partition_index],
  68. bytes([serialization.CROSS_LANG_TYPE_ID]) +
  69. cross_lang_buffer)