123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from torch.utils.data import DataLoader, RandomSampler
- from torch.utils.data.distributed import DistributedSampler
- from deepspeed.accelerator import get_accelerator
- from deepspeed.runtime.data_pipeline.data_sampling.data_sampler import DeepSpeedDataSampler
- from deepspeed.runtime.data_pipeline.constants import CURRICULUM_LEARNING, \
- DATA_EFFICIENCY, DATA_SAMPLING_NUM_WORKERS
- from deepspeed.runtime.constants import GRADIENT_ACCUMULATION_STEPS, \
- DATA_PARALLEL_GROUP, GLOBAL_RANK
- class RepeatingLoader:
- def __init__(self, loader):
- """Wraps an iterator to allow for infinite iteration. This is especially useful
- for DataLoader types that we wish to automatically restart upon completion.
- Args:
- loader (iterator): The data loader to repeat.
- """
- self.loader = loader
- self.data_iter = iter(self.loader)
- def __iter__(self):
- return self
- def __next__(self):
- try:
- batch = next(self.data_iter)
- except StopIteration:
- self.data_iter = iter(self.loader)
- batch = next(self.data_iter)
- return batch
- class DeepSpeedDataLoader(object):
- def __init__(self,
- dataset,
- batch_size,
- pin_memory,
- local_rank,
- tput_timer,
- collate_fn=None,
- num_local_io_workers=None,
- data_sampler=None,
- data_parallel_world_size=None,
- data_parallel_rank=None,
- dataloader_drop_last=False,
- deepspeed_dataloader_config={}):
- self.deepspeed_dataloader_config = deepspeed_dataloader_config
- self.tput_timer = tput_timer
- self.batch_size = batch_size
- self.curriculum_learning_enabled = False
- if CURRICULUM_LEARNING in deepspeed_dataloader_config:
- self.curriculum_learning_enabled = deepspeed_dataloader_config[CURRICULUM_LEARNING]
- if self.curriculum_learning_enabled:
- data_sampler = DeepSpeedDataSampler(self.deepspeed_dataloader_config[DATA_EFFICIENCY],
- len(dataset),
- self.batch_size,
- data_parallel_rank,
- data_parallel_world_size,
- self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP],
- self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS],
- self.deepspeed_dataloader_config[GLOBAL_RANK],
- drop_last=dataloader_drop_last)
- device_count = get_accelerator().device_count()
- num_local_io_workers = self.deepspeed_dataloader_config[DATA_SAMPLING_NUM_WORKERS]
- else:
- if local_rank >= 0:
- if data_sampler is None:
- data_sampler = DistributedSampler(dataset=dataset,
- num_replicas=data_parallel_world_size,
- rank=data_parallel_rank)
- device_count = 1
- else:
- if data_sampler is None:
- data_sampler = RandomSampler(dataset)
- device_count = get_accelerator().device_count()
- batch_size *= device_count
- if num_local_io_workers is None:
- num_local_io_workers = 2 * device_count
- self.num_local_io_workers = num_local_io_workers
- self.data_sampler = data_sampler
- self.dataset = dataset
- self.collate_fn = collate_fn
- self.device_count = device_count
- self.batch_size = batch_size
- self.pin_memory = pin_memory
- self.data = None
- self.dataloader_drop_last = dataloader_drop_last
- self.post_process_func = None
- if self.dataloader_drop_last:
- self.len = len(self.data_sampler) // self.batch_size
- else:
- from math import ceil
- self.len = ceil(len(self.data_sampler) / self.batch_size)
- def __iter__(self):
- self._create_dataloader()
- return self
- def __len__(self):
- return self.len
- def __next__(self):
- if self.tput_timer:
- self.tput_timer.start()
- if self.curriculum_learning_enabled:
- data = next(self.data_iterator)
- if self.post_process_func is not None:
- data = self.post_process_func(data, self.data_sampler.state_dict())
- return data
- else:
- return next(self.data)
- def _create_dataloader(self):
- if self.curriculum_learning_enabled:
- if self.collate_fn is None:
- self.dataloader = DataLoader(self.dataset,
- pin_memory=self.pin_memory,
- batch_sampler=self.data_sampler,
- num_workers=self.num_local_io_workers)
- else:
- self.dataloader = DataLoader(self.dataset,
- pin_memory=self.pin_memory,
- batch_sampler=self.data_sampler,
- collate_fn=self.collate_fn,
- num_workers=self.num_local_io_workers)
- self.data_iterator = iter(self.dataloader)
- return self.dataloader
- else:
- if self.collate_fn is None:
- self.dataloader = DataLoader(self.dataset,
- batch_size=self.batch_size,
- pin_memory=self.pin_memory,
- sampler=self.data_sampler,
- num_workers=self.num_local_io_workers,
- drop_last=self.dataloader_drop_last)
- else:
- self.dataloader = DataLoader(self.dataset,
- batch_size=self.batch_size,
- pin_memory=self.pin_memory,
- sampler=self.data_sampler,
- collate_fn=self.collate_fn,
- num_workers=self.num_local_io_workers,
- drop_last=self.dataloader_drop_last)
- self.data = (x for x in self.dataloader)
- return self.dataloader
- # DataLoader([(torch.randn(3, 3), torch.tensor(i % 2)) for i in range(10)], batch_size=2))
|