partitioned_optimizer_swapper.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. Licensed under the MIT license.
  4. Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
  5. """
  6. import os
  7. import torch
  8. from deepspeed.utils.logging import logger
  9. from deepspeed.ops.aio import AsyncIOBuilder
  10. from deepspeed.runtime.swap_tensor.constants import *
  11. from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object, \
  12. MIN_AIO_BYTES, AIO_ALIGNED_BYTES, get_sized_buffers, get_sized_buffer
  13. from deepspeed.runtime.swap_tensor.async_swapper import AsyncTensorSwapper
  14. from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper
  15. DEBUG_MODE = False
  16. SWAP_IN_PARAM_TIMER = 'swap_in_param'
  17. SWAP_OUT_PARAM_TIMER = 'swap_out_param'
  18. SWAP_IN_GRADIENT_TIMER = 'swap_in_gradient'
  19. class PartitionedOptimizerSwapper(OptimizerSwapper):
  20. def __init__(self,
  21. swap_config,
  22. aio_config,
  23. base_folder,
  24. optimizer,
  25. largest_numel,
  26. device,
  27. dtype,
  28. timers):
  29. super(PartitionedOptimizerSwapper,
  30. self).__init__(swap_config,
  31. aio_config,
  32. base_folder,
  33. optimizer,
  34. largest_numel,
  35. device,
  36. dtype,
  37. timers)
  38. aio_op = AsyncIOBuilder().load()
  39. self.aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE],
  40. aio_config[AIO_QUEUE_DEPTH],
  41. aio_config[AIO_SINGLE_SUBMIT],
  42. aio_config[AIO_OVERLAP_EVENTS],
  43. aio_config[AIO_THREAD_COUNT])
  44. # Overlap swapping out
  45. self.gradient_swapper = AsyncTensorSwapper(aio_handle=self.aio_handle,
  46. numel_alignment=self.numel_alignment,
  47. timers=self.timers)
  48. self.print_exclude_list += [
  49. 'aio_handle',
  50. 'gradient_swapper',
  51. 'print_exclude_list'
  52. ]
  53. if torch.distributed.get_rank() == 0:
  54. print_object(obj=self,
  55. name='PartitionedOptimizerSwapper',
  56. exclude_list=self.print_exclude_list)
  57. def initialize_parameters(self, parameters, src_tensors):
  58. self._initialize_parameters(parameters=parameters,
  59. src_tensors=src_tensors,
  60. aio_handle=self.aio_handle)
  61. def initialize_from_swapped_fp16_params(self,
  62. fp16_partitions_info,
  63. fp16_num_elems,
  64. fp16_pinned_buffers,
  65. fp32_parameters):
  66. self._initialize_from_swapped_fp16_params(
  67. aio_handle=self.aio_handle,
  68. fp16_partitions_info=fp16_partitions_info,
  69. fp16_num_elems=fp16_num_elems,
  70. fp16_pinned_buffers=fp16_pinned_buffers,
  71. fp32_parameters=fp32_parameters)
  72. def flush_gradients(self):
  73. self._flush_gradient_swapper(self.gradient_swapper)
  74. def swap_in_optimizer_state(self, parameter, async_parameter=None):
  75. swap_info = self._get_param_swap_info(parameter)
  76. if swap_info is None:
  77. return
  78. self._flush_gradient_swapper(self.gradient_swapper)
  79. required_buffer_count = len(
  80. swap_info.tensors) + (1 if swap_info.has_gradients() else 0)
  81. aligned_numel = self._io_aligned_numel(swap_info.numel())
  82. pinned_buffers = self.swap_buffer_manager.allocate(num_elems=aligned_numel,
  83. count=required_buffer_count,
  84. dtype=parameter.dtype)
  85. assert pinned_buffers is not None
  86. self.allocated_swap_buffers = pinned_buffers.copy()
  87. self._start_timer(SWAP_IN_PARAM_TIMER)
  88. self._swap_in_parameter(aio_handle=self.aio_handle,
  89. parameter=parameter,
  90. dest_buffers=pinned_buffers[:required_buffer_count])
  91. self._stop_timer(SWAP_IN_PARAM_TIMER)
  92. self.timer_names.add(SWAP_IN_PARAM_TIMER)
  93. self._start_timer(SWAP_IN_GRADIENT_TIMER)
  94. self._swap_in_gradients(aio_handle=self.aio_handle,
  95. parameter=parameter,
  96. dest_buffer=pinned_buffers[-1])
  97. self._stop_timer(SWAP_IN_GRADIENT_TIMER)
  98. self.timer_names.add(SWAP_IN_GRADIENT_TIMER)
  99. def swap_out_optimizer_state(self, parameter, async_swap=False):
  100. swap_info = self._get_param_swap_info(parameter=parameter)
  101. if swap_info is None:
  102. return
  103. self._start_timer(SWAP_OUT_PARAM_TIMER)
  104. pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(swap_info)
  105. swap_bytes = sum([
  106. self._io_aligned_numel(t.numel()) * t.element_size()
  107. for t in swap_info.tensors
  108. ])
  109. WRITE_TIMER = 'swap_submit_write'
  110. self._start_timer(WRITE_TIMER)
  111. swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths)
  112. assert self.aio_handle.wait() == len(pinned_tensors)
  113. for t in pinned_tensors:
  114. t.data = torch.Tensor()
  115. if len(unpinned_tensors) > 0:
  116. pinned_buffers = self.swap_buffer_manager.allocate_all(
  117. num_elems=self.largest_numel,
  118. dtype=self.dtype)
  119. self._swap_out_unpinned_tensors(aio_handle=self.aio_handle,
  120. unpinned_tensors=unpinned_tensors,
  121. dest_paths=unpinned_paths,
  122. pinned_buffers=pinned_buffers)
  123. self.allocated_swap_buffers += pinned_buffers
  124. for t in unpinned_tensors:
  125. t.data = torch.Tensor()
  126. self._stop_timer(WRITE_TIMER)
  127. self.swap_buffer_manager.free(self.allocated_swap_buffers)
  128. self.allocated_swap_buffers = []
  129. self._stop_timer(SWAP_OUT_PARAM_TIMER)
  130. self.timer_names.add(SWAP_OUT_PARAM_TIMER)
  131. self._log_timers([WRITE_TIMER])
  132. if DEBUG_MODE and torch.distributed.get_rank() == 0:
  133. logger.info(f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB')
  134. def swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors):
  135. self._swap_out_gradients(parameter=parameter,
  136. gradient_offsets=gradient_offsets,
  137. gradient_tensors=gradient_tensors,
  138. gradient_swapper=self.gradient_swapper)
  139. def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
  140. swap_info = self._get_param_swap_info(parameter)
  141. if swap_info is None:
  142. return
  143. assert len(swap_info.tensors) <= len(dest_buffers)
  144. swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(
  145. swap_info.tensors)
  146. swap_buffers = get_sized_buffers(dest_buffers, swap_lengths)
  147. READ_TIMER = 'swap_submit_read_param'
  148. WAIT_TIMER = 'swap_wait_read_param'
  149. self._start_timer(READ_TIMER)
  150. swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths)
  151. self._stop_timer(READ_TIMER)
  152. swap_bytes = sum(
  153. [buffer.numel() * buffer.element_size() for buffer in swap_buffers])
  154. self._start_timer(WAIT_TIMER)
  155. aio_handle.wait()
  156. self._stop_timer(WAIT_TIMER)
  157. compute_lengths = [swap_info.numel()] * len(swap_info.tensors)
  158. compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)
  159. for t, buffer in zip(swap_info.tensors, compute_buffers):
  160. t.data = buffer.data
  161. self._log_timers([READ_TIMER, WAIT_TIMER])
  162. if DEBUG_MODE and torch.distributed.get_rank() == 0:
  163. logger.info(f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
  164. def _separate_pinned_tensors(self, swap_info):
  165. pinned_tensors = []
  166. pinned_paths = []
  167. unpinned_tensors = []
  168. unpinned_paths = []
  169. for tensor, path in zip(swap_info.tensors, swap_info.swap_paths):
  170. if tensor.is_pinned():
  171. pinned_tensors.append(tensor)
  172. pinned_paths.append(path)
  173. else:
  174. unpinned_tensors.append(tensor)
  175. unpinned_paths.append(path)
  176. return pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths
  177. def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor):
  178. swap_info = self.swap_params_info[id(parameter)]
  179. param_gradients = swap_info.swapped_gradients.values()
  180. swap_buffers = [
  181. gradient_tensor.narrow(0,
  182. grad.offset,
  183. grad.length) for grad in param_gradients
  184. ]
  185. swap_paths = [grad.path for grad in param_gradients]
  186. SWAP_READ_GRADIENTS = 'swap_submit_read_gradient'
  187. SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient'
  188. self._start_timer(SWAP_READ_GRADIENTS)
  189. swap_in_tensors(aio_handle, swap_buffers, swap_paths)
  190. self._stop_timer(SWAP_READ_GRADIENTS)
  191. self._start_timer(SWAP_WAIT_GRADIENTS)
  192. assert len(swap_buffers) == aio_handle.wait()
  193. self._stop_timer(SWAP_WAIT_GRADIENTS)
  194. self._log_timers([SWAP_READ_GRADIENTS, SWAP_WAIT_GRADIENTS])
  195. def _swap_in_gradients(self, aio_handle, parameter, dest_buffer):
  196. swap_info = self.swap_params_info.get(id(parameter), None)
  197. if not (swap_info and swap_info.has_gradients()):
  198. return
  199. assert dest_buffer.is_pinned()
  200. assert parameter.numel() <= dest_buffer.numel()
  201. parameter.grad = dest_buffer.narrow(0, 0, parameter.numel())
  202. if swap_info.swapped_gradients:
  203. self._swap_in_pinned_gradients(aio_handle, parameter, parameter.grad)
  204. if swap_info.unswapped_gradients:
  205. self._retrieve_unswapped_grad_partitions(swap_info=swap_info,
  206. dest_buffer=parameter.grad)