123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482 |
- """
- @author jennakwon06
- """
- import argparse
- import json
- import logging
- import math
- import os
- import random
- import time
- from copy import copy, deepcopy
- from typing import List, Tuple
- import dask.array
- import numpy as np
- import xarray
- import ray
- from ray._private.test_utils import monitor_memory_usage
- from ray.util.dask import ray_dask_get
- """
- We simulate a real-life usecase where we process a time-series
- data of 1 month, using Dask/Xarray on a Ray cluster.
- Processing is as follows:
- (1) Load the 1-month Xarray lazily.
- Perform decimation to reduce data size.
- (2) Compute FFT on the 1-year Xarray lazily.
- Perform decimation to reduce data size.
- (3) Segment the Xarray from (2) into 30-minute Xarrays;
- at this point, we have 4380 / 30 = 146 Xarrays.
- (4) Trigger save to disk for each of the 30-minute Xarrays.
- This triggers Dask computations; there will be 146 graphs.
- Since 1460 graphs is too much to process at once,
- we determine the batch_size based on script parameters.
- (e.g. if batch_size is 100, we'll have 15 batches).
- """
- MINUTES_IN_A_MONTH = 500
- NUM_MINS_PER_OUTPUT_FILE = 30
- SAMPLING_RATE = 1000000
- SECONDS_IN_A_MIN = 60
- INPUT_SHAPE = (3, SAMPLING_RATE * SECONDS_IN_A_MIN)
- PEAK_MEMORY_CONSUMPTION_IN_GB = 6
- logging.basicConfig(
- format="%(asctime)s %(levelname)-8s %(message)s",
- level=logging.INFO,
- datefmt="%Y-%m-%d %H:%M:%S",
- )
- class TestSpec:
- def __init__(
- self,
- num_workers: int,
- worker_obj_store_size_in_gb: int,
- trigger_object_spill: bool,
- error_rate: float,
- ):
- """
- `batch_size` is the # of Dask graphs sent to the cluster
- simultaneously for processing.
- One element in the batch represents 1 Dask graph.
- The Dask graph involves reading 30 arrays (one is 1.44GB)
- and concatenating them into a Dask array.
- Then, it does FFT computations across chunks of the Dask array.
- It saves the FFT-ed version of the Dask array as an output file.
- If `trigger_object_spill` is True, then we send work to
- the cluster such that each worker gets the number of graphs
- that would exceed the worker memory, triggering object spills.
- We use the estimated peak memory consumption to determine
- how many graphs should be sent.
- If `error_rate` is True, we throw an exception at the Data
- load layer as per error rate.
- """
- self.error_rate = error_rate
- if trigger_object_spill:
- num_graphs_per_worker = (
- int(
- math.floor(
- worker_obj_store_size_in_gb / PEAK_MEMORY_CONSUMPTION_IN_GB
- )
- )
- + 1
- )
- else:
- num_graphs_per_worker = int(
- math.floor(worker_obj_store_size_in_gb / PEAK_MEMORY_CONSUMPTION_IN_GB)
- )
- self.batch_size = num_graphs_per_worker * num_workers
- def __str__(self):
- return "Error rate = {}, Batch Size = {}".format(
- self.error_rate, self.batch_size
- )
- class LoadRoutines:
- @staticmethod
- def lazy_load_xarray_one_month(test_spec: TestSpec) -> xarray.Dataset:
- """
- Lazily load an Xarray representing 1 month of data.
- The Xarray's data variable is a dask.array that's lazily constructed.
- Therefore, creating the Xarray object doesn't consume any memory.
- But computing the Xarray will.
- """
- dask_array_lists = list()
- array_dtype = np.float32
- # Create chunks with power-of-two sizes so that downstream
- # FFT computations will work correctly.
- rechunk_size = 2 << 23
- # Create MINUTES_IN_A_MONTH number of Delayed objects where
- # each Delayed object is loading an array.
- for i in range(0, MINUTES_IN_A_MONTH):
- dask_arr = dask.array.from_delayed(
- dask.delayed(LoadRoutines.load_array_one_minute)(test_spec),
- shape=INPUT_SHAPE,
- dtype=array_dtype,
- )
- dask_array_lists.append(dask_arr)
- # Return the final dask.array in an Xarray
- return xarray.Dataset(
- data_vars={
- "data_var": (
- ["channel", "time"],
- dask.array.rechunk(
- dask.array.concatenate(dask_array_lists, axis=1),
- chunks=(INPUT_SHAPE[0], rechunk_size),
- ),
- )
- },
- coords={"channel": ("channel", np.arange(INPUT_SHAPE[0]))},
- attrs={"hello": "world"},
- )
- @staticmethod
- def load_array_one_minute(test_spec: TestSpec) -> np.ndarray:
- """
- Load an array representing 1 minute of data. Each load consumes
- ~0.144GB of memory (3 * 200000 * 60 * 4 (bytes in a float)) = ~0.14GB
- In real life, this is loaded from cloud storage or disk.
- """
- if random.random() < test_spec.error_rate:
- raise Exception("Data error!")
- else:
- return np.random.random(INPUT_SHAPE)
- class TransformRoutines:
- @staticmethod
- def fft_xarray(xr_input: xarray.Dataset, n_fft: int, hop_length: int):
- """
- Perform FFT on an Xarray and return it as another Xarray.
- """
- # Infer the output chunk shape since FFT does
- # not preserve input chunk shape.
- output_chunk_shape = TransformRoutines.infer_chunk_shape_after_fft(
- n_fft=n_fft,
- hop_length=hop_length,
- time_chunk_sizes=xr_input.chunks["time"],
- )
- transformed_audio = dask.array.map_overlap(
- TransformRoutines.fft_algorithm,
- xr_input.data_var.data,
- depth={0: 0, 1: (0, n_fft - hop_length)},
- boundary={0: "none", 1: "none"},
- chunks=output_chunk_shape,
- dtype=np.float32,
- trim=True,
- algorithm_params={"hop_length": hop_length, "n_fft": n_fft},
- )
- return xarray.Dataset(
- data_vars={
- "data_var": (
- ["channel", "freq", "time"],
- transformed_audio,
- ),
- },
- coords={
- "freq": ("freq", np.arange(transformed_audio.shape[1])),
- "channel": ("channel", np.arange(INPUT_SHAPE[0])),
- },
- attrs={"hello": "world2"},
- )
- @staticmethod
- def decimate_xarray_after_load(xr_input: xarray.Dataset, decimate_factor: int):
- """
- Downsample an Xarray.
- """
- # Infer the output chunk shape since FFT does
- # not preserve input chunk shape.
- start_chunks = xr_input.data_var.data.chunks
- data_0 = xr_input.data_var.data[0] - xr_input.data_var.data[2]
- data_1 = xr_input.data_var.data[2]
- data_2 = xr_input.data_var.data[0]
- stacked_data = dask.array.stack([data_0, data_1, data_2], axis=0)
- stacked_chunks = stacked_data.chunks
- rechunking_to_chunks = (start_chunks[0], stacked_chunks[1])
- xr_input.data_var.data = stacked_data.rechunk(rechunking_to_chunks)
- in_chunks = xr_input.data_var.data.chunks
- out_chunks = (
- in_chunks[0],
- tuple([int(chunk / decimate_factor) for chunk in in_chunks[1]]),
- )
- data_ds_data = xr_input.data_var.data.map_overlap(
- TransformRoutines.decimate_raw_data,
- decimate_time=decimate_factor,
- overlap_time=10,
- depth=(0, decimate_factor * 10),
- trim=False,
- dtype="float32",
- chunks=out_chunks,
- )
- data_ds = copy(xr_input)
- data_ds = data_ds.isel(time=slice(0, data_ds_data.shape[1]))
- data_ds.data_var.data = data_ds_data
- return data_ds
- @staticmethod
- def decimate_raw_data(data: np.ndarray, decimate_time: int, overlap_time=0):
- from scipy.signal import decimate
- data = np.nan_to_num(data)
- if decimate_time > 1:
- data = decimate(data, q=decimate_time, axis=1)
- if overlap_time > 0:
- data = data[:, overlap_time:-overlap_time]
- return data
- @staticmethod
- def fft_algorithm(data: np.ndarray, algorithm_params: dict) -> np.ndarray:
- """
- Apply FFT algorithm to an input xarray.
- """
- from scipy import signal
- hop_length = algorithm_params["hop_length"]
- n_fft = algorithm_params["n_fft"]
- noverlap = n_fft - hop_length
- _, _, spectrogram = signal.stft(
- data,
- nfft=n_fft,
- nperseg=n_fft,
- noverlap=noverlap,
- return_onesided=False,
- boundary=None,
- )
- spectrogram = np.abs(spectrogram)
- spectrogram = 10 * np.log10(spectrogram**2)
- return spectrogram
- @staticmethod
- def infer_chunk_shape_after_fft(
- n_fft: int, hop_length: int, time_chunk_sizes: List
- ) -> tuple:
- """
- Infer the chunk shapes after applying FFT transformation.
- Infer is necessary for lazy transformations in Dask when
- transformations do not preserve chunk shape.
- """
- output_time_chunk_sizes = list()
- for time_chunk_size in time_chunk_sizes:
- output_time_chunk_sizes.append(math.ceil(time_chunk_size / hop_length))
- num_freq = int(n_fft / 2 + 1)
- return (INPUT_SHAPE[0],), (num_freq,), tuple(output_time_chunk_sizes)
- @staticmethod
- def fix_last_chunk_error(xr_input: xarray.Dataset, n_overlap):
- time_chunks = list(xr_input.chunks["time"])
- # purging chunks that are too small
- if time_chunks[-1] < n_overlap:
- current_len = len(xr_input.time)
- xr_input = xr_input.isel(time=slice(0, current_len - time_chunks[-1]))
- if time_chunks[0] < n_overlap:
- current_len = len(xr_input.time)
- xr_input = xr_input.isel(time=slice(time_chunks[0], current_len))
- return xr_input
- class SaveRoutines:
- @staticmethod
- def save_xarray(xarray_dataset, filename, dirpath):
- """
- Save Xarray in zarr format.
- """
- filepath = os.path.join(dirpath, filename)
- if os.path.exists(filepath):
- return "already_exists"
- try:
- xarray_dataset.to_zarr(filepath)
- except Exception as e:
- return "failure, exception = {}".format(e)
- return "success"
- @staticmethod
- def save_all_xarrays(
- xarray_filename_pairs: List[Tuple],
- ray_scheduler,
- dirpath: str,
- batch_size: int,
- ):
- def chunks(lst, n):
- """Yield successive n-sized chunks from lst."""
- for i in range(0, len(lst), n):
- yield lst[i : i + n]
- for batch_idx, batch in enumerate(chunks(xarray_filename_pairs, batch_size)):
- delayed_tasks = list()
- for xarray_filename_pair in batch:
- delayed_tasks.append(
- dask.delayed(SaveRoutines.save_xarray)(
- xarray_dataset=xarray_filename_pair[0],
- filename=xarray_filename_pair[1],
- dirpath=dirpath,
- )
- )
- logging.info(
- "[Batch Index {}] Batch size {}: Sending work to Ray Cluster.".format(
- batch_idx, batch_size
- )
- )
- res = []
- try:
- res = dask.compute(delayed_tasks, scheduler=ray_scheduler)
- except Exception:
- logging.warning(
- "[Batch Index {}] Exception while computing batch!".format(
- batch_idx
- )
- )
- finally:
- logging.info("[Batch Index {}], Result = {}".format(batch_idx, res))
- def lazy_create_xarray_filename_pairs(
- test_spec: TestSpec,
- ) -> List[Tuple[xarray.Dataset, str]]:
- n_fft = 4096
- hop_length = int(SAMPLING_RATE / 100)
- decimate_factor = 100
- logging.info("Creating 1 month lazy Xarray with decimation and FFT")
- xr1 = LoadRoutines.lazy_load_xarray_one_month(test_spec)
- xr2 = TransformRoutines.decimate_xarray_after_load(
- xr_input=xr1, decimate_factor=decimate_factor
- )
- xr3 = TransformRoutines.fix_last_chunk_error(xr2, n_overlap=n_fft - hop_length)
- xr4 = TransformRoutines.fft_xarray(xr_input=xr3, n_fft=n_fft, hop_length=hop_length)
- num_segments = int(MINUTES_IN_A_MONTH / NUM_MINS_PER_OUTPUT_FILE)
- start_time = 0
- xarray_filename_pairs: List[Tuple[xarray.Dataset, str]] = list()
- timestamp = int(time.time())
- for step in range(num_segments):
- segment_start = start_time + (NUM_MINS_PER_OUTPUT_FILE * step) # in minutes
- segment_start_index = int(
- SECONDS_IN_A_MIN
- * NUM_MINS_PER_OUTPUT_FILE
- * step
- * (SAMPLING_RATE / decimate_factor)
- / hop_length
- )
- segment_end = segment_start + NUM_MINS_PER_OUTPUT_FILE
- segment_len_sec = (segment_end - segment_start) * SECONDS_IN_A_MIN
- segment_end_index = int(
- segment_start_index + segment_len_sec * SAMPLING_RATE / hop_length
- )
- xr_segment = deepcopy(
- xr4.isel(time=slice(segment_start_index, segment_end_index))
- )
- xarray_filename_pairs.append(
- (xr_segment, "xarray_step_{}_{}.zarr".format(step, timestamp))
- )
- return xarray_filename_pairs
- def parse_script_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("--num_workers", type=int)
- parser.add_argument("--worker_obj_store_size_in_gb", type=int)
- parser.add_argument("--error_rate", type=float, default=0)
- parser.add_argument("--data_save_path", type=str)
- parser.add_argument(
- "--trigger_object_spill",
- dest="trigger_object_spill",
- action="store_true",
- )
- parser.set_defaults(trigger_object_spill=False)
- return parser.parse_known_args()
- def main():
- args, unknown = parse_script_args()
- logging.info("Received arguments: {}".format(args))
- # Create test spec
- test_spec = TestSpec(
- num_workers=args.num_workers,
- worker_obj_store_size_in_gb=args.worker_obj_store_size_in_gb,
- error_rate=args.error_rate,
- trigger_object_spill=args.trigger_object_spill,
- )
- logging.info("Created test spec: {}".format(test_spec))
- # Create the data save path if it doesn't exist.
- data_save_path = args.data_save_path
- if not os.path.exists(data_save_path):
- os.makedirs(data_save_path, mode=0o777, exist_ok=True)
- # Lazily construct Xarrays
- xarray_filename_pairs = lazy_create_xarray_filename_pairs(test_spec)
- # Connect to the Ray cluster
- ray.init(address="auto")
- monitor_actor = monitor_memory_usage()
- # Save all the Xarrays to disk; this will trigger
- # Dask computations on Ray.
- logging.info("Saving {} xarrays..".format(len(xarray_filename_pairs)))
- SaveRoutines.save_all_xarrays(
- xarray_filename_pairs=xarray_filename_pairs,
- dirpath=data_save_path,
- batch_size=test_spec.batch_size,
- ray_scheduler=ray_dask_get,
- )
- ray.get(monitor_actor.stop_run.remote())
- used_gb, usage = ray.get(monitor_actor.get_peak_memory_info.remote())
- print(f"Peak memory usage: {round(used_gb, 2)}GB")
- print(f"Peak memory usage per processes:\n {usage}")
- try:
- print(ray._private.internal_api.memory_summary(stats_only=True))
- except Exception as e:
- print(f"Warning: query memory summary failed: {e}")
- with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
- f.write(
- json.dumps(
- {
- "success": 1,
- "_peak_memory": round(used_gb, 2),
- "_peak_process_memory": usage,
- }
- )
- )
- if __name__ == "__main__":
- main()
|