123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- import os
- import sys
- import traceback
- import types
- from functools import wraps
- from itertools import chain
- import numpy as np
- import torch.utils.data
- from torch.utils.data import ConcatDataset
- from utils.commons.hparams import hparams
- def collate_xd(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
- if len(values[0].shape) == 1:
- return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id)
- elif len(values[0].shape) == 2:
- return collate_2d(values, pad_idx, left_pad, shift_right, max_len)
- elif len(values[0].shape) == 3:
- return collate_3d(values, pad_idx, left_pad, shift_right, max_len)
- def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
- if len(values[0].shape) == 1:
- return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id)
- else:
- return collate_2d(values, pad_idx, left_pad, shift_right, max_len)
- def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
- """Convert a list of 1d tensors into a padded 2d tensor."""
- size = max(v.size(0) for v in values) if max_len is None else max_len
- res = values[0].new(len(values), size).fill_(pad_idx)
- def copy_tensor(src, dst):
- assert dst.numel() == src.numel()
- if shift_right:
- dst[1:] = src[:-1]
- dst[0] = shift_id
- else:
- dst.copy_(src)
- for i, v in enumerate(values):
- copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
- return res
- def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
- """Convert a list of 2d tensors into a padded 3d tensor."""
- size = max(v.size(0) for v in values) if max_len is None else max_len
- res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx)
- def copy_tensor(src, dst):
- assert dst.numel() == src.numel()
- if shift_right:
- dst[1:] = src[:-1]
- else:
- dst.copy_(src)
- for i, v in enumerate(values):
- copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
- return res
- def collate_3d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
- """Convert a list of 2d tensors into a padded 3d tensor."""
- size = max(v.size(0) for v in values) if max_len is None else max_len
- res = values[0].new(len(values), size, values[0].shape[1], values[0].shape[2]).fill_(pad_idx)
- def copy_tensor(src, dst):
- assert dst.numel() == src.numel()
- if shift_right:
- dst[1:] = src[:-1]
- else:
- dst.copy_(src)
- for i, v in enumerate(values):
- copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
- return res
- def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
- if len(batch) == 0:
- return 0
- if len(batch) == max_sentences:
- return 1
- if num_tokens > max_tokens:
- return 1
- return 0
- def batch_by_size(
- indices, num_tokens_fn, max_tokens=None, max_sentences=None,
- required_batch_size_multiple=1, distributed=False
- ):
- """
- Yield mini-batches of indices bucketed by size. Batches may contain
- sequences of different lengths.
- Args:
- indices (List[int]): ordered list of dataset indices
- num_tokens_fn (callable): function that returns the number of tokens at
- a given index
- max_tokens (int, optional): max number of tokens in each batch
- (default: None).
- max_sentences (int, optional): max number of sentences in each
- batch (default: None).
- required_batch_size_multiple (int, optional): require batch size to
- be a multiple of N (default: 1).
- """
- max_tokens = max_tokens if max_tokens is not None else sys.maxsize
- max_sentences = max_sentences if max_sentences is not None else sys.maxsize
- bsz_mult = required_batch_size_multiple
- if isinstance(indices, types.GeneratorType):
- indices = np.fromiter(indices, dtype=np.int64, count=-1)
- sample_len = 0
- sample_lens = []
- batch = []
- batches = []
- for i in range(len(indices)):
- idx = indices[i]
- num_tokens = num_tokens_fn(idx)
- sample_lens.append(num_tokens)
- sample_len = max(sample_len, num_tokens)
- assert sample_len <= max_tokens, (
- "sentence at index {} of size {} exceeds max_tokens "
- "limit of {}!".format(idx, sample_len, max_tokens)
- )
- num_tokens = (len(batch) + 1) * sample_len
- if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
- mod_len = max(
- bsz_mult * (len(batch) // bsz_mult),
- len(batch) % bsz_mult,
- )
- batches.append(batch[:mod_len])
- batch = batch[mod_len:]
- sample_lens = sample_lens[mod_len:]
- sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
- batch.append(idx)
- if len(batch) > 0:
- batches.append(batch)
- return batches
- def unpack_dict_to_list(samples):
- samples_ = []
- bsz = samples.get('outputs').size(0)
- for i in range(bsz):
- res = {}
- for k, v in samples.items():
- try:
- res[k] = v[i]
- except:
- pass
- samples_.append(res)
- return samples_
- def remove_padding(x, padding_idx=0):
- if x is None:
- return None
- assert len(x.shape) in [1, 2]
- if len(x.shape) == 2: # [T, H]
- return x[np.abs(x).sum(-1) != padding_idx]
- elif len(x.shape) == 1: # [T]
- return x[x != padding_idx]
- def data_loader(fn):
- """
- Decorator to make any fx with this use the lazy property
- :param fn:
- :return:
- """
- wraps(fn)
- attr_name = '_lazy_' + fn.__name__
- def _get_data_loader(self):
- try:
- value = getattr(self, attr_name)
- except AttributeError:
- try:
- value = fn(self) # Lazy evaluation, done only once.
- except AttributeError as e:
- # Guard against AttributeError suppression. (Issue #142)
- traceback.print_exc()
- error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
- raise RuntimeError(error) from e
- setattr(self, attr_name, value) # Memoize evaluation.
- return value
- return _get_data_loader
- class BaseDataset(torch.utils.data.Dataset):
- def __init__(self, shuffle):
- super().__init__()
- self.hparams = hparams
- self.shuffle = shuffle
- self.sort_by_len = hparams['sort_by_len']
- self.sizes = None
- @property
- def _sizes(self):
- return self.sizes
- def __getitem__(self, index):
- raise NotImplementedError
- def collater(self, samples):
- raise NotImplementedError
- def __len__(self):
- return len(self._sizes)
- def num_tokens(self, index):
- return self.size(index)
- def size(self, index):
- """Return an example's size as a float or tuple. This value is used when
- filtering a dataset with ``--max-positions``."""
- return min(self._sizes[index], hparams['max_frames'])
- def ordered_indices(self):
- """Return an ordered list of indices. Batches will be constructed based
- on this order."""
- if self.shuffle:
- indices = np.random.permutation(len(self))
- if self.sort_by_len:
- indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
- else:
- indices = np.arange(len(self))
- return indices
- @property
- def num_workers(self):
- return int(os.getenv('NUM_WORKERS', hparams['num_workers']))
- class BaseConcatDataset(ConcatDataset):
- def collater(self, samples):
- return self.datasets[0].collater(samples)
- @property
- def _sizes(self):
- if not hasattr(self, 'sizes'):
- self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets]))
- return self.sizes
- def size(self, index):
- return min(self._sizes[index], hparams['max_frames'])
- def num_tokens(self, index):
- return self.size(index)
- def ordered_indices(self):
- """Return an ordered list of indices. Batches will be constructed based
- on this order."""
- if self.datasets[0].shuffle:
- indices = np.random.permutation(len(self))
- if self.datasets[0].sort_by_len:
- indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
- else:
- indices = np.arange(len(self))
- return indices
- @property
- def num_workers(self):
- return self.datasets[0].num_workers
|