async_swapper.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Functionality of swapping tensors to/from (NVMe) storage devices.
  6. """
  7. import torch
  8. from deepspeed import comm as dist
  9. from deepspeed.utils.logging import logger
  10. from deepspeed.runtime.swap_tensor.utils import swap_out_tensors, SwapBuffer
  11. INVALID_BUFFER_INDEX = -1
  12. ASYNC_SWAPPER_WAIT_TIMER = 'async_swap_gradient_wait'
  13. class AsyncTensorSwapper(object):
  14. def __init__(self, aio_handle, numel_alignment, timers):
  15. self.free_buffer_index = []
  16. self.swapping_buffer_index = []
  17. self.ready_buffer_index = []
  18. self.current_buffer_index = INVALID_BUFFER_INDEX
  19. self.all_buffers = []
  20. self.aio_handle = aio_handle
  21. self.numel_alignment = numel_alignment
  22. self.max_numel = 0
  23. self.num_pending_swaps = 0
  24. self.timers = timers
  25. self.timer_names = set()
  26. self.num_elements_swapped = 0
  27. self.dtype = None
  28. def has_buffers(self):
  29. return len(self.all_buffers) > 0
  30. def add_buffers(self, buffer_list):
  31. assert len(self.all_buffers) == 0
  32. assert all([buffer.is_pinned() for buffer in buffer_list])
  33. dtype = buffer_list[0].dtype
  34. assert all([buffer.dtype == dtype for buffer in buffer_list])
  35. self.dtype = dtype
  36. self.all_buffers = [SwapBuffer(buffer) for buffer in buffer_list]
  37. self.free_buffer_index += [i for i in range(len(self.all_buffers))]
  38. self.max_numel = max([buffer.numel() for buffer in buffer_list])
  39. self.timer_names = set()
  40. def get_timer_names(self):
  41. return list(self.timer_names)
  42. def release_buffers(self):
  43. self._report_statistics('Swapped out[Before flush]')
  44. self._flush_buffers_until_complete()
  45. self._report_statistics('Swapped out[After flush]')
  46. pinned_buffers = [buf.buffer for buf in self.all_buffers]
  47. self.all_buffers = []
  48. self.free_buffer_index = []
  49. self.current_buffer_index = INVALID_BUFFER_INDEX
  50. self.num_elements_swapped = 0
  51. self.dtype = None
  52. return pinned_buffers
  53. def swap_out_tensors(self, tensor_list, path_list):
  54. for tensor, swap_path in zip(tensor_list, path_list):
  55. self._swap_out_tensor(tensor, swap_path)
  56. def _report_statistics(self, message):
  57. if dist.get_rank() == 0:
  58. element_size = torch.tensor([], dtype=self.dtype).element_size()
  59. swapped_GB = (self.num_elements_swapped * element_size) / (1024**3)
  60. logger.debug(f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB')
  61. def _swap_out_tensor(self, tensor, swap_path):
  62. assert len(self.all_buffers) > 0
  63. aligned_numel = self._io_aligned_numel(tensor.numel())
  64. assert aligned_numel <= self.max_numel
  65. self._make_swap_space(aligned_numel)
  66. assert self.current_buffer_index != INVALID_BUFFER_INDEX
  67. swap_buffer = self._get_current_buffer()
  68. swap_buffer.insert_tensor(tensor, swap_path, aligned_numel)
  69. def _make_swap_space(self, numel):
  70. if self.current_buffer_index == INVALID_BUFFER_INDEX:
  71. self._allocate_buffer()
  72. return
  73. if not self._get_current_buffer().has_space(numel):
  74. if len(self.free_buffer_index) > 0:
  75. self._flush_ready_buffers()
  76. else:
  77. self._flush_buffers_until_complete()
  78. self._allocate_buffer()
  79. def _io_aligned_numel(self, numel):
  80. remainder = numel % self.numel_alignment
  81. return numel if remainder == 0 else (numel + self.numel_alignment - remainder)
  82. def _allocate_buffer(self):
  83. assert self.current_buffer_index == INVALID_BUFFER_INDEX
  84. assert len(self.all_buffers) > 0
  85. assert len(self.free_buffer_index) > 0
  86. self.current_buffer_index = self.free_buffer_index[-1]
  87. self.free_buffer_index = self.free_buffer_index[:-1]
  88. def _flush_ready_buffers(self):
  89. if self.current_buffer_index != INVALID_BUFFER_INDEX:
  90. self.ready_buffer_index.append(self.current_buffer_index)
  91. self.current_buffer_index = INVALID_BUFFER_INDEX
  92. self._swap_out_ready_buffers()
  93. def _flush_buffers_until_complete(self):
  94. self._flush_ready_buffers()
  95. assert len(self.ready_buffer_index) == 0
  96. self._wait_for_swap_complete()
  97. assert len(self.swapping_buffer_index) == 0
  98. assert len(self.free_buffer_index) == len(self.all_buffers)
  99. def _swap_out_ready_buffers(self):
  100. for buffer_index in self.ready_buffer_index:
  101. buffer = self._get_buffer(buffer_index)
  102. swap_tensors = buffer.get_swap_tensors()
  103. swap_paths = buffer.get_swap_paths()
  104. self.num_pending_swaps += len(swap_tensors)
  105. swap_out_tensors(self.aio_handle, swap_tensors, swap_paths)
  106. self.swapping_buffer_index += self.ready_buffer_index
  107. self.ready_buffer_index = []
  108. def _wait_for_swap_complete(self):
  109. assert len(self.swapping_buffer_index) > 0
  110. self._start_timer(ASYNC_SWAPPER_WAIT_TIMER)
  111. assert self.aio_handle.wait() == self.num_pending_swaps
  112. self._stop_timer(ASYNC_SWAPPER_WAIT_TIMER)
  113. self.timer_names.add(ASYNC_SWAPPER_WAIT_TIMER)
  114. self.num_pending_swaps = 0
  115. for buffer_index in self.swapping_buffer_index:
  116. buffer = self._get_buffer(buffer_index)
  117. self.num_elements_swapped += buffer.get_num_elem()
  118. buffer.reset()
  119. self.free_buffer_index += self.swapping_buffer_index
  120. assert len(self.free_buffer_index) <= len(self.all_buffers)
  121. self.swapping_buffer_index = []
  122. def _get_buffer(self, index):
  123. assert index != INVALID_BUFFER_INDEX
  124. return self.all_buffers[index]
  125. def _get_current_buffer(self):
  126. return self._get_buffer(self.current_buffer_index)
  127. def _start_timer(self, name):
  128. if self.timers:
  129. self.timers(name).start()
  130. def _stop_timer(self, name):
  131. if self.timers:
  132. self.timers(name).stop()
  133. def _log_timers(self, name_list, force=False):
  134. if self.timers and force:
  135. self.timers.log(name_list)