test_event.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import os
  2. import sys
  3. import time
  4. import json
  5. import copy
  6. import logging
  7. import requests
  8. import asyncio
  9. import random
  10. import tempfile
  11. import pytest
  12. import numpy as np
  13. import ray
  14. from ray._private.utils import binary_to_hex
  15. from ray.dashboard.tests.conftest import * # noqa
  16. from ray.dashboard.modules.event import event_consts
  17. from ray.core.generated import event_pb2
  18. from ray._private.test_utils import (
  19. format_web_url,
  20. wait_until_server_available,
  21. wait_for_condition,
  22. )
  23. from ray.dashboard.modules.event.event_utils import (
  24. monitor_events,
  25. )
  26. logger = logging.getLogger(__name__)
  27. def _get_event(msg="empty message", job_id=None, source_type=None):
  28. return {
  29. "event_id": binary_to_hex(np.random.bytes(18)),
  30. "source_type": random.choice(event_pb2.Event.SourceType.keys())
  31. if source_type is None
  32. else source_type,
  33. "host_name": "po-dev.inc.alipay.net",
  34. "pid": random.randint(1, 65536),
  35. "label": "",
  36. "message": msg,
  37. "time_stamp": time.time(),
  38. "severity": "INFO",
  39. "custom_fields": {
  40. "job_id": ray.JobID.from_int(random.randint(1, 100)).hex()
  41. if job_id is None
  42. else job_id,
  43. "node_id": "",
  44. "task_id": "",
  45. },
  46. }
  47. def _test_logger(name, log_file, max_bytes, backup_count):
  48. handler = logging.handlers.RotatingFileHandler(
  49. log_file, maxBytes=max_bytes, backupCount=backup_count
  50. )
  51. formatter = logging.Formatter("%(message)s")
  52. handler.setFormatter(formatter)
  53. logger = logging.getLogger(name)
  54. logger.propagate = False
  55. logger.setLevel(logging.INFO)
  56. logger.addHandler(handler)
  57. return logger
  58. def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
  59. assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
  60. webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
  61. session_dir = ray_start_with_dashboard["session_dir"]
  62. event_dir = os.path.join(session_dir, "logs", "events")
  63. job_id = ray.JobID.from_int(100).hex()
  64. source_type_gcs = event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
  65. source_type_raylet = event_pb2.Event.SourceType.Name(event_pb2.Event.RAYLET)
  66. test_count = 20
  67. for source_type in [source_type_gcs, source_type_raylet]:
  68. test_log_file = os.path.join(event_dir, f"event_{source_type}.log")
  69. test_logger = _test_logger(
  70. __name__ + str(random.random()),
  71. test_log_file,
  72. max_bytes=2000,
  73. backup_count=1000,
  74. )
  75. for i in range(test_count):
  76. sample_event = _get_event(str(i), job_id=job_id, source_type=source_type)
  77. test_logger.info("%s", json.dumps(sample_event))
  78. def _check_events():
  79. try:
  80. resp = requests.get(f"{webui_url}/events")
  81. resp.raise_for_status()
  82. result = resp.json()
  83. all_events = result["data"]["events"]
  84. job_events = all_events[job_id]
  85. assert len(job_events) >= test_count * 2
  86. source_messages = {}
  87. for e in job_events:
  88. source_type = e["sourceType"]
  89. message = e["message"]
  90. source_messages.setdefault(source_type, set()).add(message)
  91. assert len(source_messages[source_type_gcs]) >= test_count
  92. assert len(source_messages[source_type_raylet]) >= test_count
  93. data = {str(i) for i in range(test_count)}
  94. assert data & source_messages[source_type_gcs] == data
  95. assert data & source_messages[source_type_raylet] == data
  96. return True
  97. except Exception as ex:
  98. logger.exception(ex)
  99. return False
  100. wait_for_condition(_check_events, timeout=15)
  101. def test_event_message_limit(
  102. small_event_line_limit, disable_aiohttp_cache, ray_start_with_dashboard
  103. ):
  104. event_read_line_length_limit = small_event_line_limit
  105. assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
  106. webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
  107. session_dir = ray_start_with_dashboard["session_dir"]
  108. event_dir = os.path.join(session_dir, "logs", "events")
  109. job_id = ray.JobID.from_int(100).hex()
  110. events = []
  111. # Sample event equals with limit.
  112. sample_event = _get_event("", job_id=job_id)
  113. message_len = event_read_line_length_limit - len(json.dumps(sample_event))
  114. for i in range(10):
  115. sample_event = copy.deepcopy(sample_event)
  116. sample_event["event_id"] = binary_to_hex(np.random.bytes(18))
  117. sample_event["message"] = str(i) * message_len
  118. assert len(json.dumps(sample_event)) == event_read_line_length_limit
  119. events.append(sample_event)
  120. # Sample event longer than limit.
  121. sample_event = copy.deepcopy(sample_event)
  122. sample_event["event_id"] = binary_to_hex(np.random.bytes(18))
  123. sample_event["message"] = "2" * (message_len + 1)
  124. assert len(json.dumps(sample_event)) > event_read_line_length_limit
  125. events.append(sample_event)
  126. for i in range(event_consts.EVENT_READ_LINE_COUNT_LIMIT):
  127. events.append(_get_event(str(i), job_id=job_id))
  128. with open(os.path.join(event_dir, "tmp.log"), "w") as f:
  129. f.writelines([(json.dumps(e) + "\n") for e in events])
  130. try:
  131. os.remove(os.path.join(event_dir, "event_GCS.log"))
  132. except Exception:
  133. pass
  134. os.rename(
  135. os.path.join(event_dir, "tmp.log"), os.path.join(event_dir, "event_GCS.log")
  136. )
  137. def _check_events():
  138. try:
  139. resp = requests.get(f"{webui_url}/events")
  140. resp.raise_for_status()
  141. result = resp.json()
  142. all_events = result["data"]["events"]
  143. assert (
  144. len(all_events[job_id]) >= event_consts.EVENT_READ_LINE_COUNT_LIMIT + 10
  145. )
  146. messages = [e["message"] for e in all_events[job_id]]
  147. for i in range(10):
  148. assert str(i) * message_len in messages
  149. assert "2" * (message_len + 1) not in messages
  150. assert str(event_consts.EVENT_READ_LINE_COUNT_LIMIT - 1) in messages
  151. return True
  152. except Exception as ex:
  153. logger.exception(ex)
  154. return False
  155. wait_for_condition(_check_events, timeout=15)
  156. @pytest.mark.asyncio
  157. async def test_monitor_events():
  158. with tempfile.TemporaryDirectory() as temp_dir:
  159. common = event_pb2.Event.SourceType.Name(event_pb2.Event.COMMON)
  160. common_log = os.path.join(temp_dir, f"event_{common}.log")
  161. test_logger = _test_logger(
  162. __name__ + str(random.random()), common_log, max_bytes=10, backup_count=10
  163. )
  164. test_events1 = []
  165. monitor_task = monitor_events(
  166. temp_dir, lambda x: test_events1.extend(x), scan_interval_seconds=0.01
  167. )
  168. assert not monitor_task.done()
  169. count = 10
  170. async def _writer(*args, read_events, spin=True):
  171. for x in range(*args):
  172. test_logger.info("%s", x)
  173. if spin:
  174. while str(x) not in read_events:
  175. await asyncio.sleep(0.01)
  176. async def _check_events(expect_events, read_events, timeout=10):
  177. start_time = time.time()
  178. while True:
  179. sorted_events = sorted(int(i) for i in read_events)
  180. sorted_events = [str(i) for i in sorted_events]
  181. if time.time() - start_time > timeout:
  182. raise TimeoutError(
  183. f"Timeout, read events: {sorted_events}, "
  184. f"expect events: {expect_events}"
  185. )
  186. if len(sorted_events) == len(expect_events):
  187. if sorted_events == expect_events:
  188. break
  189. await asyncio.sleep(1)
  190. await asyncio.gather(
  191. _writer(count, read_events=test_events1),
  192. _check_events([str(i) for i in range(count)], read_events=test_events1),
  193. )
  194. monitor_task.cancel()
  195. test_events2 = []
  196. monitor_task = monitor_events(
  197. temp_dir, lambda x: test_events2.extend(x), scan_interval_seconds=0.1
  198. )
  199. await _check_events([str(i) for i in range(count)], read_events=test_events2)
  200. await _writer(count, count * 2, read_events=test_events2)
  201. await _check_events(
  202. [str(i) for i in range(count * 2)], read_events=test_events2
  203. )
  204. log_file_count = len(os.listdir(temp_dir))
  205. test_logger = _test_logger(
  206. __name__ + str(random.random()), common_log, max_bytes=1000, backup_count=10
  207. )
  208. assert len(os.listdir(temp_dir)) == log_file_count
  209. await _writer(count * 2, count * 3, spin=False, read_events=test_events2)
  210. await _check_events(
  211. [str(i) for i in range(count * 3)], read_events=test_events2
  212. )
  213. await _writer(count * 3, count * 4, spin=False, read_events=test_events2)
  214. await _check_events(
  215. [str(i) for i in range(count * 4)], read_events=test_events2
  216. )
  217. # Test cancel monitor task.
  218. monitor_task.cancel()
  219. with pytest.raises(asyncio.CancelledError):
  220. await monitor_task
  221. assert monitor_task.done()
  222. assert len(os.listdir(temp_dir)) > 1, "Event log should have rollovers."
  223. if __name__ == "__main__":
  224. sys.exit(pytest.main(["-v", __file__]))