data_analyzer.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import sys
  6. from collections import defaultdict
  7. import csv
  8. import time
  9. from multiprocessing import Process, Manager
  10. import numpy as np
  11. import torch
  12. from torch.utils.data import BatchSampler, SequentialSampler, DataLoader, Subset
  13. import deepspeed.comm as dist
  14. from deepspeed.utils import logger
  15. from deepspeed.runtime.data_pipeline.data_sampling.indexed_dataset import MMapIndexedDataset, valid_dtypes
  16. from deepspeed.runtime.data_pipeline.data_sampling.utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype
  17. class DataAnalyzer(object):
  18. def __init__(self,
  19. dataset,
  20. num_workers=1,
  21. worker_id=0,
  22. num_threads=1,
  23. num_threads_reduce=1,
  24. specific_threads=[],
  25. batch_size=1,
  26. metric_names=[],
  27. metric_functions=[],
  28. metric_types=[],
  29. metric_dtypes=[],
  30. save_path="./",
  31. collate_fn=None,
  32. custom_map_init=None,
  33. custom_map_update=None,
  34. custom_map_finalize=None,
  35. custom_reduce=None,
  36. sample_indices=None):
  37. super().__init__()
  38. self.dataset = dataset
  39. self.num_workers = num_workers
  40. self.worker_id = worker_id
  41. self.num_threads = num_threads
  42. self.num_threads_reduce = num_threads_reduce
  43. self.specific_threads = specific_threads
  44. self.batch_size = batch_size
  45. self.metric_names = metric_names
  46. self.metric_functions = metric_functions
  47. self.metric_types = metric_types
  48. self.metric_dtypes = metric_dtypes
  49. self.save_path = save_path
  50. self.collate_fn = collate_fn
  51. self.custom_map_init = custom_map_init
  52. self.custom_map_update = custom_map_update
  53. self.custom_map_finalize = custom_map_finalize
  54. self.custom_reduce = custom_reduce
  55. self.sample_indices = sample_indices
  56. def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id):
  57. metric_results = []
  58. for m_idx in range(len(metric_names)):
  59. metric_name, metric_type, metric_dtype = metric_names[m_idx], \
  60. metric_types[m_idx], metric_dtypes[m_idx]
  61. assert metric_dtype in valid_dtypes, f"metric_dtype {metric_dtype} not supported. Supported dtypes {valid_dtypes}"
  62. metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/"
  63. os.makedirs(metric_save_path, exist_ok=True)
  64. if metric_type == 'single_value_per_sample':
  65. sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
  66. sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_dtype)
  67. metric_to_sample_fname = f"{metric_save_path}/{metric_name}_metric_to_sample"
  68. os.system(f"rm -rf {metric_to_sample_fname}*")
  69. metric_to_sample_dict = defaultdict(list)
  70. metric_results.append({
  71. "sample_to_metric_fname": sample_to_metric_fname,
  72. "sample_to_metric_builder": sample_to_metric_builder,
  73. "metric_to_sample_fname": metric_to_sample_fname,
  74. "metric_to_sample_dict": metric_to_sample_dict
  75. })
  76. elif metric_type == 'accumulate_value_over_samples':
  77. metric_value = None
  78. metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
  79. metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname})
  80. return metric_results
  81. def update_metric_results(self,
  82. data,
  83. metric_types,
  84. metric_dtypes,
  85. metric_functions,
  86. metric_results,
  87. batch_start_idx=0):
  88. for m_idx in range(len(metric_types)):
  89. metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \
  90. metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx]
  91. metric_values = metric_function(data)
  92. assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \
  93. "metric_function must return a tensor or array"
  94. assert metric_values.dtype == metric_dtype, \
  95. f"metric_function result dtype {metric_values.dtype} does not match metric_dtype {metric_dtype}"
  96. if isinstance(metric_values, np.ndarray):
  97. metric_values = torch.from_numpy(metric_values)
  98. if metric_type == 'single_value_per_sample':
  99. for row in range(metric_values.size()[0]):
  100. sample_idx = batch_start_idx + row # sample idx following dataset iteration order
  101. if isinstance(data, dict) and 'index' in data: # Megatron use case, idx provided in 'index' field
  102. sample_idx = data['index'][row][0].item()
  103. elif self.sample_indices is not None: # user defined shuffling of indices
  104. sample_idx = self.sample_indices[sample_idx]
  105. metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1))
  106. metric_result["metric_to_sample_dict"][metric_values[row].item()].append(sample_idx)
  107. for m_value in metric_result["metric_to_sample_dict"]:
  108. if len(metric_result["metric_to_sample_dict"][m_value]) > 100:
  109. metric_fname = metric_result["metric_to_sample_fname"]
  110. with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
  111. writer = csv.writer(f)
  112. writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
  113. metric_result["metric_to_sample_dict"][m_value] = []
  114. elif metric_type == 'accumulate_value_over_samples':
  115. if metric_result["metric_value"] is None:
  116. metric_result["metric_value"] = metric_values
  117. else:
  118. metric_result["metric_value"].add_(metric_values)
  119. def finalize_metric_results(self, metric_types, metric_dtypes, metric_results):
  120. for m_idx in range(len(metric_types)):
  121. metric_type, metric_dtype, metric_result = metric_types[m_idx], \
  122. metric_dtypes[m_idx], metric_results[m_idx]
  123. if metric_type == 'single_value_per_sample':
  124. metric_fname = metric_result["sample_to_metric_fname"]
  125. close_mmap_dataset_builder(metric_result["sample_to_metric_builder"], metric_fname)
  126. for m_value in metric_result["metric_to_sample_dict"]:
  127. if len(metric_result["metric_to_sample_dict"][m_value]) > 0:
  128. metric_fname = metric_result["metric_to_sample_fname"]
  129. with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
  130. writer = csv.writer(f)
  131. writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
  132. metric_result["metric_to_sample_dict"][m_value] = []
  133. elif metric_type == 'accumulate_value_over_samples':
  134. if metric_result["metric_value"] is not None:
  135. metric_value_builder = create_mmap_dataset_builder(metric_result["metric_value_fname"],
  136. metric_dtype)
  137. metric_value_builder.add_item(metric_result["metric_value"].reshape(-1))
  138. close_mmap_dataset_builder(metric_value_builder, metric_result["metric_value_fname"])
  139. def run_map_helper(self, thread_id):
  140. start_idx, end_idx = self.thread_splits[thread_id][0], \
  141. self.thread_splits[thread_id][1]
  142. logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \
  143. f"on data subset {start_idx} to {end_idx}")
  144. thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx)))
  145. sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False)
  146. iterator = iter(
  147. DataLoader(thread_dataset,
  148. batch_sampler=sampler,
  149. num_workers=0,
  150. collate_fn=self.collate_fn,
  151. pin_memory=False))
  152. if self.custom_map_init is None:
  153. metric_results = self.init_metric_results(thread_id, self.metric_names, self.metric_types,
  154. self.metric_dtypes, self.save_path, self.worker_id)
  155. else:
  156. metric_results = self.custom_map_init(thread_id, self.metric_names, self.metric_types, self.metric_dtypes,
  157. self.save_path, self.worker_id)
  158. total_sample = len(thread_dataset)
  159. processed_sample = 0
  160. start = time.time()
  161. while True:
  162. try:
  163. data = next(iterator)
  164. batch_start_idx = start_idx + processed_sample
  165. if self.custom_map_update is None:
  166. self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions,
  167. metric_results, batch_start_idx)
  168. else:
  169. self.custom_map_update(data, self.metric_types, self.metric_dtypes, self.metric_functions,
  170. metric_results, batch_start_idx)
  171. processed_sample += len(data)
  172. duration = (time.time() - start) / 3600.0
  173. remain_duration = duration * total_sample / processed_sample - duration
  174. logger.info(
  175. f"worker {self.worker_id} thread {thread_id}: {processed_sample} " \
  176. f"out of {total_sample} processed in {duration:.2f} hr, " \
  177. f"estimated to finish in {remain_duration:.2f} hr")
  178. except StopIteration:
  179. logger.info(f"worker {self.worker_id} thread {thread_id}: reach end of file")
  180. break
  181. if self.custom_map_finalize is None:
  182. self.finalize_metric_results(self.metric_types, self.metric_dtypes, metric_results)
  183. else:
  184. self.custom_map_finalize(self.metric_types, self.metric_dtypes, metric_results)
  185. logger.info(f"worker {self.worker_id} thread {thread_id}: finished")
  186. def run_map(self):
  187. self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id,
  188. self.num_threads)
  189. if len(self.specific_threads) > 0:
  190. threads_to_run = self.specific_threads
  191. else:
  192. threads_to_run = list(range(self.num_threads))
  193. if self.num_threads > 1:
  194. p = []
  195. for thread in threads_to_run:
  196. p.append(Process(target=self.run_map_helper, args=(thread, )))
  197. p[thread].start()
  198. for thread in threads_to_run:
  199. p[thread].join()
  200. else:
  201. assert self.num_threads == 1
  202. self.run_map_helper(0)
  203. def get_metric_value_percentiles(self, metric_name, num_sample_per_value, total_num_samples):
  204. logger.info(f"Checking the value percentiles of metric {metric_name}...")
  205. processed_samples = 0
  206. current_percentile = 5
  207. for key in sorted(num_sample_per_value.keys()):
  208. processed_samples += num_sample_per_value[key]
  209. if processed_samples >= total_num_samples * current_percentile / 100.0:
  210. logger.info(f"Metric {metric_name} {current_percentile}th percentile: {key}")
  211. current_percentile += 5
  212. def merge_gather_map_stats(self, num_workers, num_threads, num_threads_reduce, t_idx_reduce, metric_save_path,
  213. metric_name, return_dict):
  214. results = []
  215. for w_idx in range(num_workers):
  216. for t_idx in range(num_threads):
  217. if (w_idx * num_threads + t_idx) % num_threads_reduce == t_idx_reduce:
  218. w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
  219. w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
  220. w_sample_to_metric = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
  221. unique_v = list(np.unique(w_sample_to_metric))
  222. sample_to_metric_count = len(w_sample_to_metric)
  223. logger.info(f"Finished gathering map stats from worker {w_idx} thread {t_idx}.")
  224. results.append([unique_v, sample_to_metric_count])
  225. return_dict[t_idx_reduce] = results
  226. def merge_sample_to_metric(self, t_idx_reduce, metric_save_path, metric_name, metric_value_dtype,
  227. map_worker_thread):
  228. sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
  229. sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
  230. for w_t in map_worker_thread:
  231. w_metric_save_path = f"{metric_save_path}/worker{w_t[0]}_thread{w_t[1]}/"
  232. w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
  233. w_data = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
  234. for row in range(len(w_data)):
  235. sample_to_metric_builder.add_item(torch.tensor(w_data[row].astype(np.int64), dtype=torch.long))
  236. logger.info(f"Finished merge_sample_to_metric from worker {w_t[0]} thread {w_t[1]}.")
  237. close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
  238. def merge_metric_to_sample(self, t_idx_reduce, metric_save_path, metric_name, sample_idx_dtype, metric_value_dtype,
  239. unique_metric_values, num_workers, num_threads):
  240. index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
  241. index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
  242. index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
  243. index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
  244. for unique_v in unique_metric_values:
  245. samples = []
  246. for w_idx in range(num_workers):
  247. for t_idx in range(num_threads):
  248. w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
  249. w_metric_to_sample_fname = f"{w_metric_save_path}/{metric_name}_metric_to_sample_{unique_v}.csv"
  250. if os.path.isfile(w_metric_to_sample_fname):
  251. with open(w_metric_to_sample_fname, 'r') as f:
  252. datareader = csv.reader(f)
  253. for row in datareader:
  254. samples += [int(x) for x in row]
  255. index_to_sample_builder.add_item(torch.tensor(samples, dtype=torch.long))
  256. index_to_metric_builder.add_item(torch.tensor([unique_v], dtype=torch.long))
  257. logger.info(f"Finished reducing metric {metric_name} value {unique_v}.")
  258. close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
  259. close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
  260. def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_workers, num_threads,
  261. num_threads_reduce):
  262. total_num_samples = len(dataset)
  263. sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1)
  264. logger.info(
  265. f"Total number of data samples: {total_num_samples}. Will use {sample_idx_dtype} to store the sample indexes."
  266. )
  267. for m_idx in range(len(metric_names)):
  268. metric_name, metric_type = metric_names[m_idx], metric_types[m_idx]
  269. if metric_type == 'single_value_per_sample':
  270. metric_save_path = f"{save_path}/{metric_name}/"
  271. sample_to_metric_count = 0
  272. unique_metric_values = set([])
  273. manager = Manager()
  274. return_dict = manager.dict()
  275. p = []
  276. for t_idx_reduce in range(num_threads_reduce):
  277. p.append(
  278. Process(target=self.merge_gather_map_stats,
  279. args=(
  280. num_workers,
  281. num_threads,
  282. num_threads_reduce,
  283. t_idx_reduce,
  284. metric_save_path,
  285. metric_name,
  286. return_dict,
  287. )))
  288. p[t_idx_reduce].start()
  289. for t_idx_reduce in range(num_threads_reduce):
  290. p[t_idx_reduce].join()
  291. for t_idx_reduce in range(num_threads_reduce):
  292. results = return_dict[t_idx_reduce]
  293. for res in results:
  294. unique_metric_values = unique_metric_values.union(set(res[0]))
  295. sample_to_metric_count += res[1]
  296. value_max = max(unique_metric_values)
  297. value_min = min(unique_metric_values)
  298. assert sample_to_metric_count == total_num_samples, "The number of samples in map result files are not correct. It's possible that some map worker didn't finish successfully."
  299. metric_value_dtype = find_fit_int_dtype(value_min, value_max)
  300. logger.info(
  301. f"Metric {metric_name} has values between {value_min} and {value_max}. Will use {metric_value_dtype} to store the metric values."
  302. )
  303. # sample_to_metric
  304. map_worker_thread = []
  305. for w_idx in range(num_workers):
  306. for t_idx in range(num_threads):
  307. map_worker_thread.append([w_idx, t_idx])
  308. thread_splits = split_index(0, len(map_worker_thread), num_threads_reduce)
  309. p = []
  310. for t_idx_reduce in range(num_threads_reduce):
  311. start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
  312. p.append(
  313. Process(target=self.merge_sample_to_metric,
  314. args=(
  315. t_idx_reduce,
  316. metric_save_path,
  317. metric_name,
  318. metric_value_dtype,
  319. map_worker_thread[start_idx:end_idx],
  320. )))
  321. p[t_idx_reduce].start()
  322. for t_idx_reduce in range(num_threads_reduce):
  323. p[t_idx_reduce].join()
  324. sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
  325. sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
  326. for t_idx_reduce in range(num_threads_reduce):
  327. chunk_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
  328. logger.info(f"Merging file {chunk_fname}")
  329. sample_to_metric_builder.merge_file_(chunk_fname)
  330. close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
  331. sample_to_metric = MMapIndexedDataset(sample_to_metric_fname, skip_warmup=True)
  332. assert len(sample_to_metric) == total_num_samples
  333. # metric_to_sample
  334. unique_metric_values = list(sorted(unique_metric_values))
  335. thread_splits = split_index(0, len(unique_metric_values), num_threads_reduce)
  336. p = []
  337. for t_idx_reduce in range(num_threads_reduce):
  338. start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
  339. p.append(
  340. Process(target=self.merge_metric_to_sample,
  341. args=(
  342. t_idx_reduce,
  343. metric_save_path,
  344. metric_name,
  345. sample_idx_dtype,
  346. metric_value_dtype,
  347. unique_metric_values[start_idx:end_idx],
  348. num_workers,
  349. num_threads,
  350. )))
  351. p[t_idx_reduce].start()
  352. for t_idx_reduce in range(num_threads_reduce):
  353. p[t_idx_reduce].join()
  354. index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample"
  355. index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
  356. index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric"
  357. index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
  358. for t_idx_reduce in range(num_threads_reduce):
  359. chunk_is_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
  360. logger.info(f"Merging file {chunk_is_fname}")
  361. index_to_sample_builder.merge_file_(chunk_is_fname)
  362. chunk_im_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
  363. logger.info(f"Merging file {chunk_im_fname}")
  364. index_to_metric_builder.merge_file_(chunk_im_fname)
  365. close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
  366. close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
  367. num_sample_per_value = DataAnalyzer.output_index_to_sample_percentile(
  368. index_to_sample_fname, index_to_metric_fname, metric_name, metric_save_path, total_num_samples,
  369. sample_idx_dtype)
  370. self.get_metric_value_percentiles(metric_name, num_sample_per_value, total_num_samples)
  371. elif metric_type == 'accumulate_value_over_samples':
  372. metric_save_path = f"{save_path}/{metric_name}/"
  373. metric_value = None
  374. for w_idx in range(num_workers):
  375. for t_idx in range(num_threads):
  376. w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
  377. w_metric_value_fname = f"{w_metric_save_path}/{metric_name}_metric_value"
  378. w_metric_value = MMapIndexedDataset(w_metric_value_fname, skip_warmup=True)
  379. if metric_value is None:
  380. metric_value = np.copy(w_metric_value[0])
  381. else:
  382. metric_value += np.copy(w_metric_value[0])
  383. value_max = int(max(metric_value))
  384. value_min = int(min(metric_value))
  385. metric_value_dtype = find_fit_int_dtype(value_min, value_max)
  386. metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
  387. metric_value_builder = create_mmap_dataset_builder(metric_value_fname, metric_value_dtype)
  388. metric_value_builder.add_item(torch.tensor(metric_value.astype(np.int64), dtype=torch.long))
  389. close_mmap_dataset_builder(metric_value_builder, metric_value_fname)
  390. @staticmethod
  391. def output_index_to_sample_percentile(index_to_sample_fname, index_to_metric_fname, metric_name, metric_save_path,
  392. total_num_samples, sample_idx_dtype):
  393. """ read index_to_metric and index_to_sample files and write distribution to index_to_sample_percentage_merged """
  394. num_sample_per_value = {}
  395. index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True)
  396. index_to_metric = MMapIndexedDataset(index_to_metric_fname, skip_warmup=True)
  397. index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged"
  398. index_to_sample_merged_builder = create_mmap_dataset_builder(index_to_sample_merged_fname, sample_idx_dtype)
  399. for v_idx in range(len(index_to_sample)):
  400. if v_idx > 0:
  401. assert index_to_metric[v_idx] > index_to_metric[v_idx - 1]
  402. num_sample_per_value[index_to_metric[v_idx][0]] = len(index_to_sample[v_idx])
  403. assert sum(list(num_sample_per_value.values())) == total_num_samples
  404. merge_step = max(1, len(index_to_sample) // 100)
  405. for v_idx in range(0, len(index_to_sample), merge_step):
  406. merged_samples = np.copy(
  407. np.concatenate(index_to_sample[v_idx:min(len(index_to_sample), (v_idx + merge_step))], axis=None))
  408. index_to_sample_merged_builder.add_item(torch.tensor(merged_samples.astype(np.int64), dtype=torch.long))
  409. logger.info(f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}.")
  410. close_mmap_dataset_builder(index_to_sample_merged_builder, index_to_sample_merged_fname)
  411. return num_sample_per_value
  412. def run_reduce(self):
  413. if self.custom_reduce is None:
  414. self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path,
  415. self.num_workers, self.num_threads, self.num_threads_reduce)
  416. else:
  417. self.custom_reduce(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers,
  418. self.num_threads, self.num_threads_reduce)
  419. def run_map_reduce(self, comm_group=None):
  420. self.run_map()
  421. # wait for the mapping operation, where all nodes outputs their own (partial) result files
  422. dist.barrier(group=comm_group)
  423. if self.worker_id == 0:
  424. self.run_reduce()
  425. # wait for the reduce, where rank 0 merges all (partial) files. Dataset can then be used by all nodes.
  426. dist.barrier(group=comm_group)
  427. class DistributedDataAnalyzer(object):
  428. def __init__(
  429. self,
  430. dataset,
  431. num_workers=1,
  432. num_threads=1,
  433. worker_id=0,
  434. batch_size=1,
  435. metric_names=[],
  436. metric_functions=[],
  437. metric_types=[],
  438. save_path="./",
  439. collate_fn=None,
  440. device='cuda',
  441. comm_group=None,
  442. sample_indices=None,
  443. ) -> None:
  444. self.dataset = dataset
  445. self.batch_size = batch_size
  446. self.metric_names = metric_names
  447. self.metric_functions = metric_functions
  448. self.metric_types = metric_types
  449. self.save_path = save_path
  450. self.collate_fn = collate_fn
  451. self.device = device
  452. self.sample_indices = sample_indices
  453. self.num_threads = num_threads
  454. self.worker_id = worker_id
  455. if not dist.is_initialized():
  456. dist.init_distributed()
  457. # comm_group and worker_id+num_workers are mutually exclusive
  458. self.comm_group = comm_group
  459. if self.comm_group is None:
  460. # self.comm_group = deepspeed.utils.groups._clone_world_group()
  461. self.num_workers = num_workers
  462. self.worker_id = worker_id
  463. else:
  464. self.num_workers = self.comm_group.size()
  465. self.worker_id = self.comm_group.rank()
  466. if self.worker_id == 0:
  467. logger.info(f"Distributed data analyzer initialized with {self.num_workers} workers.")
  468. def run_map_helper(self, thread_id=0, metric_queues=None):
  469. thread_start_idx, thread_end_idx = self.thread_splits[thread_id][0], self.thread_splits[thread_id][1]
  470. worker_dataset = Subset(self.dataset, list(range(thread_start_idx, thread_end_idx)))
  471. sampler = BatchSampler(SequentialSampler(worker_dataset), batch_size=self.batch_size, drop_last=False)
  472. dataloader = DataLoader(dataset=worker_dataset,
  473. batch_sampler=sampler,
  474. num_workers=0,
  475. collate_fn=self.collate_fn,
  476. pin_memory=False)
  477. # set initial results list
  478. metric_results = []
  479. for metric_type in self.metric_types:
  480. assert metric_type in ['single_value_per_sample', 'accumulate_value_over_samples'], \
  481. f"metric_type {metric_type} not implemented."
  482. metric_results.append([] if metric_type == 'single_value_per_sample' else None)
  483. # iterate dataloader and store metric results
  484. batch_start_idx = thread_start_idx
  485. for data in dataloader:
  486. for m_idx in range(len(self.metric_names)):
  487. metric_type, metric_function = self.metric_types[m_idx], self.metric_functions[m_idx]
  488. metric_values = metric_function(data)
  489. assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \
  490. "metric_function must return a tensor or array"
  491. if isinstance(metric_values, np.ndarray):
  492. metric_values = torch.from_numpy(metric_values)
  493. assert metric_values.dtype in valid_dtypes, \
  494. f"metric_function result dtype {metric_values.dtype} not supported. Supported dtypes {valid_dtypes}"
  495. if metric_type == 'single_value_per_sample':
  496. for row in range(metric_values.size()[0]):
  497. value = metric_values[row].item()
  498. sample_idx = batch_start_idx + row # sample idx following dataset iteration order
  499. if isinstance(data, dict) and 'index' in data: # Megatron use case
  500. sample_idx = data['index'][row][0].item()
  501. elif self.sample_indices is not None: # user defined shuffling of indices
  502. sample_idx = self.sample_indices[sample_idx]
  503. metric_results[m_idx].append((value, sample_idx))
  504. elif metric_type == 'accumulate_value_over_samples':
  505. if metric_results[m_idx] is None:
  506. metric_results[m_idx] = metric_values
  507. else:
  508. metric_results[m_idx].add_(metric_values)
  509. batch_start_idx += len(data)
  510. if self.num_threads == 1:
  511. return metric_results
  512. # copy metric_results to the shared queue
  513. assert metric_queues
  514. for m_idx in range(len(self.metric_names)):
  515. results = metric_results[m_idx]
  516. if torch.is_tensor(results):
  517. results = results.item() if results.dim() == 0 else results.tolist()
  518. try:
  519. metric_queues[m_idx].put((thread_id, results))
  520. except Exception as e:
  521. logger.error(f"Error putting metric results to queue: {e}")
  522. sys.exit(1)
  523. def run_map_reduce(self):
  524. # setup individual dataloaders
  525. self.worker_splits, self.thread_splits = split_dataset(self.dataset,
  526. self.num_workers,
  527. self.worker_id,
  528. num_threads=self.num_threads)
  529. node_start_idx, node_end_idx = self.worker_splits[self.worker_id]
  530. logger.info(f"worker {self.worker_id} working on data subset {node_start_idx} to {node_end_idx}.")
  531. if self.num_threads in [0, 1, None]:
  532. metric_results = self.run_map_helper()
  533. metric_results = [torch.tensor(m).to(self.device) for m in metric_results]
  534. else:
  535. # create a shared queue of results per metric to be populated by individual threads
  536. with Manager() as manager:
  537. metric_queues = [manager.Queue() for _ in self.metric_names]
  538. threads = [
  539. Process(target=self.run_map_helper, args=(t, metric_queues)) for t in range(self.num_threads)
  540. ]
  541. for thread in threads:
  542. thread.start()
  543. for thread in threads:
  544. thread.join()
  545. # gather results from shared queues into metric_results
  546. metric_results = [None for _ in self.metric_names]
  547. for m_idx, (queue, metric_type) in enumerate(zip(metric_queues, self.metric_types)):
  548. while not queue.empty():
  549. t_idx, t_results = queue.get()
  550. t_start_idx, t_end_idx = self.thread_splits[t_idx]
  551. if t_start_idx >= t_end_idx: # no results from this thread
  552. continue #corner case for small datasets and high thread count
  553. t_results = torch.tensor(t_results)
  554. if metric_type == 'single_value_per_sample':
  555. # add thread results to the metric_results list, ordered by thread idx
  556. if metric_results[m_idx] is None: # initialize if needed
  557. metric_results[m_idx] = torch.zeros(node_end_idx - node_start_idx,
  558. t_results.size(1)).to(self.device)
  559. metric_results[m_idx][t_start_idx - node_start_idx:t_end_idx - node_start_idx] = t_results
  560. else:
  561. if metric_results[m_idx] is None: # initialize if needed
  562. metric_results[m_idx] = torch.zeros(t_results.size()).to(self.device)
  563. metric_results[m_idx].add_(t_results)
  564. # compute dtype for sample ids
  565. total_num_samples = len(self.dataset)
  566. sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1)
  567. logger.info(f"Total number of data samples: {total_num_samples}.")
  568. logger.info(f"Will use {sample_idx_dtype} to store the sample indexes.")
  569. for m_idx in range(len(self.metric_names)):
  570. metric_values, metric_name, metric_type = \
  571. metric_results[m_idx], self.metric_names[m_idx], self.metric_types[m_idx]
  572. metric_save_path = f"{self.save_path}/{metric_name}/"
  573. os.makedirs(metric_save_path, exist_ok=True)
  574. if metric_type == 'single_value_per_sample':
  575. # Compute sample and metric value dtypes based on range
  576. values, samples = metric_values[:, 0], metric_values[:, 1]
  577. value_min, value_max = Dist.min_max(values, self.comm_group)
  578. sample_min, sample_max = Dist.min_max(samples, self.comm_group)
  579. metric_value_dtype = find_fit_int_dtype(value_min, value_max)
  580. sample_value_dtype = find_fit_int_dtype(sample_min, sample_max)
  581. # sample_to_metric maps sample ids to metric values, as a list of metric values
  582. sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
  583. values = [torch.tensor([x]) for x in metric_values[:, 0]]
  584. self.file_write_ordered(values, sample_to_metric_fname, metric_value_dtype)
  585. # distributed sorting by values, gives an ordered disjoint subset of keys on nodes
  586. metric_values = Dist.sample_sort(metric_values, self.comm_group, self.num_workers)
  587. metric_to_samples_dict = {}
  588. if len(metric_values) > 0:
  589. for value, sample in metric_values:
  590. if value.item() not in metric_to_samples_dict:
  591. metric_to_samples_dict[value.item()] = []
  592. metric_to_samples_dict[value.item()].append(sample.item())
  593. # index_to_metric and index_to_sample serialize a dicitonary from metric to samples
  594. # index_to_metric stores a key per row, index_to_sample stores the values per row
  595. values = [torch.tensor([x]) for x in metric_to_samples_dict.keys()]
  596. samples = [torch.tensor(metric_to_samples_dict[x]) for x in metric_to_samples_dict.keys()]
  597. index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric" #dict keys
  598. index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample" #dict values
  599. self.file_write_ordered(values, index_to_metric_fname, metric_value_dtype)
  600. self.file_write_ordered(samples, index_to_sample_fname, sample_value_dtype)
  601. if self.worker_id == 0:
  602. DataAnalyzer.output_index_to_sample_percentile(index_to_sample_fname, index_to_metric_fname,
  603. metric_name, metric_save_path, total_num_samples,
  604. sample_idx_dtype)
  605. dist.barrier(self.comm_group)
  606. elif metric_type == 'accumulate_value_over_samples':
  607. metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
  608. dist.reduce(metric_values, dst=0, op=dist.ReduceOp.SUM, group=self.comm_group)
  609. metric_value_dtype = find_fit_int_dtype(metric_values.min(), metric_values.max())
  610. if self.worker_id == 0:
  611. builder = create_mmap_dataset_builder(metric_value_fname, metric_value_dtype)
  612. builder.add_item(metric_values.cpu())
  613. close_mmap_dataset_builder(builder, metric_value_fname)
  614. dist.barrier(self.comm_group)
  615. def file_write_ordered(self, tensor_list, fname, numpy_dtype):
  616. """ MPI_file_write_ordered extended to write a list of tensors, by one rank, iteratively """
  617. # each node has a list of rows (tensors) to be written to the file.
  618. # we will serialize it in order to communicate it in one comm step.
  619. tkwargs = dict(dtype=torch.int64, device=self.device)
  620. # 1. gather on rank 0 the number of rows to be sent/recv
  621. row_count = torch.tensor([len(tensor_list)], **tkwargs)
  622. row_counts = torch.zeros(self.num_workers, **tkwargs)
  623. dist.all_gather_into_tensor(row_counts, row_count, group=self.comm_group)
  624. assert row_counts[self.worker_id] == row_count == len(tensor_list), "all_gather failed"
  625. # 2. gather on rank 0 the sizes of the rows to be sent/recv
  626. row_len = torch.tensor([len(l) for l in tensor_list], **tkwargs)
  627. row_lens = Dist.gather_v(row_len, 0, self.comm_group, self.num_workers, self.worker_id)
  628. # 4. gather on rank 0 of the total size (sum of all row lengths) to be received
  629. size = torch.tensor([sum(row_len).item()], **tkwargs)
  630. sizes = torch.zeros(self.num_workers, **tkwargs)
  631. dist.all_gather_into_tensor(sizes, size, group=self.comm_group)
  632. assert sizes[self.worker_id] == size.item(), "all_gather did not return the same sizes" #sanity check
  633. # method to deserializes a buffer into rows of different lengths and write them to file
  634. def write_buffer_to_file(buff, src, builder):
  635. assert self.worker_id == 0, "only rank 0 can write to file"
  636. # collect all buffers and write them at once
  637. buff = buff.cpu().detach().numpy()
  638. row_offsets = np.cumsum([0] + row_lens[src].tolist())
  639. arr_list = []
  640. for i in range(len(row_lens[src])):
  641. arr_list.append(buff[row_offsets[i]:row_offsets[i + 1]])
  642. builder.add_items(arr_list)
  643. # 5. rank 0 prepares output folder and file
  644. if self.worker_id == 0:
  645. os.makedirs(os.path.dirname(fname), exist_ok=True)
  646. builder = create_mmap_dataset_builder(fname, numpy_dtype)
  647. # iterate through ranks that have data to be sent/recv/written
  648. for src in [rank for rank, count in enumerate(row_counts) if count > 0]:
  649. dist.barrier(group=self.comm_group)
  650. if self.worker_id == 0 and src == 0: # rank 0's write its own data
  651. buffer = torch.cat(tensor_list, dim=0).to(self.device)
  652. write_buffer_to_file(buffer, 0, builder)
  653. elif self.worker_id == 0 and src > 0: # rank 0 receives other rank's data and writes it
  654. buffer = torch.empty(sizes[src].item(), dtype=buffer.dtype, device=buffer.device)
  655. err = dist.recv(buffer, src=src, group=self.comm_group, tag=src)
  656. assert err == src and len(buffer) > 0, "recv failed"
  657. write_buffer_to_file(buffer, src, builder)
  658. elif self.worker_id == src: # current rank sends data to rank 0
  659. buffer = torch.cat(tensor_list, dim=0).to(self.device)
  660. dist.send(buffer, 0, group=self.comm_group, tag=src)
  661. # rank 0 closes the file
  662. if self.worker_id == 0:
  663. close_mmap_dataset_builder(builder, fname) # close file
  664. dist.barrier(self.comm_group)
  665. class Dist:
  666. """ auxiliary class to perform distributed operations on tensors"""
  667. @staticmethod
  668. def min_max(tensor, comm_group):
  669. """ given a distributed tensor, return the min/max values across all ranks"""
  670. value_min, value_max = tensor.min(), tensor.max()
  671. dist.reduce(value_min, 0, op=dist.ReduceOp.MIN, group=comm_group)
  672. dist.reduce(value_max, 0, op=dist.ReduceOp.MAX, group=comm_group)
  673. return value_min.item(), value_max.item()
  674. @staticmethod
  675. def gather_v(tensor, dst, comm_group, num_workers, worker_id):
  676. """ MPI_Gatherv. gather tensors of variable sizes in a single rank """
  677. # gather the number of rows to be sent/recv
  678. size = torch.tensor([len(tensor)], dtype=torch.int64, device=tensor.device)
  679. sizes = torch.zeros(num_workers, dtype=torch.int64, device=tensor.device)
  680. dist.all_gather_into_tensor(sizes, size, group=comm_group)
  681. assert sizes[worker_id] == size, "all_gather failed"
  682. # all_gather requires all tensors to be of same size so we need to pad them
  683. max_size = max(sizes).item()
  684. buffer = torch.empty(max_size, dtype=tensor.dtype, device=tensor.device)
  685. buffer[0:size] = tensor.data
  686. buffer_list = None
  687. if worker_id == 0: # create padded recv buffers
  688. buffer_list = [torch.empty(max_size, dtype=tensor.dtype, device=tensor.device) for _ in range(num_workers)]
  689. dist.gather(buffer, buffer_list, dst=dst, group=comm_group)
  690. # revert padding and return value
  691. if worker_id == 0:
  692. buffer_list = [r[:s.item()] for r, s in zip(buffer_list, sizes)]
  693. return buffer_list
  694. @staticmethod
  695. def sample_sort(tensor, comm_group, num_workers, n_samples=100):
  696. """ perform a distributed random sort of a tensor, and returns the sorted partial tensor"""
  697. device, dims = tensor.device, tensor.size()[1]
  698. # 1 - sort rows by first column, then second column, then third, etc...
  699. tensor = torch.tensor(sorted(tensor.tolist()), dtype=tensor.dtype, device=tensor.device)
  700. # 2 - collect few samples per rank
  701. idx = torch.round(torch.linspace(0, len(tensor) - 1, n_samples)).to(int)
  702. samples = tensor[idx][:, 0].contiguous().to(device) #only first column, all but last row
  703. # 2 - Allgather samples
  704. all_samples = [torch.zeros(n_samples, dtype=samples.dtype, device=device) for _ in range(num_workers)]
  705. dist.all_gather(all_samples, samples, group=comm_group)
  706. all_samples = torch.cat(all_samples, dim=0).to(device)
  707. # 3 - Sort all samples and collect the ranges of each rank as equidistant
  708. all_samples = all_samples.sort()[0]
  709. idx = torch.round(torch.linspace(0, len(all_samples) - 1, num_workers + 1)).to(int)
  710. ranges = all_samples[idx] # range of each rank r as ranges[r] <= x < ranges[r+1]
  711. ranges[-1] += 1 # increase upper limit of last rank so that x < ranges[r+1].
  712. # 4 - collect elements to send to each rank, based on the rank ranges
  713. send = []
  714. for rank in range(num_workers):
  715. mask = (tensor[:, 0] >= ranges[rank]) & (tensor[:, 0] < ranges[rank + 1])
  716. send.append(tensor[mask])
  717. # 5. all to all to communicate the sizes to be sent/recv
  718. send_count = [torch.tensor([len(s) * dims], dtype=torch.int64, device=device) for s in send]
  719. recv_count = list(torch.empty([num_workers], dtype=torch.int64, device=device).chunk(num_workers))
  720. dist.all_to_all(recv_count, send_count, group=comm_group)
  721. # 6. all-to-all-v to communicate the elements to be sent/recv as a single tensor
  722. send = torch.cat(send, dim=0).flatten().to(device)
  723. recv = torch.zeros(sum(recv_count), dtype=send.dtype).to(device)
  724. send_count = [s.item() for s in send_count] # convert to list of ints
  725. recv_count = [r.item() for r in recv_count]
  726. dist.all_to_all_single(recv, send, recv_count, send_count, group=comm_group)
  727. del send
  728. # 7. the received tensor is the 1D disjoint subset of the distributed tensor.
  729. # We will recover the original dimensionality and sort it by columns again.
  730. recv = recv.view(-1, dims)
  731. recv = torch.tensor(sorted(recv.tolist()), dtype=recv.dtype, device=recv.device)
  732. return recv
  733. def test_compare_both_data_analyzers(dataset):
  734. """ given a dataset, compare file and memory based data analyser"""
  735. id = lambda t: t.to(torch.int64) # identity
  736. batch_sum = lambda t: id(t).sum() #sum batch
  737. num_threads = 4
  738. kwargs = dict(
  739. dataset=dataset,
  740. batch_size=2**10,
  741. worker_id=int(os.environ['RANK']),
  742. num_workers=int(os.environ['WORLD_SIZE']),
  743. metric_names=["mod", "batch_sum"],
  744. metric_functions=[id, batch_sum],
  745. metric_types=['single_value_per_sample', 'accumulate_value_over_samples'],
  746. num_threads=num_threads,
  747. )
  748. dda = DistributedDataAnalyzer(
  749. save_path="./output_dist",
  750. device=f"cuda:{int(os.environ['LOCAL_RANK'])}",
  751. **kwargs,
  752. )
  753. start_time = time.time()
  754. dda.run_map_reduce()
  755. if dda.worker_id == 0:
  756. print("DistributedDataAnalyzer runtime: %s seconds " % (time.time() - start_time))
  757. da = DataAnalyzer(num_threads_reduce=num_threads,
  758. save_path="./output_disk",
  759. metric_dtypes=[torch.int64, torch.int64],
  760. **kwargs)
  761. start_time = time.time()
  762. da.run_map_reduce()
  763. if da.worker_id == 0:
  764. print("DataAnalyzer runtime: %s seconds " % (time.time() - start_time))
  765. output_paths = [
  766. "batch_sum/batch_sum_metric_value.bin", "batch_sum/batch_sum_metric_value.idx", \
  767. "mod/mod_index_to_metric.bin", "mod/mod_index_to_metric.idx", \
  768. "mod/mod_index_to_sample.bin", "mod/mod_index_to_sample.idx", \
  769. "mod/mod_index_to_sample_percentile_merged.bin", "mod/mod_index_to_sample_percentile_merged.idx", \
  770. "mod/mod_sample_to_metric.bin", "mod/mod_sample_to_metric.idx"
  771. ]
  772. if dda.worker_id == 0:
  773. for path in output_paths:
  774. with open(os.path.join(da.save_path, path), 'rb') as f1, \
  775. open(os.path.join(dda.save_path, path), 'rb') as f2:
  776. if f1.read() != f2.read():
  777. print(f"files {path} are not identical.")
  778. if __name__ == "__main__":
  779. class TestDataset(torch.utils.data.Dataset):
  780. def __init__(self, size=10_000_000):
  781. self.values = [(x + 7) % 10_000 for x in range(size)]
  782. self.size = size
  783. __len__ = lambda self: self.size
  784. __getitem__ = lambda self, idx: self.values[idx]
  785. test_compare_both_data_analyzers(TestDataset())