partitioned_param_swapper.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team.
  3. Licensed under the MIT license.
  4. Functionality of swapping tensors to/from (NVMe) storage devices.
  5. """
  6. import os
  7. from enum import Enum
  8. import torch
  9. import torch.distributed as dist
  10. from deepspeed.utils.logging import logger
  11. from deepspeed.ops.aio import AsyncIOBuilder
  12. from .constants import *
  13. from .utils import swap_in_tensors, swap_out_tensors, MIN_AIO_BYTES, AIO_ALIGNED_BYTES, print_object, SwapBufferPool
  14. from ..zero.offload_constants import *
  15. def print_rank_0(message, debug=False, force=False):
  16. if torch.distributed.get_rank() == 0 and (debug or force):
  17. print(message)
  18. class PartitionedParamStatus(Enum):
  19. # Partitioned parameters are present and ready for use
  20. AVAILABLE = 1
  21. # partitioned params are in some non-memory device
  22. NOT_AVAILABLE = 2
  23. # partitioned params are being read from some non-memory device.
  24. INFLIGHT = 3
  25. class AsyncPartitionedParameterSwapper(object):
  26. def __init__(self, ds_config):
  27. aio_op = AsyncIOBuilder().load(verbose=False)
  28. self.aio_handle = aio_op.aio_handle
  29. #set swap buffers, create aio handles
  30. self._configure_aio(ds_config)
  31. #mapping from param id to path
  32. self.id_to_path = {}
  33. #mapping from pram_id to buffer id
  34. self.param_id_to_buffer_id = {}
  35. # mapping from param_id to swap buffer
  36. self.param_id_to_swap_buffer = {}
  37. #number of elements in the param
  38. self.param_id_to_numel = {}
  39. self.pending_writes = 0
  40. self.pending_reads = 0
  41. #keep track of async swap in params and buffers
  42. self.inflight_params = []
  43. self.inflight_swap_in_buffers = []
  44. self.inflight_numel = 0
  45. #keep track of available params
  46. self.available_params = set()
  47. self.available_numel = 0
  48. # for swapping out from partitioned fp32 params
  49. self.partitioned_swap_buffer = None
  50. self.partitioned_swap_pool = None
  51. self.invalid_buffer = torch.tensor(1).half()
  52. if dist.get_rank() == 0:
  53. exclude_list = ['aio_read_handle', 'aio_write_handle', 'buffers']
  54. print_object(obj=self,
  55. name='AsyncPartitionedParameterSwapper',
  56. exclude_list=exclude_list)
  57. def available_swap_in_buffers(self):
  58. return len(self.available_buffer_ids)
  59. def _configure_aio(self, ds_config):
  60. self.swap_config = ds_config.zero_config.offload_param
  61. self.swap_folder = os.path.join(self.swap_config[OFFLOAD_PARAM_NVME_PATH],
  62. 'zero_stage_3',
  63. 'fp16params',
  64. f'rank{dist.get_rank()}')
  65. os.makedirs(self.swap_folder, exist_ok=True)
  66. self.swap_element_size = torch.tensor([], dtype=torch.half).element_size()
  67. self.aio_config = ds_config.aio_config
  68. # Read/Write alignment for each thread during Intra-request parallelism
  69. self.min_aio_bytes = max(MIN_AIO_BYTES, self.aio_config[AIO_BLOCK_SIZE])
  70. self.aligned_bytes = AIO_ALIGNED_BYTES * self.aio_config[AIO_THREAD_COUNT]
  71. self.numel_alignment = self.aligned_bytes // self.swap_element_size
  72. self.elements_per_buffer = self.swap_config[OFFLOAD_PARAM_BUFFER_SIZE]
  73. self.aligned_elements_per_buffer = self._io_aligned_numel(
  74. self.elements_per_buffer)
  75. self.param_buffer_count = self.swap_config[OFFLOAD_PARAM_BUFFER_COUNT]
  76. self.available_buffer_ids = [i for i in range(self.param_buffer_count)]
  77. self.reserved_buffer_ids = []
  78. self.buffers = torch.empty(int(self.aligned_elements_per_buffer *
  79. self.param_buffer_count),
  80. dtype=torch.half,
  81. pin_memory=True,
  82. requires_grad=False)
  83. self.aio_read_handle = self.aio_handle(self.aio_config[AIO_BLOCK_SIZE],
  84. self.aio_config[AIO_QUEUE_DEPTH],
  85. self.aio_config[AIO_SINGLE_SUBMIT],
  86. self.aio_config[AIO_OVERLAP_EVENTS],
  87. self.aio_config[AIO_THREAD_COUNT])
  88. self.aio_write_handle = self.aio_handle(self.aio_config[AIO_BLOCK_SIZE],
  89. self.aio_config[AIO_QUEUE_DEPTH],
  90. self.aio_config[AIO_SINGLE_SUBMIT],
  91. self.aio_config[AIO_OVERLAP_EVENTS],
  92. self.aio_config[AIO_THREAD_COUNT])
  93. self.swap_out_params = []
  94. #Check if partitioned param or numel in a tensor is swappable or not
  95. def swappable_tensor(self, param=None, numel=None):
  96. if param is not None:
  97. assert numel is None, "Both parma and numel cannot be provided"
  98. numel = param.ds_tensor.ds_numel
  99. if numel is not None:
  100. return self.min_aio_bytes <= numel * self.swap_element_size
  101. assert False, "Either param or numel must be provided"
  102. def get_path(self, param, must_exist=False):
  103. paths = self._get_swap_paths([param], must_exist=must_exist)
  104. return paths[0]
  105. def _get_swap_paths(self, params, must_exist=False):
  106. paths = []
  107. for param in params:
  108. param_id = param.ds_id
  109. if param_id in self.id_to_path.keys():
  110. param_path = self.id_to_path[param_id]
  111. else:
  112. assert not must_exist, f"Path for param id {param_id} does not exist"
  113. param_path = os.path.join(self.swap_folder,
  114. f'{param_id}_param.tensor.swp')
  115. self.id_to_path[param_id] = param_path
  116. paths.append(param_path)
  117. return paths
  118. def _get_swap_buffers(self, params):
  119. buffers = []
  120. for param in params:
  121. param_id = param.ds_id
  122. assert param_id in self.param_id_to_swap_buffer.keys(), \
  123. f'param {param_id} has not been assigned a swap buffer'
  124. buffers.append(self.param_id_to_swap_buffer[param_id])
  125. return buffers
  126. def _track_numel(self, params):
  127. for param in params:
  128. assert param.ds_tensor is not None, "Partitioned tensor is None"
  129. self.param_id_to_numel[param.ds_id] = param.ds_tensor.ds_numel
  130. def _allocate_and_return_buffers_for_swap_in(self, params):
  131. compute_buffers = []
  132. swap_buffers = []
  133. for param in params:
  134. param_id = param.ds_id
  135. assert param_id in self.param_id_to_numel.keys(), f" Number of elements in param {param_id} is unknown"
  136. assert param_id not in self.param_id_to_buffer_id.keys(), f"param {param_id} already assigned swap buffer id {self.param_id_to_buffer_id[param_id]}"
  137. assert param_id not in self.param_id_to_swap_buffer.keys(), f"param {param_id} has already been assigned a swap buffer"
  138. buffer_id = self.available_buffer_ids.pop()
  139. print_rank_0(
  140. f"param {param.ds_id} is assigned swap in buffer id {buffer_id} ")
  141. self.param_id_to_buffer_id[param_id] = buffer_id
  142. aligned_swap_numel = self._io_aligned_numel(self.param_id_to_numel[param_id])
  143. swap_buffer = self.buffers.narrow(
  144. 0,
  145. int(buffer_id * self.aligned_elements_per_buffer),
  146. aligned_swap_numel)
  147. self.param_id_to_swap_buffer[param_id] = swap_buffer
  148. compute_buffer = swap_buffer.narrow(0, 0, self.param_id_to_numel[param_id])
  149. compute_buffers.append(compute_buffer)
  150. swap_buffers.append(swap_buffer)
  151. return compute_buffers, swap_buffers
  152. #waits for inflight nvme write to complete
  153. def synchronize_writes(self):
  154. if self.pending_writes == 0:
  155. return
  156. assert self.pending_writes == self.aio_write_handle.wait()
  157. self.pending_writes = 0
  158. self.remove_partition_and_release_buffers(self.swap_out_params)
  159. self.swap_out_params = []
  160. #waits for inflight nvme reads to complete
  161. def synchronize_reads(self):
  162. if self.pending_reads == 0:
  163. return
  164. assert self.pending_reads == self.aio_read_handle.wait()
  165. self.pending_reads = 0
  166. for param, swap_in_buffer in zip(self.inflight_params, self.inflight_swap_in_buffers):
  167. param_id = param.ds_id
  168. compute_buffer = swap_in_buffer.narrow(0,
  169. 0,
  170. self.param_id_to_numel[param_id])
  171. param.ds_tensor.data = compute_buffer.data
  172. param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
  173. self.available_params.update([param.ds_id for param in self.inflight_params])
  174. self.available_numel += self.inflight_numel
  175. self.inflight_params = []
  176. self.inflight_swap_in_buffers = []
  177. self.inflight_numel = 0
  178. #Removes the memory assignment and releases the buffers
  179. #Should only be executed after swapping out the tensors
  180. def remove_partition_and_release_buffers(self, params):
  181. for param in params:
  182. param_id = param.ds_id
  183. if param_id in self.param_id_to_buffer_id.keys():
  184. buffer_id = self.param_id_to_buffer_id[param_id]
  185. assert buffer_id is not None, "Missing buffer id for releasing"
  186. self.available_buffer_ids.append(buffer_id)
  187. del self.param_id_to_buffer_id[param_id]
  188. del self.param_id_to_swap_buffer[param_id]
  189. print_rank_0(f"param {param.ds_id} releases buffer id {buffer_id} ")
  190. if param_id in self.available_params:
  191. self.available_params.remove(param_id)
  192. self.available_numel -= self.param_id_to_numel[param_id]
  193. param.ds_tensor.data = self.invalid_buffer.data
  194. param.ds_tensor.status = PartitionedParamStatus.NOT_AVAILABLE
  195. #writes from in memory to nvme. Does not release the buffers
  196. def _swap_out(self, params, async_op=True):
  197. swap_out_paths = self._get_swap_paths(params)
  198. swap_out_params = self._get_swap_buffers(params)
  199. self._track_numel(params)
  200. swap_out_tensors(self.aio_write_handle, swap_out_params, swap_out_paths)
  201. self.pending_writes += len(swap_out_params)
  202. self.swap_out_params += params
  203. if not async_op:
  204. self.synchronize_writes()
  205. #blocking swap out followed by releasing the memory buffers
  206. def swap_out_and_release(self, params, async_op=False, force_buffer_release=False):
  207. if async_op:
  208. assert force_buffer_release, "Should not release preallocated buffers without completing the swap out. Set force_buffer_release to True to do it anyways"
  209. self._swap_out(params, async_op=async_op)
  210. # book keeping function for inflight swap in
  211. def _update_inflight_swap_in(self, params, swap_in_buffers, inflight_numel):
  212. self.inflight_params.extend(params)
  213. self.inflight_swap_in_buffers.extend(swap_in_buffers)
  214. self.inflight_numel += inflight_numel
  215. for param in params:
  216. param.ds_tensor.status = PartitionedParamStatus.INFLIGHT
  217. self.pending_reads += len(params)
  218. #assigns an in memory buffer and swaps in from nvme
  219. def swap_in(self, params, async_op=True, swap_in_buffers=None):
  220. assert all([param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE for param in params]), "Some params are already available or in flight"
  221. swap_in_paths = self._get_swap_paths(params)
  222. if swap_in_buffers is None:
  223. if len(self.available_buffer_ids) < len(swap_in_paths):
  224. print_rank_0(
  225. f'Not enough swap in buffers {len(self.available_buffer_ids)} for params {len(swap_in_paths)}',
  226. force=True)
  227. print_rank_0(
  228. f'Num inflight: params {len(self.inflight_params)}, buffers {len(self.inflight_swap_in_buffers)}, numel = {self.inflight_numel}',
  229. force=True)
  230. print_rank_0(
  231. f'Num available: param {len(self.available_params)}, numel = {self.available_numel}',
  232. force=True)
  233. assert len(swap_in_paths) <= len(self.available_buffer_ids), f"Not enough buffers {len(self.available_buffer_ids)} for swapping {len(swap_in_paths)}"
  234. compute_buffers, swap_in_buffers = self._allocate_and_return_buffers_for_swap_in(params)
  235. inflight_numel = sum([t.numel() for t in compute_buffers])
  236. else:
  237. inflight_numel = sum([t.numel() for t in swap_in_buffers])
  238. swap_in_tensors(self.aio_read_handle, swap_in_buffers, swap_in_paths)
  239. self._update_inflight_swap_in(params, swap_in_buffers, inflight_numel)
  240. if not async_op:
  241. self.synchronize_reads()
  242. # Enables swapping into buffer that is out the control of swapper. This is always synchronous
  243. def swap_into_buffer(self, param, dest_buffer):
  244. assert param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE, f"param {param.ds_id} is already available or inflight"
  245. require_swap_buffer = not (dest_buffer.is_pinned()
  246. and self._is_io_aligned(dest_buffer.numel()))
  247. if require_swap_buffer:
  248. assert len(self.available_buffer_ids) > 0, f"No buffer available to swap param {param.ds_id}."
  249. compute_buffers, swap_in_buffers = self._allocate_and_return_buffers_for_swap_in([param])
  250. inflight_numel = compute_buffers[0].numel()
  251. else:
  252. swap_in_buffers = [dest_buffer]
  253. inflight_numel = dest_buffer.numel()
  254. swap_in_paths = self._get_swap_paths([param])
  255. swap_in_tensors(self.aio_read_handle, swap_in_buffers, swap_in_paths)
  256. self._update_inflight_swap_in([param], swap_in_buffers, inflight_numel)
  257. self.synchronize_reads()
  258. if require_swap_buffer:
  259. dest_buffer.data.copy_(param.ds_tensor.data)
  260. # Release swap buffer memory assignment. Note, this will mark the parameter not available.
  261. self.remove_partition_and_release_buffers([param])
  262. #assign a buffer to a param and return the buffer
  263. def get_buffer(self, param, numel):
  264. param_id = param.ds_id
  265. assert self.available_swap_in_buffers() > 0, f"No swap buffers to allocate for fp16 param {param_id} of numel = {numel}"
  266. assert numel < self.elements_per_buffer, f"More elements {numel} than buffer size {self.elements_per_buffer}"
  267. self.param_id_to_numel[param_id] = numel
  268. buffer_id = self.available_buffer_ids.pop()
  269. self.param_id_to_buffer_id[param_id] = buffer_id
  270. aligned_swap_numel = self._io_aligned_numel(self.param_id_to_numel[param_id])
  271. swap_buffer = self.buffers.narrow(
  272. 0,
  273. int(buffer_id * self.aligned_elements_per_buffer),
  274. aligned_swap_numel)
  275. self.param_id_to_swap_buffer[param_id] = swap_buffer
  276. compute_buffer = swap_buffer.narrow(0, 0, self.param_id_to_numel[param_id])
  277. print_rank_0(f"param {param.ds_id} is assigned swap in buffer id {buffer_id}")
  278. return compute_buffer
  279. def reserve_available_buffers(self):
  280. buffers = []
  281. for id in self.available_buffer_ids:
  282. buffers.append(
  283. self.buffers.narrow(0,
  284. int(id * self.aligned_elements_per_buffer),
  285. int(self.aligned_elements_per_buffer)))
  286. self.reserved_buffer_ids.append(id)
  287. self.available_buffer_ids = []
  288. return buffers
  289. def release_reserved_buffers(self):
  290. for id in self.reserved_buffer_ids:
  291. self.available_buffer_ids.append(id)
  292. self.reserved_buffer_ids = []
  293. def _io_aligned_numel(self, numel):
  294. remainder = numel % self.numel_alignment
  295. return numel if remainder == 0 else (numel + self.numel_alignment - remainder)
  296. def _is_io_aligned(self, numel):
  297. return (numel % self.numel_alignment) == 0
  298. def reserve_partitioned_swap_space(self, partition_num_elems):
  299. aligned_numel = sum(
  300. [self._io_aligned_numel(numel) for numel in partition_num_elems])
  301. self.partitioned_swap_buffer = torch.zeros(aligned_numel,
  302. device='cpu',
  303. dtype=torch.half).pin_memory()
  304. self.partitioned_swap_pool = SwapBufferPool([self.partitioned_swap_buffer])
  305. def swap_out_partitioned_params(self, dst_fp16_params, src_fp32_params):
  306. assert self.partitioned_swap_buffer is not None, f'partitioned swap buffers for fp16 params not initialized'
  307. assert self.partitioned_swap_pool is not None, f'partitioned swap pool for fp16 params not initialized'
  308. assert len(dst_fp16_params) == len(src_fp32_params), \
  309. f'mismatch in number of fp16 params {len(dst_fp16_params)} and fp32 params {len(src_fp32_params)}'
  310. fp16_swap_paths = self._get_swap_paths(dst_fp16_params, must_exist=True)
  311. self.synchronize_writes()
  312. self.partitioned_swap_pool.reset()
  313. for i, fp32_tensor in enumerate(src_fp32_params):
  314. swap_tensor, _ = self.partitioned_swap_pool.insert_tensor(
  315. fp32_tensor,
  316. fp16_swap_paths[i],
  317. self._io_aligned_numel(fp32_tensor.numel())
  318. )
  319. assert swap_tensor is not None
  320. dst_fp16_params[i].ds_tensor.status = PartitionedParamStatus.AVAILABLE
  321. self.partitioned_swap_pool.swap_out(self.aio_write_handle)
  322. for param in dst_fp16_params:
  323. param.ds_tensor.status = PartitionedParamStatus.NOT_AVAILABLE