http_server_agent.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import asyncio
  2. import logging
  3. from distutils.version import LooseVersion
  4. import ray.dashboard.optional_utils as dashboard_optional_utils
  5. from ray.dashboard.optional_deps import aiohttp, aiohttp_cors, hdrs
  6. logger = logging.getLogger(__name__)
  7. routes = dashboard_optional_utils.ClassMethodRouteTable
  8. class HttpServerAgent:
  9. def __init__(self, ip, listen_port):
  10. self.ip = ip
  11. self.listen_port = listen_port
  12. self.http_host = None
  13. self.http_port = None
  14. self.http_session = None
  15. self.runner = None
  16. # Create a http session for all modules.
  17. # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
  18. if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
  19. self.http_session = aiohttp.ClientSession(loop=asyncio.get_event_loop())
  20. else:
  21. self.http_session = aiohttp.ClientSession()
  22. async def start(self, modules):
  23. # Bind routes for every module so that each module
  24. # can use decorator-style routes.
  25. for c in modules:
  26. dashboard_optional_utils.ClassMethodRouteTable.bind(c)
  27. app = aiohttp.web.Application()
  28. app.add_routes(routes=routes.bound_routes())
  29. # Enable CORS on all routes.
  30. cors = aiohttp_cors.setup(
  31. app,
  32. defaults={
  33. "*": aiohttp_cors.ResourceOptions(
  34. allow_credentials=True,
  35. expose_headers="*",
  36. allow_methods="*",
  37. allow_headers=("Content-Type", "X-Header"),
  38. )
  39. },
  40. )
  41. for route in list(app.router.routes()):
  42. cors.add(route)
  43. self.runner = aiohttp.web.AppRunner(app)
  44. await self.runner.setup()
  45. site = aiohttp.web.TCPSite(
  46. self.runner,
  47. "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0",
  48. self.listen_port,
  49. )
  50. await site.start()
  51. self.http_host, self.http_port, *_ = site._server.sockets[0].getsockname()
  52. logger.info(
  53. "Dashboard agent http address: %s:%s", self.http_host, self.http_port
  54. )
  55. # Dump registered http routes.
  56. dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
  57. for r in dump_routes:
  58. logger.info(r)
  59. logger.info("Registered %s routes.", len(dump_routes))
  60. async def cleanup(self):
  61. # Wait for finish signal.
  62. await self.runner.cleanup()