optimizer_utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team.
  3. Licensed under the MIT license.
  4. Functionality of swapping tensors to/from (NVMe) storage devices.
  5. """
  6. import os
  7. import torch
  8. from deepspeed.utils.logging import logger
  9. from deepspeed.runtime.zero.offload_constants import *
  10. from deepspeed.runtime.swap_tensor.constants import *
  11. from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, \
  12. MIN_AIO_BYTES, AIO_ALIGNED_BYTES, get_sized_buffers, get_sized_buffer
  13. from deepspeed.runtime.swap_tensor.utils import SwapBufferManager, SwapBufferPool
  14. class FlattenedTensorSwapInfo(object):
  15. def __init__(self, path, length, offset):
  16. self.path = path
  17. self.offset = offset
  18. self.length = length
  19. class OptimizerStateSwapInfo(object):
  20. def __init__(self, parameter, numel, base_folder):
  21. self.tensors = []
  22. self.param_id = id(parameter)
  23. self.swap_folder = base_folder
  24. self.swap_paths = []
  25. self.swapped_gradients = {}
  26. self.unswapped_gradients = {}
  27. self.tensor_numel = numel
  28. self.tensor_dtype = parameter.dtype
  29. self.tensor_device = parameter.device
  30. self.has_state_tensors = False
  31. self._add_tensors([parameter])
  32. def numel(self):
  33. return self.tensor_numel
  34. def has_gradients(self):
  35. return self.swapped_gradients or self.unswapped_gradients
  36. def _add_tensors(self, tensor_list):
  37. for t in tensor_list:
  38. self.tensors.append(t)
  39. self.swap_paths.append(os.path.join(self.swap_folder, f'{id(t)}.tensor.swp'))
  40. def add_state_tensors(self, tensor_list):
  41. self.has_state_tensors = True
  42. self._add_tensors(tensor_list)
  43. def device(self):
  44. return self.tensor_device
  45. def dtype(self):
  46. return self.tensor_dtype
  47. def release_memory(self):
  48. for tensor in self.tensors:
  49. tensor.data = torch.Tensor()
  50. def get_or_create_gradient_paths(self, offsets, lengths):
  51. gradient_paths = []
  52. for offset, length in zip(offsets, lengths):
  53. if not offset in self.swapped_gradients.keys():
  54. path = os.path.join(
  55. self.swap_folder,
  56. f'{self.param_id}_gradient_{offset}_{length}.tensor.swp')
  57. self.swapped_gradients[offset] = FlattenedTensorSwapInfo(
  58. path,
  59. length,
  60. offset)
  61. gradient_paths.append(self.swapped_gradients[offset].path)
  62. return gradient_paths
  63. def set_swap_buffers(self, buffers):
  64. compute_lengths = [self.numel()] * len(self.tensors)
  65. compute_buffers = get_sized_buffers(buffers, compute_lengths)
  66. for t, buffer in zip(self.tensors, compute_buffers):
  67. t.data = buffer.data
  68. def get_swap_gradient_buffers(self, swap_buffer):
  69. assert self.numel() <= swap_buffer.numel()
  70. return [
  71. swap_buffer.narrow(0,
  72. grad.offset,
  73. grad.length) for grad in self.swapped_gradients.values()
  74. ]
  75. def get_swap_gradient_paths(self):
  76. return [grad.path for grad in self.swapped_gradients.values()]
  77. def get_unpinned_state_tensors(self):
  78. return [t for t in self.tensors if not t.is_pinned()]
  79. def read_unswapped_gradients(self, dest_buffer):
  80. num_elem_count = 0
  81. for offset, grad_partition in self.unswapped_gradients.items():
  82. dst_tensor = dest_buffer.narrow(0, offset, grad_partition.numel())
  83. dst_tensor.data.copy_(grad_partition.data)
  84. num_elem_count += grad_partition.numel()
  85. return num_elem_count
  86. def release_unswapped_gradients(self):
  87. self.unswapped_gradients = {}
  88. SWAPPER_DEBUG_MODE = False
  89. SWAP_OUT_GRADIENT_TIMER = 'swap_out_gradient'
  90. class OptimizerSwapper(object):
  91. def __init__(self,
  92. swap_config,
  93. aio_config,
  94. base_folder,
  95. optimizer,
  96. largest_numel,
  97. device,
  98. dtype,
  99. timers):
  100. self.swap_config = swap_config
  101. self.aio_config = aio_config
  102. # NVMe swap management
  103. self.swap_params_info = {}
  104. self.swap_element_size = torch.tensor([], dtype=dtype).element_size()
  105. self.swap_folder = os.path.join(base_folder,
  106. 'optimizer',
  107. f'rank{torch.distributed.get_rank()}')
  108. os.makedirs(self.swap_folder, exist_ok=True)
  109. self.optimizer = optimizer
  110. # Read/Write alignment for each thread during Intra-request parallelism
  111. self.min_aio_bytes = max(MIN_AIO_BYTES, aio_config[AIO_BLOCK_SIZE])
  112. self.aligned_bytes = AIO_ALIGNED_BYTES * aio_config[AIO_THREAD_COUNT]
  113. self.numel_alignment = self.aligned_bytes // self.swap_element_size
  114. # Swap buffer management
  115. self.largest_numel = self._io_aligned_numel(largest_numel)
  116. self.dtype = dtype
  117. self.swap_buffer_manager = SwapBufferManager(
  118. num_elems=self.largest_numel,
  119. count=swap_config[OFFLOAD_OPTIMIZER_BUFFER_COUNT],
  120. dtype=dtype)
  121. # Timers
  122. self.timers = timers
  123. self.timer_names = set()
  124. # Print exclusion list
  125. self.print_exclude_list = [
  126. 'optimizer',
  127. 'swap_buffer_manager',
  128. 'swap_params_info',
  129. 'timers',
  130. 'timer_names',
  131. ]
  132. def swappable_tensor(self, param=None, numel=None):
  133. assert param is not None or numel is not None, "Either param or numel must be provided"
  134. if param is not None:
  135. return self.min_aio_bytes <= (param.numel() * self.swap_element_size)
  136. return self.min_aio_bytes <= (numel * self.swap_element_size)
  137. def init_timers(self):
  138. self.timer_names = set()
  139. def log_timers(self):
  140. if self.timer_names:
  141. self._log_timers(list(self.timer_names), force=True)
  142. def pre_backward(self):
  143. self.init_timers()
  144. def post_backward(self):
  145. pass
  146. def _flush_gradient_swapper(self, gradient_swapper):
  147. if gradient_swapper.has_buffers():
  148. self._start_timer(SWAP_OUT_GRADIENT_TIMER)
  149. pinned_buffers = gradient_swapper.release_buffers()
  150. self.swap_buffer_manager.free(pinned_buffers)
  151. self._stop_timer(SWAP_OUT_GRADIENT_TIMER)
  152. self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
  153. self.timer_names.update(gradient_swapper.get_timer_names())
  154. def _swap_out_gradients(self,
  155. parameter,
  156. gradient_offsets,
  157. gradient_tensors,
  158. gradient_swapper):
  159. if not id(parameter) in self.swap_params_info.keys():
  160. return
  161. swap_info = self.swap_params_info[id(parameter)]
  162. swappable_tensors = []
  163. swappable_offsets = []
  164. swappable_lengths = []
  165. aligned_gradients, aligned_offsets = self._adjust_for_misaligned_lengths(
  166. tensors=gradient_tensors,
  167. offsets=gradient_offsets
  168. )
  169. self._start_timer(SWAP_OUT_GRADIENT_TIMER)
  170. for tensor, offset in zip(aligned_gradients, aligned_offsets):
  171. if not self.swappable_tensor(param=tensor):
  172. swap_info.unswapped_gradients[offset] = tensor
  173. continue
  174. swappable_tensors.append(tensor)
  175. swappable_offsets.append(offset)
  176. swappable_lengths.append(tensor.numel())
  177. if len(swappable_tensors) > 0:
  178. if not gradient_swapper.has_buffers():
  179. pinned_buffers = self.swap_buffer_manager.allocate_all(
  180. num_elems=self.largest_numel,
  181. dtype=self.dtype)
  182. gradient_swapper.add_buffers(pinned_buffers)
  183. swappable_paths = swap_info.get_or_create_gradient_paths(
  184. swappable_offsets,
  185. swappable_lengths)
  186. gradient_swapper.swap_out_tensors(tensor_list=swappable_tensors,
  187. path_list=swappable_paths)
  188. self._stop_timer(SWAP_OUT_GRADIENT_TIMER)
  189. self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
  190. def _initialize_from_swapped_fp16_params(self,
  191. aio_handle,
  192. fp16_partitions_info,
  193. fp16_num_elems,
  194. fp16_pinned_buffers,
  195. fp32_parameters):
  196. assert len(fp32_parameters) == len(fp16_partitions_info)
  197. assert len(fp32_parameters) == len(fp16_num_elems)
  198. assert all([buffer.is_pinned() for buffer in fp16_pinned_buffers])
  199. fp32_swap_paths = self._get_swap_paths(parameters=fp32_parameters,
  200. num_elems=fp16_num_elems)
  201. fp32_pinned_buffers = self.swap_buffer_manager.allocate_all(
  202. num_elems=self.largest_numel,
  203. dtype=self.dtype)
  204. fp16_buffer_numel = [buf.numel() for buf in fp16_pinned_buffers]
  205. assert all([numel >= self.largest_numel for numel in fp16_buffer_numel]), \
  206. f"numel of fp16 buffers {fp16_buffer_numel} is too small for initializing fp32 params {self.largest_numel}"
  207. fp32_swap_buffers = SwapBufferPool(fp32_pinned_buffers)
  208. fp16_swap_buffers = SwapBufferPool(fp16_pinned_buffers)
  209. curr_index = 0
  210. while curr_index < len(fp32_parameters):
  211. fp16_pinned_tensors = self._swap_in_fp16_params(
  212. aio_handle=aio_handle,
  213. fp16_num_elems=fp16_num_elems[curr_index:],
  214. fp16_partitions_info=fp16_partitions_info[curr_index:],
  215. fp16_swap_buffers=fp16_swap_buffers)
  216. if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE:
  217. for i, tensor in enumerate(fp16_pinned_tensors):
  218. true_index = curr_index + i
  219. logger.info(
  220. f'swap_in_fp16_param: fp32_id = {id(fp32_parameters[true_index])} index = {true_index} orig_num_elem = {fp16_num_elems[true_index]}, swap_num_elem = {fp16_pinned_tensors[i].numel()}'
  221. )
  222. swap_out_count = self._swap_out_fp16_params(
  223. aio_handle=aio_handle,
  224. fp32_swap_paths=fp32_swap_paths[curr_index:],
  225. fp32_swap_buffers=fp32_swap_buffers,
  226. fp16_pinned_tensors=fp16_pinned_tensors)
  227. assert swap_out_count == len(fp16_pinned_tensors), \
  228. f"{swap_out_count} does not match {len(fp16_pinned_tensors)}"
  229. fp16_swap_buffers.reset()
  230. fp32_swap_buffers.reset()
  231. curr_index += swap_out_count
  232. self.swap_buffer_manager.free(fp32_pinned_buffers)
  233. def _swap_in_fp16_params(self,
  234. aio_handle,
  235. fp16_num_elems,
  236. fp16_partitions_info,
  237. fp16_swap_buffers):
  238. assert len(fp16_num_elems) > 0
  239. swapped_fp16_tensors = []
  240. swap_tensors = []
  241. swap_paths = []
  242. unswapped_srcs = []
  243. unswapped_dsts = []
  244. for i, numel in enumerate(fp16_num_elems):
  245. pinned_tensor, _ = fp16_swap_buffers.allocate_tensor(numel, None, numel)
  246. if pinned_tensor is None:
  247. break
  248. swapped_fp16_tensors.append(pinned_tensor)
  249. offset = 0
  250. for tensor, partition_numel, partition_path in fp16_partitions_info[i]:
  251. dst_tensor = pinned_tensor.narrow(0, offset, partition_numel)
  252. if partition_path is None:
  253. unswapped_srcs.append(tensor)
  254. unswapped_dsts.append(dst_tensor)
  255. else:
  256. swap_paths.append(partition_path)
  257. swap_tensors.append(dst_tensor)
  258. offset += partition_numel
  259. assert len(swapped_fp16_tensors) + len(unswapped_srcs) > 0
  260. ret = swap_in_tensors(aio_handle, swap_tensors, swap_paths)
  261. for src, dst in zip(unswapped_srcs, unswapped_dsts):
  262. dst.data.copy_(src.data)
  263. assert len(swap_tensors) == aio_handle.wait()
  264. return swapped_fp16_tensors
  265. def _swap_out_fp16_params(self,
  266. aio_handle,
  267. fp32_swap_paths,
  268. fp32_swap_buffers,
  269. fp16_pinned_tensors):
  270. assert len(fp16_pinned_tensors) <= len(fp32_swap_paths)
  271. swap_out_count = 0
  272. for i, fp16_tensor in enumerate(fp16_pinned_tensors):
  273. if not fp32_swap_buffers.has_space(fp16_tensor.numel()):
  274. fp32_swap_buffers.swap_out(aio_handle)
  275. fp32_swap_buffers.reset()
  276. pinned_tensor, _ = fp32_swap_buffers.insert_tensor(
  277. fp16_tensor,
  278. fp32_swap_paths[i],
  279. self._io_aligned_numel(fp16_tensor.numel())
  280. )
  281. assert pinned_tensor is not None
  282. swap_out_count += 1
  283. if len(fp32_swap_buffers.get_swap_tensors()) > 0:
  284. fp32_swap_buffers.swap_out(aio_handle)
  285. return swap_out_count
  286. def _initialize_parameters(self, parameters, src_tensors, aio_handle):
  287. assert len(parameters) == len(src_tensors)
  288. swap_paths = self._get_swap_paths(parameters=parameters,
  289. num_elems=[src.numel() for src in src_tensors])
  290. SWAP_INIT_TIMER = "swap_init_write"
  291. self._start_timer(SWAP_INIT_TIMER)
  292. pinned_buffers = self.swap_buffer_manager.allocate_all(
  293. num_elems=self.largest_numel,
  294. dtype=self.dtype)
  295. assert pinned_buffers is not None
  296. self._swap_out_unpinned_tensors(aio_handle=aio_handle,
  297. unpinned_tensors=src_tensors,
  298. dest_paths=swap_paths,
  299. pinned_buffers=pinned_buffers)
  300. if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE:
  301. for i, tensor in enumerate(src_tensors):
  302. logger.info(
  303. f'copy_in_fp16_param: fp32_id = {id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}'
  304. )
  305. self.swap_buffer_manager.free(pinned_buffers)
  306. self._stop_timer(SWAP_INIT_TIMER)
  307. self._log_timers([SWAP_INIT_TIMER])
  308. def _get_swap_paths(self, parameters, num_elems):
  309. swap_info_list = [
  310. self._create_param_swap_info(parameter=p,
  311. numel=numel) \
  312. for p, numel in zip(parameters, num_elems)
  313. ]
  314. assert len(swap_info_list) == len(num_elems)
  315. swap_paths = [info.swap_paths[0] for info in swap_info_list]
  316. return swap_paths
  317. def _swap_out_unpinned_tensors(self,
  318. aio_handle,
  319. unpinned_tensors,
  320. dest_paths,
  321. pinned_buffers):
  322. swap_buffer_count = len(pinned_buffers)
  323. unpinned_tensor_count = len(unpinned_tensors)
  324. for i in range(0, unpinned_tensor_count, swap_buffer_count):
  325. swap_tensor_count = min((unpinned_tensor_count - i), swap_buffer_count)
  326. src_tensors = unpinned_tensors[i:(i + swap_tensor_count)]
  327. compute_lengths = [t.numel() for t in src_tensors]
  328. compute_buffers = get_sized_buffers(pinned_buffers, compute_lengths)
  329. for dst, src in zip(compute_buffers, src_tensors):
  330. dst.data.copy_(src.data)
  331. swap_lengths = [self._io_aligned_numel(t.numel()) for t in src_tensors]
  332. swap_buffers = get_sized_buffers(pinned_buffers, swap_lengths)
  333. swap_paths = dest_paths[i:(i + swap_tensor_count)]
  334. swap_out_tensors(aio_handle, swap_buffers, swap_paths)
  335. assert aio_handle.wait() == swap_tensor_count
  336. def _adjust_for_misaligned_lengths(self, tensors, offsets):
  337. new_tensors = []
  338. new_offsets = []
  339. for orig_tensor, orig_offset in zip(tensors, offsets):
  340. if not self.swappable_tensor(param=orig_tensor):
  341. new_tensors.append(orig_tensor)
  342. new_offsets.append(orig_offset)
  343. continue
  344. remainder = orig_tensor.numel() % self.numel_alignment
  345. if remainder == 0:
  346. new_tensors.append(orig_tensor)
  347. new_offsets.append(orig_offset)
  348. continue
  349. # Split into two by making remainder a tensor
  350. aligned_length = (orig_tensor.numel() //
  351. self.numel_alignment) * self.numel_alignment
  352. new_tensors.append(orig_tensor.narrow(0, 0, aligned_length))
  353. new_offsets.append(orig_offset)
  354. # remainder tensor
  355. new_tensors.append(orig_tensor.narrow(0, aligned_length, remainder))
  356. new_offsets.append(orig_offset + aligned_length)
  357. return new_tensors, new_offsets
  358. def _retrieve_unswapped_grad_partitions(self, swap_info, dest_buffer):
  359. UNSWAPPED_READ_GRADIENTS = 'unswapped_read_gradients'
  360. self._start_timer(UNSWAPPED_READ_GRADIENTS)
  361. tensor_count = len(swap_info.unswapped_gradients)
  362. num_elem_count = swap_info.read_unswapped_gradients(dest_buffer)
  363. self._stop_timer(UNSWAPPED_READ_GRADIENTS)
  364. self._log_timers([UNSWAPPED_READ_GRADIENTS])
  365. # It should be safe to discard unswapped gradient partitions
  366. swap_info.release_unswapped_gradients()
  367. if SWAPPER_DEBUG_MODE:
  368. logger.info(
  369. f'optimizer_retrieve_unswapped_gradients: param={swap_info.param_id} tensor_count={tensor_count} elem_count={num_elem_count}'
  370. )
  371. def _get_state_tensors(self, parameter):
  372. if not parameter in self.optimizer.state:
  373. return []
  374. tensor_list = []
  375. for value in self.optimizer.state[parameter].values():
  376. if torch.is_tensor(value):
  377. tensor_list.append(value)
  378. return tensor_list
  379. def _update_param_state_info(self, swap_info, parameter):
  380. if not swap_info.has_state_tensors:
  381. state_tensors = self._get_state_tensors(parameter)
  382. if state_tensors:
  383. swap_info.add_state_tensors(state_tensors)
  384. def _create_param_swap_info(self, parameter, numel):
  385. param_id = id(parameter)
  386. assert not param_id in self.swap_params_info
  387. self.swap_params_info[param_id] = OptimizerStateSwapInfo(
  388. parameter=parameter,
  389. numel=numel,
  390. base_folder=self.swap_folder)
  391. swap_info = self.swap_params_info[param_id]
  392. self._update_param_state_info(swap_info, parameter)
  393. return swap_info
  394. def _get_param_swap_info(self, parameter):
  395. param_id = id(parameter)
  396. swap_info = self.swap_params_info.get(param_id, None)
  397. if swap_info is not None:
  398. self._update_param_state_info(swap_info, parameter)
  399. return swap_info
  400. def _start_timer(self, name):
  401. if self.timers:
  402. self.timers(name).start()
  403. def _stop_timer(self, name):
  404. if self.timers:
  405. self.timers(name).stop()
  406. def _log_timers(self, name_list, force=False):
  407. if self.timers and (SWAPPER_DEBUG_MODE or force):
  408. self.timers.log(name_list)
  409. def _io_aligned_numel(self, numel):
  410. remainder = numel % self.numel_alignment
  411. return numel if remainder == 0 else (numel + self.numel_alignment - remainder)