data_analyzer.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. from collections import defaultdict
  6. import csv
  7. import time
  8. from multiprocessing import Process, Manager
  9. import numpy as np
  10. import torch
  11. from torch.utils.data import BatchSampler, SequentialSampler, DataLoader, Subset
  12. from deepspeed.utils import logger
  13. from .indexed_dataset import MMapIndexedDataset
  14. from .utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype
  15. class DataAnalyzer(object):
  16. def __init__(self,
  17. dataset,
  18. num_workers=1,
  19. worker_id=0,
  20. num_threads=1,
  21. num_threads_reduce=1,
  22. specific_threads=[],
  23. batch_size=1,
  24. metric_names=[],
  25. metric_functions=[],
  26. metric_types=[],
  27. metric_dtypes=[],
  28. save_path="./",
  29. collate_fn=None,
  30. custom_map_init=None,
  31. custom_map_update=None,
  32. custom_map_finalize=None,
  33. custom_reduce=None):
  34. super().__init__()
  35. self.dataset = dataset
  36. self.num_workers = num_workers
  37. self.worker_id = worker_id
  38. self.num_threads = num_threads
  39. self.num_threads_reduce = num_threads_reduce
  40. self.specific_threads = specific_threads
  41. self.batch_size = batch_size
  42. self.metric_names = metric_names
  43. self.metric_functions = metric_functions
  44. self.metric_types = metric_types
  45. self.metric_dtypes = metric_dtypes
  46. self.save_path = save_path
  47. self.collate_fn = collate_fn
  48. self.custom_map_init = custom_map_init
  49. self.custom_map_update = custom_map_update
  50. self.custom_map_finalize = custom_map_finalize
  51. self.custom_reduce = custom_reduce
  52. def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id):
  53. metric_results = []
  54. for m_idx in range(len(metric_names)):
  55. metric_name, metric_type, metric_dtype = metric_names[m_idx], \
  56. metric_types[m_idx], metric_dtypes[m_idx]
  57. assert metric_dtype not in [
  58. np.float64, np.double
  59. ], "Currently floating point metric values are not supported. Please change your metric into integer values (and potentially multiply a larger coefficient to keep the precision)."
  60. metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/"
  61. os.makedirs(metric_save_path, exist_ok=True)
  62. if metric_type == 'single_value_per_sample':
  63. sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
  64. sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_dtype)
  65. metric_to_sample_fname = f"{metric_save_path}/{metric_name}_metric_to_sample"
  66. os.system(f"rm -rf {metric_to_sample_fname}*")
  67. metric_to_sample_dict = defaultdict(list)
  68. metric_results.append({
  69. "sample_to_metric_fname": sample_to_metric_fname,
  70. "sample_to_metric_builder": sample_to_metric_builder,
  71. "metric_to_sample_fname": metric_to_sample_fname,
  72. "metric_to_sample_dict": metric_to_sample_dict
  73. })
  74. elif metric_type == 'accumulate_value_over_samples':
  75. metric_value = None
  76. metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
  77. metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname})
  78. return metric_results
  79. def update_metric_results(self, data, metric_types, metric_functions, metric_results):
  80. for m_idx in range(len(metric_types)):
  81. metric_type, metric_function, metric_result = metric_types[m_idx], \
  82. metric_functions[m_idx], metric_results[m_idx]
  83. if metric_type == 'single_value_per_sample':
  84. metric_values = metric_function(data)
  85. for row in range(metric_values.size()[0]):
  86. metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1))
  87. metric_result["metric_to_sample_dict"][metric_values[row].item()].append(
  88. data['index'][row][0].item())
  89. for m_value in metric_result["metric_to_sample_dict"]:
  90. if len(metric_result["metric_to_sample_dict"][m_value]) > 100:
  91. metric_fname = metric_result["metric_to_sample_fname"]
  92. with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
  93. writer = csv.writer(f)
  94. writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
  95. metric_result["metric_to_sample_dict"][m_value] = []
  96. elif metric_type == 'accumulate_value_over_samples':
  97. metric_values = metric_function(data)
  98. if metric_result["metric_value"] is None:
  99. metric_result["metric_value"] = metric_values
  100. else:
  101. metric_result["metric_value"].add_(metric_values)
  102. def finalize_metric_results(self, metric_types, metric_dtypes, metric_results):
  103. for m_idx in range(len(metric_types)):
  104. metric_type, metric_dtype, metric_result = metric_types[m_idx], \
  105. metric_dtypes[m_idx], metric_results[m_idx]
  106. if metric_type == 'single_value_per_sample':
  107. metric_fname = metric_result["sample_to_metric_fname"]
  108. close_mmap_dataset_builder(metric_result["sample_to_metric_builder"], metric_fname)
  109. for m_value in metric_result["metric_to_sample_dict"]:
  110. if len(metric_result["metric_to_sample_dict"][m_value]) > 0:
  111. metric_fname = metric_result["metric_to_sample_fname"]
  112. with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
  113. writer = csv.writer(f)
  114. writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
  115. metric_result["metric_to_sample_dict"][m_value] = []
  116. elif metric_type == 'accumulate_value_over_samples':
  117. if metric_result["metric_value"] is not None:
  118. metric_value_builder = create_mmap_dataset_builder(metric_result["metric_value_fname"],
  119. metric_dtype)
  120. metric_value_builder.add_item(metric_result["metric_value"].reshape(-1))
  121. close_mmap_dataset_builder(metric_value_builder, metric_result["metric_value_fname"])
  122. def run_map_helper(self, thread_id):
  123. start_idx, end_idx = self.thread_splits[thread_id][0], \
  124. self.thread_splits[thread_id][1]
  125. logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \
  126. f"on data subset {start_idx} to {end_idx}")
  127. thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx)))
  128. sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False)
  129. if self.collate_fn is None:
  130. iterator = iter(DataLoader(thread_dataset, batch_sampler=sampler, num_workers=0, pin_memory=False))
  131. else:
  132. iterator = iter(
  133. DataLoader(thread_dataset,
  134. batch_sampler=sampler,
  135. num_workers=0,
  136. collate_fn=self.collate_fn,
  137. pin_memory=False))
  138. if self.custom_map_init is None:
  139. metric_results = self.init_metric_results(thread_id, self.metric_names, self.metric_types,
  140. self.metric_dtypes, self.save_path, self.worker_id)
  141. else:
  142. metric_results = self.custom_map_init(thread_id, self.metric_names, self.metric_types, self.metric_dtypes,
  143. self.save_path, self.worker_id)
  144. total_sample = len(thread_dataset)
  145. processed_sample = 0
  146. start = time.time()
  147. while True:
  148. try:
  149. data = next(iterator)
  150. if self.custom_map_update is None:
  151. self.update_metric_results(data, self.metric_types, self.metric_functions, metric_results)
  152. else:
  153. self.custom_map_update(data, self.metric_types, self.metric_functions, metric_results)
  154. processed_sample += self.batch_size
  155. duration = (time.time() - start) / 3600.0
  156. remain_duration = duration * total_sample / processed_sample - duration
  157. logger.info(
  158. f"worker {self.worker_id} thread {thread_id}: {processed_sample} " \
  159. f"out of {total_sample} processed in {duration:.2f} hr, " \
  160. f"estimated to finish in {remain_duration:.2f} hr")
  161. except StopIteration:
  162. logger.info(f"worker {self.worker_id} thread {thread_id}: reach end of file")
  163. break
  164. if self.custom_map_finalize is None:
  165. self.finalize_metric_results(self.metric_types, self.metric_dtypes, metric_results)
  166. else:
  167. self.custom_map_finalize(self.metric_types, self.metric_dtypes, metric_results)
  168. logger.info(f"worker {self.worker_id} thread {thread_id}: finished")
  169. def run_map(self):
  170. self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id,
  171. self.num_threads)
  172. if len(self.specific_threads) > 0:
  173. threads_to_run = self.specific_threads
  174. else:
  175. threads_to_run = list(range(self.num_threads))
  176. if self.num_threads > 1:
  177. p = []
  178. for thread in threads_to_run:
  179. p.append(Process(target=self.run_map_helper, args=(thread, )))
  180. p[thread].start()
  181. for thread in threads_to_run:
  182. p[thread].join()
  183. else:
  184. assert self.num_threads == 1
  185. self.run_map_helper(0)
  186. def get_metric_value_percentiles(self, metric_name, num_sample_per_value, total_num_samples):
  187. logger.info(f"Checking the value percentiles of metric {metric_name}...")
  188. processed_samples = 0
  189. current_percentile = 5
  190. for key in sorted(num_sample_per_value.keys()):
  191. processed_samples += num_sample_per_value[key]
  192. if processed_samples >= total_num_samples * current_percentile / 100.0:
  193. logger.info(f"Metric {metric_name} {current_percentile}th percentile: {key}")
  194. current_percentile += 5
  195. def merge_gather_map_stats(self, num_workers, num_threads, num_threads_reduce, t_idx_reduce, metric_save_path,
  196. metric_name, return_dict):
  197. results = []
  198. for w_idx in range(num_workers):
  199. for t_idx in range(num_threads):
  200. if (w_idx * num_threads + t_idx) % num_threads_reduce == t_idx_reduce:
  201. w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
  202. w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
  203. w_sample_to_metric = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
  204. unique_v = list(np.unique(w_sample_to_metric))
  205. sample_to_metric_count = len(w_sample_to_metric)
  206. logger.info(f"Finished gathering map stats from worker {w_idx} thread {t_idx}.")
  207. results.append([unique_v, sample_to_metric_count])
  208. return_dict[t_idx_reduce] = results
  209. def merge_sample_to_metric(self, t_idx_reduce, metric_save_path, metric_name, metric_value_dtype,
  210. map_worker_thread):
  211. sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
  212. sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
  213. for w_t in map_worker_thread:
  214. w_metric_save_path = f"{metric_save_path}/worker{w_t[0]}_thread{w_t[1]}/"
  215. w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
  216. w_data = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
  217. for row in range(len(w_data)):
  218. sample_to_metric_builder.add_item(torch.tensor(w_data[row].astype(np.int64), dtype=torch.long))
  219. logger.info(f"Finished merge_sample_to_metric from worker {w_t[0]} thread {w_t[1]}.")
  220. close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
  221. def merge_metric_to_sample(self, t_idx_reduce, metric_save_path, metric_name, sample_idx_dtype, metric_value_dtype,
  222. unique_metric_values, num_workers, num_threads):
  223. index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
  224. index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
  225. index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
  226. index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
  227. for unique_v in unique_metric_values:
  228. samples = []
  229. for w_idx in range(num_workers):
  230. for t_idx in range(num_threads):
  231. w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
  232. w_metric_to_sample_fname = f"{w_metric_save_path}/{metric_name}_metric_to_sample_{unique_v}.csv"
  233. if os.path.isfile(w_metric_to_sample_fname):
  234. with open(w_metric_to_sample_fname, 'r') as f:
  235. datareader = csv.reader(f)
  236. for row in datareader:
  237. samples += [int(x) for x in row]
  238. index_to_sample_builder.add_item(torch.tensor(samples, dtype=torch.long))
  239. index_to_metric_builder.add_item(torch.tensor([unique_v], dtype=torch.long))
  240. logger.info(f"Finished reducing metric {metric_name} value {unique_v}.")
  241. close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
  242. close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
  243. def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_workers, num_threads,
  244. num_threads_reduce):
  245. total_num_samples = len(dataset)
  246. sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1)
  247. logger.info(
  248. f"Total number of data samples: {total_num_samples}. Will use {sample_idx_dtype} to store the sample indexes."
  249. )
  250. for m_idx in range(len(metric_names)):
  251. metric_name, metric_type = metric_names[m_idx], metric_types[m_idx]
  252. if metric_type == 'single_value_per_sample':
  253. metric_save_path = f"{save_path}/{metric_name}/"
  254. sample_to_metric_count = 0
  255. unique_metric_values = set([])
  256. manager = Manager()
  257. return_dict = manager.dict()
  258. p = []
  259. for t_idx_reduce in range(num_threads_reduce):
  260. p.append(
  261. Process(target=self.merge_gather_map_stats,
  262. args=(
  263. num_workers,
  264. num_threads,
  265. num_threads_reduce,
  266. t_idx_reduce,
  267. metric_save_path,
  268. metric_name,
  269. return_dict,
  270. )))
  271. p[t_idx_reduce].start()
  272. for t_idx_reduce in range(num_threads_reduce):
  273. p[t_idx_reduce].join()
  274. for t_idx_reduce in range(num_threads_reduce):
  275. results = return_dict[t_idx_reduce]
  276. for res in results:
  277. unique_metric_values = unique_metric_values.union(set(res[0]))
  278. sample_to_metric_count += res[1]
  279. value_max = max(unique_metric_values)
  280. value_min = min(unique_metric_values)
  281. assert sample_to_metric_count == total_num_samples, "The number of samples in map result files are not correct. It's possible that some map worker didn't finish successfully."
  282. metric_value_dtype = find_fit_int_dtype(value_min, value_max)
  283. logger.info(
  284. f"Metric {metric_name} has values between {value_min} and {value_max}. Will use {metric_value_dtype} to store the metric values."
  285. )
  286. # sample_to_metric
  287. map_worker_thread = []
  288. for w_idx in range(num_workers):
  289. for t_idx in range(num_threads):
  290. map_worker_thread.append([w_idx, t_idx])
  291. thread_splits = split_index(0, len(map_worker_thread), num_threads_reduce)
  292. p = []
  293. for t_idx_reduce in range(num_threads_reduce):
  294. start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
  295. p.append(
  296. Process(target=self.merge_sample_to_metric,
  297. args=(
  298. t_idx_reduce,
  299. metric_save_path,
  300. metric_name,
  301. metric_value_dtype,
  302. map_worker_thread[start_idx:end_idx],
  303. )))
  304. p[t_idx_reduce].start()
  305. for t_idx_reduce in range(num_threads_reduce):
  306. p[t_idx_reduce].join()
  307. sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
  308. sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
  309. for t_idx_reduce in range(num_threads_reduce):
  310. chunk_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
  311. logger.info(f"Merging file {chunk_fname}")
  312. sample_to_metric_builder.merge_file_(chunk_fname)
  313. close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
  314. sample_to_metric = MMapIndexedDataset(sample_to_metric_fname, skip_warmup=True)
  315. assert len(sample_to_metric) == total_num_samples
  316. # metric_to_sample
  317. unique_metric_values = list(sorted(unique_metric_values))
  318. thread_splits = split_index(0, len(unique_metric_values), num_threads_reduce)
  319. p = []
  320. for t_idx_reduce in range(num_threads_reduce):
  321. start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
  322. p.append(
  323. Process(target=self.merge_metric_to_sample,
  324. args=(
  325. t_idx_reduce,
  326. metric_save_path,
  327. metric_name,
  328. sample_idx_dtype,
  329. metric_value_dtype,
  330. unique_metric_values[start_idx:end_idx],
  331. num_workers,
  332. num_threads,
  333. )))
  334. p[t_idx_reduce].start()
  335. for t_idx_reduce in range(num_threads_reduce):
  336. p[t_idx_reduce].join()
  337. index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample"
  338. index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
  339. index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric"
  340. index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
  341. for t_idx_reduce in range(num_threads_reduce):
  342. chunk_is_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
  343. logger.info(f"Merging file {chunk_is_fname}")
  344. index_to_sample_builder.merge_file_(chunk_is_fname)
  345. chunk_im_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
  346. logger.info(f"Merging file {chunk_im_fname}")
  347. index_to_metric_builder.merge_file_(chunk_im_fname)
  348. close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
  349. close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
  350. num_sample_per_value = {}
  351. index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True)
  352. index_to_metric = MMapIndexedDataset(index_to_metric_fname, skip_warmup=True)
  353. index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged"
  354. index_to_sample_merged_builder = create_mmap_dataset_builder(index_to_sample_merged_fname,
  355. sample_idx_dtype)
  356. for v_idx in range(len(index_to_sample)):
  357. if v_idx > 0:
  358. assert index_to_metric[v_idx] > index_to_metric[v_idx - 1]
  359. num_sample_per_value[index_to_metric[v_idx][0]] = len(index_to_sample[v_idx])
  360. assert sum(num_sample_per_value.values()) == total_num_samples
  361. merge_step = max(1, len(index_to_sample) // 100)
  362. for v_idx in range(0, len(index_to_sample), merge_step):
  363. merged_samples = np.copy(
  364. np.concatenate(index_to_sample[v_idx:min(len(index_to_sample), (v_idx + merge_step))],
  365. axis=None))
  366. index_to_sample_merged_builder.add_item(
  367. torch.tensor(merged_samples.astype(np.int64), dtype=torch.long))
  368. logger.info(f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}.")
  369. close_mmap_dataset_builder(index_to_sample_merged_builder, index_to_sample_merged_fname)
  370. self.get_metric_value_percentiles(metric_name, num_sample_per_value, total_num_samples)
  371. elif metric_type == 'accumulate_value_over_samples':
  372. metric_save_path = f"{save_path}/{metric_name}/"
  373. metric_value = None
  374. for w_idx in range(num_workers):
  375. for t_idx in range(num_threads):
  376. w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
  377. w_metric_value_fname = f"{w_metric_save_path}/{metric_name}_metric_value"
  378. w_metric_value = MMapIndexedDataset(w_metric_value_fname, skip_warmup=True)
  379. if metric_value is None:
  380. metric_value = np.copy(w_metric_value[0])
  381. else:
  382. metric_value += np.copy(w_metric_value[0])
  383. value_max = int(max(metric_value))
  384. value_min = int(min(metric_value))
  385. metric_value_dtype = find_fit_int_dtype(value_min, value_max)
  386. metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
  387. metric_value_builder = create_mmap_dataset_builder(metric_value_fname, metric_value_dtype)
  388. metric_value_builder.add_item(torch.tensor(metric_value.astype(np.int64), dtype=torch.long))
  389. close_mmap_dataset_builder(metric_value_builder, metric_value_fname)
  390. def run_reduce(self):
  391. if self.custom_reduce is None:
  392. self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path,
  393. self.num_workers, self.num_threads, self.num_threads_reduce)
  394. else:
  395. self.custom_reduce(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers,
  396. self.num_threads, self.num_threads_reduce)