indexed_datasets.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import pickle
  2. from bisect import bisect
  3. from copy import deepcopy
  4. import numpy as np
  5. import gzip
  6. def int2bytes(i: int, *, signed: bool = False) -> bytes:
  7. length = ((i + ((i * signed) < 0)).bit_length() + 7 + signed) // 8
  8. return i.to_bytes(length, byteorder='little', signed=signed)
  9. def bytes2int(b: bytes, *, signed: bool = False) -> int:
  10. return int.from_bytes(b, byteorder='little', signed=signed)
  11. def load_index_data(data_file):
  12. index_data_size = bytes2int(data_file.read(32))
  13. index_data = data_file.read(index_data_size)
  14. index_data = pickle.loads(index_data)
  15. data_offsets = deepcopy(index_data['offsets'])
  16. id2pos = deepcopy(index_data.get('id2pos', {}))
  17. meta = deepcopy(index_data.get('meta', {}))
  18. return data_offsets, id2pos, meta
  19. class IndexedDataset:
  20. def __init__(self, path, unpickle=True):
  21. self.path = path
  22. self.root_data_file = open(f"{path}.data", 'rb', buffering=-1)
  23. try:
  24. self.byte_offsets, self.id2pos, self.meta = load_index_data(self.root_data_file)
  25. self.data_files = [self.root_data_file]
  26. except:
  27. self.__init__old(path)
  28. self.meta = {}
  29. self.gzip = self.meta.get('gzip', False)
  30. if 'chunk_begin' not in self.meta:
  31. self.meta['chunk_begin'] = [0]
  32. for i in range(len(self.meta['chunk_begin'][1:])):
  33. self.data_files.append(open(f"{self.path}.{i + 1}.data", 'rb'))
  34. self.unpickle = unpickle
  35. def __init__old(self, path):
  36. self.path = path
  37. index_data = np.load(f"{path}.idx", allow_pickle=True).item()
  38. self.byte_offsets = index_data['offsets']
  39. self.id2pos = index_data.get('id2pos', {})
  40. self.data_files = [open(f"{path}.data", 'rb', buffering=-1)]
  41. def __getitem__(self, i):
  42. if self.id2pos is not None and len(self.id2pos) > 0:
  43. i = self.id2pos[i]
  44. self.check_index(i)
  45. # chunk_id = bisect(self.meta['chunk_begin'][1:], self.byte_offsets[i])
  46. # if chunk_id == 0:
  47. # data_file = open(f"{self.path}.data", 'rb', buffering=-1)
  48. # else:
  49. # data_file = open(f"{self.path}.{chunk_id}.data", 'rb', buffering=-1)
  50. # data_file.seek(self.byte_offsets[i] - self.meta['chunk_begin'][chunk_id])
  51. # b = data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])
  52. # data_file.close()
  53. chunk_id = bisect(self.meta['chunk_begin'][1:], self.byte_offsets[i])
  54. data_file = self.data_files[chunk_id]
  55. data_file.seek(self.byte_offsets[i] - self.meta['chunk_begin'][chunk_id])
  56. b = data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])
  57. unpickle = self.unpickle
  58. if unpickle:
  59. if self.gzip:
  60. b = gzip.decompress(b)
  61. item = pickle.loads(b)
  62. else:
  63. item = b
  64. return item
  65. def __del__(self):
  66. for data_file in self.data_files:
  67. data_file.close()
  68. def check_index(self, i):
  69. if i < 0 or i >= len(self.byte_offsets) - 1:
  70. raise IndexError('index out of range')
  71. def __len__(self):
  72. return len(self.byte_offsets) - 1
  73. def __iter__(self):
  74. self.iter_i = 0
  75. return self
  76. def __next__(self):
  77. if self.iter_i == len(self):
  78. raise StopIteration
  79. else:
  80. item = self[self.iter_i]
  81. self.iter_i += 1
  82. return item
  83. class IndexedDatasetBuilder:
  84. def __init__(self, path, append=False, max_size=1024 * 1024 * 1024 * 64,
  85. default_idx_size=1024 * 1024 * 16, gzip=False):
  86. self.path = self.root_path = path
  87. self.default_idx_size = default_idx_size
  88. if append:
  89. self.data_file = open(f"{path}.data", 'r+b')
  90. self.data_file.seek(0)
  91. self.byte_offsets, self.id2pos, self.meta = load_index_data(self.data_file)
  92. self.data_file.seek(0)
  93. self.data_file.write(bytes(default_idx_size))
  94. self.data_file.seek(self.byte_offsets[-1])
  95. self.gzip = self.meta['gzip']
  96. else:
  97. self.data_file = open(f"{path}.data", 'wb')
  98. self.data_file.seek(default_idx_size)
  99. self.byte_offsets = [default_idx_size]
  100. self.id2pos = {}
  101. self.meta = {}
  102. self.meta['chunk_begin'] = [0]
  103. self.gzip = self.meta['gzip'] = gzip
  104. self.root_data_file = self.data_file
  105. self.max_size = max_size
  106. self.data_chunk_id = 0
  107. def add_item(self, item, id=None, use_pickle=True):
  108. if self.byte_offsets[-1] > self.meta['chunk_begin'][-1] + self.max_size:
  109. if self.data_file != self.root_data_file:
  110. self.data_file.close()
  111. self.data_chunk_id += 1
  112. self.data_file = open(f"{self.path}.{self.data_chunk_id}.data", 'wb')
  113. self.data_file.seek(0)
  114. self.meta['chunk_begin'].append(self.byte_offsets[-1])
  115. if not use_pickle:
  116. s = item
  117. else:
  118. s = pickle.dumps(item)
  119. if self.gzip:
  120. s = gzip.compress(s, 1)
  121. bytes = self.data_file.write(s)
  122. if id is not None:
  123. self.id2pos[id] = len(self.byte_offsets) - 1
  124. self.byte_offsets.append(self.byte_offsets[-1] + bytes)
  125. def finalize(self):
  126. self.root_data_file.seek(0)
  127. s = pickle.dumps({'offsets': self.byte_offsets, 'id2pos': self.id2pos, 'meta': self.meta})
  128. assert len(s) < self.default_idx_size, (len(s), self.default_idx_size)
  129. len_bytes = int2bytes(len(s))
  130. self.root_data_file.write(len_bytes)
  131. self.root_data_file.seek(32)
  132. self.root_data_file.write(s)
  133. self.root_data_file.close()
  134. try:
  135. self.data_file.close()
  136. except:
  137. pass
  138. if __name__ == "__main__":
  139. import random
  140. from tqdm import tqdm
  141. # builder = IndexedDatasetBuilder(ds_path, append=True)
  142. # for i in tqdm(range(size)):
  143. # builder.add_item(items[i], i + size)
  144. # builder.finalize()
  145. # ds = IndexedDataset(ds_path)
  146. # for i in tqdm(range(1000)):
  147. # idx = random.randint(size, 2 * size - 1)
  148. # assert (ds[idx]['a'] == items[idx - size]['a']).all()
  149. # idx = random.randint(0, size - 1)
  150. # assert (ds[idx]['a'] == items[idx]['a']).all()
  151. ds_path = '/tmp/indexed_ds_example'
  152. size = 100
  153. items = [{"a": np.random.normal(size=[10000, 10]),
  154. "b": np.random.normal(size=[10000, 10])} for i in range(size)]
  155. builder = IndexedDatasetBuilder(ds_path, max_size=1024 * 1024 * 40)
  156. builder.meta['lengths'] = [1, 2, 3]
  157. for i in tqdm(range(size)):
  158. builder.add_item(pickle.dumps(items[i]), i, use_pickle=False)
  159. builder.finalize()
  160. ds = IndexedDataset(ds_path)
  161. assert ds.meta['lengths'] == [1, 2, 3]
  162. for i in tqdm(range(1000)):
  163. idx = random.randint(0, size - 1)
  164. assert (ds[idx]['a'] == items[idx]['a']).all()
  165. # builder = IndexedDataset2Builder(ds_path, append=True)
  166. # builder.meta['lengths'] = [1, 2, 3, 5, 6, 7]
  167. # for i in tqdm(range(size)):
  168. # builder.add_item(items[i], i + size)
  169. # builder.finalize()
  170. # ds = IndexedDataset2(ds_path)
  171. # assert ds.meta['lengths'] == [1, 2, 3, 5, 6, 7]
  172. # for i in tqdm(range(1000)):
  173. # idx = random.randint(size, 2 * size - 1)
  174. # assert (ds[idx]['a'] == items[idx - size]['a']).all()
  175. # idx = random.randint(0, size - 1)
  176. # assert (ds[idx]['a'] == items[idx]['a']).all()