locust_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. from dataclasses import asdict, dataclass
  2. from itertools import chain
  3. import json
  4. import logging
  5. import time
  6. from tqdm import tqdm
  7. from typing import Any, Dict, List
  8. import ray
  9. from ray.serve._private.utils import generate_request_id
  10. logger = logging.getLogger(__file__)
  11. logging.basicConfig(level=logging.INFO)
  12. @dataclass
  13. class LocustStage:
  14. duration_s: int
  15. users: int
  16. spawn_rate: float
  17. @dataclass
  18. class LocustLoadTestConfig:
  19. num_workers: int
  20. host_url: str
  21. auth_token: str
  22. data: Any
  23. stages: List[LocustStage]
  24. wait_for_workers_timeout_s: float = 600
  25. @dataclass
  26. class PerformanceStats:
  27. p50_latency: float
  28. p90_latency: float
  29. p99_latency: float
  30. rps: float
  31. @dataclass
  32. class LocustTestResults:
  33. history: List[Dict]
  34. total_requests: int
  35. num_failures: int
  36. avg_latency: float
  37. p50_latency: float
  38. p90_latency: float
  39. p99_latency: float
  40. avg_rps: float
  41. stats_in_stages: List[PerformanceStats]
  42. @dataclass
  43. class FailedRequest:
  44. request_id: str
  45. status_code: int
  46. exception: str
  47. response_time_s: float
  48. start_time_s: float
  49. class LocustClient:
  50. def __init__(
  51. self,
  52. host_url: str,
  53. token: str,
  54. data: Dict[str, Any] = None,
  55. ):
  56. from locust import task, constant, events, FastHttpUser
  57. from locust.contrib.fasthttp import FastResponse
  58. self.errors = []
  59. class EndpointUser(FastHttpUser):
  60. wait_time = constant(0)
  61. failed_requests = []
  62. host = host_url
  63. @task
  64. def test(self):
  65. request_id = generate_request_id()
  66. headers = (
  67. {"Authorization": f"Bearer {token}", "X-Request-ID": request_id}
  68. if token
  69. else None
  70. )
  71. with self.client.get(
  72. "", headers=headers, json=data, catch_response=True
  73. ) as r:
  74. r.request_meta["context"]["request_id"] = request_id
  75. @events.request.add_listener
  76. def on_request(
  77. response: FastResponse,
  78. exception,
  79. context,
  80. start_time: float,
  81. response_time: float,
  82. **kwargs,
  83. ):
  84. if exception:
  85. request_id = context["request_id"]
  86. response.encoding = "utf-8"
  87. err = FailedRequest(
  88. request_id=request_id,
  89. status_code=response.status_code,
  90. exception=response.text,
  91. response_time_s=response_time,
  92. start_time_s=start_time,
  93. )
  94. self.errors.append(err)
  95. print(
  96. f"Request '{request_id}' failed with exception: {response.text}"
  97. )
  98. self.user_class = EndpointUser
  99. @ray.remote(num_cpus=1)
  100. class LocustWorker(LocustClient):
  101. def __init__(
  102. self,
  103. host_url: str,
  104. token: str,
  105. master_address: str,
  106. data: Dict[str, Any] = None,
  107. ):
  108. # NOTE(zcin): We need to lazily import locust because the driver
  109. # script won't connect to ray properly otherwise.
  110. import locust
  111. from locust.env import Environment
  112. from locust.log import setup_logging
  113. super().__init__(host_url=host_url, token=token, data=data)
  114. setup_logging("INFO")
  115. self.env = Environment(user_classes=[self.user_class], events=locust.events)
  116. self.master_address = master_address
  117. def run(self) -> List[Dict]:
  118. runner = self.env.create_worker_runner(
  119. master_host=self.master_address, master_port=5557
  120. )
  121. runner.greenlet.join()
  122. return self.errors
  123. @ray.remote(num_cpus=1)
  124. class LocustMaster(LocustClient):
  125. def __init__(
  126. self,
  127. host_url: str,
  128. token: str,
  129. expected_num_workers: int,
  130. stages: List[LocustStage],
  131. wait_for_workers_timeout_s: float,
  132. ):
  133. # NOTE(zcin): We need to lazily import locust because the driver
  134. # script won't connect to ray properly otherwise.
  135. import locust
  136. from locust import LoadTestShape
  137. from locust.env import Environment
  138. from locust.log import setup_logging
  139. super().__init__(host_url=host_url, token=token)
  140. setup_logging("INFO")
  141. self.stats_in_stages: List[PerformanceStats] = []
  142. class StagesShape(LoadTestShape):
  143. curr_stage_ix = 0
  144. def tick(cls):
  145. run_time = cls.get_run_time()
  146. prefix_time = 0
  147. for i, stage in enumerate(stages):
  148. prefix_time += stage.duration_s
  149. if run_time < prefix_time:
  150. if i != cls.curr_stage_ix:
  151. self.on_stage_finished()
  152. cls.curr_stage_ix = i
  153. current_stage = stages[cls.curr_stage_ix]
  154. return current_stage.users, current_stage.spawn_rate
  155. # End of stage test
  156. self.on_stage_finished()
  157. self.master_env = Environment(
  158. user_classes=[self.user_class],
  159. shape_class=StagesShape(),
  160. events=locust.events,
  161. )
  162. self.expected_num_workers = expected_num_workers
  163. self.wait_for_workers_timeout_s = wait_for_workers_timeout_s
  164. self.master_runner = None
  165. def on_stage_finished(self):
  166. stats_entry_key = ("", "GET")
  167. stats_entry = self.master_runner.stats.entries.get(stats_entry_key)
  168. self.stats_in_stages.append(
  169. PerformanceStats(
  170. p50_latency=stats_entry.get_current_response_time_percentile(0.5),
  171. p90_latency=stats_entry.get_current_response_time_percentile(0.9),
  172. p99_latency=stats_entry.get_current_response_time_percentile(0.99),
  173. rps=stats_entry.current_rps,
  174. )
  175. )
  176. def run(self):
  177. import gevent
  178. from locust.stats import (
  179. get_stats_summary,
  180. get_percentile_stats_summary,
  181. get_error_report_summary,
  182. stats_history,
  183. stats_printer,
  184. )
  185. self.master_runner = self.master_env.create_master_runner("*", 5557)
  186. start = time.time()
  187. while len(self.master_runner.clients.ready) < self.expected_num_workers:
  188. if time.time() - start > self.wait_for_workers_timeout_s:
  189. raise RuntimeError(
  190. f"Timed out waiting for {self.expected_num_workers} workers to "
  191. "connect to Locust master."
  192. )
  193. print(
  194. f"Waiting for workers to be ready, "
  195. f"{len(self.master_runner.clients.ready)} "
  196. f"of {self.expected_num_workers} ready."
  197. )
  198. time.sleep(1)
  199. # Periodically output current stats (each entry is aggregated
  200. # stats over the past 10 seconds, by default)
  201. gevent.spawn(stats_printer(self.master_env.stats))
  202. gevent.spawn(stats_history, self.master_runner)
  203. # Start test & wait for the shape test to finish
  204. self.master_runner.start_shape()
  205. self.master_runner.shape_greenlet.join()
  206. # Send quit signal to all locust workers
  207. self.master_runner.quit()
  208. # Print stats
  209. for line in get_stats_summary(self.master_runner.stats, current=False):
  210. print(line)
  211. # Print percentile stats
  212. for line in get_percentile_stats_summary(self.master_runner.stats):
  213. print(line)
  214. # Print error report
  215. if self.master_runner.stats.errors:
  216. for line in get_error_report_summary(self.master_runner.stats):
  217. print(line)
  218. stats_entry_key = ("", "GET")
  219. stats_entry = self.master_runner.stats.entries.get(stats_entry_key)
  220. return LocustTestResults(
  221. history=self.master_runner.stats.history,
  222. total_requests=self.master_runner.stats.num_requests,
  223. num_failures=self.master_runner.stats.num_failures,
  224. avg_latency=stats_entry.avg_response_time,
  225. p50_latency=stats_entry.get_response_time_percentile(0.5),
  226. p90_latency=stats_entry.get_response_time_percentile(0.9),
  227. p99_latency=stats_entry.get_response_time_percentile(0.99),
  228. avg_rps=stats_entry.total_rps,
  229. stats_in_stages=self.stats_in_stages,
  230. )
  231. def run_locust_load_test(config: LocustLoadTestConfig) -> LocustTestResults:
  232. """Runs a Locust load test against a service.
  233. Returns:
  234. Performance results (e.g. throughput and latency) from the test.
  235. Raises:
  236. RuntimeError if any requests failed during the load test.
  237. """
  238. logger.info(f"Spawning {config.num_workers} Locust worker Ray tasks.")
  239. master_address = ray.util.get_node_ip_address()
  240. worker_refs = []
  241. # Start Locust workers
  242. for _ in tqdm(range(config.num_workers)):
  243. locust_worker = LocustWorker.remote(
  244. host_url=config.host_url,
  245. token=config.auth_token,
  246. master_address=master_address,
  247. data=config.data,
  248. )
  249. worker_refs.append(locust_worker.run.remote())
  250. # Start Locust master
  251. master_worker = LocustMaster.remote(
  252. host_url=config.host_url,
  253. token=config.auth_token,
  254. expected_num_workers=config.num_workers,
  255. stages=config.stages,
  256. wait_for_workers_timeout_s=config.wait_for_workers_timeout_s,
  257. )
  258. master_ref = master_worker.run.remote()
  259. # Collect results and metrics
  260. stats: LocustTestResults = ray.get(master_ref)
  261. errors = sorted(chain(*ray.get(worker_refs)), key=lambda e: e.start_time_s)
  262. # If there were any requests that failed, raise error.
  263. if stats.num_failures > 0:
  264. errors_json = [asdict(err) for err in errors]
  265. raise RuntimeError(
  266. f"There were failed requests: {json.dumps(errors_json, indent=4)}"
  267. )
  268. return stats