head.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. import os
  2. import sys
  3. import socket
  4. import asyncio
  5. import logging
  6. import threading
  7. from concurrent.futures import Future
  8. from queue import Queue
  9. import grpc
  10. try:
  11. from grpc import aio as aiogrpc
  12. except ImportError:
  13. from grpc.experimental import aio as aiogrpc
  14. import ray.experimental.internal_kv as internal_kv
  15. import ray._private.utils
  16. from ray._private.gcs_utils import GcsClient, use_gcs_for_bootstrap
  17. import ray._private.services
  18. import ray.dashboard.consts as dashboard_consts
  19. import ray.dashboard.utils as dashboard_utils
  20. from ray import ray_constants
  21. from ray._private.gcs_pubsub import (
  22. gcs_pubsub_enabled,
  23. GcsAioErrorSubscriber,
  24. GcsAioLogSubscriber,
  25. )
  26. from ray.core.generated import gcs_service_pb2
  27. from ray.core.generated import gcs_service_pb2_grpc
  28. from ray.dashboard.datacenter import DataOrganizer
  29. from ray.dashboard.utils import async_loop_forever
  30. logger = logging.getLogger(__name__)
  31. aiogrpc.init_grpc_aio()
  32. GRPC_CHANNEL_OPTIONS = (
  33. ("grpc.enable_http_proxy", 0),
  34. ("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
  35. ("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
  36. )
  37. async def get_gcs_address_with_retry(redis_client) -> str:
  38. while True:
  39. try:
  40. gcs_address = (
  41. await redis_client.get(dashboard_consts.GCS_SERVER_ADDRESS)
  42. ).decode()
  43. if not gcs_address:
  44. raise Exception("GCS address not found.")
  45. logger.info("Connect to GCS at %s", gcs_address)
  46. return gcs_address
  47. except Exception as ex:
  48. logger.error("Connect to GCS failed: %s, retry...", ex)
  49. await asyncio.sleep(dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
  50. class GCSHealthCheckThread(threading.Thread):
  51. def __init__(self, gcs_address: str):
  52. self.grpc_gcs_channel = ray._private.utils.init_grpc_channel(
  53. gcs_address, options=GRPC_CHANNEL_OPTIONS
  54. )
  55. self.gcs_heartbeat_info_stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
  56. self.grpc_gcs_channel
  57. )
  58. self.work_queue = Queue()
  59. super().__init__(daemon=True)
  60. def run(self) -> None:
  61. while True:
  62. future = self.work_queue.get()
  63. check_result = self._check_once_synchrounously()
  64. future.set_result(check_result)
  65. def _check_once_synchrounously(self) -> bool:
  66. request = gcs_service_pb2.CheckAliveRequest()
  67. try:
  68. reply = self.gcs_heartbeat_info_stub.CheckAlive(
  69. request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT
  70. )
  71. if reply.status.code != 0:
  72. logger.exception(f"Failed to CheckAlive: {reply.status.message}")
  73. return False
  74. except grpc.RpcError: # Deadline Exceeded
  75. logger.exception("Got RpcError when checking GCS is alive")
  76. return False
  77. return True
  78. async def check_once(self) -> bool:
  79. """Ask the thread to perform a healthcheck."""
  80. assert (
  81. threading.current_thread != self
  82. ), "caller shouldn't be from the same thread as GCSHealthCheckThread."
  83. future = Future()
  84. self.work_queue.put(future)
  85. return await asyncio.wrap_future(future)
  86. class DashboardHead:
  87. def __init__(
  88. self,
  89. http_host,
  90. http_port,
  91. http_port_retries,
  92. gcs_address,
  93. redis_address,
  94. redis_password,
  95. log_dir,
  96. temp_dir,
  97. minimal,
  98. ):
  99. self.minimal = minimal
  100. self.health_check_thread: GCSHealthCheckThread = None
  101. self._gcs_rpc_error_counter = 0
  102. # Public attributes are accessible for all head modules.
  103. # Walkaround for issue: https://github.com/ray-project/ray/issues/7084
  104. self.http_host = "127.0.0.1" if http_host == "localhost" else http_host
  105. self.http_port = http_port
  106. self.http_port_retries = http_port_retries
  107. self.gcs_address = None
  108. self.redis_address = None
  109. self.redis_password = None
  110. if use_gcs_for_bootstrap():
  111. assert gcs_address is not None
  112. self.gcs_address = gcs_address
  113. else:
  114. self.redis_address = dashboard_utils.address_tuple(redis_address)
  115. self.redis_password = redis_password
  116. self.log_dir = log_dir
  117. self.temp_dir = temp_dir
  118. self.aioredis_client = None
  119. self.aiogrpc_gcs_channel = None
  120. self.gcs_error_subscriber = None
  121. self.gcs_log_subscriber = None
  122. self.ip = ray.util.get_node_ip_address()
  123. if not use_gcs_for_bootstrap():
  124. ip, port = redis_address.split(":")
  125. else:
  126. ip, port = gcs_address.split(":")
  127. self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
  128. grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
  129. self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
  130. self.server, f"{grpc_ip}:0"
  131. )
  132. logger.info("Dashboard head grpc address: %s:%s", grpc_ip, self.grpc_port)
  133. # If the dashboard is started as non-minimal version, http server should
  134. # be configured to expose APIs.
  135. self.http_server = None
  136. async def _configure_http_server(self, modules):
  137. from ray.dashboard.http_server_head import HttpServerDashboardHead
  138. http_server = HttpServerDashboardHead(
  139. self.ip, self.http_host, self.http_port, self.http_port_retries
  140. )
  141. await http_server.run(modules)
  142. return http_server
  143. @property
  144. def http_session(self):
  145. assert self.http_server, "Accessing unsupported API in a minimal ray."
  146. return self.http_server.http_session
  147. @async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS)
  148. async def _gcs_check_alive(self):
  149. check_future = self.health_check_thread.check_once()
  150. # NOTE(simon): making sure the check procedure doesn't timeout itself.
  151. # Otherwise, the dashboard will always think that gcs is alive.
  152. try:
  153. is_alive = await asyncio.wait_for(
  154. check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1
  155. )
  156. except asyncio.TimeoutError:
  157. logger.error("Failed to check gcs health, client timed out.")
  158. is_alive = False
  159. if is_alive:
  160. self._gcs_rpc_error_counter = 0
  161. else:
  162. self._gcs_rpc_error_counter += 1
  163. if (
  164. self._gcs_rpc_error_counter
  165. > dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR
  166. ):
  167. logger.error(
  168. "Dashboard exiting because it received too many GCS RPC "
  169. "errors count: %s, threshold is %s.",
  170. self._gcs_rpc_error_counter,
  171. dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR,
  172. )
  173. # TODO(fyrestone): Do not use ray.state in
  174. # PrometheusServiceDiscoveryWriter.
  175. # Currently, we use os._exit() here to avoid hanging at the ray
  176. # shutdown(). Please refer to:
  177. # https://github.com/ray-project/ray/issues/16328
  178. os._exit(-1)
  179. def _load_modules(self):
  180. """Load dashboard head modules."""
  181. modules = []
  182. head_cls_list = dashboard_utils.get_all_modules(
  183. dashboard_utils.DashboardHeadModule
  184. )
  185. for cls in head_cls_list:
  186. logger.info(
  187. "Loading %s: %s", dashboard_utils.DashboardHeadModule.__name__, cls
  188. )
  189. c = cls(self)
  190. modules.append(c)
  191. logger.info("Loaded %d modules.", len(modules))
  192. return modules
  193. async def get_gcs_address(self):
  194. # Create an aioredis client for all modules.
  195. if use_gcs_for_bootstrap():
  196. return self.gcs_address
  197. else:
  198. try:
  199. self.aioredis_client = await dashboard_utils.get_aioredis_client(
  200. self.redis_address,
  201. self.redis_password,
  202. dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
  203. dashboard_consts.RETRY_REDIS_CONNECTION_TIMES,
  204. )
  205. except (socket.gaierror, ConnectionError):
  206. logger.error(
  207. "Dashboard head exiting: " "Failed to connect to redis at %s",
  208. self.redis_address,
  209. )
  210. sys.exit(-1)
  211. return await get_gcs_address_with_retry(self.aioredis_client)
  212. async def run(self):
  213. gcs_address = await self.get_gcs_address()
  214. # Dashboard will handle connection failure automatically
  215. self.gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0)
  216. internal_kv._initialize_internal_kv(self.gcs_client)
  217. self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
  218. gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True
  219. )
  220. if gcs_pubsub_enabled():
  221. self.gcs_error_subscriber = GcsAioErrorSubscriber(address=gcs_address)
  222. self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address)
  223. await self.gcs_error_subscriber.subscribe()
  224. await self.gcs_log_subscriber.subscribe()
  225. self.health_check_thread = GCSHealthCheckThread(gcs_address)
  226. self.health_check_thread.start()
  227. # Start a grpc asyncio server.
  228. await self.server.start()
  229. async def _async_notify():
  230. """Notify signals from queue."""
  231. while True:
  232. co = await dashboard_utils.NotifyQueue.get()
  233. try:
  234. await co
  235. except Exception:
  236. logger.exception(f"Error notifying coroutine {co}")
  237. modules = self._load_modules()
  238. http_host, http_port = self.http_host, self.http_port
  239. if not self.minimal:
  240. self.http_server = await self._configure_http_server(modules)
  241. http_host, http_port = self.http_server.get_address()
  242. internal_kv._internal_kv_put(
  243. ray_constants.DASHBOARD_ADDRESS,
  244. f"{http_host}:{http_port}",
  245. namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
  246. )
  247. # TODO: Use async version if performance is an issue
  248. # Write the dashboard head port to gcs kv.
  249. internal_kv._internal_kv_put(
  250. dashboard_consts.DASHBOARD_RPC_ADDRESS,
  251. f"{self.ip}:{self.grpc_port}",
  252. namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
  253. )
  254. # Freeze signal after all modules loaded.
  255. dashboard_utils.SignalManager.freeze()
  256. concurrent_tasks = [
  257. self._gcs_check_alive(),
  258. _async_notify(),
  259. DataOrganizer.purge(),
  260. DataOrganizer.organize(),
  261. ]
  262. await asyncio.gather(*concurrent_tasks, *(m.run(self.server) for m in modules))
  263. await self.server.wait_for_termination()
  264. if self.http_server:
  265. await self.http_server.cleanup()