dataloader.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. '''
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. from torch.utils.data import DataLoader, RandomSampler
  6. from torch.utils.data.distributed import DistributedSampler
  7. class RepeatingLoader:
  8. def __init__(self, loader):
  9. """Wraps an iterator to allow for infinite iteration. This is especially useful
  10. for DataLoader types that we wish to automatically restart upon completion.
  11. Args:
  12. loader (iterator): The data loader to repeat.
  13. """
  14. self.loader = loader
  15. self.data_iter = iter(self.loader)
  16. def __iter__(self):
  17. return self
  18. def __next__(self):
  19. try:
  20. batch = next(self.data_iter)
  21. except StopIteration:
  22. self.data_iter = iter(self.loader)
  23. batch = next(self.data_iter)
  24. return batch
  25. class DeepSpeedDataLoader(object):
  26. def __init__(self,
  27. dataset,
  28. batch_size,
  29. pin_memory,
  30. local_rank,
  31. tput_timer,
  32. collate_fn=None,
  33. num_local_io_workers=None,
  34. data_sampler=None,
  35. data_parallel_world_size=None,
  36. data_parallel_rank=None,
  37. dataloader_drop_last=False):
  38. self.tput_timer = tput_timer
  39. self.batch_size = batch_size
  40. if local_rank >= 0:
  41. if data_sampler is None:
  42. data_sampler = DistributedSampler(dataset=dataset,
  43. num_replicas=data_parallel_world_size,
  44. rank=data_parallel_rank)
  45. device_count = 1
  46. else:
  47. if data_sampler is None:
  48. data_sampler = RandomSampler(dataset)
  49. device_count = torch.cuda.device_count()
  50. batch_size *= device_count
  51. if num_local_io_workers is None:
  52. num_local_io_workers = 2 * device_count
  53. self.num_local_io_workers = num_local_io_workers
  54. self.data_sampler = data_sampler
  55. self.dataset = dataset
  56. self.collate_fn = collate_fn
  57. self.device_count = device_count
  58. self.batch_size = batch_size
  59. self.pin_memory = pin_memory
  60. self.len = len(self.data_sampler)
  61. self.data = None
  62. self.dataloader_drop_last = dataloader_drop_last
  63. def __iter__(self):
  64. self._create_dataloader()
  65. return self
  66. def __len__(self):
  67. return self.len
  68. def __next__(self):
  69. if self.tput_timer:
  70. self.tput_timer.start()
  71. return next(self.data)
  72. def _create_dataloader(self):
  73. if self.collate_fn is None:
  74. self.dataloader = DataLoader(self.dataset,
  75. batch_size=self.batch_size,
  76. pin_memory=self.pin_memory,
  77. sampler=self.data_sampler,
  78. num_workers=self.num_local_io_workers,
  79. drop_last=self.dataloader_drop_last)
  80. else:
  81. self.dataloader = DataLoader(self.dataset,
  82. batch_size=self.batch_size,
  83. pin_memory=self.pin_memory,
  84. sampler=self.data_sampler,
  85. collate_fn=self.collate_fn,
  86. num_workers=self.num_local_io_workers,
  87. drop_last=self.dataloader_drop_last)
  88. self.data = (x for x in self.dataloader)
  89. return self.dataloader
  90. # DataLoader([(torch.randn(3, 3), torch.tensor(i % 2)) for i in range(10)], batch_size=2))