optimizer_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Functionality of swapping tensors to/from (NVMe) storage devices.
  6. """
  7. import os
  8. import torch
  9. from deepspeed import comm as dist
  10. from deepspeed.utils.logging import logger
  11. from deepspeed.runtime.swap_tensor.constants import *
  12. from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, \
  13. MIN_AIO_BYTES, AIO_ALIGNED_BYTES, get_sized_buffers
  14. from deepspeed.runtime.swap_tensor.utils import SwapBufferManager, SwapBufferPool
  15. from deepspeed.accelerator import get_accelerator
  16. class FlattenedTensorSwapInfo(object):
  17. def __init__(self, path, length, offset):
  18. self.path = path
  19. self.offset = offset
  20. self.length = length
  21. class OptimizerStateSwapInfo(object):
  22. def __init__(self, parameter, numel, base_folder):
  23. self.tensors = []
  24. self.param_id = OptimizerSwapper.parameter_id(parameter)
  25. self.swap_folder = base_folder
  26. self.swap_paths = []
  27. self.swapped_gradients = {}
  28. self.unswapped_gradients = {}
  29. self.tensor_numel = numel
  30. self.tensor_dtype = parameter.dtype
  31. self.tensor_device = parameter.device
  32. self.has_state_tensors = False
  33. self._add_tensors([parameter])
  34. def numel(self):
  35. return self.tensor_numel
  36. def has_gradients(self):
  37. return self.swapped_gradients or self.unswapped_gradients
  38. def _add_tensors(self, tensor_list):
  39. for t in tensor_list:
  40. self.tensors.append(t)
  41. self.swap_paths.append(os.path.join(self.swap_folder, f'{OptimizerSwapper.parameter_id(t)}.tensor.swp'))
  42. def add_state_tensors(self, tensor_list):
  43. self.has_state_tensors = True
  44. self._add_tensors(tensor_list)
  45. def device(self):
  46. return self.tensor_device
  47. def dtype(self):
  48. return self.tensor_dtype
  49. def release_memory(self):
  50. for tensor in self.tensors:
  51. tensor.data = torch.Tensor()
  52. def get_or_create_gradient_paths(self, offsets, lengths):
  53. gradient_paths = []
  54. for offset, length in zip(offsets, lengths):
  55. if not offset in self.swapped_gradients.keys():
  56. path = os.path.join(self.swap_folder, f'{self.param_id}_gradient_{offset}_{length}.tensor.swp')
  57. self.swapped_gradients[offset] = FlattenedTensorSwapInfo(path, length, offset)
  58. gradient_paths.append(self.swapped_gradients[offset].path)
  59. return gradient_paths
  60. def set_swap_buffers(self, buffers):
  61. compute_lengths = [self.numel()] * len(self.tensors)
  62. compute_buffers = get_sized_buffers(buffers, compute_lengths)
  63. for t, buffer in zip(self.tensors, compute_buffers):
  64. t.data = buffer.data
  65. def get_swap_gradient_buffers(self, swap_buffer):
  66. assert self.numel() <= swap_buffer.numel()
  67. return [swap_buffer.narrow(0, grad.offset, grad.length) for grad in self.swapped_gradients.values()]
  68. def get_swap_gradient_paths(self):
  69. return [grad.path for grad in self.swapped_gradients.values()]
  70. def get_unpinned_state_tensors(self):
  71. return [t for t in self.tensors if not get_accelerator().is_pinned(t)]
  72. def read_unswapped_gradients(self, dest_buffer):
  73. num_elem_count = 0
  74. for offset, grad_partition in self.unswapped_gradients.items():
  75. dst_tensor = dest_buffer.narrow(0, offset, grad_partition.numel())
  76. dst_tensor.data.copy_(grad_partition.data)
  77. num_elem_count += grad_partition.numel()
  78. return num_elem_count
  79. def release_unswapped_gradients(self):
  80. self.unswapped_gradients = {}
  81. SWAPPER_DEBUG_MODE = False
  82. SWAP_OUT_GRADIENT_TIMER = 'swap_out_gradient'
  83. class OptimizerSwapper(object):
  84. @staticmethod
  85. def parameter_id(param):
  86. return param.ds_id
  87. def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_numel, device, dtype, timers):
  88. self.swap_config = swap_config
  89. self.aio_config = aio_config
  90. # NVMe swap management
  91. self.swap_params_info = {}
  92. self.swap_element_size = torch.tensor([], dtype=dtype).element_size()
  93. self.swap_folder = os.path.join(base_folder, 'optimizer', f'rank{dist.get_rank()}')
  94. os.makedirs(self.swap_folder, exist_ok=True)
  95. self.optimizer = optimizer
  96. # Read/Write alignment for each thread during Intra-request parallelism
  97. self.min_aio_bytes = max(MIN_AIO_BYTES, aio_config[AIO_BLOCK_SIZE])
  98. self.aligned_bytes = AIO_ALIGNED_BYTES * aio_config[AIO_THREAD_COUNT]
  99. self.numel_alignment = self.aligned_bytes // self.swap_element_size
  100. # Swap buffer management
  101. self.largest_numel = self._io_aligned_numel(largest_numel)
  102. self.dtype = dtype
  103. self.swap_buffer_manager = SwapBufferManager(num_elems=self.largest_numel,
  104. count=swap_config.buffer_count,
  105. dtype=dtype)
  106. # Timers
  107. self.timers = timers
  108. self.timer_names = set()
  109. # Print exclusion list
  110. self.print_exclude_list = [
  111. 'optimizer',
  112. 'swap_buffer_manager',
  113. 'swap_params_info',
  114. 'timers',
  115. 'timer_names',
  116. ]
  117. def swappable_tensor(self, param=None, numel=None):
  118. assert param is not None or numel is not None, "Either param or numel must be provided"
  119. if param is not None:
  120. return self.min_aio_bytes <= (param.numel() * self.swap_element_size)
  121. return self.min_aio_bytes <= (numel * self.swap_element_size)
  122. def init_timers(self):
  123. self.timer_names = set()
  124. def log_timers(self):
  125. if self.timer_names:
  126. self._log_timers(list(self.timer_names), force=True)
  127. def pre_backward(self):
  128. self.init_timers()
  129. def post_backward(self):
  130. pass
  131. def _flush_gradient_swapper(self, gradient_swapper):
  132. if gradient_swapper.has_buffers():
  133. self._start_timer(SWAP_OUT_GRADIENT_TIMER)
  134. pinned_buffers = gradient_swapper.release_buffers()
  135. self.swap_buffer_manager.free(pinned_buffers)
  136. self._stop_timer(SWAP_OUT_GRADIENT_TIMER)
  137. self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
  138. self.timer_names.update(gradient_swapper.get_timer_names())
  139. def _swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors, gradient_swapper):
  140. if not OptimizerSwapper.parameter_id(parameter) in self.swap_params_info.keys():
  141. return
  142. swap_info = self.swap_params_info[OptimizerSwapper.parameter_id(parameter)]
  143. swappable_tensors = []
  144. swappable_offsets = []
  145. swappable_lengths = []
  146. aligned_gradients, aligned_offsets = self._adjust_for_misaligned_lengths(tensors=gradient_tensors,
  147. offsets=gradient_offsets)
  148. self._start_timer(SWAP_OUT_GRADIENT_TIMER)
  149. for tensor, offset in zip(aligned_gradients, aligned_offsets):
  150. if not self.swappable_tensor(param=tensor):
  151. swap_info.unswapped_gradients[offset] = tensor
  152. continue
  153. swappable_tensors.append(tensor)
  154. swappable_offsets.append(offset)
  155. swappable_lengths.append(tensor.numel())
  156. if len(swappable_tensors) > 0:
  157. if not gradient_swapper.has_buffers():
  158. pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
  159. gradient_swapper.add_buffers(pinned_buffers)
  160. swappable_paths = swap_info.get_or_create_gradient_paths(swappable_offsets, swappable_lengths)
  161. gradient_swapper.swap_out_tensors(tensor_list=swappable_tensors, path_list=swappable_paths)
  162. self._stop_timer(SWAP_OUT_GRADIENT_TIMER)
  163. self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
  164. def _initialize_from_swapped_fp16_params(self, aio_handle, fp16_partitions_info, fp16_num_elems,
  165. fp16_pinned_buffers, fp32_parameters):
  166. assert len(fp32_parameters) == len(fp16_partitions_info)
  167. assert len(fp32_parameters) == len(fp16_num_elems)
  168. assert all([get_accelerator().is_pinned(buffer) for buffer in fp16_pinned_buffers])
  169. fp32_swap_paths = self._get_swap_paths(parameters=fp32_parameters, num_elems=fp16_num_elems)
  170. fp32_pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
  171. fp16_buffer_numel = [buf.numel() for buf in fp16_pinned_buffers]
  172. assert all([numel >= self.largest_numel for numel in fp16_buffer_numel]), \
  173. f"numel of fp16 buffers {fp16_buffer_numel} is too small for initializing fp32 params {self.largest_numel}"
  174. fp32_swap_buffers = SwapBufferPool(fp32_pinned_buffers)
  175. fp16_swap_buffers = SwapBufferPool(fp16_pinned_buffers)
  176. curr_index = 0
  177. while curr_index < len(fp32_parameters):
  178. fp16_pinned_tensors = self._swap_in_fp16_params(aio_handle=aio_handle,
  179. fp16_num_elems=fp16_num_elems[curr_index:],
  180. fp16_partitions_info=fp16_partitions_info[curr_index:],
  181. fp16_swap_buffers=fp16_swap_buffers)
  182. if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE:
  183. for i, tensor in enumerate(fp16_pinned_tensors):
  184. true_index = curr_index + i
  185. logger.info(
  186. f'swap_in_fp16_param: fp32_id = {OptimizerSwapper.parameter_id(fp32_parameters[true_index])} index = {true_index} orig_num_elem = {fp16_num_elems[true_index]}, swap_num_elem = {fp16_pinned_tensors[i].numel()}'
  187. )
  188. swap_out_count = self._swap_out_fp16_params(aio_handle=aio_handle,
  189. fp32_swap_paths=fp32_swap_paths[curr_index:],
  190. fp32_swap_buffers=fp32_swap_buffers,
  191. fp16_pinned_tensors=fp16_pinned_tensors)
  192. assert swap_out_count == len(fp16_pinned_tensors), \
  193. f"{swap_out_count} does not match {len(fp16_pinned_tensors)}"
  194. fp16_swap_buffers.reset()
  195. fp32_swap_buffers.reset()
  196. curr_index += swap_out_count
  197. self.swap_buffer_manager.free(fp32_pinned_buffers)
  198. def _swap_in_fp16_params(self, aio_handle, fp16_num_elems, fp16_partitions_info, fp16_swap_buffers):
  199. assert len(fp16_num_elems) > 0
  200. swapped_fp16_tensors = []
  201. swap_tensors = []
  202. swap_paths = []
  203. unswapped_srcs = []
  204. unswapped_dsts = []
  205. for i, numel in enumerate(fp16_num_elems):
  206. pinned_tensor, _ = fp16_swap_buffers.allocate_tensor(numel, None, numel)
  207. if pinned_tensor is None:
  208. break
  209. swapped_fp16_tensors.append(pinned_tensor)
  210. offset = 0
  211. for tensor, partition_numel, partition_path in fp16_partitions_info[i]:
  212. dst_tensor = pinned_tensor.narrow(0, offset, partition_numel)
  213. if partition_path is None:
  214. unswapped_srcs.append(tensor)
  215. unswapped_dsts.append(dst_tensor)
  216. else:
  217. swap_paths.append(partition_path)
  218. swap_tensors.append(dst_tensor)
  219. offset += partition_numel
  220. assert len(swapped_fp16_tensors) + len(unswapped_srcs) > 0
  221. ret = swap_in_tensors(aio_handle, swap_tensors, swap_paths)
  222. for src, dst in zip(unswapped_srcs, unswapped_dsts):
  223. dst.data.copy_(src.data)
  224. assert len(swap_tensors) == aio_handle.wait()
  225. return swapped_fp16_tensors
  226. def _swap_out_fp16_params(self, aio_handle, fp32_swap_paths, fp32_swap_buffers, fp16_pinned_tensors):
  227. assert len(fp16_pinned_tensors) <= len(fp32_swap_paths)
  228. swap_out_count = 0
  229. for i, fp16_tensor in enumerate(fp16_pinned_tensors):
  230. if not fp32_swap_buffers.has_space(fp16_tensor.numel()):
  231. fp32_swap_buffers.swap_out(aio_handle)
  232. fp32_swap_buffers.reset()
  233. pinned_tensor, _ = fp32_swap_buffers.insert_tensor(fp16_tensor, fp32_swap_paths[i],
  234. self._io_aligned_numel(fp16_tensor.numel()))
  235. assert pinned_tensor is not None
  236. swap_out_count += 1
  237. if len(fp32_swap_buffers.get_swap_tensors()) > 0:
  238. fp32_swap_buffers.swap_out(aio_handle)
  239. return swap_out_count
  240. def _initialize_parameters(self, parameters, src_tensors, aio_handle):
  241. assert len(parameters) == len(src_tensors)
  242. swap_paths = self._get_swap_paths(parameters=parameters, num_elems=[src.numel() for src in src_tensors])
  243. SWAP_INIT_TIMER = "swap_init_write"
  244. self._start_timer(SWAP_INIT_TIMER)
  245. pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
  246. assert pinned_buffers is not None
  247. self._swap_out_unpinned_tensors(aio_handle=aio_handle,
  248. unpinned_tensors=src_tensors,
  249. dest_paths=swap_paths,
  250. pinned_buffers=pinned_buffers)
  251. if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE:
  252. for i, tensor in enumerate(src_tensors):
  253. logger.info(
  254. f'copy_in_fp16_param: fp32_id = {OptimizerSwapper.parameter_id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}'
  255. )
  256. self.swap_buffer_manager.free(pinned_buffers)
  257. self._stop_timer(SWAP_INIT_TIMER)
  258. self._log_timers([SWAP_INIT_TIMER])
  259. def _get_swap_paths(self, parameters, num_elems):
  260. swap_info_list = [
  261. self._create_param_swap_info(parameter=p,
  262. numel=numel) \
  263. for p, numel in zip(parameters, num_elems)
  264. ]
  265. assert len(swap_info_list) == len(num_elems)
  266. swap_paths = [info.swap_paths[0] for info in swap_info_list]
  267. return swap_paths
  268. def _swap_out_unpinned_tensors(self, aio_handle, unpinned_tensors, dest_paths, pinned_buffers):
  269. swap_buffer_count = len(pinned_buffers)
  270. unpinned_tensor_count = len(unpinned_tensors)
  271. for i in range(0, unpinned_tensor_count, swap_buffer_count):
  272. swap_tensor_count = min((unpinned_tensor_count - i), swap_buffer_count)
  273. src_tensors = unpinned_tensors[i:(i + swap_tensor_count)]
  274. compute_lengths = [t.numel() for t in src_tensors]
  275. compute_buffers = get_sized_buffers(pinned_buffers, compute_lengths)
  276. for dst, src in zip(compute_buffers, src_tensors):
  277. dst.data.copy_(src.data)
  278. swap_lengths = [self._io_aligned_numel(t.numel()) for t in src_tensors]
  279. swap_buffers = get_sized_buffers(pinned_buffers, swap_lengths)
  280. swap_paths = dest_paths[i:(i + swap_tensor_count)]
  281. swap_out_tensors(aio_handle, swap_buffers, swap_paths)
  282. assert aio_handle.wait() == swap_tensor_count
  283. def _adjust_for_misaligned_lengths(self, tensors, offsets):
  284. new_tensors = []
  285. new_offsets = []
  286. for orig_tensor, orig_offset in zip(tensors, offsets):
  287. if not self.swappable_tensor(param=orig_tensor):
  288. new_tensors.append(orig_tensor)
  289. new_offsets.append(orig_offset)
  290. continue
  291. remainder = orig_tensor.numel() % self.numel_alignment
  292. if remainder == 0:
  293. new_tensors.append(orig_tensor)
  294. new_offsets.append(orig_offset)
  295. continue
  296. # Split into two by making remainder a tensor
  297. aligned_length = (orig_tensor.numel() // self.numel_alignment) * self.numel_alignment
  298. new_tensors.append(orig_tensor.narrow(0, 0, aligned_length))
  299. new_offsets.append(orig_offset)
  300. # remainder tensor
  301. new_tensors.append(orig_tensor.narrow(0, aligned_length, remainder))
  302. new_offsets.append(orig_offset + aligned_length)
  303. return new_tensors, new_offsets
  304. def _retrieve_unswapped_grad_partitions(self, swap_info, dest_buffer):
  305. UNSWAPPED_READ_GRADIENTS = 'unswapped_read_gradients'
  306. self._start_timer(UNSWAPPED_READ_GRADIENTS)
  307. tensor_count = len(swap_info.unswapped_gradients)
  308. num_elem_count = swap_info.read_unswapped_gradients(dest_buffer)
  309. self._stop_timer(UNSWAPPED_READ_GRADIENTS)
  310. self._log_timers([UNSWAPPED_READ_GRADIENTS])
  311. # It should be safe to discard unswapped gradient partitions
  312. swap_info.release_unswapped_gradients()
  313. if SWAPPER_DEBUG_MODE:
  314. logger.info(
  315. f'optimizer_retrieve_unswapped_gradients: param={swap_info.param_id} tensor_count={tensor_count} elem_count={num_elem_count}'
  316. )
  317. def _get_state_tensors(self, parameter):
  318. if not parameter in self.optimizer.state:
  319. return []
  320. tensor_list = []
  321. for state_name, value in self.optimizer.state[parameter].items():
  322. if torch.is_tensor(value):
  323. value.ds_id = state_name + '-' + parameter.ds_id
  324. tensor_list.append(value)
  325. return tensor_list
  326. def _update_param_state_info(self, swap_info, parameter):
  327. if not swap_info.has_state_tensors:
  328. state_tensors = self._get_state_tensors(parameter)
  329. if state_tensors:
  330. swap_info.add_state_tensors(state_tensors)
  331. def _create_param_swap_info(self, parameter, numel):
  332. param_id = OptimizerSwapper.parameter_id(parameter)
  333. assert not param_id in self.swap_params_info
  334. self.swap_params_info[param_id] = OptimizerStateSwapInfo(parameter=parameter,
  335. numel=numel,
  336. base_folder=self.swap_folder)
  337. swap_info = self.swap_params_info[param_id]
  338. self._update_param_state_info(swap_info, parameter)
  339. return swap_info
  340. def _get_param_swap_info(self, parameter):
  341. param_id = OptimizerSwapper.parameter_id(parameter)
  342. swap_info = self.swap_params_info.get(param_id, None)
  343. if swap_info is not None:
  344. self._update_param_state_info(swap_info, parameter)
  345. return swap_info
  346. def _start_timer(self, name):
  347. if self.timers:
  348. self.timers(name).start()
  349. def _stop_timer(self, name):
  350. if self.timers:
  351. self.timers(name).stop()
  352. def _log_timers(self, name_list, force=False):
  353. if self.timers and (SWAPPER_DEBUG_MODE or force):
  354. self.timers.log(name_list)
  355. def _io_aligned_numel(self, numel):
  356. remainder = numel % self.numel_alignment
  357. return numel if remainder == 0 else (numel + self.numel_alignment - remainder)