dataloader.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. '''
  2. Copyright 2019 The Microsoft DeepSpeed Team
  3. '''
  4. from torch.utils.data import DataLoader, RandomSampler
  5. from torch.utils.data.distributed import DistributedSampler
  6. from deepspeed.accelerator import get_accelerator
  7. from deepspeed.runtime.data_pipeline.data_sampling.data_sampler import DeepSpeedDataSampler
  8. from deepspeed.runtime.data_pipeline.constants import CURRICULUM_LEARNING, \
  9. DATA_EFFICIENCY, DATA_SAMPLING_NUM_WORKERS
  10. from deepspeed.runtime.constants import GRADIENT_ACCUMULATION_STEPS, \
  11. DATA_PARALLEL_GROUP, GLOBAL_RANK
  12. class RepeatingLoader:
  13. def __init__(self, loader):
  14. """Wraps an iterator to allow for infinite iteration. This is especially useful
  15. for DataLoader types that we wish to automatically restart upon completion.
  16. Args:
  17. loader (iterator): The data loader to repeat.
  18. """
  19. self.loader = loader
  20. self.data_iter = iter(self.loader)
  21. def __iter__(self):
  22. return self
  23. def __next__(self):
  24. try:
  25. batch = next(self.data_iter)
  26. except StopIteration:
  27. self.data_iter = iter(self.loader)
  28. batch = next(self.data_iter)
  29. return batch
  30. class DeepSpeedDataLoader(object):
  31. def __init__(self,
  32. dataset,
  33. batch_size,
  34. pin_memory,
  35. local_rank,
  36. tput_timer,
  37. collate_fn=None,
  38. num_local_io_workers=None,
  39. data_sampler=None,
  40. data_parallel_world_size=None,
  41. data_parallel_rank=None,
  42. dataloader_drop_last=False,
  43. deepspeed_dataloader_config={}):
  44. self.deepspeed_dataloader_config = deepspeed_dataloader_config
  45. self.tput_timer = tput_timer
  46. self.batch_size = batch_size
  47. self.curriculum_learning_enabled = False
  48. if CURRICULUM_LEARNING in deepspeed_dataloader_config:
  49. self.curriculum_learning_enabled = deepspeed_dataloader_config[
  50. CURRICULUM_LEARNING]
  51. if self.curriculum_learning_enabled:
  52. data_sampler = DeepSpeedDataSampler(
  53. self.deepspeed_dataloader_config[DATA_EFFICIENCY],
  54. len(dataset),
  55. self.batch_size,
  56. data_parallel_rank,
  57. data_parallel_world_size,
  58. self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP],
  59. self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS],
  60. self.deepspeed_dataloader_config[GLOBAL_RANK],
  61. drop_last=dataloader_drop_last)
  62. device_count = get_accelerator().device_count()
  63. num_local_io_workers = self.deepspeed_dataloader_config[
  64. DATA_SAMPLING_NUM_WORKERS]
  65. else:
  66. if local_rank >= 0:
  67. if data_sampler is None:
  68. data_sampler = DistributedSampler(
  69. dataset=dataset,
  70. num_replicas=data_parallel_world_size,
  71. rank=data_parallel_rank)
  72. device_count = 1
  73. else:
  74. if data_sampler is None:
  75. data_sampler = RandomSampler(dataset)
  76. device_count = get_accelerator().device_count()
  77. batch_size *= device_count
  78. if num_local_io_workers is None:
  79. num_local_io_workers = 2 * device_count
  80. self.num_local_io_workers = num_local_io_workers
  81. self.data_sampler = data_sampler
  82. self.dataset = dataset
  83. self.collate_fn = collate_fn
  84. self.device_count = device_count
  85. self.batch_size = batch_size
  86. self.pin_memory = pin_memory
  87. self.data = None
  88. self.dataloader_drop_last = dataloader_drop_last
  89. self.post_process_func = None
  90. if self.dataloader_drop_last:
  91. self.len = len(self.data_sampler) // self.batch_size
  92. else:
  93. from math import ceil
  94. self.len = ceil(len(self.data_sampler) / self.batch_size)
  95. def __iter__(self):
  96. self._create_dataloader()
  97. return self
  98. def __len__(self):
  99. return self.len
  100. def __next__(self):
  101. if self.tput_timer:
  102. self.tput_timer.start()
  103. if self.curriculum_learning_enabled:
  104. data = next(self.data_iterator)
  105. if self.post_process_func is not None:
  106. data = self.post_process_func(data, self.data_sampler.state_dict())
  107. return data
  108. else:
  109. return next(self.data)
  110. def _create_dataloader(self):
  111. if self.curriculum_learning_enabled:
  112. if self.collate_fn is None:
  113. self.dataloader = DataLoader(self.dataset,
  114. pin_memory=self.pin_memory,
  115. batch_sampler=self.data_sampler,
  116. num_workers=self.num_local_io_workers)
  117. else:
  118. self.dataloader = DataLoader(self.dataset,
  119. pin_memory=self.pin_memory,
  120. batch_sampler=self.data_sampler,
  121. collate_fn=self.collate_fn,
  122. num_workers=self.num_local_io_workers)
  123. self.data_iterator = iter(self.dataloader)
  124. return self.dataloader
  125. else:
  126. if self.collate_fn is None:
  127. self.dataloader = DataLoader(self.dataset,
  128. batch_size=self.batch_size,
  129. pin_memory=self.pin_memory,
  130. sampler=self.data_sampler,
  131. num_workers=self.num_local_io_workers,
  132. drop_last=self.dataloader_drop_last)
  133. else:
  134. self.dataloader = DataLoader(self.dataset,
  135. batch_size=self.batch_size,
  136. pin_memory=self.pin_memory,
  137. sampler=self.data_sampler,
  138. collate_fn=self.collate_fn,
  139. num_workers=self.num_local_io_workers,
  140. drop_last=self.dataloader_drop_last)
  141. self.data = (x for x in self.dataloader)
  142. return self.dataloader
  143. # DataLoader([(torch.randn(3, 3), torch.tensor(i % 2)) for i in range(10)], batch_size=2))