partitioned_param_coordinator.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from dataclasses import dataclass
  5. import collections
  6. from collections import UserDict
  7. from typing import Deque, Set
  8. from deepspeed import comm as dist
  9. from deepspeed.utils.logging import logger
  10. from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
  11. from deepspeed.runtime.zero.partition_parameters import *
  12. from deepspeed.runtime.zero.partitioned_param_profiler import PartitionedParameterProfiler
  13. from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
  14. from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
  15. from deepspeed.accelerator import get_accelerator
  16. import logging
  17. ENABLE_PROFILER = False
  18. def debug_rank0(message: str) -> None:
  19. if dist.get_rank() == 0:
  20. logger.debug(message)
  21. @instrument_w_nvtx
  22. def get_all_parameters(sub_module, recurse=False):
  23. return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
  24. def iter_params(module: Module, recurse=False) -> Iterable[Parameter]:
  25. return map(lambda pair: pair[1], get_all_parameters(module, recurse))
  26. class ZeRoTraceMode(Enum):
  27. # Record trace of the network during a single forward+backward (for training) or forward (for inference)
  28. RECORD = 1
  29. # Use recorded network trace to optimize current forward+backward or forward
  30. COMPLETE = 2
  31. # Recorded trace does not match current forward+backward or forward pass.
  32. INVALID = 3
  33. class InflightParamRegistry(UserDict):
  34. """registry for parameters in flight"""
  35. def __setitem__(self, param: Parameter, handle: AllGatherCoalescedHandle) -> None:
  36. if param in self.data:
  37. raise RuntimeError(f"{param.ds_summary()} already in registry")
  38. if param.ds_status != ZeroParamStatus.INFLIGHT:
  39. raise RuntimeError(f"attempted to add non-inflight parameter to registry {param.ds_summary()}")
  40. self.data[param] = handle
  41. class PartitionedParameterCoordinator:
  42. FORWARD_FETCH_SUBMIT = 'forward_fetch_submit'
  43. FORWARD_FETCH_WAIT = 'forward_fetch_wait'
  44. FORWARD_PREFETCH_SUBMIT = 'forward_prefetch_submit'
  45. BACKWARD_FETCH_SUBMIT = 'backward_fetch_submit'
  46. BACKWARD_FETCH_WAIT = 'backward_fetch_wait'
  47. BACKWARD_PREFETCH_SUBMIT = 'backward_prefetch_wait'
  48. FORWARD_ALL_GATHER = 'forward_all_gather'
  49. BACKWARD_ALL_GATHER = 'backward_all_gather'
  50. """Handles partitioning and gathering of parameters."""
  51. @dataclass
  52. class __ParamInTrace:
  53. param: Parameter
  54. step_id_last_used_at: int
  55. def __init__(
  56. self,
  57. prefetch_bucket_sz: int,
  58. max_reuse_distance_in_numel: int,
  59. max_available_parameters_in_numel: int,
  60. allgather_stream: get_accelerator().Stream,
  61. inflight_param_registry: InflightParamRegistry,
  62. prefetch_nvme: bool = False,
  63. timers=None,
  64. zero_quantized_weights=False,
  65. zero_quantized_nontrainable_weights=False,
  66. ) -> None:
  67. # mapping of param -> handle for each param that is currently in flight
  68. self.__inflight_param_registry = inflight_param_registry
  69. # keeps track of the number of submodules invoked so far.
  70. self.__step_id: int = 0
  71. # network tracing mode
  72. self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD
  73. # sequence of submodules/parameters in forward pass + backward pass
  74. self.__submodule_order: Iterable[Module] = []
  75. self.__param_order: Iterable[__class__.__ParamInTrace] = []
  76. self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10))
  77. self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque())
  78. # number of available params, and max number of available params
  79. self.__n_available_params: int = 0
  80. self.__max_n_available_params: int = max_available_parameters_in_numel
  81. # max distance between two use of the module beyond which module is released
  82. self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel
  83. # queue for parameters to fetch. parameters will be popped off the left
  84. # side of the dequeue as they are fetched
  85. self.__param_queue: Deque[__class__.__ParamInTrace] = None
  86. self.__prefetch_bucket_sz: int = prefetch_bucket_sz
  87. self.__prefetch_nvme: bool = prefetch_nvme
  88. self.hierarchy: int = 0
  89. self.zero_quantized_weights = zero_quantized_weights
  90. self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
  91. # stream that will be used for allgather operations
  92. self.__allgather_stream: get_accelerator().Stream = allgather_stream
  93. # limit the number of fetch events that can be queued at once
  94. # otherwise, what happens is memory is allocated by the host thread at the
  95. # time of the call, but not used until later by the asynchronous cuda stream.
  96. # allowing an infinite number of these to queue up causes a lot of memory
  97. # pressure that then becomes detrimental to performance.
  98. # this is a much less elegant way of fixing this vs something like using
  99. # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now
  100. # because ideally in the future its replaced by an async allocation
  101. # mechanism which doesn't require any configuration by the user.
  102. self.__ongoing_fetch_events: Deque[get_accelerator().Event] = collections.deque()
  103. # TODO. make this configurable via JSON
  104. self.__max_ongoing_fetch_events: int = 2
  105. self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)
  106. """Tracing and Tracking
  107. TODO. consider performing trace before initializing PartitionedParameterCoordinator
  108. and passing trace results into constructor. This way all the code in here can
  109. just assume that the trace is complete and the results can be entirely
  110. immutable.
  111. Bookkeeping operations used to track where we are in the forward/backward pass
  112. """
  113. def _clear_trace_structures(self) -> None:
  114. self.__submodule_order = []
  115. self.__param_order = []
  116. self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10))
  117. self.__param_queue = None
  118. def is_complete_trace(self) -> bool:
  119. return self.__trace_mode == ZeRoTraceMode.COMPLETE
  120. def is_invalid_trace(self) -> bool:
  121. return self.__trace_mode == ZeRoTraceMode.INVALID
  122. def is_record_trace(self) -> bool:
  123. return self.__trace_mode == ZeRoTraceMode.RECORD
  124. def _invalidate_trace(self) -> None:
  125. if self.is_invalid_trace():
  126. raise RuntimeError("attempted to invalidate already invalid trace")
  127. self.__trace_mode = ZeRoTraceMode.INVALID
  128. self._clear_trace_structures()
  129. def trace_prologue(self, sub_module: Module) -> None:
  130. if self.is_complete_trace():
  131. # sub_module must match expectation else invalidate trace cache
  132. if len(self.__submodule_order) <= self.__step_id:
  133. print_rank_0(
  134. f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: "
  135. f"cache has only {len(self.__submodule_order)} modules",
  136. force=True)
  137. self._invalidate_trace()
  138. return
  139. if sub_module != self.__submodule_order[self.__step_id]:
  140. expected_module_id = self.__submodule_order[self.__step_id].id
  141. print_rank_0(
  142. f"Invalidate trace cache @ step {self.__step_id}: "
  143. f"expected module {expected_module_id}, but got module {sub_module.id}",
  144. force=True)
  145. self._invalidate_trace()
  146. def record_module(self, sub_module: Module) -> None:
  147. """adds sub module to trace"""
  148. if not self.is_record_trace():
  149. raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
  150. self.__submodule_order.append(sub_module)
  151. self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)
  152. def record_parameters(self, sub_module: Module) -> None:
  153. """adds sub module to trace"""
  154. if not self.is_record_trace():
  155. raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
  156. step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
  157. for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id):
  158. self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id))
  159. def construct_parameter_trace_from_module_trace(self):
  160. """use module trace to construct parameter trace"""
  161. self.__param_order = []
  162. for sub_module in self.__submodule_order:
  163. self.record_parameters(sub_module)
  164. def reset_step(self) -> None:
  165. """indicate that we have completed one fwd+bwd for the model"""
  166. if self.__inflight_param_registry:
  167. raise RuntimeError(f"still have inflight params "
  168. f"{[p.ds_summary() for p in self.__inflight_param_registry.keys()]}")
  169. if not self.is_complete_trace(): # not self.trace_complete:
  170. # Make sure that recorded submodule orders are identical across ranks
  171. assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order])
  172. if self.is_record_trace():
  173. # Successfully recorded a trace
  174. self.construct_parameter_trace_from_module_trace()
  175. # Make sure that recorded parameter orders are identical across ranks
  176. assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order])
  177. assert_ints_same_as_other_ranks([p.step_id_last_used_at for p in self.__param_order])
  178. self.__submodule_order = tuple(self.__submodule_order) # freeze
  179. self.__param_order = tuple(self.__param_order) # freeze
  180. self.__trace_mode = ZeRoTraceMode.COMPLETE
  181. print_rank_0(
  182. f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}",
  183. force=False)
  184. else:
  185. # Enable trace recording for next forward/backward pass
  186. self.__trace_mode = ZeRoTraceMode.RECORD
  187. else:
  188. if self.__profiler is not None:
  189. self.__profiler.log_events()
  190. self.__param_queue = collections.deque(self.__param_order) # reset fetch queue
  191. self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10))
  192. self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque())
  193. self.__step_id = 0
  194. self.__n_available_params = 0
  195. self.__profiler.reset_events()
  196. def _dump_params(self, tag, sub_module, params, step_id=None):
  197. if step_id is None:
  198. step_id = self.__step_id
  199. param_names = [debug_param2name_id(p) for p in params]
  200. print_rank_0(f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}',
  201. force=False)
  202. def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None):
  203. if step_id is None:
  204. step_id = self.__step_id
  205. print_rank_0(f'{tag} mod = {mod_id}, step = {step_id}, p_ids = {p_ids}', force=False)
  206. """Fetch and Release
  207. Fetching, prefetching, and releasing parameters
  208. """
  209. @instrument_w_nvtx
  210. @torch.no_grad()
  211. def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
  212. """This method does the following (in order):
  213. 1. kick off fetch for parameters in immediately required sub module
  214. 2. kick off fetch for next few parameters we will need later (prefetch)
  215. 3. block on parameters in immediately required sub module
  216. """
  217. if logger.isEnabledFor(logging.DEBUG):
  218. debug_rank0(
  219. f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
  220. + str({
  221. "avail": f"{self.__n_available_params:.1e}",
  222. "queue_sz": f"{len(self.__param_queue or [])}",
  223. "inflight": [p.ds_id for p in self.__inflight_param_registry],
  224. }))
  225. params_to_fetch = frozenset(iter_params(current_submodule))
  226. fetch_numel = sum(
  227. [p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
  228. if fetch_numel > 0:
  229. event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT
  230. self._dump_param_ids(event_name, current_submodule.id,
  231. [p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
  232. self.__profiler.start_event(event_name)
  233. # kick off all gather for params in the immediately required submodule
  234. #for param in params_to_fetch:
  235. if logger.isEnabledFor(logging.DEBUG):
  236. for param in params_to_fetch:
  237. debug_rank0(f"-fetch: {param.ds_summary()}")
  238. self.__all_gather_params(params_to_fetch, forward)
  239. self.__profiler.stop_event(event_name, fetch_numel)
  240. wait_numel = 0
  241. wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT
  242. self.__profiler.start_event(wait_event_name)
  243. # wait for parameters in the immediately needed submodule to become available
  244. for param in params_to_fetch:
  245. param.ds_active_sub_modules.add(current_submodule.id)
  246. if logger.isEnabledFor(logging.DEBUG):
  247. debug_rank0(f"-wait: {param.ds_summary()}")
  248. if param in self.__inflight_param_registry:
  249. wait_numel += param.partition_numel()
  250. with get_accelerator().stream(self.__allgather_stream):
  251. while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query():
  252. self.__ongoing_fetch_events.popleft()
  253. if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events:
  254. self.__ongoing_fetch_events.popleft().synchronize()
  255. self.__inflight_param_registry.pop(param).wait()
  256. if not get_accelerator().is_synchronized_device():
  257. event = get_accelerator().Event()
  258. event.record()
  259. self.__ongoing_fetch_events.append(event)
  260. assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
  261. if not get_accelerator().is_synchronized_device():
  262. get_accelerator().current_stream().wait_stream(self.__allgather_stream)
  263. self.__profiler.stop_event(wait_event_name, wait_numel)
  264. # kick off parameter prefetches for upcoming modules
  265. # don't prefetch if we dont have a completed model trace
  266. if self.is_complete_trace():
  267. # go through the parameters we need for the current module and pop them
  268. # off the fetch queue so that they aren't prefetched later.
  269. # if params have already been popped off the fetch queue by earlier
  270. # prefetches we won't look for them here
  271. discarded_from_prefetch_queue = set()
  272. params_not_already_fetched = set(
  273. filter(lambda p: self.__most_recent_step_id_param_fetched_for[p] < self.__step_id, params_to_fetch))
  274. while self.__param_queue and len(discarded_from_prefetch_queue) < len(params_not_already_fetched):
  275. param_in_trace = self.__param_queue.popleft()
  276. self.__most_recent_step_id_param_fetched_for[
  277. param_in_trace.param] = param_in_trace.step_id_last_used_at
  278. discarded_from_prefetch_queue.add(param_in_trace.param)
  279. if discarded_from_prefetch_queue != params_not_already_fetched:
  280. raise RuntimeError(
  281. f"tracing error at step {self.__step_id}: \n"
  282. f"module id: {current_submodule.id}, training: {current_submodule.training}\n"
  283. f"expected the next {len(params_not_already_fetched)} parameters in the "
  284. f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n"
  285. f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}.")
  286. def _is_currently_on_nvme(param):
  287. if param.nvme_swapper is None:
  288. return False
  289. return param.ds_tensor.final_location == OffloadDeviceEnum.nvme \
  290. and param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE
  291. # kick off all gather for params in the next few submodules (prefetch)
  292. if self.__prefetch_bucket_sz > 0:
  293. max_params_to_prefetch = min(self.__max_n_available_params - self.__n_available_params,
  294. self.__prefetch_bucket_sz)
  295. params_to_prefetch = set()
  296. numel_prefetching = 0
  297. while self.__param_queue and numel_prefetching < max_params_to_prefetch:
  298. param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft()
  299. if _is_currently_on_nvme(param_in_trace.param):
  300. # nvme prefetch is handled elsewhere. Need to break here to preserve fetch order
  301. self.__param_queue.appendleft(param_in_trace)
  302. break
  303. do_prefetch = param_in_trace.param.ds_status == ZeroParamStatus.NOT_AVAILABLE
  304. if param_in_trace.param in params_to_prefetch:
  305. # Avoid duplicates
  306. do_prefetch = False
  307. self.__most_recent_step_id_param_fetched_for[param_in_trace.param] = \
  308. max(self.__most_recent_step_id_param_fetched_for[param_in_trace.param],
  309. param_in_trace.step_id_last_used_at)
  310. if do_prefetch:
  311. params_to_prefetch.add(param_in_trace.param)
  312. numel_prefetching += param_in_trace.param.ds_numel
  313. if numel_prefetching > 0:
  314. event_name = __class__.FORWARD_PREFETCH_SUBMIT if forward else __class__.BACKWARD_PREFETCH_SUBMIT
  315. self.__profiler.start_event(event_name)
  316. if logger.isEnabledFor(logging.DEBUG):
  317. for param in params_to_prefetch:
  318. debug_rank0(f"-prefetch: {param.ds_summary()}")
  319. self.__all_gather_params(params_to_prefetch, forward)
  320. self.__profiler.stop_event(event_name, numel_prefetching)
  321. if self.__prefetch_nvme:
  322. self.__prefetch_nvme_param_partitions()
  323. self.__step_id += 1
  324. @instrument_w_nvtx
  325. @torch.no_grad()
  326. def release_sub_module(self, submodule: Module, backward: bool) -> None:
  327. """release the parameters of a sub module, assuming they meet conditions to
  328. be released."""
  329. params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
  330. p.ds_id for p in iter_params(submodule)))
  331. for param in iter_params(submodule):
  332. param.ds_active_sub_modules.discard(submodule.id)
  333. if param.ds_id in params_to_release and not param.is_external_param:
  334. self.__release_param(param, backward)
  335. @instrument_w_nvtx
  336. @torch.no_grad()
  337. def release_and_reset_all(self, module: Module) -> None:
  338. """release all module parameters"""
  339. for param in iter_params(module, recurse=True):
  340. if param in self.__inflight_param_registry:
  341. raise RuntimeError(f"param {param.ds_summary()} still in flight")
  342. # TODO. make this throw if if there are still active submodules. currently
  343. # there's a hook execution issue
  344. param.ds_active_sub_modules.clear()
  345. self.__release_param(param, backward=False)
  346. for param in iter_params(module, recurse=True):
  347. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  348. raise RuntimeError(f"{param.ds_summary()} expected to be released")
  349. @instrument_w_nvtx
  350. def __all_gather_params(self, params: Set[Parameter], forward: bool) -> None:
  351. quantized_params = []
  352. nonquantized_params = []
  353. for param in params:
  354. if hasattr(param.ds_tensor, 'ds_quant_scale'):
  355. quantized_params.append(param)
  356. else:
  357. nonquantized_params.append(param)
  358. if quantized_params:
  359. self.__all_gather_params_(quantized_params, forward, quantize=True)
  360. if nonquantized_params:
  361. self.__all_gather_params_(nonquantized_params, forward, quantize=self.zero_quantized_weights)
  362. def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: bool = False) -> None:
  363. """for each partitioned parameter, kick off an async allgather and store
  364. the work handle for the in flight parameters."""
  365. partitioned_params = []
  366. all_gather_numel = 0
  367. for param in params:
  368. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  369. partitioned_params.append(param)
  370. all_gather_numel += param.ds_numel
  371. if partitioned_params:
  372. partitioned_params
  373. self.__n_available_params += all_gather_numel
  374. with get_accelerator().stream(self.__allgather_stream):
  375. event_name = __class__.FORWARD_ALL_GATHER if forward else __class__.BACKWARD_ALL_GATHER
  376. self.__profiler.start_event(event_name)
  377. handle = partitioned_params[0].all_gather_coalesced(partitioned_params,
  378. forward=forward,
  379. quantize=quantize)
  380. self.__profiler.stop_event(event_name, all_gather_numel)
  381. for param in partitioned_params:
  382. assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary()
  383. self.__inflight_param_registry[param] = handle
  384. # Release swap buffers for persisted params on nvme since they will never be partitioned or evicted from GPU
  385. swap_persisted_params = [
  386. p for p in partitioned_params if p.ds_persist and p.ds_tensor.final_location == OffloadDeviceEnum.nvme
  387. ]
  388. if swap_persisted_params:
  389. swap_persisted_params[0].nvme_swapper.remove_partition_and_release_buffers(swap_persisted_params)
  390. @instrument_w_nvtx
  391. def __release_param(self, param: Parameter, backward: bool) -> None:
  392. if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
  393. if logger.isEnabledFor(logging.DEBUG):
  394. debug_rank0(f"-release: {param.ds_summary()}")
  395. param.partition(backward=backward)
  396. self.__n_available_params -= param.ds_numel
  397. @instrument_w_nvtx
  398. @functools.lru_cache(maxsize=None)
  399. def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set[int]:
  400. if not self.is_complete_trace():
  401. raise RuntimeError("expected trace to be complete")
  402. params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) if not p.ds_persist)
  403. # Problem: When prefetcher scans the param trace, it skips AVAILABLE params.
  404. # This creates issues if those params are released before the skipped uses:
  405. # 1) It hurts performance as the skipped uses are never prefetched.
  406. # 2) For nvme params, we run out of swap buffers because the prefetch order
  407. # diverges from the trace.
  408. # Solution: Don't release params whose reuse was skipped by prefetch. This is
  409. # possible because we detect such skips during prefetch and mark those params.
  410. for param in iter_params(submodule_to_release):
  411. if self.__most_recent_step_id_param_fetched_for[param] > step_id:
  412. params_to_release.discard(param.ds_id)
  413. # examine all modules within `max_reuse_dist_in_numel` of the current step,
  414. # if we see any of the candidate parameters to be released reoccur while
  415. # doing this, remove them from the set of parameters to release.
  416. params_traversed = 0
  417. for module in self.__submodule_order[step_id:]:
  418. if params_traversed >= self.__max_reuse_dist_in_numel:
  419. break
  420. for param in iter_params(module):
  421. params_to_release.discard(param.ds_id)
  422. params_traversed += param.ds_numel
  423. return params_to_release
  424. @instrument_w_nvtx
  425. def __prefetch_nvme_param_partitions(self) -> None:
  426. """swap in parameter partitions from nvme for those parameters that will be used
  427. after the ones that are already being prefetched into full parameters
  428. """
  429. if not self.is_complete_trace():
  430. return
  431. numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry)
  432. numel_considered = 0
  433. swap_in_params = []
  434. for param_in_trace in self.__param_queue:
  435. param = param_in_trace.param
  436. if param.nvme_swapper is None:
  437. continue
  438. if (numel_considered > 2 * numel_in_flight
  439. or len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers()):
  440. break
  441. if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
  442. swap_in_params.append(param)
  443. numel_considered += param.ds_numel
  444. if swap_in_params:
  445. swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True)