optional_utils.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. """
  2. Optional utils module contains utility methods
  3. that require optional dependencies.
  4. """
  5. import asyncio
  6. import collections
  7. import functools
  8. import inspect
  9. import json
  10. import logging
  11. import os
  12. import time
  13. import traceback
  14. from collections import namedtuple
  15. from typing import Any
  16. import ray.dashboard.consts as dashboard_consts
  17. from ray.ray_constants import env_bool
  18. try:
  19. create_task = asyncio.create_task
  20. except AttributeError:
  21. create_task = asyncio.ensure_future
  22. # All third-party dependencies that are not included in the minimal Ray
  23. # installation must be included in this file. This allows us to determine if
  24. # the agent has the necessary dependencies to be started.
  25. from ray.dashboard.optional_deps import aiohttp, hdrs, PathLike, RouteDef
  26. from ray.dashboard.utils import to_google_style, CustomEncoder
  27. logger = logging.getLogger(__name__)
  28. class ClassMethodRouteTable:
  29. """A helper class to bind http route to class method."""
  30. _bind_map = collections.defaultdict(dict)
  31. _routes = aiohttp.web.RouteTableDef()
  32. class _BindInfo:
  33. def __init__(self, filename, lineno, instance):
  34. self.filename = filename
  35. self.lineno = lineno
  36. self.instance = instance
  37. @classmethod
  38. def routes(cls):
  39. return cls._routes
  40. @classmethod
  41. def bound_routes(cls):
  42. bound_items = []
  43. for r in cls._routes._items:
  44. if isinstance(r, RouteDef):
  45. route_method = getattr(r.handler, "__route_method__")
  46. route_path = getattr(r.handler, "__route_path__")
  47. instance = cls._bind_map[route_method][route_path].instance
  48. if instance is not None:
  49. bound_items.append(r)
  50. else:
  51. bound_items.append(r)
  52. routes = aiohttp.web.RouteTableDef()
  53. routes._items = bound_items
  54. return routes
  55. @classmethod
  56. def _register_route(cls, method, path, **kwargs):
  57. def _wrapper(handler):
  58. if path in cls._bind_map[method]:
  59. bind_info = cls._bind_map[method][path]
  60. raise Exception(
  61. f"Duplicated route path: {path}, "
  62. f"previous one registered at "
  63. f"{bind_info.filename}:{bind_info.lineno}"
  64. )
  65. bind_info = cls._BindInfo(
  66. handler.__code__.co_filename, handler.__code__.co_firstlineno, None
  67. )
  68. @functools.wraps(handler)
  69. async def _handler_route(*args) -> aiohttp.web.Response:
  70. try:
  71. # Make the route handler as a bound method.
  72. # The args may be:
  73. # * (Request, )
  74. # * (self, Request)
  75. req = args[-1]
  76. return await handler(bind_info.instance, req)
  77. except Exception:
  78. logger.exception("Handle %s %s failed.", method, path)
  79. return rest_response(success=False, message=traceback.format_exc())
  80. cls._bind_map[method][path] = bind_info
  81. _handler_route.__route_method__ = method
  82. _handler_route.__route_path__ = path
  83. return cls._routes.route(method, path, **kwargs)(_handler_route)
  84. return _wrapper
  85. @classmethod
  86. def head(cls, path, **kwargs):
  87. return cls._register_route(hdrs.METH_HEAD, path, **kwargs)
  88. @classmethod
  89. def get(cls, path, **kwargs):
  90. return cls._register_route(hdrs.METH_GET, path, **kwargs)
  91. @classmethod
  92. def post(cls, path, **kwargs):
  93. return cls._register_route(hdrs.METH_POST, path, **kwargs)
  94. @classmethod
  95. def put(cls, path, **kwargs):
  96. return cls._register_route(hdrs.METH_PUT, path, **kwargs)
  97. @classmethod
  98. def patch(cls, path, **kwargs):
  99. return cls._register_route(hdrs.METH_PATCH, path, **kwargs)
  100. @classmethod
  101. def delete(cls, path, **kwargs):
  102. return cls._register_route(hdrs.METH_DELETE, path, **kwargs)
  103. @classmethod
  104. def view(cls, path, **kwargs):
  105. return cls._register_route(hdrs.METH_ANY, path, **kwargs)
  106. @classmethod
  107. def static(cls, prefix: str, path: PathLike, **kwargs: Any) -> None:
  108. cls._routes.static(prefix, path, **kwargs)
  109. @classmethod
  110. def bind(cls, instance):
  111. def predicate(o):
  112. if inspect.ismethod(o):
  113. return hasattr(o, "__route_method__") and hasattr(o, "__route_path__")
  114. return False
  115. handler_routes = inspect.getmembers(instance, predicate)
  116. for _, h in handler_routes:
  117. cls._bind_map[h.__func__.__route_method__][
  118. h.__func__.__route_path__
  119. ].instance = instance
  120. def rest_response(
  121. success, message, convert_google_style=True, **kwargs
  122. ) -> aiohttp.web.Response:
  123. # In the dev context we allow a dev server running on a
  124. # different port to consume the API, meaning we need to allow
  125. # cross-origin access
  126. if os.environ.get("RAY_DASHBOARD_DEV") == "1":
  127. headers = {"Access-Control-Allow-Origin": "*"}
  128. else:
  129. headers = {}
  130. return aiohttp.web.json_response(
  131. {
  132. "result": success,
  133. "msg": message,
  134. "data": to_google_style(kwargs) if convert_google_style else kwargs,
  135. },
  136. dumps=functools.partial(json.dumps, cls=CustomEncoder),
  137. headers=headers,
  138. )
  139. # The cache value type used by aiohttp_cache.
  140. _AiohttpCacheValue = namedtuple("AiohttpCacheValue", ["data", "expiration", "task"])
  141. # The methods with no request body used by aiohttp_cache.
  142. _AIOHTTP_CACHE_NOBODY_METHODS = {hdrs.METH_GET, hdrs.METH_DELETE}
  143. def aiohttp_cache(
  144. ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS,
  145. maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE,
  146. enable=not env_bool(dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False),
  147. ):
  148. assert maxsize > 0
  149. cache = collections.OrderedDict()
  150. def _wrapper(handler):
  151. if enable:
  152. @functools.wraps(handler)
  153. async def _cache_handler(*args) -> aiohttp.web.Response:
  154. # Make the route handler as a bound method.
  155. # The args may be:
  156. # * (Request, )
  157. # * (self, Request)
  158. req = args[-1]
  159. # Make key.
  160. if req.method in _AIOHTTP_CACHE_NOBODY_METHODS:
  161. key = req.path_qs
  162. else:
  163. key = (req.path_qs, await req.read())
  164. # Query cache.
  165. value = cache.get(key)
  166. if value is not None:
  167. cache.move_to_end(key)
  168. if not value.task.done() or value.expiration >= time.time():
  169. # Update task not done or the data is not expired.
  170. return aiohttp.web.Response(**value.data)
  171. def _update_cache(task):
  172. try:
  173. response = task.result()
  174. except Exception:
  175. response = rest_response(
  176. success=False, message=traceback.format_exc()
  177. )
  178. data = {
  179. "status": response.status,
  180. "headers": dict(response.headers),
  181. "body": response.body,
  182. }
  183. cache[key] = _AiohttpCacheValue(
  184. data, time.time() + ttl_seconds, task
  185. )
  186. cache.move_to_end(key)
  187. if len(cache) > maxsize:
  188. cache.popitem(last=False)
  189. return response
  190. task = create_task(handler(*args))
  191. task.add_done_callback(_update_cache)
  192. if value is None:
  193. return await task
  194. else:
  195. return aiohttp.web.Response(**value.data)
  196. suffix = f"[cache ttl={ttl_seconds}, max_size={maxsize}]"
  197. _cache_handler.__name__ += suffix
  198. _cache_handler.__qualname__ += suffix
  199. return _cache_handler
  200. else:
  201. return handler
  202. if inspect.iscoroutinefunction(ttl_seconds):
  203. target_func = ttl_seconds
  204. ttl_seconds = dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS
  205. return _wrapper(target_func)
  206. else:
  207. return _wrapper