test_function.py 690 B

12345678910111213141516171819202122
  1. from ray.streaming import function
  2. from ray.streaming.runtime import gateway_client
  3. def test_get_simple_function_class():
  4. simple_map_func_class = function._get_simple_function_class(
  5. function.MapFunction)
  6. assert simple_map_func_class is function.SimpleMapFunction
  7. class MapFunc(function.MapFunction):
  8. def map(self, value):
  9. return str(value)
  10. def test_load_function():
  11. # function_bytes, module_name, function_name/class_name,
  12. # function_interface
  13. descriptor_func_bytes = gateway_client.serialize(
  14. [None, __name__, MapFunc.__name__, "MapFunction"])
  15. func = function.load_function(descriptor_func_bytes)
  16. assert type(func) is MapFunc