indexed_dataset.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Part of this code was adopted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
  6. """
  7. # Copyright (c) Facebook, Inc. and its affiliates.
  8. #
  9. # This source code is licensed under the MIT license found in the
  10. # LICENSE file in the root directory of this source tree.
  11. # copied from fairseq/fairseq/data/indexed_dataset.py
  12. # Removed IndexedRawTextDataset since it relied on Fairseq dictionary
  13. # other slight modifications to remove fairseq dependencies
  14. # Added document index to index file and made it accessible.
  15. # An empty sentence no longer separates documents.
  16. # Some of the fixes/improvements are adopted from
  17. # https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/main/megatron/data/indexed_dataset.py
  18. from functools import lru_cache
  19. import os
  20. import shutil
  21. import struct
  22. from itertools import accumulate
  23. import numpy as np
  24. import torch
  25. def __best_fitting_dtype(vocab_size=None):
  26. if vocab_size is not None and vocab_size < 65500:
  27. return np.uint16
  28. else:
  29. return np.int32
  30. def get_available_dataset_impl():
  31. return ['lazy', 'cached', 'mmap']
  32. def infer_dataset_impl(path):
  33. if IndexedDataset.exists(path):
  34. with open(index_file_path(path), 'rb') as f:
  35. magic = f.read(8)
  36. if magic == IndexedDataset._HDR_MAGIC:
  37. return 'cached'
  38. elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
  39. return 'mmap'
  40. else:
  41. return None
  42. else:
  43. print(f"Dataset does not exist: {path}")
  44. print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
  45. return None
  46. def make_builder(out_file, impl, vocab_size=None):
  47. if impl == 'mmap':
  48. return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
  49. else:
  50. return IndexedDatasetBuilder(out_file)
  51. def make_dataset(path, impl, skip_warmup=False):
  52. if not IndexedDataset.exists(path):
  53. print(f"Dataset does not exist: {path}")
  54. print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
  55. return None
  56. if impl == 'infer':
  57. impl = infer_dataset_impl(path)
  58. if impl == 'lazy' and IndexedDataset.exists(path):
  59. return IndexedDataset(path)
  60. elif impl == 'cached' and IndexedDataset.exists(path):
  61. return IndexedCachedDataset(path)
  62. elif impl == 'mmap' and MMapIndexedDataset.exists(path):
  63. return MMapIndexedDataset(path, skip_warmup)
  64. print(f"Unknown dataset implementation: {impl}")
  65. return None
  66. def dataset_exists(path, impl):
  67. if impl == 'mmap':
  68. return MMapIndexedDataset.exists(path)
  69. else:
  70. return IndexedDataset.exists(path)
  71. def read_longs(f, n):
  72. a = np.empty(n, dtype=np.int64)
  73. f.readinto(a)
  74. return a
  75. def write_longs(f, a):
  76. f.write(np.array(a, dtype=np.int64))
  77. # valid metric_dtypes as numpy and torch types
  78. dtypes = {
  79. 1: (np.uint8, torch.uint8),
  80. 2: (np.int8, torch.int8),
  81. 3: (np.int16, torch.int16),
  82. 4: (np.int32, torch.int32),
  83. 5: (np.int64, torch.int64),
  84. 6: (np.uint16, None),
  85. 7: (np.uint32, None),
  86. 8: (np.uint64, None),
  87. }
  88. valid_dtypes = set([dt[0] for dt in dtypes.values()] + [dt[1] for dt in dtypes.values() if dt[1] is not None])
  89. def code(dtype):
  90. for c, (np_dt, torch_dt) in dtypes.items():
  91. if dtype in [np_dt, torch_dt]:
  92. return c
  93. raise ValueError(f"{dtype} not supported. Supported types: {valid_dtypes}")
  94. def index_file_path(prefix_path):
  95. return prefix_path + '.idx'
  96. def data_file_path(prefix_path):
  97. return prefix_path + '.bin'
  98. def create_doc_idx(sizes):
  99. doc_idx = [0]
  100. for i, s in enumerate(sizes):
  101. if s == 0:
  102. doc_idx.append(i + 1)
  103. return doc_idx
  104. class IndexedDataset(torch.utils.data.Dataset):
  105. """Loader for IndexedDataset"""
  106. _HDR_MAGIC = b'TNTIDX\x00\x00'
  107. def __init__(self, path):
  108. super().__init__()
  109. self.path = path
  110. self.data_file = None
  111. self.read_index(path)
  112. def read_index(self, path):
  113. with open(index_file_path(path), 'rb') as f:
  114. magic = f.read(8)
  115. assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
  116. 'Make sure that --dataset-impl is configured properly.')
  117. version = f.read(8)
  118. assert struct.unpack('<Q', version) == (1, )
  119. code, self.element_size = struct.unpack('<QQ', f.read(16))
  120. self.dtype = dtypes[code][0] #numpy type
  121. self._len, self.s = struct.unpack('<QQ', f.read(16))
  122. self.doc_count = struct.unpack('<Q', f.read(8))
  123. self.dim_offsets = read_longs(f, self._len + 1)
  124. self.data_offsets = read_longs(f, self._len + 1)
  125. self.sizes = read_longs(f, self.s)
  126. self.doc_idx = read_longs(f, self.doc_count)
  127. def read_data(self, path):
  128. self.data_file = open(data_file_path(path), 'rb', buffering=0)
  129. def check_index(self, i):
  130. if i < 0 or i >= self._len:
  131. raise IndexError('index out of range')
  132. def __del__(self):
  133. if self.data_file:
  134. self.data_file.close()
  135. # @lru_cache(maxsize=8)
  136. def __getitem__(self, idx):
  137. if not self.data_file:
  138. self.read_data(self.path)
  139. if isinstance(idx, int):
  140. i = idx
  141. self.check_index(i)
  142. tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
  143. a = np.empty(tensor_size, dtype=self.dtype)
  144. self.data_file.seek(self.data_offsets[i] * self.element_size)
  145. self.data_file.readinto(a)
  146. return a
  147. elif isinstance(idx, slice):
  148. start, stop, step = idx.indices(len(self))
  149. if step != 1:
  150. raise ValueError("Slices into indexed_dataset must be contiguous")
  151. sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
  152. size = sum(sizes)
  153. a = np.empty(size, dtype=self.dtype)
  154. self.data_file.seek(self.data_offsets[start] * self.element_size)
  155. self.data_file.readinto(a)
  156. offsets = list(accumulate(sizes))
  157. sents = np.split(a, offsets[:-1])
  158. return sents
  159. def __len__(self):
  160. return self._len
  161. def num_tokens(self, index):
  162. return self.sizes[index]
  163. def size(self, index):
  164. return self.sizes[index]
  165. @staticmethod
  166. def exists(path):
  167. return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
  168. @property
  169. def supports_prefetch(self):
  170. return False # avoid prefetching to save memory
  171. class IndexedCachedDataset(IndexedDataset):
  172. def __init__(self, path):
  173. super().__init__(path)
  174. self.cache = None
  175. self.cache_index = {}
  176. @property
  177. def supports_prefetch(self):
  178. return True
  179. def prefetch(self, indices):
  180. if all(i in self.cache_index for i in indices):
  181. return
  182. if not self.data_file:
  183. self.read_data(self.path)
  184. indices = sorted(set(indices))
  185. total_size = 0
  186. for i in indices:
  187. total_size += self.data_offsets[i + 1] - self.data_offsets[i]
  188. self.cache = np.empty(total_size, dtype=self.dtype)
  189. ptx = 0
  190. self.cache_index.clear()
  191. for i in indices:
  192. self.cache_index[i] = ptx
  193. size = self.data_offsets[i + 1] - self.data_offsets[i]
  194. a = self.cache[ptx:ptx + size]
  195. self.data_file.seek(self.data_offsets[i] * self.element_size)
  196. self.data_file.readinto(a)
  197. ptx += size
  198. if self.data_file:
  199. # close and delete data file after prefetch so we can pickle
  200. self.data_file.close()
  201. self.data_file = None
  202. # @lru_cache(maxsize=8)
  203. def __getitem__(self, idx):
  204. if isinstance(idx, int):
  205. i = idx
  206. self.check_index(i)
  207. tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
  208. a = np.empty(tensor_size, dtype=self.dtype)
  209. ptx = self.cache_index[i]
  210. np.copyto(a, self.cache[ptx:ptx + a.size])
  211. return a
  212. elif isinstance(idx, slice):
  213. # Hack just to make this work, can optimizer later if necessary
  214. sents = []
  215. for i in range(*idx.indices(len(self))):
  216. sents.append(self[i])
  217. return sents
  218. class IndexedDatasetBuilder(object):
  219. def __init__(self, out_file, dtype=np.int32):
  220. self.out_file = open(out_file, 'wb')
  221. self.dtype = dtype
  222. self.data_offsets = [0]
  223. self.dim_offsets = [0]
  224. self.sizes = []
  225. self.element_size = self.dtype().itemsize
  226. self.doc_idx = [0]
  227. def add_item(self, tensor):
  228. bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
  229. self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
  230. for s in tensor.size():
  231. self.sizes.append(s)
  232. self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
  233. def end_document(self):
  234. self.doc_idx.append(len(self.sizes))
  235. def merge_file_(self, another_file):
  236. index = IndexedDataset(another_file)
  237. assert index.dtype == self.dtype
  238. doc_offset = len(self.sizes)
  239. begin = self.data_offsets[-1]
  240. for data_offset in index.data_offsets[1:]:
  241. self.data_offsets.append(begin + data_offset)
  242. self.sizes.extend(index.sizes)
  243. begin = self.dim_offsets[-1]
  244. for dim_offset in index.dim_offsets[1:]:
  245. self.dim_offsets.append(begin + dim_offset)
  246. self.doc_idx.extend((doc_offset + index.doc_idx)[1:])
  247. with open(data_file_path(another_file), 'rb') as f:
  248. while True:
  249. data = f.read(1024)
  250. if data:
  251. self.out_file.write(data)
  252. else:
  253. break
  254. def finalize(self, index_file):
  255. self.out_file.close()
  256. index = open(index_file, 'wb')
  257. index.write(b'TNTIDX\x00\x00')
  258. index.write(struct.pack('<Q', 1))
  259. index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
  260. index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
  261. index.write(struct.pack('<Q', len(self.doc_idx)))
  262. write_longs(index, self.dim_offsets)
  263. write_longs(index, self.data_offsets)
  264. write_longs(index, self.sizes)
  265. write_longs(index, self.doc_idx)
  266. index.close()
  267. def _warmup_mmap_file(path):
  268. with open(path, 'rb') as stream:
  269. while stream.read(100 * 1024 * 1024):
  270. pass
  271. def exscan_from_cumsum_(arr):
  272. # given an array holding the result of an inclusive scan (cumsum),
  273. # convert to an exclusive scan (shift to the right)
  274. # [10, 30, 35, 50] --> [0, 10, 30, 35]
  275. if arr.size > 1:
  276. arr[1:] = arr[:-1]
  277. if arr.size > 0:
  278. arr[0] = 0
  279. def get_pointers_with_total(sizes, elemsize, dtype):
  280. """Return a numpy array of type np.dtype giving the byte offsets.
  281. Multiplies values in the sizes array by elemsize (bytes),
  282. and then computes an exclusive scan to get byte offsets.
  283. Returns the total number of bytes as second item in a tuple.
  284. """
  285. # scale values in sizes array by elemsize to get sizes in bytes
  286. pointers = np.array(sizes, dtype=dtype)
  287. pointers *= elemsize
  288. np.cumsum(pointers, axis=0, out=pointers)
  289. # get total number of bytes from all sizes (last element)
  290. bytes_last = pointers[-1] if len(sizes) > 0 else 0
  291. # convert to byte offsets
  292. exscan_from_cumsum_(pointers)
  293. return pointers, bytes_last
  294. class MMapIndexedDataset(torch.utils.data.Dataset):
  295. class Index(object):
  296. _HDR_MAGIC = b'MMIDIDX\x00\x00'
  297. @classmethod
  298. def writer(cls, path, dtype):
  299. class _Writer(object):
  300. def __enter__(self):
  301. self._file = open(path, 'wb')
  302. self._file.write(cls._HDR_MAGIC)
  303. self._file.write(struct.pack('<Q', 1))
  304. self._file.write(struct.pack('<B', code(dtype)))
  305. return self
  306. @staticmethod
  307. def _get_pointers(sizes, npdtype):
  308. """Return a numpy array of byte offsets given a list of sizes.
  309. Multiplies values in the sizes array by dtype size (bytes),
  310. and then computes an exclusive scan to get byte offsets.
  311. """
  312. # compute element sizes in bytes
  313. pointers, _ = get_pointers_with_total(sizes, dtype().itemsize, npdtype)
  314. return pointers
  315. def write(self, sizes, doc_idx):
  316. self._file.write(struct.pack('<Q', len(sizes)))
  317. self._file.write(struct.pack('<Q', len(doc_idx)))
  318. sizes32 = np.array(sizes, dtype=np.int32)
  319. self._file.write(sizes32.tobytes(order='C'))
  320. del sizes32
  321. pointers = self._get_pointers(sizes, np.int64)
  322. del sizes
  323. self._file.write(pointers.tobytes(order='C'))
  324. del pointers
  325. doc_idx = np.array(doc_idx, dtype=np.int64)
  326. self._file.write(doc_idx.tobytes(order='C'))
  327. def __exit__(self, exc_type, exc_val, exc_tb):
  328. self._file.close()
  329. return _Writer()
  330. def __init__(self, path, skip_warmup=False):
  331. with open(path, 'rb') as stream:
  332. magic_test = stream.read(9)
  333. assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. '
  334. 'Make sure that --dataset-impl is configured properly.')
  335. version = struct.unpack('<Q', stream.read(8))
  336. assert (1, ) == version
  337. dtype_code, = struct.unpack('<B', stream.read(1))
  338. self._dtype = dtypes[dtype_code][0] #numpy type
  339. self._dtype_size = self._dtype().itemsize
  340. self._len = struct.unpack('<Q', stream.read(8))[0]
  341. self._doc_count = struct.unpack('<Q', stream.read(8))[0]
  342. offset = stream.tell()
  343. if not skip_warmup:
  344. print(" warming up index mmap file...")
  345. _warmup_mmap_file(path)
  346. self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
  347. self._bin_buffer = memoryview(self._bin_buffer_mmap)
  348. print(" reading sizes...")
  349. self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
  350. print(" reading pointers...")
  351. self._pointers = np.frombuffer(self._bin_buffer,
  352. dtype=np.int64,
  353. count=self._len,
  354. offset=offset + self._sizes.nbytes)
  355. print(" reading document index...")
  356. self._doc_idx = np.frombuffer(self._bin_buffer,
  357. dtype=np.int64,
  358. count=self._doc_count,
  359. offset=offset + self._sizes.nbytes + self._pointers.nbytes)
  360. def __del__(self):
  361. self._bin_buffer_mmap._mmap.close()
  362. del self._bin_buffer_mmap
  363. @property
  364. def dtype(self):
  365. return self._dtype
  366. @property
  367. def sizes(self):
  368. return self._sizes
  369. @property
  370. def doc_idx(self):
  371. return self._doc_idx
  372. @lru_cache(maxsize=8)
  373. def __getitem__(self, i):
  374. return self._pointers[i], self._sizes[i]
  375. def __len__(self):
  376. return self._len
  377. def __init__(self, path, skip_warmup=False):
  378. super().__init__()
  379. self._path = None
  380. self._index = None
  381. self._bin_buffer = None
  382. self._do_init(path, skip_warmup)
  383. def __getstate__(self):
  384. return self._path
  385. def __setstate__(self, state):
  386. self._do_init(state)
  387. def _do_init(self, path, skip_warmup):
  388. self._path = path
  389. self._index = self.Index(index_file_path(self._path), skip_warmup)
  390. if not skip_warmup:
  391. print(" warming up data mmap file...")
  392. _warmup_mmap_file(data_file_path(self._path))
  393. print(" creating numpy buffer of mmap...")
  394. self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
  395. print(" creating memory view of numpy buffer...")
  396. self._bin_buffer = memoryview(self._bin_buffer_mmap)
  397. def __del__(self):
  398. self._bin_buffer_mmap._mmap.close()
  399. del self._bin_buffer_mmap
  400. del self._index
  401. def __len__(self):
  402. return len(self._index)
  403. # @lru_cache(maxsize=8)
  404. def __getitem__(self, idx):
  405. if isinstance(idx, int):
  406. ptr, size = self._index[idx]
  407. np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
  408. return np_array
  409. elif isinstance(idx, slice):
  410. start, stop, step = idx.indices(len(self))
  411. if step != 1:
  412. raise ValueError("Slices into indexed_dataset must be contiguous")
  413. ptr = self._index._pointers[start]
  414. sizes = self._index._sizes[idx]
  415. offsets = list(accumulate(sizes))
  416. total_size = sum(sizes)
  417. np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
  418. sents = np.split(np_array, offsets[:-1])
  419. return sents
  420. def get(self, idx, offset=0, length=None):
  421. """ Retrieves a single item from the dataset with the option to only
  422. return a portion of the item.
  423. get(idx) is the same as [idx] but get() does not support slicing.
  424. """
  425. ptr, size = self._index[idx]
  426. if length is None:
  427. length = size - offset
  428. ptr += offset * np.dtype(self._index.dtype).itemsize
  429. np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
  430. return np_array
  431. @property
  432. def sizes(self):
  433. return self._index.sizes
  434. def size(self, index):
  435. return self._index.sizes[index]
  436. @property
  437. def doc_idx(self):
  438. return self._index.doc_idx
  439. def get_doc_idx(self):
  440. return self._index._doc_idx
  441. def set_doc_idx(self, doc_idx_):
  442. self._index._doc_idx = doc_idx_
  443. @property
  444. def supports_prefetch(self):
  445. return False
  446. @staticmethod
  447. def exists(path):
  448. return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
  449. @property
  450. def dtype(self):
  451. return self._index.dtype
  452. class MMapIndexedDatasetBuilder(object):
  453. def __init__(self, out_file, dtype=np.int64):
  454. self._data_file = open(out_file, 'wb')
  455. self._dtype = [np_dt for np_dt, torch_dt in dtypes.values() if dtype in [np_dt, torch_dt]][0]
  456. self._sizes = []
  457. self._doc_idx = [0]
  458. def add_item(self, tensor):
  459. """ write the tensor to the file and update its size in the index"""
  460. np_array = np.array(tensor.numpy(), dtype=self._dtype)
  461. self._data_file.write(np_array.tobytes(order='C'))
  462. self._sizes.append(np_array.size)
  463. def add_items(self, arr_list):
  464. """ write a list of arrays to the file and update their sizes in the index"""
  465. np_arrays = [arr.astype(self._dtype) for arr in arr_list]
  466. self._data_file.writelines([arr.tobytes(order='C') for arr in np_arrays])
  467. for arr in np_arrays:
  468. self._sizes.append(arr.size)
  469. def add_item_numpy(self, np_array):
  470. if np_array.dtype != self._dtype:
  471. np_array = np_array.astype(self._dtype)
  472. self._data_file.write(np_array.tobytes(order='C'))
  473. self._sizes.append(np_array.size)
  474. def end_document(self):
  475. self._doc_idx.append(len(self._sizes))
  476. def merge_file_(self, another_file):
  477. # Concatenate index
  478. index = MMapIndexedDataset.Index(index_file_path(another_file))
  479. assert index.dtype == self._dtype
  480. total_len = len(index.sizes) + len(self._sizes)
  481. print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}")
  482. offset = len(self._sizes)
  483. self._sizes.extend(index.sizes)
  484. self._doc_idx.extend((offset + index.doc_idx)[1:])
  485. # Concatenate data
  486. with open(data_file_path(another_file), 'rb') as f:
  487. shutil.copyfileobj(f, self._data_file)
  488. self._data_file.flush()
  489. assert os.stat(self._data_file.name).st_size != 0, f"Zero-sized file: {self._data_file.name}"
  490. def finalize(self, index_file):
  491. self._data_file.close()
  492. with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
  493. index.write(self._sizes, self._doc_idx)