operator.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import enum
  2. import importlib
  3. import logging
  4. from abc import ABC, abstractmethod
  5. from ray.streaming import function
  6. from ray.streaming import message
  7. from ray.streaming.collector import Collector
  8. from ray.streaming.collector import CollectionCollector
  9. from ray.streaming.function import SourceFunction
  10. from ray.streaming.runtime import gateway_client
  11. logger = logging.getLogger(__name__)
  12. class OperatorType(enum.Enum):
  13. SOURCE = 0 # Sources are where your program reads its input from
  14. ONE_INPUT = 1 # This operator has one data stream as it's input stream.
  15. TWO_INPUT = 2 # This operator has two data stream as it's input stream.
  16. class Operator(ABC):
  17. """
  18. Abstract base class for all operators.
  19. An operator is used to run a :class:`function.Function`.
  20. """
  21. @abstractmethod
  22. def open(self, collectors, runtime_context):
  23. pass
  24. @abstractmethod
  25. def finish(self):
  26. pass
  27. @abstractmethod
  28. def close(self):
  29. pass
  30. @abstractmethod
  31. def operator_type(self) -> OperatorType:
  32. pass
  33. @abstractmethod
  34. def save_checkpoint(self):
  35. pass
  36. @abstractmethod
  37. def load_checkpoint(self, checkpoint_obj):
  38. pass
  39. class OneInputOperator(Operator, ABC):
  40. """Interface for stream operators with one input."""
  41. @abstractmethod
  42. def process_element(self, record):
  43. pass
  44. def operator_type(self):
  45. return OperatorType.ONE_INPUT
  46. class TwoInputOperator(Operator, ABC):
  47. """Interface for stream operators with two input"""
  48. @abstractmethod
  49. def process_element(self, record1, record2):
  50. pass
  51. def operator_type(self):
  52. return OperatorType.TWO_INPUT
  53. class StreamOperator(Operator, ABC):
  54. """
  55. Basic interface for stream operators. Implementers would implement one of
  56. :class:`OneInputOperator` or :class:`TwoInputOperator` to to create
  57. operators that process elements.
  58. """
  59. def __init__(self, func):
  60. self.func = func
  61. self.collectors = None
  62. self.runtime_context = None
  63. def open(self, collectors, runtime_context):
  64. self.collectors = collectors
  65. self.runtime_context = runtime_context
  66. self.func.open(runtime_context)
  67. def finish(self):
  68. pass
  69. def close(self):
  70. self.func.close()
  71. def collect(self, record):
  72. for collector in self.collectors:
  73. collector.collect(record)
  74. def save_checkpoint(self):
  75. self.func.save_checkpoint()
  76. def load_checkpoint(self, checkpoint_obj):
  77. self.func.load_checkpoint(checkpoint_obj)
  78. class SourceOperator(Operator, ABC):
  79. @abstractmethod
  80. def fetch(self):
  81. pass
  82. class SourceOperatorImpl(SourceOperator, StreamOperator):
  83. """
  84. Operator to run a :class:`function.SourceFunction`
  85. """
  86. class SourceContextImpl(function.SourceContext):
  87. def __init__(self, collectors):
  88. self.collectors = collectors
  89. def collect(self, value):
  90. for collector in self.collectors:
  91. collector.collect(message.Record(value))
  92. def __init__(self, func: SourceFunction):
  93. assert isinstance(func, function.SourceFunction)
  94. super().__init__(func)
  95. self.source_context = None
  96. def open(self, collectors, runtime_context):
  97. super().open(collectors, runtime_context)
  98. self.source_context = SourceOperatorImpl.SourceContextImpl(collectors)
  99. self.func.init(runtime_context.get_parallelism(),
  100. runtime_context.get_task_index())
  101. def fetch(self):
  102. self.func.fetch(self.source_context)
  103. def operator_type(self):
  104. return OperatorType.SOURCE
  105. class MapOperator(StreamOperator, OneInputOperator):
  106. """
  107. Operator to run a :class:`function.MapFunction`
  108. """
  109. def __init__(self, map_func: function.MapFunction):
  110. assert isinstance(map_func, function.MapFunction)
  111. super().__init__(map_func)
  112. def process_element(self, record):
  113. self.collect(message.Record(self.func.map(record.value)))
  114. class FlatMapOperator(StreamOperator, OneInputOperator):
  115. """
  116. Operator to run a :class:`function.FlatMapFunction`
  117. """
  118. def __init__(self, flat_map_func: function.FlatMapFunction):
  119. assert isinstance(flat_map_func, function.FlatMapFunction)
  120. super().__init__(flat_map_func)
  121. self.collection_collector = None
  122. def open(self, collectors, runtime_context):
  123. super().open(collectors, runtime_context)
  124. self.collection_collector = CollectionCollector(collectors)
  125. def process_element(self, record):
  126. self.func.flat_map(record.value, self.collection_collector)
  127. class FilterOperator(StreamOperator, OneInputOperator):
  128. """
  129. Operator to run a :class:`function.FilterFunction`
  130. """
  131. def __init__(self, filter_func: function.FilterFunction):
  132. assert isinstance(filter_func, function.FilterFunction)
  133. super().__init__(filter_func)
  134. def process_element(self, record):
  135. if self.func.filter(record.value):
  136. self.collect(record)
  137. class KeyByOperator(StreamOperator, OneInputOperator):
  138. """
  139. Operator to run a :class:`function.KeyFunction`
  140. """
  141. def __init__(self, key_func: function.KeyFunction):
  142. assert isinstance(key_func, function.KeyFunction)
  143. super().__init__(key_func)
  144. def process_element(self, record):
  145. key = self.func.key_by(record.value)
  146. self.collect(message.KeyRecord(key, record.value))
  147. class ReduceOperator(StreamOperator, OneInputOperator):
  148. """
  149. Operator to run a :class:`function.ReduceFunction`
  150. """
  151. def __init__(self, reduce_func: function.ReduceFunction):
  152. assert isinstance(reduce_func, function.ReduceFunction)
  153. super().__init__(reduce_func)
  154. self.reduce_state = {}
  155. def open(self, collectors, runtime_context):
  156. super().open(collectors, runtime_context)
  157. def process_element(self, record: message.KeyRecord):
  158. key = record.key
  159. value = record.value
  160. if key in self.reduce_state:
  161. old_value = self.reduce_state[key]
  162. new_value = self.func.reduce(old_value, value)
  163. self.reduce_state[key] = new_value
  164. self.collect(message.Record(new_value))
  165. else:
  166. self.reduce_state[key] = value
  167. self.collect(record)
  168. class SinkOperator(StreamOperator, OneInputOperator):
  169. """
  170. Operator to run a :class:`function.SinkFunction`
  171. """
  172. def __init__(self, sink_func: function.SinkFunction):
  173. assert isinstance(sink_func, function.SinkFunction)
  174. super().__init__(sink_func)
  175. def process_element(self, record):
  176. self.func.sink(record.value)
  177. class UnionOperator(StreamOperator, OneInputOperator):
  178. """Operator for union operation"""
  179. def __init__(self):
  180. super().__init__(function.EmptyFunction())
  181. def process_element(self, record):
  182. self.collect(record)
  183. class ChainedOperator(StreamOperator, ABC):
  184. class ForwardCollector(Collector):
  185. def __init__(self, succeeding_operator):
  186. self.succeeding_operator = succeeding_operator
  187. def collect(self, record):
  188. self.succeeding_operator.process_element(record)
  189. def __init__(self, operators, configs):
  190. super().__init__(operators[0].func)
  191. self.operators = operators
  192. self.configs = configs
  193. def open(self, collectors, runtime_context):
  194. # Dont' call super.open() as we `open` every operator separately.
  195. num_operators = len(self.operators)
  196. succeeding_collectors = [
  197. ChainedOperator.ForwardCollector(operator)
  198. for operator in self.operators[1:]
  199. ]
  200. for i in range(0, num_operators - 1):
  201. forward_collectors = [succeeding_collectors[i]]
  202. self.operators[i].open(
  203. forward_collectors,
  204. self.__create_runtime_context(runtime_context, i))
  205. self.operators[-1].open(
  206. collectors,
  207. self.__create_runtime_context(runtime_context, num_operators - 1))
  208. def operator_type(self) -> OperatorType:
  209. return self.operators[0].operator_type()
  210. def __create_runtime_context(self, runtime_context, index):
  211. def get_config():
  212. return self.configs[index]
  213. runtime_context.get_config = get_config
  214. return runtime_context
  215. @staticmethod
  216. def new_chained_operator(operators, configs):
  217. operator_type = operators[0].operator_type()
  218. logger.info(
  219. "Building ChainedOperator from operators {} and configs {}."
  220. .format(operators, configs))
  221. if operator_type == OperatorType.SOURCE:
  222. return ChainedSourceOperator(operators, configs)
  223. elif operator_type == OperatorType.ONE_INPUT:
  224. return ChainedOneInputOperator(operators, configs)
  225. elif operator_type == OperatorType.TWO_INPUT:
  226. return ChainedTwoInputOperator(operators, configs)
  227. else:
  228. raise Exception("Current operator type is not supported")
  229. class ChainedSourceOperator(SourceOperator, ChainedOperator):
  230. def __init__(self, operators, configs):
  231. super().__init__(operators, configs)
  232. def fetch(self):
  233. self.operators[0].fetch()
  234. class ChainedOneInputOperator(ChainedOperator):
  235. def __init__(self, operators, configs):
  236. super().__init__(operators, configs)
  237. def process_element(self, record):
  238. self.operators[0].process_element(record)
  239. class ChainedTwoInputOperator(ChainedOperator):
  240. def __init__(self, operators, configs):
  241. super().__init__(operators, configs)
  242. def process_element(self, record1, record2):
  243. self.operators[0].process_element(record1, record2)
  244. def load_chained_operator(chained_operator_bytes: bytes):
  245. """Load chained operator from serialized operators and configs"""
  246. serialized_operators, configs = gateway_client.deserialize(
  247. chained_operator_bytes)
  248. operators = [
  249. load_operator(desc_bytes) for desc_bytes in serialized_operators
  250. ]
  251. return ChainedOperator.new_chained_operator(operators, configs)
  252. def load_operator(descriptor_operator_bytes: bytes):
  253. """
  254. Deserialize `descriptor_operator_bytes` to get operator info, then
  255. create streaming operator.
  256. Note that this function must be kept in sync with
  257. `io.ray.streaming.runtime.python.GraphPbBuilder.serializeOperator`
  258. Args:
  259. descriptor_operator_bytes: serialized operator info
  260. Returns:
  261. a streaming operator
  262. """
  263. assert len(descriptor_operator_bytes) > 0
  264. function_desc_bytes, module_name, class_name \
  265. = gateway_client.deserialize(descriptor_operator_bytes)
  266. if function_desc_bytes:
  267. return create_operator_with_func(
  268. function.load_function(function_desc_bytes))
  269. else:
  270. assert module_name
  271. assert class_name
  272. mod = importlib.import_module(module_name)
  273. cls = getattr(mod, class_name)
  274. assert issubclass(cls, Operator)
  275. print("cls", cls)
  276. return cls()
  277. _function_to_operator = {
  278. function.SourceFunction: SourceOperatorImpl,
  279. function.MapFunction: MapOperator,
  280. function.FlatMapFunction: FlatMapOperator,
  281. function.FilterFunction: FilterOperator,
  282. function.KeyFunction: KeyByOperator,
  283. function.ReduceFunction: ReduceOperator,
  284. function.SinkFunction: SinkOperator,
  285. }
  286. def create_operator_with_func(func: function.Function):
  287. """Create an operator according to a :class:`function.Function`
  288. Args:
  289. func: a subclass of function.Function
  290. Returns:
  291. an operator
  292. """
  293. operator_class = None
  294. super_classes = func.__class__.mro()
  295. for super_class in super_classes:
  296. operator_class = _function_to_operator.get(super_class, None)
  297. if operator_class is not None:
  298. break
  299. assert operator_class is not None
  300. return operator_class(func)