utils.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. Licensed under the MIT license.
  4. Functionality of swapping tensors to/from (NVMe) storage devices.
  5. """
  6. import os
  7. import torch
  8. from deepspeed.utils.logging import logger
  9. from deepspeed.runtime.swap_tensor.constants import AIO_BLOCK_SIZE, AIO_QUEUE_DEPTH, \
  10. AIO_THREAD_COUNT, AIO_SINGLE_SUBMIT, AIO_OVERLAP_EVENTS
  11. MIN_AIO_BYTES = 1024**2
  12. AIO_ALIGNED_BYTES = 1024
  13. def swap_in_tensors(swap_handle, tensor_buffers, swap_paths):
  14. for buffer, path in zip(tensor_buffers, swap_paths):
  15. assert (swap_handle.async_pread(buffer, path) == 0)
  16. def swap_out_tensors(swap_handle, tensor_buffers, swap_paths):
  17. for buffer, path in zip(tensor_buffers, swap_paths):
  18. assert (swap_handle.async_pwrite(buffer, path) == 0)
  19. def print_object(obj, name, exclude_list=[]):
  20. logger.info('{}:'.format(name))
  21. for arg in sorted(vars(obj)):
  22. if not arg in exclude_list:
  23. dots = '.' * (29 - len(arg))
  24. logger.info(' {} {} {}'.format(arg, dots, getattr(obj, arg)))
  25. class SwapBuffer(object):
  26. def __init__(self, buffer):
  27. self.buffer = buffer
  28. self.reset()
  29. def reset(self):
  30. self.offset = 0
  31. self.swap_tensors = {}
  32. self.compute_tensors = {}
  33. self.swap_paths = {}
  34. self.num_elem = 0
  35. def insert_tensor(self, tensor, swap_path, aligned_numel):
  36. swap_tensor, compute_tensor = self.allocate_tensor(swap_path, tensor.numel(), aligned_numel)
  37. compute_tensor.data.copy_(tensor.data)
  38. return swap_tensor, compute_tensor
  39. def allocate_tensor(self, swap_path, numel, aligned_numel):
  40. assert self.has_space(aligned_numel)
  41. assert not self.offset in self.swap_tensors
  42. allocate_offset = self.offset
  43. swap_tensor = self.buffer.narrow(0, allocate_offset, aligned_numel)
  44. dest_tensor = swap_tensor.narrow(0, 0, numel)
  45. self.swap_tensors[allocate_offset] = swap_tensor
  46. self.compute_tensors[allocate_offset] = dest_tensor
  47. self.swap_paths[allocate_offset] = swap_path
  48. self.offset += aligned_numel
  49. self.num_elem += numel
  50. return self.swap_tensors[allocate_offset], self.compute_tensors[allocate_offset]
  51. def has_space(self, numel):
  52. return (self.offset + numel) <= self.buffer.numel()
  53. def get_swap_tensors(self):
  54. return [tensor for tensor in self.swap_tensors.values()]
  55. def get_swap_paths(self):
  56. return [path for path in self.swap_paths.values()]
  57. def get_compute_tensors(self):
  58. return [tensor for tensor in self.compute_tensors.values()]
  59. def get_num_elem(self):
  60. return self.num_elem
  61. def get_swap_tensor(self, offset):
  62. return self.swap_tensors.get(offset, None)
  63. def get_compute_tensor(self, offset):
  64. return self.compute_tensors.get(offset, None)
  65. def get_swap_path(self, offset):
  66. return self.swap_paths(offset, None)
  67. class SwapBufferPool(object):
  68. def __init__(self, buffers):
  69. assert all([buf.is_pinned() for buf in buffers])
  70. self.buffers = [SwapBuffer(buf) for buf in buffers]
  71. self.current_index = 0
  72. def reset(self):
  73. self.current_index = 0
  74. for buffer in self.buffers:
  75. buffer.reset()
  76. def allocate_tensor(self, numel, swap_path, aligned_numel):
  77. if self.has_space(aligned_numel):
  78. swap_tensor, compute_tensor = self._get_current_buffer().allocate_tensor(swap_path, numel, aligned_numel)
  79. return swap_tensor, compute_tensor
  80. return None, None
  81. def insert_tensor(self, tensor, swap_path, aligned_numel):
  82. if self.has_space(aligned_numel):
  83. swap_tensor, compute_tensor = self._get_current_buffer().insert_tensor(tensor, swap_path, aligned_numel)
  84. return swap_tensor, compute_tensor
  85. return None, None
  86. def get_swap_tensors(self):
  87. swap_tensors = []
  88. for buffer in self._get_used_buffers():
  89. swap_tensors += buffer.get_swap_tensors()
  90. return swap_tensors
  91. def get_swap_paths(self):
  92. swap_paths = []
  93. for buffer in self._get_used_buffers():
  94. swap_paths += buffer.get_swap_paths()
  95. return swap_paths
  96. def get_compute_tensors(self):
  97. compute_tensors = []
  98. for buffer in self._get_used_buffers():
  99. compute_tensors += buffer.get_compute_tensors()
  100. return compute_tensors
  101. def has_space(self, numel):
  102. if self._get_current_buffer().has_space(numel):
  103. return True
  104. if self.current_index == len(self.buffers) - 1:
  105. return False
  106. self.current_index += 1
  107. return self._get_current_buffer().has_space(numel)
  108. def swap_out(self, aio_handle, async_op=False):
  109. swap_tensors = self.get_swap_tensors()
  110. swap_paths = self.get_swap_paths()
  111. assert all([p is not None for p in swap_paths])
  112. swap_out_tensors(aio_handle, swap_tensors, swap_paths)
  113. if not async_op:
  114. assert len(swap_tensors) == aio_handle.wait()
  115. def swap_in(self, aio_handle, async_op=False):
  116. swap_tensors = self.get_swap_tensors()
  117. swap_paths = self.get_swap_paths()
  118. assert all([p is not None for p in swap_paths])
  119. swap_in_tensors(aio_handle, swap_tensors, swap_paths)
  120. if not async_op:
  121. assert len(swap_tensors) == aio_handle.wait()
  122. def _get_current_buffer(self):
  123. return self.buffers[self.current_index]
  124. def _get_used_buffers(self):
  125. return self.buffers[:self.current_index + 1]
  126. class SwapBufferManager(object):
  127. def __init__(self, num_elems, count, dtype):
  128. self.num_elems = num_elems
  129. self.count = count
  130. self.dtype = dtype
  131. self.all_buffers = [
  132. torch.zeros(num_elems,
  133. device='cpu',
  134. dtype=dtype).pin_memory() for _ in range(count)
  135. ]
  136. self.free_buffer_index = [i for i in range(count)]
  137. self.used_buffer_index = {}
  138. self.gigabytes = (self.all_buffers[0].element_size() * num_elems * count) / (1024
  139. **3)
  140. if torch.distributed.get_rank() == 0:
  141. exclude_list = ['all_buffers']
  142. print_object(obj=self, name='SwapBufferManager', exclude_list=exclude_list)
  143. def allocate(self, num_elems, count, dtype):
  144. assert dtype == self.dtype
  145. assert num_elems <= self.num_elems
  146. if count > len(self.free_buffer_index):
  147. return None
  148. used_indices = self.free_buffer_index[-count:]
  149. self.free_buffer_index = self.free_buffer_index[:-count]
  150. buffers = []
  151. for i in used_indices:
  152. tmp_buffer = self.all_buffers[i].narrow(0, 0, num_elems)
  153. buffers.append(tmp_buffer)
  154. self.used_buffer_index[id(tmp_buffer)] = i
  155. return buffers
  156. def allocate_all(self, num_elems, dtype):
  157. return self.allocate(num_elems=num_elems,
  158. count=len(self.free_buffer_index),
  159. dtype=dtype)
  160. def free(self, buffers):
  161. buffer_ids = []
  162. for buf in buffers:
  163. buffer_ids.append(id(buf))
  164. assert all([b_id in self.used_buffer_index for b_id in buffer_ids])
  165. for b_id in buffer_ids:
  166. self.free_buffer_index.append(self.used_buffer_index[b_id])
  167. del (self.used_buffer_index[b_id])
  168. def get_sized_buffer(buffer, num_elems):
  169. assert num_elems <= buffer.numel(), \
  170. f'num_elems {num_elems} > buffer {buffer.numel()}'
  171. return buffer.narrow(0, 0, num_elems) if num_elems < buffer.numel() else buffer
  172. def get_sized_buffers(buffer_list, num_elems_list):
  173. swap_buffers = [
  174. get_sized_buffer(buffer, num_elems) \
  175. for buffer, num_elems in zip(buffer_list, num_elems_list)
  176. ]
  177. return swap_buffers