dataloader.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from torch.utils.data import DataLoader, RandomSampler
  5. from torch.utils.data.distributed import DistributedSampler
  6. from deepspeed.accelerator import get_accelerator
  7. from deepspeed.runtime.data_pipeline.data_sampling.data_sampler import DeepSpeedDataSampler
  8. from deepspeed.runtime.data_pipeline.constants import CURRICULUM_LEARNING, \
  9. DATA_EFFICIENCY, DATA_SAMPLING_NUM_WORKERS
  10. from deepspeed.runtime.constants import GRADIENT_ACCUMULATION_STEPS, \
  11. DATA_PARALLEL_GROUP, GLOBAL_RANK
  12. class RepeatingLoader:
  13. def __init__(self, loader):
  14. """Wraps an iterator to allow for infinite iteration. This is especially useful
  15. for DataLoader types that we wish to automatically restart upon completion.
  16. Args:
  17. loader (iterator): The data loader to repeat.
  18. """
  19. self.loader = loader
  20. self.data_iter = iter(self.loader)
  21. def __iter__(self):
  22. return self
  23. def __next__(self):
  24. try:
  25. batch = next(self.data_iter)
  26. except StopIteration:
  27. self.data_iter = iter(self.loader)
  28. batch = next(self.data_iter)
  29. return batch
  30. class DeepSpeedDataLoader(object):
  31. def __init__(self,
  32. dataset,
  33. batch_size,
  34. pin_memory,
  35. local_rank,
  36. tput_timer,
  37. collate_fn=None,
  38. num_local_io_workers=None,
  39. data_sampler=None,
  40. data_parallel_world_size=None,
  41. data_parallel_rank=None,
  42. dataloader_drop_last=False,
  43. deepspeed_dataloader_config={}):
  44. self.deepspeed_dataloader_config = deepspeed_dataloader_config
  45. self.tput_timer = tput_timer
  46. self.batch_size = batch_size
  47. self.curriculum_learning_enabled = False
  48. if CURRICULUM_LEARNING in deepspeed_dataloader_config:
  49. self.curriculum_learning_enabled = deepspeed_dataloader_config[CURRICULUM_LEARNING]
  50. if self.curriculum_learning_enabled:
  51. data_sampler = DeepSpeedDataSampler(self.deepspeed_dataloader_config[DATA_EFFICIENCY],
  52. len(dataset),
  53. self.batch_size,
  54. data_parallel_rank,
  55. data_parallel_world_size,
  56. self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP],
  57. self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS],
  58. self.deepspeed_dataloader_config[GLOBAL_RANK],
  59. drop_last=dataloader_drop_last)
  60. device_count = get_accelerator().device_count()
  61. num_local_io_workers = self.deepspeed_dataloader_config[DATA_SAMPLING_NUM_WORKERS]
  62. else:
  63. if local_rank >= 0:
  64. if data_sampler is None:
  65. data_sampler = DistributedSampler(dataset=dataset,
  66. num_replicas=data_parallel_world_size,
  67. rank=data_parallel_rank)
  68. device_count = 1
  69. else:
  70. if data_sampler is None:
  71. data_sampler = RandomSampler(dataset)
  72. device_count = get_accelerator().device_count()
  73. batch_size *= device_count
  74. if num_local_io_workers is None:
  75. num_local_io_workers = 2 * device_count
  76. self.num_local_io_workers = num_local_io_workers
  77. self.data_sampler = data_sampler
  78. self.dataset = dataset
  79. self.collate_fn = collate_fn
  80. self.device_count = device_count
  81. self.batch_size = batch_size
  82. self.pin_memory = pin_memory
  83. self.data = None
  84. self.dataloader_drop_last = dataloader_drop_last
  85. self.post_process_func = None
  86. if self.dataloader_drop_last:
  87. self.len = len(self.data_sampler) // self.batch_size
  88. else:
  89. from math import ceil
  90. self.len = ceil(len(self.data_sampler) / self.batch_size)
  91. def __iter__(self):
  92. self._create_dataloader()
  93. return self
  94. def __len__(self):
  95. return self.len
  96. def __next__(self):
  97. if self.tput_timer:
  98. self.tput_timer.start()
  99. if self.curriculum_learning_enabled:
  100. data = next(self.data_iterator)
  101. if self.post_process_func is not None:
  102. data = self.post_process_func(data, self.data_sampler.state_dict())
  103. return data
  104. else:
  105. return next(self.data)
  106. def _create_dataloader(self):
  107. if self.curriculum_learning_enabled:
  108. if self.collate_fn is None:
  109. self.dataloader = DataLoader(self.dataset,
  110. pin_memory=self.pin_memory,
  111. batch_sampler=self.data_sampler,
  112. num_workers=self.num_local_io_workers)
  113. else:
  114. self.dataloader = DataLoader(self.dataset,
  115. pin_memory=self.pin_memory,
  116. batch_sampler=self.data_sampler,
  117. collate_fn=self.collate_fn,
  118. num_workers=self.num_local_io_workers)
  119. self.data_iterator = iter(self.dataloader)
  120. return self.dataloader
  121. else:
  122. if self.collate_fn is None:
  123. self.dataloader = DataLoader(self.dataset,
  124. batch_size=self.batch_size,
  125. pin_memory=self.pin_memory,
  126. sampler=self.data_sampler,
  127. num_workers=self.num_local_io_workers,
  128. drop_last=self.dataloader_drop_last)
  129. else:
  130. self.dataloader = DataLoader(self.dataset,
  131. batch_size=self.batch_size,
  132. pin_memory=self.pin_memory,
  133. sampler=self.data_sampler,
  134. collate_fn=self.collate_fn,
  135. num_workers=self.num_local_io_workers,
  136. drop_last=self.dataloader_drop_last)
  137. self.data = (x for x in self.dataloader)
  138. return self.dataloader
  139. # DataLoader([(torch.randn(3, 3), torch.tensor(i % 2)) for i in range(10)], batch_size=2))