dataset_utils.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import os
  2. import sys
  3. import traceback
  4. import types
  5. from functools import wraps
  6. from itertools import chain
  7. import numpy as np
  8. import torch.utils.data
  9. from torch.utils.data import ConcatDataset
  10. from utils.commons.hparams import hparams
  11. def collate_xd(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
  12. if len(values[0].shape) == 1:
  13. return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id)
  14. elif len(values[0].shape) == 2:
  15. return collate_2d(values, pad_idx, left_pad, shift_right, max_len)
  16. elif len(values[0].shape) == 3:
  17. return collate_3d(values, pad_idx, left_pad, shift_right, max_len)
  18. def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
  19. if len(values[0].shape) == 1:
  20. return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id)
  21. else:
  22. return collate_2d(values, pad_idx, left_pad, shift_right, max_len)
  23. def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
  24. """Convert a list of 1d tensors into a padded 2d tensor."""
  25. size = max(v.size(0) for v in values) if max_len is None else max_len
  26. res = values[0].new(len(values), size).fill_(pad_idx)
  27. def copy_tensor(src, dst):
  28. assert dst.numel() == src.numel()
  29. if shift_right:
  30. dst[1:] = src[:-1]
  31. dst[0] = shift_id
  32. else:
  33. dst.copy_(src)
  34. for i, v in enumerate(values):
  35. copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
  36. return res
  37. def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
  38. """Convert a list of 2d tensors into a padded 3d tensor."""
  39. size = max(v.size(0) for v in values) if max_len is None else max_len
  40. res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx)
  41. def copy_tensor(src, dst):
  42. assert dst.numel() == src.numel()
  43. if shift_right:
  44. dst[1:] = src[:-1]
  45. else:
  46. dst.copy_(src)
  47. for i, v in enumerate(values):
  48. copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
  49. return res
  50. def collate_3d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
  51. """Convert a list of 2d tensors into a padded 3d tensor."""
  52. size = max(v.size(0) for v in values) if max_len is None else max_len
  53. res = values[0].new(len(values), size, values[0].shape[1], values[0].shape[2]).fill_(pad_idx)
  54. def copy_tensor(src, dst):
  55. assert dst.numel() == src.numel()
  56. if shift_right:
  57. dst[1:] = src[:-1]
  58. else:
  59. dst.copy_(src)
  60. for i, v in enumerate(values):
  61. copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
  62. return res
  63. def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
  64. if len(batch) == 0:
  65. return 0
  66. if len(batch) == max_sentences:
  67. return 1
  68. if num_tokens > max_tokens:
  69. return 1
  70. return 0
  71. def batch_by_size(
  72. indices, num_tokens_fn, max_tokens=None, max_sentences=None,
  73. required_batch_size_multiple=1, distributed=False
  74. ):
  75. """
  76. Yield mini-batches of indices bucketed by size. Batches may contain
  77. sequences of different lengths.
  78. Args:
  79. indices (List[int]): ordered list of dataset indices
  80. num_tokens_fn (callable): function that returns the number of tokens at
  81. a given index
  82. max_tokens (int, optional): max number of tokens in each batch
  83. (default: None).
  84. max_sentences (int, optional): max number of sentences in each
  85. batch (default: None).
  86. required_batch_size_multiple (int, optional): require batch size to
  87. be a multiple of N (default: 1).
  88. """
  89. max_tokens = max_tokens if max_tokens is not None else sys.maxsize
  90. max_sentences = max_sentences if max_sentences is not None else sys.maxsize
  91. bsz_mult = required_batch_size_multiple
  92. if isinstance(indices, types.GeneratorType):
  93. indices = np.fromiter(indices, dtype=np.int64, count=-1)
  94. sample_len = 0
  95. sample_lens = []
  96. batch = []
  97. batches = []
  98. for i in range(len(indices)):
  99. idx = indices[i]
  100. num_tokens = num_tokens_fn(idx)
  101. sample_lens.append(num_tokens)
  102. sample_len = max(sample_len, num_tokens)
  103. assert sample_len <= max_tokens, (
  104. "sentence at index {} of size {} exceeds max_tokens "
  105. "limit of {}!".format(idx, sample_len, max_tokens)
  106. )
  107. num_tokens = (len(batch) + 1) * sample_len
  108. if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
  109. mod_len = max(
  110. bsz_mult * (len(batch) // bsz_mult),
  111. len(batch) % bsz_mult,
  112. )
  113. batches.append(batch[:mod_len])
  114. batch = batch[mod_len:]
  115. sample_lens = sample_lens[mod_len:]
  116. sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
  117. batch.append(idx)
  118. if len(batch) > 0:
  119. batches.append(batch)
  120. return batches
  121. def unpack_dict_to_list(samples):
  122. samples_ = []
  123. bsz = samples.get('outputs').size(0)
  124. for i in range(bsz):
  125. res = {}
  126. for k, v in samples.items():
  127. try:
  128. res[k] = v[i]
  129. except:
  130. pass
  131. samples_.append(res)
  132. return samples_
  133. def remove_padding(x, padding_idx=0):
  134. if x is None:
  135. return None
  136. assert len(x.shape) in [1, 2]
  137. if len(x.shape) == 2: # [T, H]
  138. return x[np.abs(x).sum(-1) != padding_idx]
  139. elif len(x.shape) == 1: # [T]
  140. return x[x != padding_idx]
  141. def data_loader(fn):
  142. """
  143. Decorator to make any fx with this use the lazy property
  144. :param fn:
  145. :return:
  146. """
  147. wraps(fn)
  148. attr_name = '_lazy_' + fn.__name__
  149. def _get_data_loader(self):
  150. try:
  151. value = getattr(self, attr_name)
  152. except AttributeError:
  153. try:
  154. value = fn(self) # Lazy evaluation, done only once.
  155. except AttributeError as e:
  156. # Guard against AttributeError suppression. (Issue #142)
  157. traceback.print_exc()
  158. error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
  159. raise RuntimeError(error) from e
  160. setattr(self, attr_name, value) # Memoize evaluation.
  161. return value
  162. return _get_data_loader
  163. class BaseDataset(torch.utils.data.Dataset):
  164. def __init__(self, shuffle):
  165. super().__init__()
  166. self.hparams = hparams
  167. self.shuffle = shuffle
  168. self.sort_by_len = hparams['sort_by_len']
  169. self.sizes = None
  170. @property
  171. def _sizes(self):
  172. return self.sizes
  173. def __getitem__(self, index):
  174. raise NotImplementedError
  175. def collater(self, samples):
  176. raise NotImplementedError
  177. def __len__(self):
  178. return len(self._sizes)
  179. def num_tokens(self, index):
  180. return self.size(index)
  181. def size(self, index):
  182. """Return an example's size as a float or tuple. This value is used when
  183. filtering a dataset with ``--max-positions``."""
  184. return min(self._sizes[index], hparams['max_frames'])
  185. def ordered_indices(self):
  186. """Return an ordered list of indices. Batches will be constructed based
  187. on this order."""
  188. if self.shuffle:
  189. indices = np.random.permutation(len(self))
  190. if self.sort_by_len:
  191. indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
  192. else:
  193. indices = np.arange(len(self))
  194. return indices
  195. @property
  196. def num_workers(self):
  197. return int(os.getenv('NUM_WORKERS', hparams['num_workers']))
  198. class BaseConcatDataset(ConcatDataset):
  199. def collater(self, samples):
  200. return self.datasets[0].collater(samples)
  201. @property
  202. def _sizes(self):
  203. if not hasattr(self, 'sizes'):
  204. self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets]))
  205. return self.sizes
  206. def size(self, index):
  207. return min(self._sizes[index], hparams['max_frames'])
  208. def num_tokens(self, index):
  209. return self.size(index)
  210. def ordered_indices(self):
  211. """Return an ordered list of indices. Batches will be constructed based
  212. on this order."""
  213. if self.datasets[0].shuffle:
  214. indices = np.random.permutation(len(self))
  215. if self.datasets[0].sort_by_len:
  216. indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
  217. else:
  218. indices = np.arange(len(self))
  219. return indices
  220. @property
  221. def num_workers(self):
  222. return self.datasets[0].num_workers