async_swapper.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team.
  3. Licensed under the MIT license.
  4. Functionality of swapping tensors to/from (NVMe) storage devices.
  5. """
  6. import torch
  7. from deepspeed.utils.logging import logger
  8. from deepspeed.runtime.swap_tensor.utils import swap_out_tensors, SwapBuffer
  9. INVALID_BUFFER_INDEX = -1
  10. ASYNC_SWAPPER_WAIT_TIMER = 'async_swap_gradient_wait'
  11. class AsyncTensorSwapper(object):
  12. def __init__(self, aio_handle, numel_alignment, timers):
  13. self.free_buffer_index = []
  14. self.swapping_buffer_index = []
  15. self.ready_buffer_index = []
  16. self.current_buffer_index = INVALID_BUFFER_INDEX
  17. self.all_buffers = []
  18. self.aio_handle = aio_handle
  19. self.numel_alignment = numel_alignment
  20. self.max_numel = 0
  21. self.num_pending_swaps = 0
  22. self.timers = timers
  23. self.timer_names = set()
  24. self.num_elements_swapped = 0
  25. self.dtype = None
  26. def has_buffers(self):
  27. return len(self.all_buffers) > 0
  28. def add_buffers(self, buffer_list):
  29. assert len(self.all_buffers) == 0
  30. assert all([buffer.is_pinned() for buffer in buffer_list])
  31. dtype = buffer_list[0].dtype
  32. assert all([buffer.dtype == dtype for buffer in buffer_list])
  33. self.dtype = dtype
  34. self.all_buffers = [SwapBuffer(buffer) for buffer in buffer_list]
  35. self.free_buffer_index += [i for i in range(len(self.all_buffers))]
  36. self.max_numel = max([buffer.numel() for buffer in buffer_list])
  37. self.timer_names = set()
  38. def get_timer_names(self):
  39. return list(self.timer_names)
  40. def release_buffers(self):
  41. self._report_statistics('Swapped out[Before flush]')
  42. self._flush_buffers_until_complete()
  43. self._report_statistics('Swapped out[After flush]')
  44. pinned_buffers = [buf.buffer for buf in self.all_buffers]
  45. self.all_buffers = []
  46. self.free_buffer_index = []
  47. self.current_buffer_index = INVALID_BUFFER_INDEX
  48. self.num_elements_swapped = 0
  49. self.dtype = None
  50. return pinned_buffers
  51. def swap_out_tensors(self, tensor_list, path_list):
  52. for tensor, swap_path in zip(tensor_list, path_list):
  53. self._swap_out_tensor(tensor, swap_path)
  54. def _report_statistics(self, message):
  55. if torch.distributed.get_rank() == 0:
  56. element_size = torch.tensor([], dtype=self.dtype).element_size()
  57. swapped_GB = (self.num_elements_swapped * element_size) / (1024**3)
  58. logger.info(
  59. f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB'
  60. )
  61. def _swap_out_tensor(self, tensor, swap_path):
  62. assert len(self.all_buffers) > 0
  63. aligned_numel = self._io_aligned_numel(tensor.numel())
  64. assert aligned_numel <= self.max_numel
  65. self._make_swap_space(aligned_numel)
  66. assert self.current_buffer_index != INVALID_BUFFER_INDEX
  67. swap_buffer = self._get_current_buffer()
  68. swap_buffer.insert_tensor(tensor, swap_path, aligned_numel)
  69. def _make_swap_space(self, numel):
  70. if self.current_buffer_index == INVALID_BUFFER_INDEX:
  71. self._allocate_buffer()
  72. return
  73. if not self._get_current_buffer().has_space(numel):
  74. if len(self.free_buffer_index) > 0:
  75. self._flush_ready_buffers()
  76. else:
  77. self._flush_buffers_until_complete()
  78. self._allocate_buffer()
  79. def _io_aligned_numel(self, numel):
  80. remainder = numel % self.numel_alignment
  81. return numel if remainder == 0 else (numel + self.numel_alignment - remainder)
  82. def _allocate_buffer(self):
  83. assert self.current_buffer_index == INVALID_BUFFER_INDEX
  84. assert len(self.all_buffers) > 0
  85. assert len(self.free_buffer_index) > 0
  86. self.current_buffer_index = self.free_buffer_index[-1]
  87. self.free_buffer_index = self.free_buffer_index[:-1]
  88. def _flush_ready_buffers(self):
  89. if self.current_buffer_index != INVALID_BUFFER_INDEX:
  90. self.ready_buffer_index.append(self.current_buffer_index)
  91. self.current_buffer_index = INVALID_BUFFER_INDEX
  92. self._swap_out_ready_buffers()
  93. def _flush_buffers_until_complete(self):
  94. self._flush_ready_buffers()
  95. assert len(self.ready_buffer_index) == 0
  96. self._wait_for_swap_complete()
  97. assert len(self.swapping_buffer_index) == 0
  98. assert len(self.free_buffer_index) == len(self.all_buffers)
  99. def _swap_out_ready_buffers(self):
  100. for buffer_index in self.ready_buffer_index:
  101. buffer = self._get_buffer(buffer_index)
  102. swap_tensors = buffer.get_swap_tensors()
  103. swap_paths = buffer.get_swap_paths()
  104. self.num_pending_swaps += len(swap_tensors)
  105. swap_out_tensors(self.aio_handle, swap_tensors, swap_paths)
  106. self.swapping_buffer_index += self.ready_buffer_index
  107. self.ready_buffer_index = []
  108. def _wait_for_swap_complete(self):
  109. assert len(self.swapping_buffer_index) > 0
  110. self._start_timer(ASYNC_SWAPPER_WAIT_TIMER)
  111. assert self.aio_handle.wait() == self.num_pending_swaps
  112. self._stop_timer(ASYNC_SWAPPER_WAIT_TIMER)
  113. self.timer_names.add(ASYNC_SWAPPER_WAIT_TIMER)
  114. self.num_pending_swaps = 0
  115. for buffer_index in self.swapping_buffer_index:
  116. buffer = self._get_buffer(buffer_index)
  117. self.num_elements_swapped += buffer.get_num_elem()
  118. buffer.reset()
  119. self.free_buffer_index += self.swapping_buffer_index
  120. assert len(self.free_buffer_index) <= len(self.all_buffers)
  121. self.swapping_buffer_index = []
  122. def _get_buffer(self, index):
  123. assert index != INVALID_BUFFER_INDEX
  124. return self.all_buffers[index]
  125. def _get_current_buffer(self):
  126. return self._get_buffer(self.current_buffer_index)
  127. def _start_timer(self, name):
  128. if self.timers:
  129. self.timers(name).start()
  130. def _stop_timer(self, name):
  131. if self.timers:
  132. self.timers(name).stop()
  133. def _log_timers(self, name_list, force=False):
  134. if self.timers and force:
  135. self.timers.log(name_list)