parameter_offload.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import sys
  5. import torch
  6. from collections import OrderedDict
  7. from deepspeed.utils import z3_leaf_module
  8. from deepspeed.runtime.utils import see_memory_usage
  9. from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param
  10. from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
  11. from deepspeed.runtime.zero.partition_parameters import _init_external_params
  12. from deepspeed.runtime.zero.partition_parameters import *
  13. from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params
  14. from deepspeed.accelerator import get_accelerator
  15. FWD_MODULE_STACK = list()
  16. # ensure we only warn once, otherwise every iteration will trigger a warning
  17. warned = False
  18. #for each tensor in outputs run the forward_function and register backward_function as hook
  19. def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs):
  20. if type(outputs) is tuple:
  21. touched_outputs = []
  22. for output in outputs:
  23. touched_output = _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function,
  24. output)
  25. touched_outputs.append(touched_output)
  26. return tuple(touched_outputs)
  27. elif type(outputs) is torch.Tensor:
  28. forward_function(outputs)
  29. if outputs.requires_grad:
  30. outputs.register_hook(backward_function)
  31. return outputs
  32. else:
  33. return outputs
  34. class ZeROOrderedDict(OrderedDict):
  35. def __init__(self, parent_module, *args, **kwargs):
  36. """A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
  37. Args:
  38. parent_module (``collections.OrderedDict``): the collection to replace
  39. """
  40. super().__init__(*args, **kwargs)
  41. self._parent_module = parent_module
  42. self._in_forward = False
  43. def __getitem__(self, key):
  44. param = super().__getitem__(key)
  45. # Params can be registered as None (e.g., bias)
  46. if param is None:
  47. return param
  48. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  49. if self._parent_module._parameters._in_forward:
  50. register_external_parameter(FWD_MODULE_STACK[-1], param)
  51. param.all_gather()
  52. print_rank_0(f'Registering external parameter from getter {key} ds_id = {param.ds_id}', force=False)
  53. return param
  54. def _inject_parameters(module, cls):
  55. for module in module.modules():
  56. if cls == ZeROOrderedDict:
  57. new_param = cls(parent_module=module)
  58. else:
  59. new_param = cls()
  60. for key, param in module._parameters.items():
  61. new_param[key] = param
  62. module._parameters = new_param
  63. class DeepSpeedZeRoOffload(object):
  64. def __init__(
  65. self,
  66. module,
  67. timers,
  68. ds_config,
  69. overlap_comm=True,
  70. prefetch_bucket_size=50000000,
  71. max_reuse_distance=1000000000,
  72. max_live_parameters=1000000000,
  73. param_persistence_threshold=100000,
  74. model_persistence_threshold=sys.maxsize,
  75. dp_process_group=None,
  76. offload_param_config=None,
  77. mpu=None,
  78. zero_param_parallel_group=None,
  79. zero_quantized_weights=False,
  80. zero_quantized_nontrainable_weights=False,
  81. ):
  82. see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
  83. print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False)
  84. self.module = module
  85. self.timers = timers
  86. self.dtype = list(module.parameters())[0].dtype
  87. self.dp_process_group = dp_process_group
  88. self.offload_device = None
  89. self.offload_param_pin_memory = False
  90. self.zero_param_parallel_group = zero_param_parallel_group
  91. self.zero_quantized_weights = zero_quantized_weights
  92. self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
  93. if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
  94. self.offload_device = offload_param_config.device
  95. self.offload_param_pin_memory = offload_param_config.pin_memory
  96. self._convert_to_zero_parameters(ds_config, module, mpu)
  97. for m in module.modules():
  98. _init_external_params(m)
  99. _inject_parameters(module, ZeROOrderedDict)
  100. self.param_numel_persistence_threshold = int(param_persistence_threshold)
  101. self.model_persistence_threshold = int(model_persistence_threshold)
  102. self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold,
  103. self.model_persistence_threshold)
  104. self.param_coordinators = {}
  105. self._prefetch_bucket_sz = int(prefetch_bucket_size)
  106. self._max_reuse_distance_in_numel = int(max_reuse_distance)
  107. self._max_available_parameters_in_numel = int(max_live_parameters)
  108. self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream(
  109. ) if overlap_comm else get_accelerator().default_stream()
  110. if not hasattr(module, "ds_inflight_param_registry"):
  111. module.ds_inflight_param_registry = dict()
  112. # we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator
  113. module.ds_inflight_param_registry[True] = InflightParamRegistry()
  114. module.ds_inflight_param_registry[False] = InflightParamRegistry()
  115. self.__inflight_param_registry = module.ds_inflight_param_registry
  116. self.forward_hooks = []
  117. self.backward_hooks = []
  118. self.setup_zero_stage3_hooks()
  119. print_rank_0(
  120. f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}',
  121. force=False)
  122. see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True)
  123. @instrument_w_nvtx
  124. def partition_all_parameters(self):
  125. """Partitioning Parameters that were not partitioned usually if parameters
  126. of modules whose input parameters do not require grad computation do not
  127. trigger post call and will therefore will remain unpartitioned"""
  128. self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module)
  129. for param in iter_params(self.module, recurse=True):
  130. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  131. raise RuntimeError(f"{param.ds_summary()} expected to be released")
  132. def get_param_coordinator(self, training):
  133. if not training in self.param_coordinators:
  134. self.param_coordinators[training] = PartitionedParameterCoordinator(
  135. prefetch_bucket_sz=self._prefetch_bucket_sz,
  136. max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
  137. max_available_parameters_in_numel=self._max_available_parameters_in_numel,
  138. allgather_stream=self.__allgather_stream,
  139. inflight_param_registry=self.__inflight_param_registry[training],
  140. prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
  141. timers=self.timers,
  142. zero_quantized_weights=self.zero_quantized_weights,
  143. zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
  144. )
  145. return self.param_coordinators[training]
  146. def empty_partition_cache(self):
  147. self.partition_all_parameters()
  148. def _convert_to_zero_parameters(self, ds_config, module, mpu):
  149. non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
  150. if non_zero_params:
  151. zero_params = [p for p in module.parameters() if is_zero_param(p)]
  152. if zero_params:
  153. zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
  154. else:
  155. group = None
  156. if mpu:
  157. group = mpu.get_data_parallel_group()
  158. Init(module=module,
  159. data_parallel_group=group,
  160. dtype=self.dtype,
  161. config_dict_or_path=ds_config,
  162. remote_device=self.offload_device,
  163. pin_memory=self.offload_param_pin_memory,
  164. mpu=mpu,
  165. zero_param_parallel_group=self.zero_param_parallel_group,
  166. zero_quantized_weights=self.zero_quantized_weights,
  167. zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights)
  168. def destroy(self):
  169. self._remove_module_hooks()
  170. def _remove_module_hooks(self):
  171. num_forward_hooks = len(self.forward_hooks)
  172. num_backward_hooks = len(self.backward_hooks)
  173. for hook in self.forward_hooks:
  174. hook.remove()
  175. for hook in self.backward_hooks:
  176. hook.remove()
  177. print_rank_0(f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}',
  178. force=False)
  179. def setup_zero_stage3_hooks(self):
  180. self.hierarchy = 0
  181. #reset step if in inference mode
  182. @instrument_w_nvtx
  183. def _end_of_forward_hook(module, *args):
  184. if not torch._C.is_grad_enabled():
  185. self.get_param_coordinator(training=False).reset_step()
  186. #likely one of them should be enough but just to be safe
  187. self._register_hooks_recursively(self.module)
  188. self.module.register_forward_hook(_end_of_forward_hook)
  189. # Add top module to stack trace
  190. global FWD_MODULE_STACK
  191. FWD_MODULE_STACK.append(self.module)
  192. def mark_persistent_parameters(self, param_threshold, model_threshold):
  193. persistent_params = []
  194. total_persistent_parameters = 0
  195. params_count = 0
  196. for name, param in self.module.named_parameters(recurse=True):
  197. if param.ds_numel + total_persistent_parameters > model_threshold:
  198. continue
  199. if param.ds_numel <= param_threshold:
  200. params_count += 1
  201. param.ds_persist = True
  202. persistent_params.append(param)
  203. total_persistent_parameters += param.ds_numel
  204. print_rank_0(
  205. f"Parameter Offload: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
  206. force=True)
  207. return persistent_params
  208. def _register_hooks_recursively(self, module, count=[0]):
  209. my_count = count[0]
  210. module.id = my_count
  211. #print(f"{module.__class__} : {module.id}")
  212. if z3_leaf_module(module):
  213. for param in module.parameters():
  214. param.ds_z3_leaf_module = module
  215. else:
  216. for child in module.children():
  217. count[0] = count[0] + 1
  218. self._register_hooks_recursively(child, count=count)
  219. @instrument_w_nvtx
  220. def _pre_forward_module_hook(module, *args):
  221. self.pre_sub_module_forward_function(module)
  222. @instrument_w_nvtx
  223. def _post_forward_module_hook(module, input, output):
  224. global FWD_MODULE_STACK
  225. FWD_MODULE_STACK.pop()
  226. if output is None:
  227. output = []
  228. elif not isinstance(output, (list, tuple)):
  229. if torch.is_tensor(output):
  230. output = [output]
  231. else:
  232. #print(f'got UNKNOWN type {type(output)}')
  233. outputs = []
  234. output = output if isinstance(output, dict) else vars(output)
  235. for name, val in output.items():
  236. if not name.startswith('__') and torch.is_tensor(val):
  237. outputs.append(val)
  238. output = outputs
  239. for item in filter(lambda item: is_zero_param(item) or hasattr(item, 'ds_param_alias'), output):
  240. key = id(item) if hasattr(item, 'ds_id') else id(item.ds_param_alias)
  241. actual_external_param = item if hasattr(item, 'ds_id') else item.ds_param_alias
  242. if not any(key in m._external_params for m in FWD_MODULE_STACK):
  243. actual_external_param.is_external_param = True
  244. module_to_register = FWD_MODULE_STACK[-1]
  245. register_external_parameter(module_to_register, actual_external_param)
  246. print_rank_0(
  247. f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {actual_external_param.ds_id}.',
  248. force=False)
  249. # It's possible that the parameter was already external to the completed module. If so, remove it the
  250. # registration as it will be covered by the outer module instead.
  251. if key in module._external_params:
  252. print_rank_0(
  253. f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {actual_external_param.ds_id}',
  254. force=False)
  255. unregister_external_parameter(module, actual_external_param)
  256. actual_external_param.all_gather()
  257. self.post_sub_module_forward_function(module)
  258. def _bwd_hook_unexpected_inputs_msg(value):
  259. return f"A module has unknown inputs or outputs type ({type(value)}) and the tensors embedded in it cannot be detected. " \
  260. "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " \
  261. "output tensors and therefore may not get triggered properly."
  262. def _pre_backward_module_hook(module, inputs, output):
  263. if not hasattr(module, "pre_bwd_fn"):
  264. @instrument_w_nvtx
  265. def _run_before_backward_function(sub_module):
  266. # some models (e.g. Albert) may run multiple forwards on the same layer in a loop
  267. # before doing backwards, so each backward will need a pre-fetch - using reference
  268. # counting to support this scenario
  269. #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
  270. if sub_module.applied_pre_backward_ref_cnt > 0:
  271. self.pre_sub_module_backward_function(sub_module)
  272. sub_module.applied_pre_backward_ref_cnt -= 1
  273. #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
  274. class PreBackwardFunctionForModule(torch.autograd.Function):
  275. @staticmethod
  276. def forward(ctx, outputs):
  277. # Capture `module` and _run_before_backward_function
  278. ctx.module = module
  279. ctx.pre_backward_function = _run_before_backward_function
  280. if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
  281. ctx.module.applied_pre_backward_ref_cnt = 0
  282. ctx.module.applied_pre_backward_ref_cnt += 1
  283. outputs = outputs.detach()
  284. return outputs
  285. @staticmethod
  286. def backward(ctx, *args):
  287. ctx.pre_backward_function(ctx.module)
  288. return args
  289. module.pre_bwd_fn = PreBackwardFunctionForModule
  290. return apply_to_tensors_only(module.pre_bwd_fn.apply,
  291. output,
  292. warning_msg_fn=_bwd_hook_unexpected_inputs_msg)
  293. #This is an alternate to doing _post_backward_module_hook
  294. #it uses tensor.register_hook instead of using torch.autograd.Function
  295. def _alternate_post_backward_module_hook(module, inputs):
  296. module.ds_grads_remaining = 0
  297. #print(f"Before Forward {module.__class__.__name__}")
  298. def _run_after_backward_hook(*unused):
  299. module.ds_grads_remaining = module.ds_grads_remaining - 1
  300. if module.ds_grads_remaining == 0:
  301. #print(f"After backward {module.__class__.__name__}")
  302. self.post_sub_module_backward_function(module)
  303. def _run_before_forward_function(input):
  304. if input.requires_grad:
  305. module.ds_grads_remaining += 1
  306. return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function,
  307. _run_after_backward_hook, inputs)
  308. def _post_backward_module_hook(module, inputs):
  309. module.ds_grads_remaining = 0
  310. if not hasattr(module, "post_bwd_fn"):
  311. @instrument_w_nvtx
  312. def _run_after_backward_function(sub_module):
  313. if sub_module.ds_grads_remaining == 0:
  314. self.post_sub_module_backward_function(sub_module)
  315. class PostBackwardFunctionModule(torch.autograd.Function):
  316. @staticmethod
  317. def forward(ctx, output):
  318. ctx.module = module
  319. if output.requires_grad:
  320. #TODO SOME TIMES post backward does not seem to be triggered debug in detail
  321. #Should only cause increase in memory not correctness issue
  322. #if output.grad_fn.__class__.__name__ == 'ViewBackward':
  323. # ctx.view=True
  324. # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
  325. #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
  326. #if module.ds_grads_remaining == 0:
  327. # print(f"Before Forward: {ctx.module.__class__.__name__}")
  328. module.ds_grads_remaining += 1
  329. ctx.post_backward_function = _run_after_backward_function
  330. output = output.detach()
  331. return output
  332. @staticmethod
  333. def backward(ctx, *args):
  334. ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
  335. if ctx.module.ds_grads_remaining == 0:
  336. ctx.post_backward_function(ctx.module)
  337. return args
  338. module.post_bwd_fn = PostBackwardFunctionModule
  339. return apply_to_tensors_only(module.post_bwd_fn.apply,
  340. inputs,
  341. warning_msg_fn=_bwd_hook_unexpected_inputs_msg)
  342. # Pre forward hook
  343. self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook))
  344. # Post forward hook
  345. self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))
  346. # Pre backward hook
  347. self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook))
  348. # post backward hook
  349. self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))
  350. @torch.no_grad()
  351. def pre_sub_module_forward_function(self, sub_module):
  352. see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)
  353. global FWD_MODULE_STACK
  354. FWD_MODULE_STACK.append(sub_module)
  355. param_coordinator = self.get_param_coordinator(training=sub_module.training)
  356. param_coordinator.trace_prologue(sub_module)
  357. if param_coordinator.is_record_trace():
  358. param_coordinator.record_module(sub_module)
  359. param_coordinator.fetch_sub_module(sub_module, forward=True)
  360. see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
  361. @torch.no_grad()
  362. def post_sub_module_forward_function(self, sub_module):
  363. see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
  364. force=False)
  365. param_coordinator = self.get_param_coordinator(training=sub_module.training)
  366. param_coordinator.release_sub_module(sub_module)
  367. see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
  368. force=False)
  369. @torch.no_grad()
  370. def pre_sub_module_backward_function(self, sub_module):
  371. assert sub_module.training, "backward pass is invalid for module in evaluation mode"
  372. param_coordinator = self.get_param_coordinator(training=True)
  373. param_coordinator.trace_prologue(sub_module)
  374. if param_coordinator.is_record_trace():
  375. param_coordinator.record_module(sub_module)
  376. param_coordinator.fetch_sub_module(sub_module, forward=False)
  377. @torch.no_grad()
  378. def post_sub_module_backward_function(self, sub_module):
  379. assert sub_module.training, "backward pass is invalid for module in evaluation mode"
  380. see_memory_usage(
  381. f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
  382. force=False)
  383. self.get_param_coordinator(training=True).release_sub_module(sub_module)
  384. see_memory_usage(
  385. f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
  386. force=False)