large_scale_test.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. """
  2. @author jennakwon06
  3. """
  4. import argparse
  5. import json
  6. import logging
  7. import math
  8. import os
  9. import random
  10. import time
  11. from copy import copy, deepcopy
  12. from typing import List, Tuple
  13. import dask.array
  14. import numpy as np
  15. import xarray
  16. import ray
  17. from ray._private.test_utils import monitor_memory_usage
  18. from ray.util.dask import ray_dask_get
  19. """
  20. We simulate a real-life usecase where we process a time-series
  21. data of 1 month, using Dask/Xarray on a Ray cluster.
  22. Processing is as follows:
  23. (1) Load the 1-month Xarray lazily.
  24. Perform decimation to reduce data size.
  25. (2) Compute FFT on the 1-year Xarray lazily.
  26. Perform decimation to reduce data size.
  27. (3) Segment the Xarray from (2) into 30-minute Xarrays;
  28. at this point, we have 4380 / 30 = 146 Xarrays.
  29. (4) Trigger save to disk for each of the 30-minute Xarrays.
  30. This triggers Dask computations; there will be 146 graphs.
  31. Since 1460 graphs is too much to process at once,
  32. we determine the batch_size based on script parameters.
  33. (e.g. if batch_size is 100, we'll have 15 batches).
  34. """
  35. MINUTES_IN_A_MONTH = 500
  36. NUM_MINS_PER_OUTPUT_FILE = 30
  37. SAMPLING_RATE = 1000000
  38. SECONDS_IN_A_MIN = 60
  39. INPUT_SHAPE = (3, SAMPLING_RATE * SECONDS_IN_A_MIN)
  40. PEAK_MEMORY_CONSUMPTION_IN_GB = 6
  41. logging.basicConfig(
  42. format="%(asctime)s %(levelname)-8s %(message)s",
  43. level=logging.INFO,
  44. datefmt="%Y-%m-%d %H:%M:%S",
  45. )
  46. class TestSpec:
  47. def __init__(
  48. self,
  49. num_workers: int,
  50. worker_obj_store_size_in_gb: int,
  51. trigger_object_spill: bool,
  52. error_rate: float,
  53. ):
  54. """
  55. `batch_size` is the # of Dask graphs sent to the cluster
  56. simultaneously for processing.
  57. One element in the batch represents 1 Dask graph.
  58. The Dask graph involves reading 30 arrays (one is 1.44GB)
  59. and concatenating them into a Dask array.
  60. Then, it does FFT computations across chunks of the Dask array.
  61. It saves the FFT-ed version of the Dask array as an output file.
  62. If `trigger_object_spill` is True, then we send work to
  63. the cluster such that each worker gets the number of graphs
  64. that would exceed the worker memory, triggering object spills.
  65. We use the estimated peak memory consumption to determine
  66. how many graphs should be sent.
  67. If `error_rate` is True, we throw an exception at the Data
  68. load layer as per error rate.
  69. """
  70. self.error_rate = error_rate
  71. if trigger_object_spill:
  72. num_graphs_per_worker = (
  73. int(
  74. math.floor(
  75. worker_obj_store_size_in_gb / PEAK_MEMORY_CONSUMPTION_IN_GB
  76. )
  77. )
  78. + 1
  79. )
  80. else:
  81. num_graphs_per_worker = int(
  82. math.floor(worker_obj_store_size_in_gb / PEAK_MEMORY_CONSUMPTION_IN_GB)
  83. )
  84. self.batch_size = num_graphs_per_worker * num_workers
  85. def __str__(self):
  86. return "Error rate = {}, Batch Size = {}".format(
  87. self.error_rate, self.batch_size
  88. )
  89. class LoadRoutines:
  90. @staticmethod
  91. def lazy_load_xarray_one_month(test_spec: TestSpec) -> xarray.Dataset:
  92. """
  93. Lazily load an Xarray representing 1 month of data.
  94. The Xarray's data variable is a dask.array that's lazily constructed.
  95. Therefore, creating the Xarray object doesn't consume any memory.
  96. But computing the Xarray will.
  97. """
  98. dask_array_lists = list()
  99. array_dtype = np.float32
  100. # Create chunks with power-of-two sizes so that downstream
  101. # FFT computations will work correctly.
  102. rechunk_size = 2 << 23
  103. # Create MINUTES_IN_A_MONTH number of Delayed objects where
  104. # each Delayed object is loading an array.
  105. for i in range(0, MINUTES_IN_A_MONTH):
  106. dask_arr = dask.array.from_delayed(
  107. dask.delayed(LoadRoutines.load_array_one_minute)(test_spec),
  108. shape=INPUT_SHAPE,
  109. dtype=array_dtype,
  110. )
  111. dask_array_lists.append(dask_arr)
  112. # Return the final dask.array in an Xarray
  113. return xarray.Dataset(
  114. data_vars={
  115. "data_var": (
  116. ["channel", "time"],
  117. dask.array.rechunk(
  118. dask.array.concatenate(dask_array_lists, axis=1),
  119. chunks=(INPUT_SHAPE[0], rechunk_size),
  120. ),
  121. )
  122. },
  123. coords={"channel": ("channel", np.arange(INPUT_SHAPE[0]))},
  124. attrs={"hello": "world"},
  125. )
  126. @staticmethod
  127. def load_array_one_minute(test_spec: TestSpec) -> np.ndarray:
  128. """
  129. Load an array representing 1 minute of data. Each load consumes
  130. ~0.144GB of memory (3 * 200000 * 60 * 4 (bytes in a float)) = ~0.14GB
  131. In real life, this is loaded from cloud storage or disk.
  132. """
  133. if random.random() < test_spec.error_rate:
  134. raise Exception("Data error!")
  135. else:
  136. return np.random.random(INPUT_SHAPE)
  137. class TransformRoutines:
  138. @staticmethod
  139. def fft_xarray(xr_input: xarray.Dataset, n_fft: int, hop_length: int):
  140. """
  141. Perform FFT on an Xarray and return it as another Xarray.
  142. """
  143. # Infer the output chunk shape since FFT does
  144. # not preserve input chunk shape.
  145. output_chunk_shape = TransformRoutines.infer_chunk_shape_after_fft(
  146. n_fft=n_fft,
  147. hop_length=hop_length,
  148. time_chunk_sizes=xr_input.chunks["time"],
  149. )
  150. transformed_audio = dask.array.map_overlap(
  151. TransformRoutines.fft_algorithm,
  152. xr_input.data_var.data,
  153. depth={0: 0, 1: (0, n_fft - hop_length)},
  154. boundary={0: "none", 1: "none"},
  155. chunks=output_chunk_shape,
  156. dtype=np.float32,
  157. trim=True,
  158. algorithm_params={"hop_length": hop_length, "n_fft": n_fft},
  159. )
  160. return xarray.Dataset(
  161. data_vars={
  162. "data_var": (
  163. ["channel", "freq", "time"],
  164. transformed_audio,
  165. ),
  166. },
  167. coords={
  168. "freq": ("freq", np.arange(transformed_audio.shape[1])),
  169. "channel": ("channel", np.arange(INPUT_SHAPE[0])),
  170. },
  171. attrs={"hello": "world2"},
  172. )
  173. @staticmethod
  174. def decimate_xarray_after_load(xr_input: xarray.Dataset, decimate_factor: int):
  175. """
  176. Downsample an Xarray.
  177. """
  178. # Infer the output chunk shape since FFT does
  179. # not preserve input chunk shape.
  180. start_chunks = xr_input.data_var.data.chunks
  181. data_0 = xr_input.data_var.data[0] - xr_input.data_var.data[2]
  182. data_1 = xr_input.data_var.data[2]
  183. data_2 = xr_input.data_var.data[0]
  184. stacked_data = dask.array.stack([data_0, data_1, data_2], axis=0)
  185. stacked_chunks = stacked_data.chunks
  186. rechunking_to_chunks = (start_chunks[0], stacked_chunks[1])
  187. xr_input.data_var.data = stacked_data.rechunk(rechunking_to_chunks)
  188. in_chunks = xr_input.data_var.data.chunks
  189. out_chunks = (
  190. in_chunks[0],
  191. tuple([int(chunk / decimate_factor) for chunk in in_chunks[1]]),
  192. )
  193. data_ds_data = xr_input.data_var.data.map_overlap(
  194. TransformRoutines.decimate_raw_data,
  195. decimate_time=decimate_factor,
  196. overlap_time=10,
  197. depth=(0, decimate_factor * 10),
  198. trim=False,
  199. dtype="float32",
  200. chunks=out_chunks,
  201. )
  202. data_ds = copy(xr_input)
  203. data_ds = data_ds.isel(time=slice(0, data_ds_data.shape[1]))
  204. data_ds.data_var.data = data_ds_data
  205. return data_ds
  206. @staticmethod
  207. def decimate_raw_data(data: np.ndarray, decimate_time: int, overlap_time=0):
  208. from scipy.signal import decimate
  209. data = np.nan_to_num(data)
  210. if decimate_time > 1:
  211. data = decimate(data, q=decimate_time, axis=1)
  212. if overlap_time > 0:
  213. data = data[:, overlap_time:-overlap_time]
  214. return data
  215. @staticmethod
  216. def fft_algorithm(data: np.ndarray, algorithm_params: dict) -> np.ndarray:
  217. """
  218. Apply FFT algorithm to an input xarray.
  219. """
  220. from scipy import signal
  221. hop_length = algorithm_params["hop_length"]
  222. n_fft = algorithm_params["n_fft"]
  223. noverlap = n_fft - hop_length
  224. _, _, spectrogram = signal.stft(
  225. data,
  226. nfft=n_fft,
  227. nperseg=n_fft,
  228. noverlap=noverlap,
  229. return_onesided=False,
  230. boundary=None,
  231. )
  232. spectrogram = np.abs(spectrogram)
  233. spectrogram = 10 * np.log10(spectrogram**2)
  234. return spectrogram
  235. @staticmethod
  236. def infer_chunk_shape_after_fft(
  237. n_fft: int, hop_length: int, time_chunk_sizes: List
  238. ) -> tuple:
  239. """
  240. Infer the chunk shapes after applying FFT transformation.
  241. Infer is necessary for lazy transformations in Dask when
  242. transformations do not preserve chunk shape.
  243. """
  244. output_time_chunk_sizes = list()
  245. for time_chunk_size in time_chunk_sizes:
  246. output_time_chunk_sizes.append(math.ceil(time_chunk_size / hop_length))
  247. num_freq = int(n_fft / 2 + 1)
  248. return (INPUT_SHAPE[0],), (num_freq,), tuple(output_time_chunk_sizes)
  249. @staticmethod
  250. def fix_last_chunk_error(xr_input: xarray.Dataset, n_overlap):
  251. time_chunks = list(xr_input.chunks["time"])
  252. # purging chunks that are too small
  253. if time_chunks[-1] < n_overlap:
  254. current_len = len(xr_input.time)
  255. xr_input = xr_input.isel(time=slice(0, current_len - time_chunks[-1]))
  256. if time_chunks[0] < n_overlap:
  257. current_len = len(xr_input.time)
  258. xr_input = xr_input.isel(time=slice(time_chunks[0], current_len))
  259. return xr_input
  260. class SaveRoutines:
  261. @staticmethod
  262. def save_xarray(xarray_dataset, filename, dirpath):
  263. """
  264. Save Xarray in zarr format.
  265. """
  266. filepath = os.path.join(dirpath, filename)
  267. if os.path.exists(filepath):
  268. return "already_exists"
  269. try:
  270. xarray_dataset.to_zarr(filepath)
  271. except Exception as e:
  272. return "failure, exception = {}".format(e)
  273. return "success"
  274. @staticmethod
  275. def save_all_xarrays(
  276. xarray_filename_pairs: List[Tuple],
  277. ray_scheduler,
  278. dirpath: str,
  279. batch_size: int,
  280. ):
  281. def chunks(lst, n):
  282. """Yield successive n-sized chunks from lst."""
  283. for i in range(0, len(lst), n):
  284. yield lst[i : i + n]
  285. for batch_idx, batch in enumerate(chunks(xarray_filename_pairs, batch_size)):
  286. delayed_tasks = list()
  287. for xarray_filename_pair in batch:
  288. delayed_tasks.append(
  289. dask.delayed(SaveRoutines.save_xarray)(
  290. xarray_dataset=xarray_filename_pair[0],
  291. filename=xarray_filename_pair[1],
  292. dirpath=dirpath,
  293. )
  294. )
  295. logging.info(
  296. "[Batch Index {}] Batch size {}: Sending work to Ray Cluster.".format(
  297. batch_idx, batch_size
  298. )
  299. )
  300. res = []
  301. try:
  302. res = dask.compute(delayed_tasks, scheduler=ray_scheduler)
  303. except Exception:
  304. logging.warning(
  305. "[Batch Index {}] Exception while computing batch!".format(
  306. batch_idx
  307. )
  308. )
  309. finally:
  310. logging.info("[Batch Index {}], Result = {}".format(batch_idx, res))
  311. def lazy_create_xarray_filename_pairs(
  312. test_spec: TestSpec,
  313. ) -> List[Tuple[xarray.Dataset, str]]:
  314. n_fft = 4096
  315. hop_length = int(SAMPLING_RATE / 100)
  316. decimate_factor = 100
  317. logging.info("Creating 1 month lazy Xarray with decimation and FFT")
  318. xr1 = LoadRoutines.lazy_load_xarray_one_month(test_spec)
  319. xr2 = TransformRoutines.decimate_xarray_after_load(
  320. xr_input=xr1, decimate_factor=decimate_factor
  321. )
  322. xr3 = TransformRoutines.fix_last_chunk_error(xr2, n_overlap=n_fft - hop_length)
  323. xr4 = TransformRoutines.fft_xarray(xr_input=xr3, n_fft=n_fft, hop_length=hop_length)
  324. num_segments = int(MINUTES_IN_A_MONTH / NUM_MINS_PER_OUTPUT_FILE)
  325. start_time = 0
  326. xarray_filename_pairs: List[Tuple[xarray.Dataset, str]] = list()
  327. timestamp = int(time.time())
  328. for step in range(num_segments):
  329. segment_start = start_time + (NUM_MINS_PER_OUTPUT_FILE * step) # in minutes
  330. segment_start_index = int(
  331. SECONDS_IN_A_MIN
  332. * NUM_MINS_PER_OUTPUT_FILE
  333. * step
  334. * (SAMPLING_RATE / decimate_factor)
  335. / hop_length
  336. )
  337. segment_end = segment_start + NUM_MINS_PER_OUTPUT_FILE
  338. segment_len_sec = (segment_end - segment_start) * SECONDS_IN_A_MIN
  339. segment_end_index = int(
  340. segment_start_index + segment_len_sec * SAMPLING_RATE / hop_length
  341. )
  342. xr_segment = deepcopy(
  343. xr4.isel(time=slice(segment_start_index, segment_end_index))
  344. )
  345. xarray_filename_pairs.append(
  346. (xr_segment, "xarray_step_{}_{}.zarr".format(step, timestamp))
  347. )
  348. return xarray_filename_pairs
  349. def parse_script_args():
  350. parser = argparse.ArgumentParser()
  351. parser.add_argument("--num_workers", type=int)
  352. parser.add_argument("--worker_obj_store_size_in_gb", type=int)
  353. parser.add_argument("--error_rate", type=float, default=0)
  354. parser.add_argument("--data_save_path", type=str)
  355. parser.add_argument(
  356. "--trigger_object_spill",
  357. dest="trigger_object_spill",
  358. action="store_true",
  359. )
  360. parser.set_defaults(trigger_object_spill=False)
  361. return parser.parse_known_args()
  362. def main():
  363. args, unknown = parse_script_args()
  364. logging.info("Received arguments: {}".format(args))
  365. # Create test spec
  366. test_spec = TestSpec(
  367. num_workers=args.num_workers,
  368. worker_obj_store_size_in_gb=args.worker_obj_store_size_in_gb,
  369. error_rate=args.error_rate,
  370. trigger_object_spill=args.trigger_object_spill,
  371. )
  372. logging.info("Created test spec: {}".format(test_spec))
  373. # Create the data save path if it doesn't exist.
  374. data_save_path = args.data_save_path
  375. if not os.path.exists(data_save_path):
  376. os.makedirs(data_save_path, mode=0o777, exist_ok=True)
  377. # Lazily construct Xarrays
  378. xarray_filename_pairs = lazy_create_xarray_filename_pairs(test_spec)
  379. # Connect to the Ray cluster
  380. ray.init(address="auto")
  381. monitor_actor = monitor_memory_usage()
  382. # Save all the Xarrays to disk; this will trigger
  383. # Dask computations on Ray.
  384. logging.info("Saving {} xarrays..".format(len(xarray_filename_pairs)))
  385. SaveRoutines.save_all_xarrays(
  386. xarray_filename_pairs=xarray_filename_pairs,
  387. dirpath=data_save_path,
  388. batch_size=test_spec.batch_size,
  389. ray_scheduler=ray_dask_get,
  390. )
  391. ray.get(monitor_actor.stop_run.remote())
  392. used_gb, usage = ray.get(monitor_actor.get_peak_memory_info.remote())
  393. print(f"Peak memory usage: {round(used_gb, 2)}GB")
  394. print(f"Peak memory usage per processes:\n {usage}")
  395. try:
  396. print(ray._private.internal_api.memory_summary(stats_only=True))
  397. except Exception as e:
  398. print(f"Warning: query memory summary failed: {e}")
  399. with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
  400. f.write(
  401. json.dumps(
  402. {
  403. "success": 1,
  404. "_peak_memory": round(used_gb, 2),
  405. "_peak_process_memory": usage,
  406. }
  407. )
  408. )
  409. if __name__ == "__main__":
  410. main()