parameter_offload.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. """
  2. "Copyright 2022 The Microsoft DeepSpeed Team.
  3. Licensed under the MIT license.
  4. """
  5. import sys
  6. import torch
  7. from collections import OrderedDict
  8. from deepspeed.runtime.utils import see_memory_usage
  9. from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
  10. from deepspeed.runtime.zero.partition_parameters import _init_external_params
  11. from deepspeed.runtime.zero.partition_parameters import *
  12. from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params
  13. from deepspeed import comm as dist
  14. from deepspeed.accelerator import get_accelerator
  15. FWD_MODULE_STACK = list()
  16. def is_builtin_type(obj):
  17. # https://stackoverflow.com/a/17795199
  18. return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"
  19. # ensure we only warn once, otherwise every iteration will trigger a warning
  20. warned = False
  21. #apply torch.autograd.Function that calls a backward_function to tensors in output
  22. def _apply_to_tensors_only(module, functional, backward_function, outputs):
  23. if isinstance(outputs, (tuple, list)):
  24. touched_outputs = []
  25. for output in outputs:
  26. touched_output = _apply_to_tensors_only(module,
  27. functional,
  28. backward_function,
  29. output)
  30. touched_outputs.append(touched_output)
  31. return outputs.__class__(touched_outputs)
  32. elif isinstance(outputs, dict):
  33. # apply inplace to avoid recreating dict inherited objects
  34. for key in outputs.keys():
  35. outputs[key] = _apply_to_tensors_only(module,
  36. functional,
  37. backward_function,
  38. outputs[key])
  39. return outputs
  40. elif isinstance(outputs, torch.Tensor):
  41. # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter
  42. touched_outputs = functional.apply(module, backward_function, outputs)
  43. # restore zero param attributes if those get stripped by `backward_function`
  44. if not is_zero_param(touched_outputs) and is_zero_param(outputs):
  45. touched_outputs.ds_param_alias = outputs
  46. return touched_outputs
  47. else:
  48. if not is_builtin_type(outputs):
  49. global warned
  50. if not warned and dist.get_rank() == 0:
  51. logger.warning(
  52. f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
  53. "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
  54. "output tensors and therefore may not get triggered properly.")
  55. warned = True
  56. return outputs
  57. #for each tensor in outputs run the forward_function and register backward_function as hook
  58. def _apply_forward_and_backward_to_tensors_only(module,
  59. forward_function,
  60. backward_function,
  61. outputs):
  62. if type(outputs) is tuple:
  63. touched_outputs = []
  64. for output in outputs:
  65. touched_output = _apply_forward_and_backward_to_tensors_only(
  66. module,
  67. forward_function,
  68. backward_function,
  69. output)
  70. touched_outputs.append(touched_output)
  71. return tuple(touched_outputs)
  72. elif type(outputs) is torch.Tensor:
  73. forward_function(outputs)
  74. if outputs.requires_grad:
  75. outputs.register_hook(backward_function)
  76. return outputs
  77. else:
  78. return outputs
  79. class ZeROOrderedDict(OrderedDict):
  80. def __init__(self, parent_module, *args, **kwargs):
  81. """A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
  82. Args:
  83. parent_module (``collections.OrderedDict``): the collection to replace
  84. """
  85. super().__init__(*args, **kwargs)
  86. self._parent_module = parent_module
  87. self._in_forward = False
  88. def __getitem__(self, key):
  89. param = super().__getitem__(key)
  90. # Params can be registered as None (e.g., bias)
  91. if param is None:
  92. return param
  93. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  94. if self._parent_module._parameters._in_forward:
  95. register_external_parameter(FWD_MODULE_STACK[-1], param)
  96. param.all_gather()
  97. print_rank_0(
  98. f'Registering external parameter from getter {key} ds_id = {param.ds_id}',
  99. force=False)
  100. return param
  101. def _inject_parameters(module, cls):
  102. for module in module.modules():
  103. if cls == ZeROOrderedDict:
  104. new_param = cls(parent_module=module)
  105. else:
  106. new_param = cls()
  107. for key, param in module._parameters.items():
  108. new_param[key] = param
  109. module._parameters = new_param
  110. class PreBackwardFunction(torch.autograd.Function):
  111. @staticmethod
  112. def forward(ctx, module, pre_backward_function, outputs):
  113. ctx.module = module
  114. ctx.pre_backward_function = pre_backward_function
  115. if not hasattr(module, "applied_pre_backward_ref_cnt"):
  116. module.applied_pre_backward_ref_cnt = 0
  117. module.applied_pre_backward_ref_cnt += 1
  118. #print(f"After Forward: {ctx.module.__class__.__name__}")
  119. outputs = outputs.detach()
  120. return outputs
  121. @staticmethod
  122. def backward(ctx, *args):
  123. #print(f"Before Backward: {ctx.module.__class__.__name__}")
  124. ctx.pre_backward_function(ctx.module)
  125. return (None, None) + args
  126. class PostBackwardFunction(torch.autograd.Function):
  127. @staticmethod
  128. def forward(ctx, module, pre_backward_function, output):
  129. ctx.module = module
  130. if output.requires_grad:
  131. #TODO SOME TIMES post backward does not seem to be triggered debug in detail
  132. #Should only cause increase in memory not correctness issue
  133. #if output.grad_fn.__class__.__name__ == 'ViewBackward':
  134. # ctx.view=True
  135. # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
  136. #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
  137. #if module.ds_grads_remaining == 0:
  138. # print(f"Before Forward: {ctx.module.__class__.__name__}")
  139. module.ds_grads_remaining += 1
  140. ctx.pre_backward_function = pre_backward_function
  141. output = output.detach()
  142. return output
  143. @staticmethod
  144. def backward(ctx, *args):
  145. ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
  146. if ctx.module.ds_grads_remaining == 0:
  147. ctx.pre_backward_function(ctx.module)
  148. #print(f"After Backward: {ctx.module.__class__.__name__}")
  149. return (None, None) + args
  150. class DeepSpeedZeRoOffload(object):
  151. def __init__(self,
  152. module,
  153. timers,
  154. ds_config,
  155. overlap_comm=True,
  156. prefetch_bucket_size=50000000,
  157. max_reuse_distance=1000000000,
  158. max_live_parameters=1000000000,
  159. param_persistence_threshold=100000,
  160. model_persistence_threshold=sys.maxsize,
  161. offload_param_config=None,
  162. mpu=None):
  163. see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
  164. print_rank_0(f"initialized {__class__.__name__} with args: {locals()}",
  165. force=False)
  166. self.module = module
  167. self.dtype = list(module.parameters())[0].dtype
  168. self.offload_device = None
  169. self.offload_param_pin_memory = False
  170. if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
  171. self.offload_device = offload_param_config.device
  172. self.offload_param_pin_memory = offload_param_config.pin_memory
  173. self._convert_to_zero_parameters(ds_config, module, mpu)
  174. for m in module.modules():
  175. _init_external_params(m)
  176. _inject_parameters(module, ZeROOrderedDict)
  177. self.param_numel_persistence_threshold = int(param_persistence_threshold)
  178. self.model_persistence_threshold = int(model_persistence_threshold)
  179. self.persistent_parameters = self.mark_persistent_parameters(
  180. self.param_numel_persistence_threshold,
  181. self.model_persistence_threshold)
  182. self.param_coordinators = {}
  183. self._prefetch_bucket_sz = int(prefetch_bucket_size)
  184. self._max_reuse_distance_in_numel = int(max_reuse_distance)
  185. self._max_available_parameters_in_numel = int(max_live_parameters)
  186. self.__allgather_stream = get_accelerator().Stream(
  187. ) if overlap_comm else get_accelerator().default_stream()
  188. self.forward_hooks = []
  189. self.backward_hooks = []
  190. self.setup_zero_stage3_hooks()
  191. print_rank_0(
  192. f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}',
  193. force=False)
  194. see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True)
  195. @instrument_w_nvtx
  196. def partition_all_parameters(self):
  197. """Partitioning Parameters that were not partitioned usually if parameters
  198. of modules whose input parameters do not require grad computation do not
  199. trigger post call and will therefore will remain unpartitioned"""
  200. self.get_param_coordinator(training=self.module.training).release_and_reset_all(
  201. self.module)
  202. for param in iter_params(self.module, recurse=True):
  203. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  204. raise RuntimeError(f"{param.ds_summary()} expected to be released")
  205. def get_param_coordinator(self, training):
  206. if not training in self.param_coordinators:
  207. self.param_coordinators[training] = PartitionedParameterCoordinator(
  208. prefetch_bucket_sz=self._prefetch_bucket_sz,
  209. max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
  210. max_available_parameters_in_numel=self.
  211. _max_available_parameters_in_numel,
  212. allgather_stream=self.__allgather_stream,
  213. prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
  214. )
  215. return self.param_coordinators[training]
  216. def _convert_to_zero_parameters(self, ds_config, module, mpu):
  217. non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
  218. if non_zero_params:
  219. zero_params = [p for p in module.parameters() if is_zero_param(p)]
  220. if zero_params:
  221. zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
  222. else:
  223. group = None
  224. if mpu:
  225. group = mpu.get_data_parallel_group()
  226. Init(module=module,
  227. data_parallel_group=group,
  228. dtype=self.dtype,
  229. config_dict_or_path=ds_config,
  230. remote_device=self.offload_device,
  231. pin_memory=self.offload_param_pin_memory,
  232. mpu=mpu)
  233. def destroy(self):
  234. self._remove_module_hooks()
  235. def _remove_module_hooks(self):
  236. num_forward_hooks = len(self.forward_hooks)
  237. num_backward_hooks = len(self.backward_hooks)
  238. for hook in self.forward_hooks:
  239. hook.remove()
  240. for hook in self.backward_hooks:
  241. hook.remove()
  242. print_rank_0(
  243. f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}',
  244. force=False)
  245. def setup_zero_stage3_hooks(self):
  246. self.hierarchy = 0
  247. #reset step if in inference mode
  248. @instrument_w_nvtx
  249. def _end_of_forward_hook(module, *args):
  250. if not torch._C.is_grad_enabled():
  251. self.get_param_coordinator(training=False).reset_step()
  252. #likely one of them should be enough but just to be safe
  253. self._register_hooks_recursively(self.module)
  254. self.module.register_forward_hook(_end_of_forward_hook)
  255. # Add top module to stack trace
  256. global FWD_MODULE_STACK
  257. FWD_MODULE_STACK.append(self.module)
  258. def mark_persistent_parameters(self, param_threshold, model_threshold):
  259. persistent_params = []
  260. total_persistent_parameters = 0
  261. params_count = 0
  262. for _, param in self.module.named_parameters(recurse=True):
  263. if param.ds_numel + total_persistent_parameters > model_threshold:
  264. continue
  265. if param.ds_numel < param_threshold:
  266. params_count += 1
  267. param.ds_persist = True
  268. persistent_params.append(param)
  269. total_persistent_parameters += param.ds_numel
  270. print_rank_0(
  271. f"Parameter Offload: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
  272. force=True)
  273. return persistent_params
  274. def _register_hooks_recursively(self, module, count=[0]):
  275. my_count = count[0]
  276. module.id = my_count
  277. #print(f"{module.__class__} : {module.id}")
  278. for child in module.children():
  279. count[0] = count[0] + 1
  280. self._register_hooks_recursively(child, count=count)
  281. @instrument_w_nvtx
  282. def _pre_forward_module_hook(module, *args):
  283. self.pre_sub_module_forward_function(module)
  284. @instrument_w_nvtx
  285. def _post_forward_module_hook(module, input, output):
  286. global FWD_MODULE_STACK
  287. FWD_MODULE_STACK.pop()
  288. if output is None:
  289. output = []
  290. elif not isinstance(output, (list, tuple)):
  291. if torch.is_tensor(output):
  292. output = [output]
  293. else:
  294. #print(f'got UNKNOWN type {type(output)}')
  295. outputs = []
  296. output = output if isinstance(output, dict) else vars(output)
  297. for name, val in output.items():
  298. if not name.startswith('__') and torch.is_tensor(val):
  299. outputs.append(val)
  300. output = outputs
  301. for item in filter(
  302. lambda item: is_zero_param(item) or hasattr(item,
  303. 'ds_param_alias'),
  304. output):
  305. key = id(item) if hasattr(item, 'ds_id') else id(item.ds_param_alias)
  306. actual_external_param = item if hasattr(item,
  307. 'ds_id') else item.ds_param_alias
  308. if not any(key in m._external_params for m in FWD_MODULE_STACK):
  309. actual_external_param.is_external_param = True
  310. module_to_register = FWD_MODULE_STACK[-1]
  311. register_external_parameter(module_to_register,
  312. actual_external_param)
  313. print_rank_0(
  314. f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {actual_external_param.ds_id}.',
  315. force=False)
  316. # It's possible that the parameter was already external to the completed module. If so, remove it the
  317. # registration as it will be covered by the outer module instead.
  318. if key in module._external_params:
  319. print_rank_0(
  320. f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {actual_external_param.ds_id}',
  321. force=False)
  322. unregister_external_parameter(module, actual_external_param)
  323. actual_external_param.all_gather()
  324. self.post_sub_module_forward_function(module)
  325. def _pre_backward_module_hook(module, inputs, output):
  326. @instrument_w_nvtx
  327. def _run_before_backward_function(sub_module):
  328. # some models (e.g. Albert) may run multiple forwards on the same layer in a loop
  329. # before doing backwards, so each backward will need a pre-fetch - using reference
  330. # counting to support this scenario
  331. #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
  332. if sub_module.applied_pre_backward_ref_cnt > 0:
  333. self.pre_sub_module_backward_function(sub_module)
  334. sub_module.applied_pre_backward_ref_cnt -= 1
  335. #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
  336. return _apply_to_tensors_only(module,
  337. PreBackwardFunction,
  338. _run_before_backward_function,
  339. output)
  340. #This is an alternate to doing _post_backward_module_hook
  341. #it uses tensor.register_hook instead of using torch.autograd.Function
  342. def _alternate_post_backward_module_hook(module, inputs):
  343. module.ds_grads_remaining = 0
  344. #print(f"Before Forward {module.__class__.__name__}")
  345. def _run_after_backward_hook(*unused):
  346. module.ds_grads_remaining = module.ds_grads_remaining - 1
  347. if module.ds_grads_remaining == 0:
  348. #print(f"After backward {module.__class__.__name__}")
  349. self.post_sub_module_backward_function(module)
  350. def _run_before_forward_function(input):
  351. if input.requires_grad:
  352. module.ds_grads_remaining += 1
  353. return _apply_forward_and_backward_to_tensors_only(
  354. module,
  355. _run_before_forward_function,
  356. _run_after_backward_hook,
  357. inputs)
  358. def _post_backward_module_hook(module, inputs):
  359. module.ds_grads_remaining = 0
  360. @instrument_w_nvtx
  361. def _run_after_backward_function(sub_module):
  362. if sub_module.ds_grads_remaining == 0:
  363. self.post_sub_module_backward_function(sub_module)
  364. return _apply_to_tensors_only(module,
  365. PostBackwardFunction,
  366. _run_after_backward_function,
  367. inputs)
  368. # Pre forward hook
  369. self.forward_hooks.append(
  370. module.register_forward_pre_hook(_pre_forward_module_hook))
  371. # Post forward hook
  372. self.forward_hooks.append(
  373. module.register_forward_hook(_post_forward_module_hook))
  374. # Pre backward hook
  375. self.backward_hooks.append(
  376. module.register_forward_hook(_pre_backward_module_hook))
  377. # post backward hook
  378. self.backward_hooks.append(
  379. module.register_forward_pre_hook(_post_backward_module_hook))
  380. @torch.no_grad()
  381. def pre_sub_module_forward_function(self, sub_module):
  382. see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}",
  383. force=False)
  384. global FWD_MODULE_STACK
  385. FWD_MODULE_STACK.append(sub_module)
  386. param_coordinator = self.get_param_coordinator(training=sub_module.training)
  387. param_coordinator.trace_prologue(sub_module)
  388. if param_coordinator.is_record_trace():
  389. param_coordinator.record_module(sub_module)
  390. param_coordinator.fetch_sub_module(sub_module)
  391. see_memory_usage(
  392. f"Before sub module function {sub_module.__class__.__name__} after fetch",
  393. force=False)
  394. @torch.no_grad()
  395. def post_sub_module_forward_function(self, sub_module):
  396. see_memory_usage(
  397. f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
  398. force=False)
  399. param_coordinator = self.get_param_coordinator(training=sub_module.training)
  400. param_coordinator.release_sub_module(sub_module)
  401. see_memory_usage(
  402. f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
  403. force=False)
  404. @torch.no_grad()
  405. def pre_sub_module_backward_function(self, sub_module):
  406. param_coordinator = self.get_param_coordinator(training=sub_module.training)
  407. param_coordinator.trace_prologue(sub_module)
  408. if param_coordinator.is_record_trace():
  409. param_coordinator.record_module(sub_module)
  410. param_coordinator.fetch_sub_module(sub_module)
  411. @torch.no_grad()
  412. def post_sub_module_backward_function(self, sub_module):
  413. see_memory_usage(
  414. f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
  415. force=False)
  416. self.get_param_coordinator(
  417. training=sub_module.training).release_sub_module(sub_module)
  418. see_memory_usage(
  419. f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
  420. force=False)