_prometheus_metrics.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import asyncio
  2. import aiohttp
  3. import os
  4. import time
  5. import traceback
  6. from urllib.parse import quote
  7. from typing import Optional
  8. import logging
  9. import json
  10. import argparse
  11. logger = logging.getLogger(__name__)
  12. DEFAULT_PROMETHEUS_HOST = "http://localhost:9090"
  13. PROMETHEUS_HOST_ENV_VAR = "RAY_PROMETHEUS_HOST"
  14. RETRIES = 3
  15. class PrometheusQueryError(Exception):
  16. def __init__(self, status, message):
  17. self.message = (
  18. "Error fetching data from prometheus. "
  19. f"status: {status}, message: {message}"
  20. )
  21. super().__init__(self.message)
  22. class PrometheusClient:
  23. def __init__(self) -> None:
  24. self.http_session = aiohttp.ClientSession()
  25. self.prometheus_host = os.environ.get(
  26. PROMETHEUS_HOST_ENV_VAR, DEFAULT_PROMETHEUS_HOST
  27. )
  28. async def query_prometheus(self, query_type, **kwargs):
  29. url = f"{self.prometheus_host}/api/v1/{query_type}?" + "&".join(
  30. [f"{k}={quote(str(v), safe='')}" for k, v in kwargs.items()]
  31. )
  32. logger.debug(f"Running Prometheus query {url}")
  33. async with self.http_session.get(url) as resp:
  34. for _ in range(RETRIES):
  35. if resp.status == 200:
  36. prom_data = await resp.json()
  37. return prom_data["data"]["result"]
  38. time.sleep(1)
  39. return None
  40. async def close(self):
  41. await self.http_session.close()
  42. # Metrics here mirror what we have in Grafana.
  43. async def _get_prometheus_metrics(start_time: float, end_time: float) -> dict:
  44. client = PrometheusClient()
  45. kwargs = {
  46. "query_type": "query_range",
  47. "start": int(start_time),
  48. "end": int(end_time),
  49. "step": 15,
  50. }
  51. metrics = {
  52. "cpu_utilization": client.query_prometheus(
  53. query="ray_node_cpu_utilization * ray_node_cpu_count / 100", **kwargs
  54. ),
  55. "cpu_count": client.query_prometheus(query="ray_node_cpu_count", **kwargs),
  56. "gpu_utilization": client.query_prometheus(
  57. query="ray_node_gpus_utilization / 100", **kwargs
  58. ),
  59. "gpu_count": client.query_prometheus(query="ray_node_gpus_available", **kwargs),
  60. "disk_usage": client.query_prometheus(query="ray_node_disk_usage", **kwargs),
  61. "disk_space": client.query_prometheus(
  62. query="sum(ray_node_disk_free) + sum(ray_node_disk_usage)", **kwargs
  63. ),
  64. "memory_usage": client.query_prometheus(query="ray_node_mem_used", **kwargs),
  65. "total_memory": client.query_prometheus(query="ray_node_mem_total", **kwargs),
  66. "gpu_memory_usage": client.query_prometheus(
  67. query="ray_node_gram_used * 1024 * 1024", **kwargs
  68. ),
  69. "gpu_total_memory": client.query_prometheus(
  70. query=(
  71. "(sum(ray_node_gram_available) + sum(ray_node_gram_used)) * 1024 * 1024"
  72. ),
  73. **kwargs,
  74. ),
  75. "network_receive_speed": client.query_prometheus(
  76. query="ray_node_network_receive_speed", **kwargs
  77. ),
  78. "network_send_speed": client.query_prometheus(
  79. query="ray_node_network_send_speed", **kwargs
  80. ),
  81. "cluster_active_nodes": client.query_prometheus(
  82. query="ray_cluster_active_nodes", **kwargs
  83. ),
  84. "cluster_failed_nodes": client.query_prometheus(
  85. query="ray_cluster_failed_nodes", **kwargs
  86. ),
  87. "cluster_pending_nodes": client.query_prometheus(
  88. query="ray_cluster_pending_nodes", **kwargs
  89. ),
  90. }
  91. metrics = {k: await v for k, v in metrics.items()}
  92. await client.close()
  93. return metrics
  94. def get_prometheus_metrics(start_time: float, end_time: float) -> dict:
  95. try:
  96. return asyncio.run(_get_prometheus_metrics(start_time, end_time))
  97. except Exception:
  98. logger.error(
  99. "Couldn't obtain Prometheus metrics. "
  100. f"Exception below:\n{traceback.format_exc()}"
  101. )
  102. return {}
  103. def save_prometheus_metrics(
  104. start_time: float,
  105. end_time: Optional[float] = None,
  106. path: Optional[str] = None,
  107. use_ray: bool = False,
  108. ) -> bool:
  109. path = path or os.environ.get("METRICS_OUTPUT_JSON", None)
  110. if path:
  111. if not end_time:
  112. end_time = time.time()
  113. if use_ray:
  114. import ray
  115. from ray.air.util.node import _force_on_current_node
  116. addr = os.environ.get("RAY_ADDRESS", None)
  117. ray.init(addr)
  118. @ray.remote(num_cpus=0)
  119. def get_metrics():
  120. end_time = time.time()
  121. return get_prometheus_metrics(start_time, end_time)
  122. remote_run = _force_on_current_node(get_metrics)
  123. ref = remote_run.remote()
  124. metrics = ray.get(ref, timeout=900)
  125. else:
  126. metrics = get_prometheus_metrics(start_time, end_time)
  127. with open(path, "w") as metrics_output_file:
  128. json.dump(metrics, metrics_output_file)
  129. return path
  130. return None
  131. if __name__ == "__main__":
  132. parser = argparse.ArgumentParser()
  133. parser.add_argument("start_time", type=float, help="Start time")
  134. parser.add_argument(
  135. "--path", default="", type=str, help="Where to save the metrics json"
  136. )
  137. parser.add_argument(
  138. "--use_ray",
  139. default=False,
  140. action="store_true",
  141. help="Whether to run this script in a ray.remote call (for Ray Client)",
  142. )
  143. args = parser.parse_args()
  144. save_prometheus_metrics(args.start_time, path=args.path, use_ray=args.use_ray)