event_head.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import os
  2. import asyncio
  3. import logging
  4. from typing import Union
  5. from collections import OrderedDict, defaultdict
  6. import aiohttp.web
  7. import ray.dashboard.utils as dashboard_utils
  8. import ray.dashboard.optional_utils as dashboard_optional_utils
  9. from ray.dashboard.modules.event import event_consts
  10. from ray.dashboard.modules.event.event_utils import (
  11. parse_event_strings,
  12. monitor_events,
  13. )
  14. from ray.core.generated import event_pb2
  15. from ray.core.generated import event_pb2_grpc
  16. from ray.dashboard.datacenter import DataSource
  17. logger = logging.getLogger(__name__)
  18. routes = dashboard_optional_utils.ClassMethodRouteTable
  19. JobEvents = OrderedDict
  20. dashboard_utils._json_compatible_types.add(JobEvents)
  21. class EventHead(
  22. dashboard_utils.DashboardHeadModule, event_pb2_grpc.ReportEventServiceServicer
  23. ):
  24. def __init__(self, dashboard_head):
  25. super().__init__(dashboard_head)
  26. self._event_dir = os.path.join(self._dashboard_head.log_dir, "events")
  27. os.makedirs(self._event_dir, exist_ok=True)
  28. self._monitor: Union[asyncio.Task, None] = None
  29. @staticmethod
  30. def _update_events(event_list):
  31. # {job_id: {event_id: event}}
  32. all_job_events = defaultdict(JobEvents)
  33. for event in event_list:
  34. event_id = event["event_id"]
  35. custom_fields = event.get("custom_fields")
  36. system_event = False
  37. if custom_fields:
  38. job_id = custom_fields.get("job_id", "global") or "global"
  39. else:
  40. job_id = "global"
  41. if system_event is False:
  42. all_job_events[job_id][event_id] = event
  43. # TODO(fyrestone): Limit the event count per job.
  44. for job_id, new_job_events in all_job_events.items():
  45. job_events = DataSource.events.get(job_id, JobEvents())
  46. job_events.update(new_job_events)
  47. DataSource.events[job_id] = job_events
  48. async def ReportEvents(self, request, context):
  49. received_events = []
  50. if request.event_strings:
  51. received_events.extend(parse_event_strings(request.event_strings))
  52. logger.info("Received %d events", len(received_events))
  53. self._update_events(received_events)
  54. return event_pb2.ReportEventsReply(send_success=True)
  55. @routes.get("/events")
  56. @dashboard_optional_utils.aiohttp_cache(2)
  57. async def get_event(self, req) -> aiohttp.web.Response:
  58. job_id = req.query.get("job_id")
  59. if job_id is None:
  60. all_events = {
  61. job_id: list(job_events.values())
  62. for job_id, job_events in DataSource.events.items()
  63. }
  64. return dashboard_optional_utils.rest_response(
  65. success=True, message="All events fetched.", events=all_events
  66. )
  67. job_events = DataSource.events.get(job_id, {})
  68. return dashboard_optional_utils.rest_response(
  69. success=True,
  70. message="Job events fetched.",
  71. job_id=job_id,
  72. events=list(job_events.values()),
  73. )
  74. async def run(self, server):
  75. event_pb2_grpc.add_ReportEventServiceServicer_to_server(self, server)
  76. self._monitor = monitor_events(
  77. self._event_dir,
  78. lambda data: self._update_events(parse_event_strings(data)),
  79. source_types=event_consts.EVENT_HEAD_MONITOR_SOURCE_TYPES,
  80. )
  81. @staticmethod
  82. def is_minimal_module():
  83. return False