123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- '''
- Copyright 2019 The Microsoft DeepSpeed Team
- '''
- import torch
- from torch.utils.data import DataLoader, RandomSampler
- from torch.utils.data.distributed import DistributedSampler
- class RepeatingLoader:
- def __init__(self, loader):
- """Wraps an iterator to allow for infinite iteration. This is especially useful
- for DataLoader types that we wish to automatically restart upon completion.
- Args:
- loader (iterator): The data loader to repeat.
- """
- self.loader = loader
- self.data_iter = iter(self.loader)
- def __iter__(self):
- return self
- def __next__(self):
- try:
- batch = next(self.data_iter)
- except StopIteration:
- self.data_iter = iter(self.loader)
- batch = next(self.data_iter)
- return batch
- class DeepSpeedDataLoader(object):
- def __init__(self,
- dataset,
- batch_size,
- pin_memory,
- local_rank,
- tput_timer,
- collate_fn=None,
- num_local_io_workers=None,
- data_sampler=None,
- data_parallel_world_size=None,
- data_parallel_rank=None,
- dataloader_drop_last=False):
- self.tput_timer = tput_timer
- self.batch_size = batch_size
- if local_rank >= 0:
- if data_sampler is None:
- data_sampler = DistributedSampler(dataset=dataset,
- num_replicas=data_parallel_world_size,
- rank=data_parallel_rank)
- device_count = 1
- else:
- if data_sampler is None:
- data_sampler = RandomSampler(dataset)
- device_count = torch.cuda.device_count()
- batch_size *= device_count
- if num_local_io_workers is None:
- num_local_io_workers = 2 * device_count
- self.num_local_io_workers = num_local_io_workers
- self.data_sampler = data_sampler
- self.dataset = dataset
- self.collate_fn = collate_fn
- self.device_count = device_count
- self.batch_size = batch_size
- self.pin_memory = pin_memory
- self.len = len(self.data_sampler)
- self.data = None
- self.dataloader_drop_last = dataloader_drop_last
- def __iter__(self):
- self._create_dataloader()
- return self
- def __len__(self):
- return self.len
- def __next__(self):
- if self.tput_timer:
- self.tput_timer.start()
- return next(self.data)
- def _create_dataloader(self):
- if self.collate_fn is None:
- self.dataloader = DataLoader(self.dataset,
- batch_size=self.batch_size,
- pin_memory=self.pin_memory,
- sampler=self.data_sampler,
- num_workers=self.num_local_io_workers,
- drop_last=self.dataloader_drop_last)
- else:
- self.dataloader = DataLoader(self.dataset,
- batch_size=self.batch_size,
- pin_memory=self.pin_memory,
- sampler=self.data_sampler,
- collate_fn=self.collate_fn,
- num_workers=self.num_local_io_workers,
- drop_last=self.dataloader_drop_last)
- self.data = (x for x in self.dataloader)
- return self.dataloader
- # DataLoader([(torch.randn(3, 3), torch.tensor(i % 2)) for i in range(10)], batch_size=2))
|