parameter_offload.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import sys
  5. import torch
  6. from collections import OrderedDict
  7. from deepspeed.runtime.utils import see_memory_usage
  8. from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
  9. from deepspeed.runtime.zero.partition_parameters import _init_external_params
  10. from deepspeed.runtime.zero.partition_parameters import *
  11. from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params
  12. from deepspeed import comm as dist
  13. from deepspeed.accelerator import get_accelerator
  14. FWD_MODULE_STACK = list()
  15. def is_builtin_type(obj):
  16. # https://stackoverflow.com/a/17795199
  17. return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"
  18. def isinstance_namedtuple(obj: object) -> bool:
  19. """
  20. Is this an instance of namedtuple/NamedTuple?
  21. From: https://stackoverflow.com/a/62692640
  22. Args:
  23. obj (object): An object.
  24. Returns:
  25. bool: True if namedtuple/NamedTuple else False.
  26. """
  27. return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields')
  28. # ensure we only warn once, otherwise every iteration will trigger a warning
  29. warned = False
  30. def _apply_to_tensors_only(module, functional, backward_function, outputs):
  31. """
  32. Apply a torch.autograd.Function that calls a `backward_function` to every Tensor in `outputs`.
  33. Args:
  34. module (torch.nn.Module): A torch module
  35. functional (Type[torch.autograd.Function]): The function class to apply.
  36. backward_function (Callable[[torch.nn.Module], None]): A backward_function to pass to
  37. `functional.apply`.
  38. outputs (Any): The output of `module`.
  39. Returns:
  40. Any: The output of `module`.
  41. """
  42. if isinstance(outputs, (tuple, list)):
  43. touched_outputs = []
  44. for output in outputs:
  45. touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
  46. touched_outputs.append(touched_output)
  47. if isinstance_namedtuple(outputs):
  48. # namedtuples require a slightly different syntax.
  49. return outputs.__class__(*touched_outputs)
  50. return outputs.__class__(touched_outputs)
  51. elif isinstance(outputs, dict):
  52. # apply inplace to avoid recreating dict inherited objects
  53. for key in outputs.keys():
  54. outputs[key] = _apply_to_tensors_only(module, functional, backward_function, outputs[key])
  55. return outputs
  56. elif isinstance(outputs, torch.Tensor):
  57. # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter
  58. touched_outputs = functional.apply(module, backward_function, outputs)
  59. # restore zero param attributes if those get stripped by `backward_function`
  60. if not is_zero_param(touched_outputs) and is_zero_param(outputs):
  61. touched_outputs.ds_param_alias = outputs
  62. return touched_outputs
  63. else:
  64. if not is_builtin_type(outputs):
  65. global warned
  66. if not warned and dist.get_rank() == 0:
  67. logger.warning(
  68. f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
  69. "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
  70. "output tensors and therefore may not get triggered properly.")
  71. warned = True
  72. return outputs
  73. #for each tensor in outputs run the forward_function and register backward_function as hook
  74. def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs):
  75. if type(outputs) is tuple:
  76. touched_outputs = []
  77. for output in outputs:
  78. touched_output = _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function,
  79. output)
  80. touched_outputs.append(touched_output)
  81. return tuple(touched_outputs)
  82. elif type(outputs) is torch.Tensor:
  83. forward_function(outputs)
  84. if outputs.requires_grad:
  85. outputs.register_hook(backward_function)
  86. return outputs
  87. else:
  88. return outputs
  89. class ZeROOrderedDict(OrderedDict):
  90. def __init__(self, parent_module, *args, **kwargs):
  91. """A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
  92. Args:
  93. parent_module (``collections.OrderedDict``): the collection to replace
  94. """
  95. super().__init__(*args, **kwargs)
  96. self._parent_module = parent_module
  97. self._in_forward = False
  98. def __getitem__(self, key):
  99. param = super().__getitem__(key)
  100. # Params can be registered as None (e.g., bias)
  101. if param is None:
  102. return param
  103. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  104. if self._parent_module._parameters._in_forward:
  105. register_external_parameter(FWD_MODULE_STACK[-1], param)
  106. param.all_gather()
  107. print_rank_0(f'Registering external parameter from getter {key} ds_id = {param.ds_id}', force=False)
  108. return param
  109. def _inject_parameters(module, cls):
  110. for module in module.modules():
  111. if cls == ZeROOrderedDict:
  112. new_param = cls(parent_module=module)
  113. else:
  114. new_param = cls()
  115. for key, param in module._parameters.items():
  116. new_param[key] = param
  117. module._parameters = new_param
  118. class PreBackwardFunction(torch.autograd.Function):
  119. @staticmethod
  120. def forward(ctx, module, pre_backward_function, outputs):
  121. ctx.module = module
  122. ctx.pre_backward_function = pre_backward_function
  123. if not hasattr(module, "applied_pre_backward_ref_cnt"):
  124. module.applied_pre_backward_ref_cnt = 0
  125. module.applied_pre_backward_ref_cnt += 1
  126. #print(f"After Forward: {ctx.module.__class__.__name__}")
  127. outputs = outputs.detach()
  128. return outputs
  129. @staticmethod
  130. def backward(ctx, *args):
  131. #print(f"Before Backward: {ctx.module.__class__.__name__}")
  132. ctx.pre_backward_function(ctx.module)
  133. return (None, None) + args
  134. class PostBackwardFunction(torch.autograd.Function):
  135. @staticmethod
  136. def forward(ctx, module, pre_backward_function, output):
  137. ctx.module = module
  138. if output.requires_grad:
  139. #TODO SOME TIMES post backward does not seem to be triggered debug in detail
  140. #Should only cause increase in memory not correctness issue
  141. #if output.grad_fn.__class__.__name__ == 'ViewBackward':
  142. # ctx.view=True
  143. # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
  144. #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
  145. #if module.ds_grads_remaining == 0:
  146. # print(f"Before Forward: {ctx.module.__class__.__name__}")
  147. module.ds_grads_remaining += 1
  148. ctx.pre_backward_function = pre_backward_function
  149. output = output.detach()
  150. return output
  151. @staticmethod
  152. def backward(ctx, *args):
  153. ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
  154. if ctx.module.ds_grads_remaining == 0:
  155. ctx.pre_backward_function(ctx.module)
  156. #print(f"After Backward: {ctx.module.__class__.__name__}")
  157. return (None, None) + args
  158. class DeepSpeedZeRoOffload(object):
  159. def __init__(
  160. self,
  161. module,
  162. timers,
  163. ds_config,
  164. overlap_comm=True,
  165. prefetch_bucket_size=50000000,
  166. max_reuse_distance=1000000000,
  167. max_live_parameters=1000000000,
  168. param_persistence_threshold=100000,
  169. model_persistence_threshold=sys.maxsize,
  170. dp_process_group=None,
  171. offload_param_config=None,
  172. mpu=None,
  173. zero_param_parallel_group=None,
  174. zero_quantized_weights=False,
  175. zero_quantized_nontrainable_weights=False,
  176. ):
  177. see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
  178. print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False)
  179. self.module = module
  180. self.timers = timers
  181. self.dtype = list(module.parameters())[0].dtype
  182. self.dp_process_group = dp_process_group
  183. self.offload_device = None
  184. self.offload_param_pin_memory = False
  185. self.zero_param_parallel_group = zero_param_parallel_group
  186. self.zero_quantized_weights = zero_quantized_weights
  187. self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
  188. if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
  189. self.offload_device = offload_param_config.device
  190. self.offload_param_pin_memory = offload_param_config.pin_memory
  191. self._convert_to_zero_parameters(ds_config, module, mpu)
  192. for m in module.modules():
  193. _init_external_params(m)
  194. _inject_parameters(module, ZeROOrderedDict)
  195. self.param_numel_persistence_threshold = int(param_persistence_threshold)
  196. self.model_persistence_threshold = int(model_persistence_threshold)
  197. self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold,
  198. self.model_persistence_threshold)
  199. self.param_coordinators = {}
  200. self._prefetch_bucket_sz = int(prefetch_bucket_size)
  201. self._max_reuse_distance_in_numel = int(max_reuse_distance)
  202. self._max_available_parameters_in_numel = int(max_live_parameters)
  203. self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream(
  204. ) if overlap_comm else get_accelerator().default_stream()
  205. if not hasattr(module, "ds_inflight_param_registry"):
  206. module.ds_inflight_param_registry = dict()
  207. # we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator
  208. module.ds_inflight_param_registry[True] = InflightParamRegistry()
  209. module.ds_inflight_param_registry[False] = InflightParamRegistry()
  210. self.__inflight_param_registry = module.ds_inflight_param_registry
  211. self.forward_hooks = []
  212. self.backward_hooks = []
  213. self.setup_zero_stage3_hooks()
  214. print_rank_0(
  215. f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}',
  216. force=False)
  217. see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True)
  218. @instrument_w_nvtx
  219. def partition_all_parameters(self):
  220. """Partitioning Parameters that were not partitioned usually if parameters
  221. of modules whose input parameters do not require grad computation do not
  222. trigger post call and will therefore will remain unpartitioned"""
  223. self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module)
  224. for param in iter_params(self.module, recurse=True):
  225. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  226. raise RuntimeError(f"{param.ds_summary()} expected to be released")
  227. def get_param_coordinator(self, training):
  228. if not training in self.param_coordinators:
  229. self.param_coordinators[training] = PartitionedParameterCoordinator(
  230. prefetch_bucket_sz=self._prefetch_bucket_sz,
  231. max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
  232. max_available_parameters_in_numel=self._max_available_parameters_in_numel,
  233. allgather_stream=self.__allgather_stream,
  234. inflight_param_registry=self.__inflight_param_registry[training],
  235. prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
  236. timers=self.timers,
  237. zero_quantized_weights=self.zero_quantized_weights,
  238. zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
  239. )
  240. return self.param_coordinators[training]
  241. def empty_partition_cache(self):
  242. self.partition_all_parameters()
  243. def _convert_to_zero_parameters(self, ds_config, module, mpu):
  244. non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
  245. if non_zero_params:
  246. zero_params = [p for p in module.parameters() if is_zero_param(p)]
  247. if zero_params:
  248. zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
  249. else:
  250. group = None
  251. if mpu:
  252. group = mpu.get_data_parallel_group()
  253. Init(module=module,
  254. data_parallel_group=group,
  255. dtype=self.dtype,
  256. config_dict_or_path=ds_config,
  257. remote_device=self.offload_device,
  258. pin_memory=self.offload_param_pin_memory,
  259. mpu=mpu,
  260. zero_param_parallel_group=self.zero_param_parallel_group,
  261. zero_quantized_weights=self.zero_quantized_weights,
  262. zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights)
  263. def destroy(self):
  264. self._remove_module_hooks()
  265. def _remove_module_hooks(self):
  266. num_forward_hooks = len(self.forward_hooks)
  267. num_backward_hooks = len(self.backward_hooks)
  268. for hook in self.forward_hooks:
  269. hook.remove()
  270. for hook in self.backward_hooks:
  271. hook.remove()
  272. print_rank_0(f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}',
  273. force=False)
  274. def setup_zero_stage3_hooks(self):
  275. self.hierarchy = 0
  276. #reset step if in inference mode
  277. @instrument_w_nvtx
  278. def _end_of_forward_hook(module, *args):
  279. if not torch._C.is_grad_enabled():
  280. self.get_param_coordinator(training=False).reset_step()
  281. #likely one of them should be enough but just to be safe
  282. self._register_hooks_recursively(self.module)
  283. self.module.register_forward_hook(_end_of_forward_hook)
  284. # Add top module to stack trace
  285. global FWD_MODULE_STACK
  286. FWD_MODULE_STACK.append(self.module)
  287. def mark_persistent_parameters(self, param_threshold, model_threshold):
  288. persistent_params = []
  289. total_persistent_parameters = 0
  290. params_count = 0
  291. for name, param in self.module.named_parameters(recurse=True):
  292. if param.ds_numel + total_persistent_parameters > model_threshold:
  293. continue
  294. if param.ds_numel <= param_threshold:
  295. params_count += 1
  296. param.ds_persist = True
  297. persistent_params.append(param)
  298. total_persistent_parameters += param.ds_numel
  299. print_rank_0(
  300. f"Parameter Offload: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
  301. force=True)
  302. return persistent_params
  303. def _register_hooks_recursively(self, module, count=[0]):
  304. my_count = count[0]
  305. module.id = my_count
  306. #print(f"{module.__class__} : {module.id}")
  307. for child in module.children():
  308. count[0] = count[0] + 1
  309. self._register_hooks_recursively(child, count=count)
  310. @instrument_w_nvtx
  311. def _pre_forward_module_hook(module, *args):
  312. self.pre_sub_module_forward_function(module)
  313. @instrument_w_nvtx
  314. def _post_forward_module_hook(module, input, output):
  315. global FWD_MODULE_STACK
  316. FWD_MODULE_STACK.pop()
  317. if output is None:
  318. output = []
  319. elif not isinstance(output, (list, tuple)):
  320. if torch.is_tensor(output):
  321. output = [output]
  322. else:
  323. #print(f'got UNKNOWN type {type(output)}')
  324. outputs = []
  325. output = output if isinstance(output, dict) else vars(output)
  326. for name, val in output.items():
  327. if not name.startswith('__') and torch.is_tensor(val):
  328. outputs.append(val)
  329. output = outputs
  330. for item in filter(lambda item: is_zero_param(item) or hasattr(item, 'ds_param_alias'), output):
  331. key = id(item) if hasattr(item, 'ds_id') else id(item.ds_param_alias)
  332. actual_external_param = item if hasattr(item, 'ds_id') else item.ds_param_alias
  333. if not any(key in m._external_params for m in FWD_MODULE_STACK):
  334. actual_external_param.is_external_param = True
  335. module_to_register = FWD_MODULE_STACK[-1]
  336. register_external_parameter(module_to_register, actual_external_param)
  337. print_rank_0(
  338. f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {actual_external_param.ds_id}.',
  339. force=False)
  340. # It's possible that the parameter was already external to the completed module. If so, remove it the
  341. # registration as it will be covered by the outer module instead.
  342. if key in module._external_params:
  343. print_rank_0(
  344. f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {actual_external_param.ds_id}',
  345. force=False)
  346. unregister_external_parameter(module, actual_external_param)
  347. actual_external_param.all_gather()
  348. self.post_sub_module_forward_function(module)
  349. def _pre_backward_module_hook(module, inputs, output):
  350. @instrument_w_nvtx
  351. def _run_before_backward_function(sub_module):
  352. # some models (e.g. Albert) may run multiple forwards on the same layer in a loop
  353. # before doing backwards, so each backward will need a pre-fetch - using reference
  354. # counting to support this scenario
  355. #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
  356. if sub_module.applied_pre_backward_ref_cnt > 0:
  357. self.pre_sub_module_backward_function(sub_module)
  358. sub_module.applied_pre_backward_ref_cnt -= 1
  359. #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
  360. return _apply_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output)
  361. #This is an alternate to doing _post_backward_module_hook
  362. #it uses tensor.register_hook instead of using torch.autograd.Function
  363. def _alternate_post_backward_module_hook(module, inputs):
  364. module.ds_grads_remaining = 0
  365. #print(f"Before Forward {module.__class__.__name__}")
  366. def _run_after_backward_hook(*unused):
  367. module.ds_grads_remaining = module.ds_grads_remaining - 1
  368. if module.ds_grads_remaining == 0:
  369. #print(f"After backward {module.__class__.__name__}")
  370. self.post_sub_module_backward_function(module)
  371. def _run_before_forward_function(input):
  372. if input.requires_grad:
  373. module.ds_grads_remaining += 1
  374. return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function,
  375. _run_after_backward_hook, inputs)
  376. def _post_backward_module_hook(module, inputs):
  377. module.ds_grads_remaining = 0
  378. @instrument_w_nvtx
  379. def _run_after_backward_function(sub_module):
  380. if sub_module.ds_grads_remaining == 0:
  381. self.post_sub_module_backward_function(sub_module)
  382. return _apply_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs)
  383. # Pre forward hook
  384. self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook))
  385. # Post forward hook
  386. self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))
  387. # Pre backward hook
  388. self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook))
  389. # post backward hook
  390. self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))
  391. def pre_sub_module_forward_function(self, sub_module):
  392. see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)
  393. prev_grad_state = torch.is_grad_enabled(
  394. ) # we don't want to enable grad for sub modules fetching, yet the subfunction need to know if grad is enabled
  395. torch.set_grad_enabled(False)
  396. global FWD_MODULE_STACK
  397. FWD_MODULE_STACK.append(sub_module)
  398. param_coordinator = self.get_param_coordinator(training=sub_module.training)
  399. param_coordinator.trace_prologue(sub_module)
  400. if param_coordinator.is_record_trace():
  401. param_coordinator.record_module(sub_module)
  402. param_coordinator.fetch_sub_module(sub_module, forward=prev_grad_state)
  403. torch.set_grad_enabled(prev_grad_state)
  404. see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
  405. @torch.no_grad()
  406. def post_sub_module_forward_function(self, sub_module):
  407. see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
  408. force=False)
  409. param_coordinator = self.get_param_coordinator(training=sub_module.training)
  410. param_coordinator.release_sub_module(sub_module, backward=False)
  411. see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
  412. force=False)
  413. @torch.no_grad()
  414. def pre_sub_module_backward_function(self, sub_module):
  415. assert sub_module.training, "backward pass is invalid for module in evaluation mode"
  416. param_coordinator = self.get_param_coordinator(training=True)
  417. param_coordinator.trace_prologue(sub_module)
  418. if param_coordinator.is_record_trace():
  419. param_coordinator.record_module(sub_module)
  420. param_coordinator.fetch_sub_module(sub_module, forward=False)
  421. @torch.no_grad()
  422. def post_sub_module_backward_function(self, sub_module):
  423. assert sub_module.training, "backward pass is invalid for module in evaluation mode"
  424. see_memory_usage(
  425. f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
  426. force=False)
  427. self.get_param_coordinator(training=True).release_sub_module(sub_module, backward=True)
  428. see_memory_usage(
  429. f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
  430. force=False)