reporter_head.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import json
  2. import logging
  3. import yaml
  4. import os
  5. import aiohttp.web
  6. from aioredis.pubsub import Receiver
  7. import ray
  8. import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
  9. import ray.dashboard.utils as dashboard_utils
  10. import ray.dashboard.optional_utils as dashboard_optional_utils
  11. import ray.experimental.internal_kv as internal_kv
  12. import ray._private.services
  13. import ray._private.utils
  14. from ray.ray_constants import (
  15. DEBUG_AUTOSCALING_STATUS,
  16. DEBUG_AUTOSCALING_STATUS_LEGACY,
  17. DEBUG_AUTOSCALING_ERROR,
  18. )
  19. from ray.core.generated import reporter_pb2
  20. from ray.core.generated import reporter_pb2_grpc
  21. from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioResourceUsageSubscriber
  22. from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
  23. from ray.dashboard.datacenter import DataSource
  24. logger = logging.getLogger(__name__)
  25. routes = dashboard_optional_utils.ClassMethodRouteTable
  26. class ReportHead(dashboard_utils.DashboardHeadModule):
  27. def __init__(self, dashboard_head):
  28. super().__init__(dashboard_head)
  29. self._stubs = {}
  30. self._ray_config = None
  31. DataSource.agents.signal.append(self._update_stubs)
  32. # TODO(fyrestone): Avoid using ray.state in dashboard, it's not
  33. # asynchronous and will lead to low performance. ray disconnect()
  34. # will be hang when the ray.state is connected and the GCS is exit.
  35. # Please refer to: https://github.com/ray-project/ray/issues/16328
  36. assert dashboard_head.gcs_address or dashboard_head.redis_address
  37. gcs_address = dashboard_head.gcs_address
  38. redis_address = dashboard_head.redis_address
  39. redis_password = dashboard_head.redis_password
  40. temp_dir = dashboard_head.temp_dir
  41. # Flatten the redis address
  42. if isinstance(dashboard_head.redis_address, tuple):
  43. redis_address = f"{redis_address[0]}:{redis_address[1]}"
  44. self.service_discovery = PrometheusServiceDiscoveryWriter(
  45. redis_address, redis_password, gcs_address, temp_dir
  46. )
  47. async def _update_stubs(self, change):
  48. if change.old:
  49. node_id, port = change.old
  50. ip = DataSource.node_id_to_ip[node_id]
  51. self._stubs.pop(ip)
  52. if change.new:
  53. node_id, ports = change.new
  54. ip = DataSource.node_id_to_ip[node_id]
  55. options = (("grpc.enable_http_proxy", 0),)
  56. channel = ray._private.utils.init_grpc_channel(
  57. f"{ip}:{ports[1]}", options=options, asynchronous=True
  58. )
  59. stub = reporter_pb2_grpc.ReporterServiceStub(channel)
  60. self._stubs[ip] = stub
  61. @routes.get("/api/launch_profiling")
  62. async def launch_profiling(self, req) -> aiohttp.web.Response:
  63. ip = req.query["ip"]
  64. pid = int(req.query["pid"])
  65. duration = int(req.query["duration"])
  66. reporter_stub = self._stubs[ip]
  67. reply = await reporter_stub.GetProfilingStats(
  68. reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration)
  69. )
  70. profiling_info = (
  71. json.loads(reply.profiling_stats)
  72. if reply.profiling_stats
  73. else reply.std_out
  74. )
  75. return dashboard_optional_utils.rest_response(
  76. success=True, message="Profiling success.", profiling_info=profiling_info
  77. )
  78. @routes.get("/api/ray_config")
  79. async def get_ray_config(self, req) -> aiohttp.web.Response:
  80. if self._ray_config is None:
  81. try:
  82. config_path = os.path.expanduser("~/ray_bootstrap_config.yaml")
  83. with open(config_path) as f:
  84. cfg = yaml.safe_load(f)
  85. except yaml.YAMLError:
  86. return dashboard_optional_utils.rest_response(
  87. success=False,
  88. message=f"No config found at {config_path}.",
  89. )
  90. except FileNotFoundError:
  91. return dashboard_optional_utils.rest_response(
  92. success=False, message="Invalid config, could not load YAML."
  93. )
  94. payload = {
  95. "min_workers": cfg.get("min_workers", "unspecified"),
  96. "max_workers": cfg.get("max_workers", "unspecified"),
  97. }
  98. try:
  99. payload["head_type"] = cfg["head_node"]["InstanceType"]
  100. except KeyError:
  101. payload["head_type"] = "unknown"
  102. try:
  103. payload["worker_type"] = cfg["worker_nodes"]["InstanceType"]
  104. except KeyError:
  105. payload["worker_type"] = "unknown"
  106. self._ray_config = payload
  107. return dashboard_optional_utils.rest_response(
  108. success=True,
  109. message="Fetched ray config.",
  110. **self._ray_config,
  111. )
  112. @routes.get("/api/cluster_status")
  113. async def get_cluster_status(self, req):
  114. """Returns status information about the cluster.
  115. Currently contains two fields:
  116. autoscaling_status (str): a status message from the autoscaler.
  117. autoscaling_error (str): an error message from the autoscaler if
  118. anything has gone wrong during autoscaling.
  119. These fields are both read from the GCS, it's expected that the
  120. autoscaler writes them there.
  121. """
  122. assert ray.experimental.internal_kv._internal_kv_initialized()
  123. legacy_status = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS_LEGACY)
  124. formatted_status_string = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS)
  125. formatted_status = (
  126. json.loads(formatted_status_string.decode())
  127. if formatted_status_string
  128. else {}
  129. )
  130. error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR)
  131. return dashboard_optional_utils.rest_response(
  132. success=True,
  133. message="Got cluster status.",
  134. autoscaling_status=legacy_status.decode() if legacy_status else None,
  135. autoscaling_error=error.decode() if error else None,
  136. cluster_status=formatted_status if formatted_status else None,
  137. )
  138. async def run(self, server):
  139. # Need daemon True to avoid dashboard hangs at exit.
  140. self.service_discovery.daemon = True
  141. self.service_discovery.start()
  142. if gcs_pubsub_enabled():
  143. gcs_addr = await self._dashboard_head.get_gcs_address()
  144. subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
  145. await subscriber.subscribe()
  146. while True:
  147. try:
  148. # The key is b'RAY_REPORTER:{node id hex}',
  149. # e.g. b'RAY_REPORTER:2b4fbd...'
  150. key, data = await subscriber.poll()
  151. if key is None:
  152. continue
  153. data = json.loads(data)
  154. node_id = key.split(":")[-1]
  155. DataSource.node_physical_stats[node_id] = data
  156. except Exception:
  157. logger.exception(
  158. "Error receiving node physical stats " "from reporter agent."
  159. )
  160. else:
  161. receiver = Receiver()
  162. aioredis_client = self._dashboard_head.aioredis_client
  163. reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
  164. await aioredis_client.psubscribe(receiver.pattern(reporter_key))
  165. logger.info(f"Subscribed to {reporter_key}")
  166. async for sender, msg in receiver.iter():
  167. try:
  168. key, data = msg
  169. data = json.loads(ray._private.utils.decode(data))
  170. key = key.decode("utf-8")
  171. node_id = key.split(":")[-1]
  172. DataSource.node_physical_stats[node_id] = data
  173. except Exception:
  174. logger.exception(
  175. "Error receiving node physical stats " "from reporter agent."
  176. )
  177. @staticmethod
  178. def is_minimal_module():
  179. return False