partition_parameters.py 86 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import math
  5. import os
  6. import types
  7. from typing import Callable, Iterable
  8. from enum import Enum
  9. import functools
  10. import itertools
  11. from typing import List
  12. import logging
  13. import torch
  14. from torch import Tensor
  15. from deepspeed import comm as dist
  16. from torch.nn import Module
  17. from torch.nn import Parameter
  18. from .linear import zero3_linear_wrap
  19. from deepspeed.utils import groups
  20. import deepspeed
  21. from ..utils import get_only_unique_item, see_memory_usage
  22. from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
  23. from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks
  24. from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
  25. from deepspeed.runtime.config_utils import get_config_default
  26. from deepspeed.utils import instrument_w_nvtx, logger
  27. from deepspeed.comm.comm import init_distributed
  28. from deepspeed.utils.debug import (debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name,
  29. debug_param2name_id, debug_param2name_id_shape_status)
  30. from deepspeed.accelerator import get_accelerator
  31. from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
  32. partitioned_param_data_shape = [0]
  33. zero_init_context = 0
  34. top_level_context = None
  35. class NoGatherHandle:
  36. def __init__(self, param: Parameter) -> None:
  37. if param.ds_status != ZeroParamStatus.INFLIGHT:
  38. raise RuntimeError(f"expected param {param.ds_summary()} to be available")
  39. param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(),
  40. non_blocking=True).view(param.ds_shape)
  41. self.__param = param
  42. def wait(self) -> None:
  43. get_accelerator().current_stream().synchronize()
  44. self.__param.ds_status = ZeroParamStatus.AVAILABLE
  45. class NoGatherCoalescedHandle:
  46. def __init__(self, params: List[Parameter]) -> None:
  47. self.__params = params
  48. self.__complete = False
  49. for param in self.__params:
  50. if param.ds_status != ZeroParamStatus.INFLIGHT:
  51. raise RuntimeError(f"expected param {param.ds_summary()} to not be available")
  52. param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(),
  53. non_blocking=True).view(param.ds_shape)
  54. @instrument_w_nvtx
  55. def wait(self) -> None:
  56. if self.__complete:
  57. return
  58. get_accelerator().current_stream().synchronize()
  59. for param in self.__params:
  60. assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
  61. param.ds_status = ZeroParamStatus.AVAILABLE
  62. self.__complete = True
  63. def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):
  64. return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True)
  65. def print_rank_0(message, debug=False, force=False):
  66. rank = dist.get_rank()
  67. if rank == 0 and (debug or force):
  68. print(message)
  69. # other variations
  70. # - print for all ranks w/o interleaving
  71. # printflock(f"[{rank}] {message}")
  72. # - print to log file per rank
  73. # log_rank_file(rank, message)
  74. def debug_rank0(msg: str) -> None:
  75. if dist.get_rank() == 0:
  76. logger.debug(msg)
  77. def is_zero_param(parameter):
  78. if not torch.is_tensor(parameter):
  79. return False
  80. return hasattr(parameter, 'ds_id')
  81. def _init_external_params(module):
  82. if not hasattr(module, '_external_params'):
  83. module._external_params = {}
  84. def external_parameters(self):
  85. return self._external_params.items()
  86. def all_parameters(self):
  87. return itertools.chain(self.named_parameters(self, recurse=False), external_parameters(self))
  88. module.ds_external_parameters = types.MethodType(external_parameters, module)
  89. module.all_parameters = types.MethodType(all_parameters, module)
  90. def register_external_parameter(module, parameter):
  91. """Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in
  92. the forward and backward passes of ``module``.
  93. This is used when a parameter is accessed outside of its owning module's
  94. ``forward()``. DeepSpeed must know to collect it from its partitioned
  95. state and when to release the memory.
  96. .. note::
  97. This is only applicable to training with ZeRO stage 3.
  98. Args:
  99. module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass.
  100. parameter (``torch.nn.Parameter``): The parameter to register.
  101. Raises:
  102. RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
  103. Examples
  104. ========
  105. #. Register a weight that is used in another module's forward pass (line 6).
  106. Parameter ``layer1.weight`` is used by ``layer2`` (line 11).
  107. .. code-block:: python
  108. :linenos:
  109. :emphasize-lines: 6,11
  110. class ModuleZ3(torch.nn.Module):
  111. def __init__(self, *args):
  112. super().__init__(self, *args)
  113. self.layer1 = SomeLayer()
  114. self.layer2 = OtherLayer()
  115. deepspeed.zero.register_external_parameter(self, self.layer1.weight)
  116. def forward(self, input):
  117. x = self.layer1(input)
  118. # self.layer1.weight is required by self.layer2.forward
  119. y = self.layer2(x, self.layer1.weight)
  120. return y
  121. """
  122. if not isinstance(parameter, torch.nn.Parameter):
  123. raise RuntimeError('Parameter is not a torch.nn.Parameter')
  124. if not hasattr(module, '_external_params'):
  125. _init_external_params(module)
  126. key = id(parameter)
  127. module._external_params[key] = parameter
  128. def unregister_external_parameter(module, parameter):
  129. """Reverses the effects of :meth:`register_external_parameter`.
  130. Args:
  131. module (``torch.nn.Module``): The module to affect.
  132. parameter (``torch.nn.Parameter``): The parameter to unregister.
  133. Raises:
  134. RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
  135. RuntimeError: If ``parameter`` is not a registered external parameter of ``module``.
  136. """
  137. if not isinstance(parameter, torch.nn.Parameter):
  138. raise RuntimeError('Parameter is not a torch.nn.Parameter')
  139. if not hasattr(module, '_external_params') or id(parameter) not in module._external_params:
  140. raise RuntimeError('Parameter is not a registered external parameter of module.')
  141. key = id(parameter)
  142. del module._external_params[key]
  143. class ZeroParamType(Enum):
  144. # same as regular pytorch parameters
  145. NORMAL = 1
  146. # parameters are partitioned across data parallel process
  147. PARTITIONED = 2
  148. # the parameter is held with a unique process rank
  149. # and is not available on all other process
  150. REMOTE = 3
  151. class ZeroParamStatus(Enum):
  152. # parameters are fully present and ready for use on all processes
  153. AVAILABLE = 1
  154. # parameters are either partitioned or remote in some or all process
  155. NOT_AVAILABLE = 2
  156. # parameters are being gathered.
  157. INFLIGHT = 3
  158. _orig_torch_tensor = torch.tensor
  159. _orig_torch_empty = torch.empty
  160. _orig_torch_zeros = torch.zeros
  161. _orig_torch_ones = torch.ones
  162. _orig_torch_full = torch.full
  163. _orig_torch_arange = torch.arange
  164. _orig_torch_eye = torch.eye
  165. _orig_torch_randn = torch.randn
  166. def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable:
  167. def wrapped_fn(*args, **kwargs) -> Tensor:
  168. if kwargs.get("device", None) is None:
  169. kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
  170. tensor: Tensor = fn(*args, **kwargs)
  171. if tensor.is_floating_point():
  172. tensor = tensor.to(target_fp_dtype)
  173. return tensor
  174. return wrapped_fn
  175. def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable:
  176. def new_tensor(cls, *args) -> Tensor:
  177. device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
  178. tensor = _orig_torch_empty(0, device=device).new_empty(*args)
  179. if tensor.is_floating_point():
  180. tensor = tensor.to(dtype)
  181. return tensor
  182. return new_tensor
  183. # https://stackoverflow.com/a/63851681/9201239
  184. def get_all_subclasses(cls):
  185. subclass_list = []
  186. def recurse(cl):
  187. for subclass in cl.__subclasses__():
  188. subclass_list.append(subclass)
  189. recurse(subclass)
  190. recurse(cls)
  191. return set(subclass_list)
  192. @instrument_w_nvtx
  193. def free_param(param: Parameter) -> None:
  194. """Free underlying storage of a parameter."""
  195. assert not param.ds_active_sub_modules, param.ds_summary()
  196. if get_accelerator().on_accelerator(param.data):
  197. # need to make sure that we don't free the parameter while it is still
  198. # being used for computation
  199. if not get_accelerator().is_synchronized_device():
  200. param.data.record_stream(get_accelerator().current_stream())
  201. # param.data doesn't store anything meaningful in partitioned state
  202. param.data = torch.empty(0, dtype=param.dtype, device=param.device)
  203. param.ds_status = ZeroParamStatus.NOT_AVAILABLE
  204. reuse_buffers = False
  205. temp_contiguous_tensor = None
  206. empty_buffers = {}
  207. # Inserts _post_init_method at the end of init method
  208. # for all sub classes of torch.nn.Module
  209. class InsertPostInitMethodToModuleSubClasses(object):
  210. num_module_parameters = 0
  211. num_module_elements = 0
  212. def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtype=None):
  213. self.mem_efficient_linear = mem_efficient_linear
  214. self.enabled = enabled
  215. self._set_dtype(ds_config, dtype)
  216. assert self.dtype in [
  217. torch.half, torch.bfloat16, torch.float
  218. ], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]"
  219. self.wrapped_cls = set()
  220. def __enter__(self):
  221. if not self.enabled:
  222. return
  223. global zero_init_context
  224. if zero_init_context == 0:
  225. self.patch_init_and_builtins()
  226. global top_level_context
  227. top_level_context = self
  228. zero_init_context += 1
  229. def __exit__(self, exc_type, exc_value, traceback):
  230. if not self.enabled:
  231. return
  232. global zero_init_context
  233. zero_init_context -= 1
  234. # Exiting the top level context
  235. if zero_init_context == 0:
  236. self.unpatch_init_and_builtins()
  237. global top_level_context
  238. top_level_context = None
  239. if dist.get_rank() == 0:
  240. billion_elems = InsertPostInitMethodToModuleSubClasses.num_module_elements / 1e9
  241. num_params = InsertPostInitMethodToModuleSubClasses.num_module_parameters
  242. logger.info(
  243. f"finished initializing model - num_params = {num_params}, num_elems = {billion_elems:.2f}B")
  244. # Now that we cleaned up the metaclass injection, raise the exception.
  245. if exc_type is not None:
  246. return False
  247. # To be implemented by inheriting classes
  248. def _post_init_method(self, module):
  249. pass
  250. def _set_dtype(self, ds_config, dtype):
  251. if ds_config is not None and dtype is None:
  252. if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
  253. raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")
  254. if ds_config.bfloat16_enabled:
  255. self.dtype = torch.bfloat16
  256. elif ds_config.fp16_enabled:
  257. self.dtype = torch.half
  258. else:
  259. self.dtype = torch.float
  260. else:
  261. self.dtype = dtype or torch.half
  262. def patch_init_and_builtins(self):
  263. def apply_with_gather(orig_module_apply_fn: Callable) -> Callable:
  264. """many models make use of child modules like Linear or Embedding which
  265. perform their own weight initialization in their __init__ methods,
  266. but will then have more weight initialization in a parent module's __init__
  267. method that modifies weights of child modules, which is typically done
  268. using the Module.apply method.
  269. since the Init context manager partitions child modules immediately after
  270. they are initialized, without modifying apply we would entirely skip
  271. any initialization done by parent modules.
  272. to get around this issue, we wrap the function passed to Module.apply
  273. so that the applied function is applied to child modules correctly.
  274. """
  275. def get_wrapped_fn_to_apply(fn_to_apply: Callable) -> Callable:
  276. if hasattr(fn_to_apply, "wrapped"):
  277. return fn_to_apply
  278. @functools.wraps(fn_to_apply)
  279. def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None:
  280. """gathers parameters before calling apply function. afterwards
  281. parameters are broadcasted to ensure consistency across all ranks
  282. then re-partitioned.
  283. takes the following steps:
  284. 1. allgathers parameters for the current module being worked on
  285. 2. calls the original function
  286. 3. broadcasts root rank's parameters to the other ranks
  287. 4. re-partitions the parameters
  288. """
  289. # TODO Delay error checking for dangling partitioned parameters to post module init
  290. # raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, "
  291. # f"were zero params, is it possible that the parameters were "
  292. # f"overwritten after they were initialized? "
  293. # f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ")
  294. params_to_apply_fn_to: Iterable[Parameter] = list(
  295. sorted([p for p in module_to_apply_fn_to.parameters(recurse=False) if is_zero_param(p)],
  296. key=lambda p: p.ds_id))
  297. for param in params_to_apply_fn_to:
  298. param.all_gather()
  299. fn_to_apply(module_to_apply_fn_to)
  300. for param in params_to_apply_fn_to:
  301. dist.broadcast(param.data, 0, group=param.ds_process_group)
  302. for param in params_to_apply_fn_to:
  303. param.partition(has_been_updated=True)
  304. wrapped_fn_to_apply.wrapped = True
  305. return wrapped_fn_to_apply
  306. @functools.wraps(orig_module_apply_fn)
  307. def wrapped_apply(module: Module, fn_to_apply: Callable) -> None:
  308. orig_module_apply_fn(module, get_wrapped_fn_to_apply(fn_to_apply))
  309. return wrapped_apply
  310. def partition_after(f):
  311. @functools.wraps(f)
  312. def wrapper(module, *args, **kwargs):
  313. # important logic: We want to run post_init only after child's __init__ is
  314. # completed, and do nothing after __init__ of any of its parents and grandparents in
  315. # the inheritance ancestry. This way the partitioning will need to happen only once
  316. # when the whole object is ready to be partitioned and not before. This is because
  317. # often the child module will need to tweak the weights - for example running a
  318. # custom weights init function. So if a parent created the weights param, the child
  319. # won't need to gather it in order to tweak it
  320. print_rank_0(f'Before initializing {module.__class__.__name__}', force=False)
  321. is_child_module = False
  322. if not hasattr(module, "_ds_child_entered"):
  323. # child's __init__ was called, since parents all see the same object they can now skip post_init
  324. is_child_module = True
  325. setattr(module, "_ds_child_entered", True)
  326. f(module, *args, **kwargs)
  327. if is_child_module:
  328. # child's __init__ is done, now we can run a single post_init on the child object
  329. delattr(module, "_ds_child_entered")
  330. print_rank_0(f'Running post_init for {module.__class__.__name__}', force=False)
  331. self._post_init_method(module)
  332. print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False)
  333. return wrapper
  334. def _enable_class(cls):
  335. cls._old_init = cls.__init__
  336. cls.__init__ = partition_after(cls.__init__)
  337. def _init_subclass(cls, **kwargs):
  338. cls._old_init = cls.__init__
  339. cls.__init__ = partition_after(cls.__init__)
  340. # Replace .__init__() for all existing subclasses of torch.nn.Module recursively
  341. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  342. _enable_class(subclass)
  343. # holding onto some methods so we can put them back the way they were in __exit__
  344. torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
  345. torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply
  346. torch.Tensor.__old_new__ = torch.Tensor.__new__
  347. # Replace .__init__() for future subclasses of torch.nn.Module
  348. torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
  349. if Init.override_module_apply:
  350. torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
  351. self._add_tensor_creation_wrappers()
  352. if self.mem_efficient_linear:
  353. print_rank_0(
  354. "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
  355. force=False)
  356. self.linear_bk = torch.nn.functional.linear
  357. torch.nn.functional.linear = zero3_linear_wrap
  358. self.patched = True
  359. def unpatch_init_and_builtins(self):
  360. if self.patched:
  361. def _disable_class(cls):
  362. cls.__init__ = cls._old_init
  363. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  364. _disable_class(subclass)
  365. # putting methods back the way we found them
  366. torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
  367. if Init.override_module_apply:
  368. torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
  369. self._remove_tensor_creation_wrappers()
  370. self.patched = False
  371. def _add_tensor_creation_wrappers(self):
  372. torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
  373. torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype)
  374. torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype)
  375. torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype)
  376. torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
  377. torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
  378. torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype)
  379. torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype)
  380. torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, self.dtype)
  381. def _remove_tensor_creation_wrappers(self):
  382. torch.Tensor.__new__ = torch.Tensor.__old_new__
  383. torch.tensor = _orig_torch_tensor
  384. torch.empty = _orig_torch_empty
  385. torch.zeros = _orig_torch_zeros
  386. torch.ones = _orig_torch_ones
  387. torch.full = _orig_torch_full
  388. torch.arange = _orig_torch_arange
  389. torch.eye = _orig_torch_eye
  390. torch.randn = _orig_torch_randn
  391. def shutdown_init_context():
  392. """
  393. This function is used to initialize deepspeed engine inside the context of Init.
  394. We need to remove the wrappers but keep the context.
  395. """
  396. if top_level_context:
  397. top_level_context.unpatch_init_and_builtins()
  398. def restore_init_context():
  399. """
  400. This function is used to restore the wrappers after deepspeed engine is initialized.
  401. """
  402. if top_level_context:
  403. top_level_context.patch_init_and_builtins()
  404. class AllGatherHandle:
  405. def __init__(self, handle, param: Parameter, quantization=None) -> None:
  406. if param.ds_status != ZeroParamStatus.INFLIGHT:
  407. raise RuntimeError(f"expected param {param.ds_summary()} to be available")
  408. self.__handle = handle
  409. self.__param = param
  410. self.__quantization = quantization
  411. def wait(self) -> None:
  412. instrument_w_nvtx(self.__handle.wait)()
  413. if self.__quantization:
  414. instrument_w_nvtx(self.__quantization.quant_handle.wait)()
  415. self.__param.data = self.__quantization.backend.dequantize(
  416. self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device)
  417. self.__param.ds_status = ZeroParamStatus.AVAILABLE
  418. class AllGatherCoalescedHandle:
  419. def __init__(
  420. self,
  421. allgather_handle,
  422. params: List[Parameter],
  423. partitions: List[Tensor],
  424. world_size: int,
  425. use_secondary_tensor=False,
  426. forward=False,
  427. quantization=None,
  428. ) -> None:
  429. self.allgather_handle = allgather_handle
  430. self.params = params
  431. self.partitions = partitions
  432. self.world_size = world_size
  433. self.use_secondary_tensor = use_secondary_tensor
  434. self.forward = forward
  435. self.complete = False
  436. self.quantization = quantization
  437. for param in self.params:
  438. if param.ds_status != ZeroParamStatus.INFLIGHT:
  439. raise RuntimeError(f"expected param {param.ds_summary()} to not be available")
  440. @instrument_w_nvtx
  441. def wait(self) -> None:
  442. if self.complete:
  443. return
  444. instrument_w_nvtx(self.allgather_handle.wait)()
  445. if self.quantization:
  446. instrument_w_nvtx(self.quantization.quant_handle.wait)()
  447. flat_tensor = self.quantization.backend.dequantize(
  448. self.quantization.quantized_param, self.quantization.scale_buffer).to(self.params[0].device)
  449. self.partitions: List[Parameter] = []
  450. for i in range(self.quantization.world_size):
  451. self.partitions.append(
  452. flat_tensor.narrow(0, self.quantization.partition_sz * i, self.quantization.partition_sz))
  453. # split the single tensor out into individual tensors
  454. param_offset = 0
  455. for param in self.params:
  456. assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
  457. partitions: List[Tensor] = []
  458. ds_tensor_numel = param.ds_tensor.ds_numel
  459. if self.use_secondary_tensor and not self.forward:
  460. ds_tensor_numel *= param.ds_secondary_tensor_num_of_groups
  461. for rank in range(self.world_size):
  462. param_start = rank * ds_tensor_numel
  463. if param_start < param.ds_numel:
  464. part_to_copy = self.partitions[rank].narrow(0, param_offset,
  465. min(param.ds_numel - param_start, ds_tensor_numel))
  466. partitions.append(part_to_copy)
  467. param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
  468. param.ds_status = ZeroParamStatus.AVAILABLE
  469. for part_to_copy in partitions:
  470. if not get_accelerator().is_synchronized_device():
  471. part_to_copy.record_stream(get_accelerator().current_stream())
  472. param_offset += ds_tensor_numel
  473. self.complete = True
  474. class QuantizationInfo:
  475. # a placeholder object to store all quant related vars used in handles
  476. def __init__(self) -> None:
  477. self.quantized_param = None
  478. self.backend = None
  479. self.quant_handle = None
  480. self.scale_buffer = None
  481. class CUDAQuantizer:
  482. async_flag = True
  483. target_group_size = 8000 # the optimal size is 4k, so we set the target to be below 8k
  484. group_size_cache = dict()
  485. def __init__(self):
  486. self.quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load()
  487. def quantize(self, param, groups=None):
  488. if groups is None:
  489. try:
  490. groups = self.group_size_cache[param.numel()]
  491. except KeyError:
  492. groups = math.ceil(param.numel() / self.target_group_size)
  493. while groups < param.numel():
  494. if param.numel() % (8 * groups) == 0:
  495. break
  496. groups += 1
  497. while True:
  498. if param.numel() % (8 * groups * 2) == 0 and param.numel(
  499. ) / groups > self.target_group_size: #hard limit of 16k group_size
  500. groups *= 2
  501. else:
  502. break
  503. assert (
  504. param.numel() % (8 * groups) == 0
  505. ), f"Qantized weight requires the number of weights be a multiple of 8. Yet {param.numel()} cannot be divided by 8*{groups}"
  506. assert (param.numel() / groups < 16000), f"{param.numel()} / {groups} is larger than 16k"
  507. assert param.numel(
  508. ) > groups, f"Adaptive grouping algorithm cannot find a group size for input tensor of size {param.numel()}"
  509. self.group_size_cache[param.numel()] = groups
  510. return self.quantizer_cuda_module.quantize(param.to(get_accelerator().device_name()), groups, 8,
  511. self.quantizer_cuda_module.Symmetric)
  512. def dequantize(self, quantized_param, scale):
  513. return self.quantizer_cuda_module.dequantize(quantized_param, scale, scale.numel(), 8,
  514. self.quantizer_cuda_module.Symmetric)
  515. def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandle:
  516. for param in params:
  517. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  518. raise RuntimeError(param.ds_summary())
  519. param.ds_status = ZeroParamStatus.INFLIGHT
  520. params = sorted(params, key=lambda p: p.ds_id)
  521. if len(params) == 1:
  522. param, = params
  523. return NoGatherHandle(param)
  524. return NoGatherCoalescedHandle(params)
  525. # Replaces all parameters in module with Scattered Parameters
  526. class Init(InsertPostInitMethodToModuleSubClasses):
  527. param_id = 0
  528. param_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "param_persistence_threshold")
  529. model_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "model_persistence_threshold")
  530. num_persisted_parameters = 0
  531. num_persisted_elements = 0
  532. apply_param_persistence = False
  533. override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply")
  534. def __init__(self,
  535. module=None,
  536. data_parallel_group=None,
  537. mem_efficient_linear=True,
  538. remote_device=None,
  539. pin_memory=False,
  540. config_dict_or_path=None,
  541. config=None,
  542. enabled=True,
  543. dtype=None,
  544. mpu=None,
  545. zero_param_parallel_group=None,
  546. zero_quantized_weights=False):
  547. """A context to enable massive model construction for training with
  548. ZeRO-3. Models are automatically partitioned (or, sharded) across the
  549. system and converted to half precision.
  550. Args:
  551. module (``torch.nn.Module``, optional): If provided, partition the model as
  552. if it was constructed in the context.
  553. data_parallel_group (``deepspeed.comm`` process group, optional):
  554. The group of processes to partition among. Defaults to all processes.
  555. mem_efficient_linear (bool, optional): Replace
  556. torch.nn.functional.linear with an implementation that allows
  557. DeepSpeed to partition parameters. Defaults to ``True``.
  558. remote_device (string, optional): The initial device to store model
  559. weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
  560. memory. The model may still be moved to GPU based on the
  561. offload settings for training. Defaults to param offload device if a config is
  562. defined, otherwise GPU.
  563. pin_memory (bool, optional): Potentially increase performance by
  564. using pinned memory for model weights. ``remote_device`` must be
  565. ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``.
  566. config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
  567. for swapping fp16 params to NVMe.
  568. config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
  569. enabled (bool, optional): If ``False``, this context has no
  570. effect. Defaults to ``True``.
  571. dtype (``dtype``, optional): Can be used to change the data type of the parameters.
  572. Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
  573. mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
  574. zero_param_parallel_group(``object``, optional): Parallel (comm) group for dual partitioning of ZeRO params.
  575. zero_quantized_weights (bool, optional): If ``True``, turn on quantized weights in all gather weights. Default is ``False``
  576. This context accelerates model initialization and enables models that
  577. are too large to allocate in their entirety in CPU memory. It has the
  578. following effects:
  579. #. allocates tensors to either GPU or CPU memory or NVMe
  580. #. converts floating point tensors to half precision
  581. #. immediately partitions tensors among the group of data-parallel devices
  582. #. (*optional*) replaces ``torch.nn.functional.linear`` with a more
  583. memory-efficient implementation
  584. These modifications allow for models that exceed the size of local CPU/GPU
  585. memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
  586. or GPU memory or NVMe) across all nodes. Consider initializing a model with one
  587. trillion parameters, whose weights occupy two terabytes (TB) in half
  588. precision. The initial CPU allocation in full precision requires 4TB of
  589. memory *per process*, and so a system with 8 GPUs per node would need 32TB of
  590. CPU memory due to data-parallel redundancies. Instead, by immediately
  591. partitioning tensors we remove the redundancies. The result is that
  592. regardless of the number of GPUs, we still only require the original 4TB. This
  593. allows for a linear increase in model size with the aggregate system memory.
  594. For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
  595. parameter model with 4 nodes and 32 GPUs.
  596. Important: If the fp16 weights of the model can't fit onto a single GPU memory
  597. this feature must be used.
  598. .. note::
  599. Initializes ``deepspeed.comm`` if it has not already been done so.
  600. See :meth:`deepspeed.init_distributed` for more information.
  601. .. note::
  602. Only applicable to training with ZeRO-3.
  603. Examples
  604. --------
  605. #. Allocate a model and partition it among all processes:
  606. .. code-block:: python
  607. with deepspeed.zero.Init():
  608. model = MyLargeModel()
  609. #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:
  610. .. code-block:: python
  611. with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
  612. remote_device="cpu",
  613. pin_memory=True):
  614. model = MyLargeModel()
  615. #. Partition an already-allocated model in CPU memory:
  616. .. code-block:: python
  617. model = deepspeed.zero.Init(module=model)
  618. """
  619. if config is not None:
  620. config_dict_or_path = config
  621. logger.warning(
  622. f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.')
  623. _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path,
  624. mpu) if config_dict_or_path is not None else None
  625. if _ds_config is not None:
  626. mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear
  627. super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype)
  628. if not dist.is_initialized():
  629. init_distributed()
  630. assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
  631. if data_parallel_group is None:
  632. self.ds_process_group = dist.get_world_group()
  633. else:
  634. self.ds_process_group = data_parallel_group
  635. self.rank = dist.get_rank(group=self.ds_process_group)
  636. self.dp_world_size = dist.get_world_size(group=self.ds_process_group)
  637. self.zero_param_process_group = zero_param_parallel_group
  638. if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1 and self.zero_param_process_group is None:
  639. groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size)
  640. self.zero_param_process_group = groups._get_zero_param_intra_parallel_group()
  641. self.num_ranks_in_param_group = self.dp_world_size
  642. self.rank_in_group = self.rank
  643. self.num_param_groups = 1
  644. if self.zero_param_process_group is not None:
  645. self.num_ranks_in_param_group = groups._get_zero_param_intra_parallel_group_world_size()
  646. self.num_param_groups = int(self.dp_world_size / self.num_ranks_in_param_group)
  647. self.rank_in_group = groups._get_zero_param_intra_parallel_rank_in_mygroup()
  648. print_rank_0(f"hpZeRO group size? {self.num_ranks_in_param_group}", force=True)
  649. logger.debug(
  650. "hpZeRO partition parameter my rank in world {} my rank in group {} ranks in my param partition group: {} "
  651. .format(self.rank, self.rank_in_group, groups._get_zero_param_intra_parallel_group_ranks()))
  652. # Local device is the device where the parameters are consumed, must be default device.
  653. # It is the device where parameters are fully instantiated using allgather
  654. self.local_device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
  655. get_accelerator().set_device(self.local_device)
  656. self.quantized_weights = zero_quantized_weights
  657. if _ds_config is not None and _ds_config.zero_config.zero_quantized_weights and not self.quantized_weights:
  658. self.quantized_weights = _ds_config.zero_config.zero_quantized_weights
  659. self.module = module
  660. if (self.quantized_weights):
  661. self.quantizer_module = CUDAQuantizer()
  662. print_rank_0(f'Using quantizer: {self.quantizer_module.__class__.__name__}', force=True)
  663. if _ds_config is not None:
  664. Init.override_module_apply = _ds_config.zero_config.override_module_apply
  665. if _ds_config.zero_config.offload_param is not None:
  666. remote_device = _ds_config.zero_config.offload_param.device
  667. pin_memory = _ds_config.zero_config.offload_param.pin_memory
  668. self._validate_remote_device(remote_device, _ds_config)
  669. # Remote device is the device where parameter partitions are stored
  670. # It can be same as local_device or it could be CPU or NVMe.
  671. self.remote_device = self.local_device if remote_device in [None, OffloadDeviceEnum.none] else remote_device
  672. self.pin_memory = pin_memory if (self.remote_device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme
  673. ]) else False
  674. # Enable fp16 param swapping to NVMe
  675. if self.remote_device == OffloadDeviceEnum.nvme:
  676. self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config, self.dtype)
  677. else:
  678. self.param_swapper = None
  679. # If we are provided an already-allocated module to prepare.
  680. if module is not None:
  681. assert isinstance(module, torch.nn.Module)
  682. self._convert_to_zero_parameters(module.parameters(recurse=True))
  683. self.use_all_gather_into_tensor = dist.has_all_gather_into_tensor()
  684. if not self.use_all_gather_into_tensor:
  685. logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}")
  686. def _update_persist_config(self, ds_config):
  687. Init.apply_param_persistence = True
  688. Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold
  689. Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.num_partitions
  690. def _zero_init_param(self, param):
  691. self._convert_to_deepspeed_param(param)
  692. if dist.get_world_group() == self.get_dp_process_group():
  693. dist.broadcast(param, 0, self.get_dp_process_group())
  694. else:
  695. dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0), self.get_dp_process_group())
  696. param.partition()
  697. def _convert_to_zero_parameters(self, param_list):
  698. for param in param_list:
  699. if is_zero_param(param):
  700. continue
  701. param.data = param.data.to(self.local_device)
  702. self._zero_init_param(param)
  703. def _validate_remote_device(self, remote_device, ds_config):
  704. if ds_config is not None:
  705. if remote_device in [None, OffloadDeviceEnum.cpu]:
  706. if ds_config.zero_config.offload_param is not None:
  707. offload_param_device = ds_config.zero_config.offload_param.device
  708. assert offload_param_device != OffloadDeviceEnum.nvme, \
  709. f"'device' in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."
  710. if remote_device == OffloadDeviceEnum.nvme:
  711. assert ds_config.zero_config.offload_param is not None, \
  712. f'"offload_param" must be defined in DeepSpeed Config if remote device is {OffloadDeviceEnum.nvme}.'
  713. assert ds_config.zero_config.offload_param.nvme_path is not None, \
  714. f'"nvme_path" in DeepSpeed Config cannot be None if remote device is {OffloadDeviceEnum.nvme}'
  715. def _post_init_method(self, module):
  716. #see_memory_usage(f"Before converting params in {module.__class__.__name__}", force=False)
  717. print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
  718. see_memory_usage(f"Before converting and partitioning params in {module.__class__.__name__}", force=False)
  719. for name, param in module.named_parameters(recurse=False):
  720. print_rank_0(f'Analyzing param {name} in {module.__class__.__name__}', force=False)
  721. InsertPostInitMethodToModuleSubClasses.num_module_parameters += 1
  722. InsertPostInitMethodToModuleSubClasses.num_module_elements += param.numel()
  723. if not is_zero_param(param):
  724. if not get_accelerator().on_accelerator(param):
  725. param.data = param.data.to(self.local_device)
  726. self._zero_init_param(param)
  727. print_rank_0(
  728. f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}")
  729. see_memory_usage(
  730. f"Param count {InsertPostInitMethodToModuleSubClasses.num_module_elements}. After converting and partitioning params in {module.__class__.__name__}",
  731. force=False)
  732. def _convert_to_deepspeed_param(self, param):
  733. # Partitioned, Normal, Remote
  734. param.ds_param_type = ZeroParamType.PARTITIONED
  735. # Replicated vs Partitioned vs Inflight
  736. param.ds_status = ZeroParamStatus.AVAILABLE
  737. # Stores the shape of the original tensor
  738. param.ds_shape = param.shape
  739. # Stores the number of elements in the original parameter without padding
  740. param.ds_numel = param.numel()
  741. # Stores the partitioned copy of the tensor
  742. param.ds_tensor = None
  743. # Keeps track of how many active sub-modules need this param at any given point in time
  744. param.ds_active_sub_modules = set()
  745. # If this flag is true, then the parameters are replicated throughput training
  746. # And only partitioned before the step
  747. if Init.apply_param_persistence and param.ds_numel <= Init.param_persistence_threshold and Init.num_persisted_elements + param.ds_numel <= Init.model_persistence_threshold:
  748. param.ds_persist = True
  749. Init.num_persisted_parameters += 1
  750. Init.num_persisted_elements += param.ds_numel
  751. else:
  752. param.ds_persist = False
  753. param.is_external_param = False
  754. # The group that the parameter is scattered across.
  755. param.ds_process_group = self.ds_process_group
  756. # Stores the secondary partitioned copy of the tensor
  757. param.ds_secondary_tensor = None
  758. #Process group for secondary partition all (group) gather
  759. param.ds_zero_param_process_group = self.zero_param_process_group
  760. param.ds_secondary_tensor_group_size = self.num_ranks_in_param_group
  761. param.ds_secondary_tensor_num_of_groups = self.num_param_groups
  762. # This is set to the Async Param swapper if remote device is nvme
  763. # else this is set to None
  764. param.nvme_swapper = self.param_swapper
  765. # DeepSpeed Param ID
  766. param.ds_id = Init.param_id
  767. Init.param_id += 1
  768. def all_gather(param_list=None, async_op=False, hierarchy=0):
  769. cls = param
  770. if param_list is None:
  771. param_list = [cls]
  772. return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
  773. @instrument_w_nvtx
  774. def all_gather_coalesced(params: Iterable[Parameter],
  775. forward: bool,
  776. safe_mode: bool = False) -> AllGatherCoalescedHandle:
  777. # fetches from nvme if the partition is not available and in nvme
  778. self._ensure_availability_of_partitioned_params(params)
  779. quant = self.quantized_weights
  780. if self.module is not None and self.module.training is False:
  781. quant = False
  782. if self.num_partitions == 1:
  783. return _no_gather_coalesced(params)
  784. for param in params:
  785. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  786. raise RuntimeError(param.ds_summary())
  787. param.ds_status = ZeroParamStatus.INFLIGHT
  788. #use appropriate all gather process group
  789. ds_process_group = self.ds_process_group
  790. rank_in_group = self.rank
  791. world_size = self.dp_world_size
  792. use_secondary_tensor = False
  793. if self.zero_param_process_group and not forward:
  794. ds_process_group = self.zero_param_process_group #intragroup
  795. rank_in_group = self.rank_in_group
  796. world_size = self.num_ranks_in_param_group
  797. #pprint(dir(ds_process_group))
  798. # ensure that each rank has params in same order. the allgather
  799. # is done by flattening the parameter list into a single tensor that
  800. # can be allgathered in a single call - this means that if each rank
  801. # gives a list of the same parameters in a different order we will
  802. # silently get incorrect parameter values, and have very difficult
  803. # to debug correctness issues.
  804. params = sorted(params, key=lambda p: p.ds_id)
  805. if logger.isEnabledFor(logging.DEBUG):
  806. debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")
  807. if safe_mode:
  808. # ensure that same list (with same ordering) of parameters are
  809. # being allgathered across all ranks, otherwise could mix
  810. # data between tensors.
  811. assert_ints_same_as_other_ranks([p.ds_id for p in params])
  812. # ensure that tensors from each rank agree on the same ds_numel
  813. # otherwise could mix data between tensors.
  814. assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params])
  815. if len(params) == 1:
  816. # have an opportunity to avoid some intermediate memory allocations
  817. param, = params
  818. buffer_size = math.ceil(param.ds_numel / world_size) * world_size
  819. if not forward and param.ds_secondary_tensor is not None:
  820. buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized
  821. param_buffer = torch.empty(
  822. buffer_size,
  823. dtype=param.dtype if not quant else torch.int8,
  824. device=get_accelerator().current_device_name(),
  825. requires_grad=False,
  826. )
  827. param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor
  828. if not quant:
  829. handles = _dist_allgather_fn(
  830. param_ds_tensor.to(get_accelerator().current_device_name()),
  831. param_buffer,
  832. ds_process_group,
  833. )
  834. param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
  835. return AllGatherHandle(handles, param)
  836. else:
  837. quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor)
  838. handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device_name()),
  839. param_buffer, ds_process_group)
  840. quant_scale_buffer = torch.empty(
  841. scales.numel() * world_size,
  842. dtype=torch.float32,
  843. device=get_accelerator().current_device_name(),
  844. requires_grad=False,
  845. )
  846. quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device_name()),
  847. quant_scale_buffer, ds_process_group)
  848. quant_info = QuantizationInfo()
  849. quant_info.quantized_param = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(
  850. param.device)
  851. quant_info.backend = self.quantizer_module
  852. quant_info.quant_handle = quant_handle
  853. quant_info.scale_buffer = quant_scale_buffer
  854. return AllGatherHandle(handle, param, quantization=quant_info)
  855. else:
  856. partition_sz = sum(p.ds_tensor.ds_numel for p in params)
  857. if params[0].ds_secondary_tensor is not None and not forward:
  858. partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)
  859. flat_tensor = torch.empty(partition_sz * world_size,
  860. dtype=get_only_unique_item(p.dtype
  861. for p in params) if not quant else torch.int8,
  862. device=get_accelerator().current_device_name(),
  863. requires_grad=False)
  864. if not quant:
  865. partitions: List[Parameter] = []
  866. for i in range(world_size):
  867. partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))
  868. if params[0].ds_secondary_tensor is not None and not forward:
  869. use_secondary_tensor = True
  870. instrument_w_nvtx(torch.cat)(
  871. [p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
  872. out=partitions[rank_in_group])
  873. else:
  874. instrument_w_nvtx(
  875. torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
  876. out=partitions[rank_in_group])
  877. handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
  878. #Fix get_partition_dp_group(params[0]))
  879. return AllGatherCoalescedHandle(
  880. allgather_handle=handle,
  881. params=params,
  882. partitions=partitions,
  883. world_size=world_size,
  884. use_secondary_tensor=use_secondary_tensor,
  885. forward=forward,
  886. )
  887. else:
  888. if params[0].ds_secondary_tensor is not None and not forward:
  889. use_secondary_tensor = True
  890. quantized_param, scales = self.quantizer_module.quantize(
  891. instrument_w_nvtx(torch.cat)(
  892. [p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params]))
  893. else:
  894. quantized_param, scales = self.quantizer_module.quantize(
  895. instrument_w_nvtx(
  896. torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params]))
  897. handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
  898. quant_info = QuantizationInfo()
  899. quant_scale_buffer = torch.empty(
  900. scales.numel() * world_size,
  901. dtype=torch.float32,
  902. device=get_accelerator().current_device_name(),
  903. requires_grad=False,
  904. )
  905. quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
  906. quant_info.quantized_param = flat_tensor
  907. quant_info.backend = self.quantizer_module
  908. quant_info.quant_handle = quant_handle
  909. quant_info.scale_buffer = quant_scale_buffer
  910. quant_info.partition_sz = partition_sz
  911. quant_info.world_size = world_size
  912. return AllGatherCoalescedHandle(
  913. allgather_handle=handle,
  914. params=params,
  915. partitions=None,
  916. world_size=world_size,
  917. use_secondary_tensor=use_secondary_tensor,
  918. forward=forward,
  919. quantization=quant_info,
  920. )
  921. def partition(param_list=None, backward=False, hierarchy=0, has_been_updated=False):
  922. cls = param
  923. print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
  924. force=False)
  925. if param_list is None:
  926. param_list = [cls]
  927. self._partition(param_list, has_been_updated=has_been_updated)
  928. def reduce_gradients_at_owner(param_list=None, hierarchy=0):
  929. cls = param
  930. if param_list is None:
  931. param_list = [cls]
  932. print_rank_0(
  933. f"{'--'*hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner"
  934. )
  935. self._reduce_scatter_gradients(param_list)
  936. def partition_gradients(param_list=None, partition_buffers=None, hierarchy=0, accumulate=False):
  937. cls = param
  938. print_rank_0(
  939. f"{'--'*hierarchy}----Partitioning param gradient with id {debug_param2name_id_shape_device(cls)}")
  940. if param_list is None:
  941. param_list = [cls]
  942. if isinstance(partition_buffers, torch.Tensor):
  943. partition_buffers = [partition_buffers]
  944. self._partition_gradients(param_list, partition_buffers=partition_buffers, accumulate=accumulate)
  945. def aligned_size():
  946. return self._aligned_size(param)
  947. def padding_size():
  948. return self._padding_size(param)
  949. def partition_numel():
  950. return self._partition_numel(param)
  951. def item_override():
  952. param.all_gather()
  953. return param._orig_item()
  954. def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict:
  955. return {
  956. "id": debug_param2name_id(slf) if use_debug_name else slf.ds_id,
  957. "status": slf.ds_status.name,
  958. "numel": slf.numel(),
  959. "ds_numel": slf.ds_numel,
  960. "shape": tuple(slf.shape),
  961. "ds_shape": tuple(slf.ds_shape),
  962. "requires_grad": slf.requires_grad,
  963. "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None,
  964. "persist": slf.ds_persist,
  965. "active_sub_modules": slf.ds_active_sub_modules,
  966. "ds_tensor.shape": slf.ds_tensor.shape if slf.ds_tensor is not None else None
  967. }
  968. def convert_to_zero_parameters(param_list):
  969. self._convert_to_zero_parameters(param_list)
  970. def allgather_before(func: Callable) -> Callable:
  971. def wrapped(*args, **kwargs):
  972. param.all_gather()
  973. return func(*args, **kwargs)
  974. return wrapped
  975. # Collectives for gathering and partitioning parameters
  976. param.all_gather = all_gather
  977. param.all_gather_coalesced = all_gather_coalesced
  978. param.partition = partition
  979. # Collective for averaging gradients
  980. param.reduce_gradients_at_owner = reduce_gradients_at_owner
  981. param.partition_gradients = partition_gradients
  982. # Partitioning size utilities
  983. param.aligned_size = aligned_size
  984. param.padding_size = padding_size
  985. param.partition_numel = partition_numel
  986. param.ds_summary = types.MethodType(ds_summary, param)
  987. param.item = allgather_before(param.item)
  988. param.convert_to_zero_parameters = convert_to_zero_parameters
  989. def _aligned_size(self, param):
  990. return param.ds_numel + self._padding_size(param)
  991. def _padding_size(self, param):
  992. remainder = param.ds_numel % self.num_partitions
  993. return (self.num_partitions - remainder) if remainder else 0
  994. def _partition_numel(self, param):
  995. return param.ds_tensor.ds_numel
  996. def _ensure_availability_of_partitioned_params(self, params):
  997. swap_in_list = []
  998. swap_in_flight = []
  999. for param in params:
  1000. if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
  1001. assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
  1002. swap_in_list.append(param)
  1003. if param.ds_tensor.status == PartitionedParamStatus.INFLIGHT:
  1004. assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
  1005. swap_in_flight.append(param)
  1006. if len(swap_in_list) > 0:
  1007. swap_in_list[0].nvme_swapper.swap_in(swap_in_list, async_op=False)
  1008. elif len(swap_in_flight) > 0:
  1009. swap_in_flight[0].nvme_swapper.synchronize_reads()
  1010. @instrument_w_nvtx
  1011. def _all_gather(self, param_list, async_op=False, hierarchy=None):
  1012. # fetches from nvme if the partition is not available and in nvme
  1013. self._ensure_availability_of_partitioned_params(param_list)
  1014. handles = []
  1015. all_gather_list = []
  1016. for param in param_list:
  1017. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  1018. if async_op:
  1019. handle = self._allgather_param(param, async_op=async_op, hierarchy=hierarchy)
  1020. param.ds_status = ZeroParamStatus.INFLIGHT # if async_op else ZeroParamStatus.AVAILABLE
  1021. handles.append(handle)
  1022. else:
  1023. all_gather_list.append(param)
  1024. if not async_op:
  1025. if len(param_list) == 1:
  1026. ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
  1027. else:
  1028. ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy)
  1029. for param in all_gather_list:
  1030. param.ds_status = ZeroParamStatus.AVAILABLE
  1031. return ret_value
  1032. return handles
  1033. def _partition(self, param_list, force=False, has_been_updated=False):
  1034. for param in param_list:
  1035. print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False)
  1036. if self.zero_param_process_group is not None:
  1037. self._partition_param_sec(param, has_been_updated=has_been_updated)
  1038. self._partition_param(param, has_been_updated=has_been_updated)
  1039. param.ds_status = ZeroParamStatus.NOT_AVAILABLE
  1040. # if param.ds_tensor is not None:
  1041. # assert id(param.data) == id(param.ds_tensor.data), \
  1042. # "After the parameters are initially partitioned, make sure we are not recreating the partition."
  1043. #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False)
  1044. @instrument_w_nvtx
  1045. def _partition_param(self, param, buffer=None, has_been_updated=False):
  1046. assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
  1047. global reuse_buffers
  1048. print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False)
  1049. if param.ds_status is ZeroParamStatus.AVAILABLE:
  1050. print_rank_0(f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}", force=False)
  1051. # if reuse_buffers and False:
  1052. # numel = buffer.numel()
  1053. # buffer = param.data.view(-1)
  1054. # print_rank_0(
  1055. # "Returning buffer for param {param.ds_id} with numel {param.ds_numel} to empty buffers",
  1056. # force=False)
  1057. # if numel in empty_buffers:
  1058. # empty_buffers[numel].append(buffer)
  1059. # if deepspeed.comm.get_rank():
  1060. # print(f"Releasing {param.data.numel()}")
  1061. if param.ds_tensor is not None and not has_been_updated: ##param already partitioned
  1062. #print_rank_0(f"Param {param.ds_id} pri {param.ds_tensor.size()} loc? {param.ds_tensor.final_location}", force=True)
  1063. #param.data = param.ds_tensor.data
  1064. see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
  1065. # param.data does not store anything meaningful in partitioned state
  1066. free_param(param)
  1067. see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)
  1068. if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
  1069. print_rank_0(f"Param {param.ds_id} partition released since it exists in nvme", force=False)
  1070. param.nvme_swapper.remove_partition_and_release_buffers([param])
  1071. print_rank_0(
  1072. f"after swap Param {param.ds_id} {param.ds_tensor.shape} partition released since it exists in nvme",
  1073. force=False)
  1074. return
  1075. tensor_size = self._aligned_size(param)
  1076. partition_size = tensor_size // self.num_partitions
  1077. if param.ds_tensor is None:
  1078. final_location = None
  1079. if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor(
  1080. numel=partition_size):
  1081. final_location = OffloadDeviceEnum.nvme
  1082. buffer = self.param_swapper.get_buffer(param, partition_size)
  1083. partitioned_tensor = torch.empty(0, dtype=param.dtype, device=buffer.device)
  1084. partitioned_tensor.data = buffer.data
  1085. print_rank_0(f"ID {param.ds_id} Initializing partition for the first time for nvme offload.")
  1086. else:
  1087. if param.ds_persist:
  1088. device = self.local_device
  1089. elif self.remote_device == OffloadDeviceEnum.nvme:
  1090. device = OffloadDeviceEnum.cpu
  1091. else:
  1092. device = self.remote_device
  1093. partitioned_tensor = torch.empty(partition_size, dtype=param.dtype, device=device)
  1094. if device == OffloadDeviceEnum.cpu and self.pin_memory:
  1095. partitioned_tensor = get_accelerator().pin_memory(partitioned_tensor)
  1096. partitioned_tensor.requires_grad = False
  1097. param.ds_tensor = partitioned_tensor
  1098. param.ds_tensor.ds_numel = partition_size
  1099. param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
  1100. param.ds_tensor.final_location = final_location
  1101. start = partition_size * self.get_partition_rank()
  1102. end = start + partition_size
  1103. one_dim_param = param.contiguous().view(-1)
  1104. if start < param.ds_numel and end <= param.ds_numel:
  1105. src_tensor = one_dim_param.narrow(0, start, partition_size)
  1106. param.ds_tensor.copy_(src_tensor)
  1107. #partitioned_tensor = src_tensor.clone().detach().to(self.remote_device)
  1108. else:
  1109. # partitioned_tensor = torch.zeros(partition_size,
  1110. # dtype=param.dtype,
  1111. # device=self.remote_device )
  1112. if start < param.ds_numel:
  1113. elements_to_copy = param.ds_numel - start
  1114. param.ds_tensor.narrow(0, 0,
  1115. elements_to_copy).copy_(one_dim_param.narrow(0, start, elements_to_copy))
  1116. #print(f"Remote device {self.remote_device}")
  1117. #param.ds_tensor = partitioned_tensor
  1118. #param.data = param.ds_tensor.data
  1119. # param.data does not store anything meaningful in partitioned state
  1120. see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
  1121. free_param(param)
  1122. see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)
  1123. if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
  1124. self.param_swapper.swap_out_and_release([param])
  1125. print_rank_0(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.")
  1126. see_memory_usage(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.", force=False)
  1127. print_rank_0(f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}")
  1128. @instrument_w_nvtx
  1129. def _partition_param_sec(self, param, buffer=None, has_been_updated=False):
  1130. assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
  1131. global reuse_buffers
  1132. ##support for NVME secondary param offload
  1133. #print_rank_0(f"SEC Param id {param.ds_id} status is {param.ds_status}", force=True)
  1134. if param.ds_status is ZeroParamStatus.AVAILABLE:
  1135. if param.ds_secondary_tensor is not None and not has_been_updated: ##param already partitioned
  1136. return
  1137. #check padding
  1138. tensor_size = self._aligned_size(param)
  1139. partition_size = tensor_size // self.dp_world_size
  1140. secondary_partition_size = int(tensor_size // self.num_ranks_in_param_group)
  1141. if param.ds_secondary_tensor is None:
  1142. final_location = None
  1143. secondary_partitioned_tensor = torch.empty(secondary_partition_size,
  1144. dtype=param.dtype,
  1145. device=self.remote_device)
  1146. if self.pin_memory:
  1147. secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory()
  1148. secondary_partitioned_tensor.requires_grad = False
  1149. param.ds_secondary_tensor = secondary_partitioned_tensor
  1150. param.ds_secondary_tensor.ds_numel = secondary_partition_size
  1151. param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE
  1152. param.ds_secondary_tensor.final_location = final_location
  1153. #use rank in group for secondary tensor
  1154. secondary_start = secondary_partition_size * self.rank_in_group
  1155. secondary_end = secondary_start + secondary_partition_size
  1156. one_dim_param = param.contiguous().view(-1)
  1157. start = partition_size * self.rank
  1158. end = start + partition_size
  1159. if start < param.ds_numel and end <= param.ds_numel:
  1160. if secondary_start < param.ds_numel and secondary_end <= param.ds_numel:
  1161. sec_src_tensor = one_dim_param.narrow(0, secondary_start, secondary_partition_size)
  1162. param.ds_secondary_tensor.copy_(sec_src_tensor)
  1163. else:
  1164. if start < param.ds_numel:
  1165. elements_to_copy = param.ds_numel - start
  1166. elements_to_copy_sec = elements_to_copy * param.ds_secondary_tensor_num_of_groups
  1167. param.ds_secondary_tensor.narrow(0, 0, elements_to_copy_sec).copy_(
  1168. one_dim_param.narrow(0, secondary_start, elements_to_copy_sec))
  1169. print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}",
  1170. force=False)
  1171. def _param_status(self, param):
  1172. if param.ds_tensor is not None:
  1173. print_rank_0(
  1174. f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned numel {param.ds_tensor.numel()}, data numel {param.data.numel()}"
  1175. )
  1176. else:
  1177. print_rank_0(
  1178. f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned ds_tensor {param.ds_tensor}, data numel {param.data.numel()}"
  1179. )
  1180. def _allgather_param(self, param, async_op=False, hierarchy=0):
  1181. partition_size = param.ds_tensor.ds_numel
  1182. tensor_size = partition_size * self.num_partitions
  1183. aligned_param_size = self._aligned_size(param)
  1184. assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}'
  1185. print_rank_0(
  1186. f"{'--'* hierarchy}---- Before allocating allgather param {debug_param2name_id_shape_status(param)} partition size={partition_size}"
  1187. )
  1188. see_memory_usage(
  1189. f'Before allocate allgather param {debug_param2name_id_shape_status(param)} partition_size={partition_size} ',
  1190. force=False)
  1191. flat_tensor = torch.zeros(aligned_param_size, dtype=param.dtype, device=param.device).view(-1)
  1192. see_memory_usage(
  1193. f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
  1194. force=False)
  1195. get_accelerator().synchronize()
  1196. print_rank_0(
  1197. f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
  1198. )
  1199. # if not flat_tensor.numel() > 100000:
  1200. # replicated_tensor = flat_tensor.narrow(0,
  1201. # 0,
  1202. # param.ds_numel).view(param.ds_shape)
  1203. # param.data = replicated_tensor.data
  1204. # return None
  1205. if self.use_all_gather_into_tensor:
  1206. handle = dist.all_gather_into_tensor(flat_tensor,
  1207. param.ds_tensor.to(get_accelerator().device_name()),
  1208. group=self.get_partition_dp_group(param),
  1209. async_op=async_op)
  1210. else:
  1211. partitions = []
  1212. for i in range(self.num_partitions):
  1213. partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size))
  1214. if i == dist.get_rank(group=self.get_partition_dp_group(param)):
  1215. partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
  1216. handle = dist.all_gather(partitions,
  1217. partitions[self.get_partition_rank()],
  1218. group=self.get_partition_dp_group(param),
  1219. async_op=async_op)
  1220. replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
  1221. param.data = replicated_tensor.data
  1222. return handle
  1223. def _allgather_params_coalesced(self, param_list, hierarchy=0):
  1224. """ blocking call
  1225. avoid explicit memory copy in _allgather_params
  1226. """
  1227. if len(param_list) == 0:
  1228. return
  1229. if self.num_partitions == 1:
  1230. handle = _no_gather_coalesced(param_list)
  1231. handle.wait()
  1232. return None
  1233. # collect local tensors and partition sizes
  1234. partition_sizes = []
  1235. local_tensors = []
  1236. for param in param_list:
  1237. partition_sizes.append(param.ds_tensor.ds_numel)
  1238. local_tensors.append(param.ds_tensor.to(get_accelerator().device_name()))
  1239. # allocate memory for allgather params
  1240. allgather_params = []
  1241. for psize in partition_sizes:
  1242. tensor_size = psize * self.num_partitions
  1243. flat_tensor = torch.empty(tensor_size, dtype=param_list[0].dtype, device=self.local_device).view(-1)
  1244. flat_tensor.requires_grad = False
  1245. allgather_params.append(flat_tensor)
  1246. # launch
  1247. launch_handles = []
  1248. for param_idx, param in enumerate(param_list):
  1249. input_tensor = local_tensors[param_idx].view(-1)
  1250. if self.use_all_gather_into_tensor:
  1251. # try the _all_gather_base from Pytorch master
  1252. h = dist.all_gather_into_tensor(allgather_params[param_idx],
  1253. input_tensor,
  1254. group=self.get_partition_dp_group(param),
  1255. async_op=True)
  1256. else:
  1257. output_list = []
  1258. for i in range(self.num_partitions):
  1259. psize = partition_sizes[param_idx]
  1260. partition = allgather_params[param_idx].narrow(0, i * psize, psize)
  1261. output_list.append(partition)
  1262. if not get_accelerator().on_accelerator(partition):
  1263. logger.warning(
  1264. f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}')
  1265. # back to old all_gather function
  1266. h = dist.all_gather(output_list, input_tensor, group=self.get_partition_dp_group(param), async_op=True)
  1267. launch_handles.append(h)
  1268. # Wait ensures the operation is enqueued, but not necessarily complete.
  1269. launch_handles[-1].wait()
  1270. # assign to param.data (not copy)
  1271. for i, param in enumerate(param_list):
  1272. gathered_tensor = allgather_params[i]
  1273. param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data
  1274. # guarantee the communication to be completed
  1275. get_accelerator().synchronize()
  1276. return None
  1277. def _allgather_params(self, param_list, hierarchy=0):
  1278. if len(param_list) == 0:
  1279. return
  1280. partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
  1281. tensor_size = partition_size * self.num_partitions
  1282. flat_tensor = torch.empty(tensor_size, dtype=param_list[0].dtype, device=self.local_device)
  1283. flat_tensor.requires_grad = False
  1284. partitions = []
  1285. for i in range(self.num_partitions):
  1286. start = partition_size * i
  1287. partitions.append(flat_tensor.narrow(0, start, partition_size))
  1288. if i == self.get_partition_rank():
  1289. offset = 0
  1290. for param in param_list:
  1291. param_numel = param.ds_tensor.ds_numel
  1292. partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data)
  1293. offset += param_numel
  1294. dist.all_gather(partitions,
  1295. partitions[self.get_partition_rank()],
  1296. group=self.get_partition_dp_group(param),
  1297. async_op=False)
  1298. param_offset = 0
  1299. for param in param_list:
  1300. param_partition_size = param.ds_tensor.ds_numel
  1301. param_size = param.ds_numel
  1302. replicated_tensor = torch.empty(param.ds_shape, dtype=param.dtype, device=self.local_device)
  1303. for i in range(self.num_partitions):
  1304. start = i * partition_size
  1305. param_start = i * param_partition_size
  1306. if param_start < param_size:
  1307. numel_to_copy = min(param_size - param_start, param_partition_size)
  1308. part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
  1309. replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy)
  1310. #param_offset += param.data.numel()
  1311. param_offset += param.ds_tensor.ds_numel
  1312. param.data = replicated_tensor.data
  1313. return None
  1314. def _reduce_scatter_gradients(self, param_list):
  1315. #print_rank_0([param.grad for param in param_list])
  1316. #assert any([param.grad is None for param in param_list]), "None gradients cannot be reduce scattered"
  1317. handles_and_reduced_partitions = []
  1318. for param in param_list:
  1319. assert param.grad.numel(
  1320. ) == param.ds_numel, f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter gradients whose size is not same as the params"
  1321. handles_and_reduced_partitions.append(self._reduce_scatter_gradient(param))
  1322. for param, (handle, reduced_partition) in zip(param_list, handles_and_reduced_partitions):
  1323. if handle is not None:
  1324. handle.wait()
  1325. # some ranks may have partitions that are padded to go beyond the grad size.
  1326. # For these ranks the output of reduce scatter is a separate buffer and needs
  1327. # to be copied in
  1328. partition_size = param.ds_tensor.ds_numel
  1329. start = self.get_partition_rank() * partition_size
  1330. end = start + partition_size
  1331. #print_rank_0("REduce scatter was executed for param {param.ds_id}")
  1332. if start < param.ds_numel < end:
  1333. elements = param.ds_numel - start
  1334. param.grad.view(-1).narrow(0, start, elements).copy_(reduced_partition.narrow(0, 0, elements))
  1335. def _reduce_scatter_gradient(self, param):
  1336. partition_size = param.ds_tensor.ds_numel
  1337. #output = torch.empty(partition_size, dtype=param.dtype, device=param.device)
  1338. total_size = partition_size * self.num_partitions
  1339. input_list = []
  1340. for i in range(self.num_partitions):
  1341. start = i * partition_size
  1342. end = start + partition_size
  1343. #print("before reduce scatter gradients")
  1344. if start < param.ds_numel and end <= param.ds_numel:
  1345. input = param.grad.view(-1).narrow(0, start, partition_size)
  1346. else:
  1347. input = torch.zeros(partition_size, dtype=param.dtype, device=param.device)
  1348. if start < param.ds_numel:
  1349. elements = param.ds_numel - start
  1350. input.narrow(0, 0, elements).copy_(param.grad.view(-1).narrow(0, start, elements))
  1351. #print("after reduce scatter gradients")
  1352. input_list.append(input)
  1353. rank = dist.get_rank(group=self.get_partition_dp_group(param))
  1354. handle = dist.reduce_scatter(input_list[rank],
  1355. input_list,
  1356. group=self.get_partition_dp_group(param),
  1357. async_op=True)
  1358. return handle, input_list[rank]
  1359. def _partition_gradients(self, param_list, partition_buffers=None, accumulate=False):
  1360. if partition_buffers is None:
  1361. partition_buffers = [None] * len(param_list)
  1362. for param, partition_buffer in zip(param_list, partition_buffers):
  1363. self._partition_gradient(param, partition_buffer=partition_buffer, accumulate=accumulate)
  1364. def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
  1365. #import pdb;pdb.set_trace()
  1366. # param.grad=None
  1367. # param.grad.test()
  1368. print_rank_0(
  1369. f"Partitioning param {param.ds_id} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.ds_numel}"
  1370. )
  1371. see_memory_usage("Before partitioning gradients", force=False)
  1372. partition_size = param.ds_tensor.ds_numel
  1373. if partition_buffer is None:
  1374. assert not accumulate, "No buffer to accumulate to"
  1375. partition_buffer = torch.zeros(partition_size, dtype=param.dtype, device=param.device)
  1376. else:
  1377. assert partition_buffer.numel(
  1378. ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
  1379. rank = dist.get_rank(group=self.get_partition_dp_group(param))
  1380. start = partition_size * rank
  1381. end = start + partition_size
  1382. dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size)
  1383. #print("before partition gradients")
  1384. if start < param.ds_numel:
  1385. elements = min(param.ds_numel - start, partition_size)
  1386. dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
  1387. src_tensor = param.grad.view(-1).narrow(0, start, elements)
  1388. # just copy the grad partition to the buffer
  1389. if not accumulate:
  1390. dest_tensor.copy_(src_tensor)
  1391. # if source and destination are on same device,
  1392. # add to the provided buffer
  1393. elif src_tensor.device == dest_tensor.device:
  1394. dest_tensor.add_(src_tensor)
  1395. # if source and destination are on different device, copy first to src
  1396. # then add and move back to the destination. This seems to run faster
  1397. # when src is gpu and dest is cpu
  1398. # adding directly to cpu is very slow
  1399. else:
  1400. acc_tensor = torch.empty(src_tensor.numel(), dtype=param.dtype, device=param.device)
  1401. acc_tensor.copy_(dest_tensor)
  1402. acc_tensor.add_(src_tensor)
  1403. dest_tensor.copy_(acc_tensor)
  1404. # partition_buffer.view(-1).narrow(
  1405. # 0,
  1406. # 0,
  1407. # elements).copy_(param.grad.view(-1).narrow(0,
  1408. # start,
  1409. # elements))
  1410. #print("after partition gradients")
  1411. param.grad.data = dest_tensor_full_buffer.data
  1412. see_memory_usage("After partitioning gradients", force=False)
  1413. def get_partition_dp_group(self, param):
  1414. return param.ds_process_group
  1415. def get_partition_rank(self):
  1416. """subclass can overload to specify different relative rank in
  1417. parameter partition group"""
  1418. return self.rank
  1419. @property
  1420. def num_partitions(self):
  1421. return self.dp_world_size
  1422. def get_dp_process_group(self):
  1423. """ Return the communication group with all data-parallel ranks """
  1424. return self.ds_process_group
  1425. class GatheredParameters:
  1426. def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True):
  1427. """A context that collects parameters that were partitioned via a
  1428. :class:`deepspeed.zero.Init` context. The parameters are partitioned
  1429. again upon exit.
  1430. Args:
  1431. params (``torch.nn.Parameter``): A single parameter, or an iterable of parameters (list, tuple, generator) of parameters to collect.
  1432. It's assumed that all parameters are zero params.
  1433. modifier_rank (int, optional): If specified, this rank's parameter will be
  1434. broadcasted on exit from the context. This argument is required if ``params`` are
  1435. modified, so that all processes have a consistent view of the data. Defaults
  1436. to ``None``.
  1437. fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be
  1438. registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
  1439. enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.
  1440. Important: Make sure to use ``modifier_rank`` that is not ``None`` (e.g., ``modifier_rank=0``)
  1441. if you need the GPU memory allocated by gather to be released upon exit from the context manager.
  1442. Important: if ``params`` isn't an iterable of parameters or a single parameter it'll be silently ignored!
  1443. Examples
  1444. ========
  1445. #. Allocate a partitioned module, initialize its weight on rank 0, and update all
  1446. processes.
  1447. .. code-block:: python
  1448. with deepspeed.zero.Init():
  1449. linear = torch.nn.Linear(1000,1000)
  1450. with deepspeed.zero.GatheredParameters(linear.weight,
  1451. modifier_rank=0):
  1452. if deepspeed.comm.get_rank() == 0:
  1453. linear.weight.zero_()
  1454. with deepspeed.zero.GatheredParameters(linear.weight,
  1455. modifier_rank=0):
  1456. if deepspeed.comm.get_rank() == 0:
  1457. linear.weight.zero_()
  1458. #. Collect a partitioned weight to pass to another module during
  1459. training. The parameter will be registered as an external parameter
  1460. and made available during the backward pass.
  1461. .. code-block:: python
  1462. :emphasize-lines: 6
  1463. def forward(self, input):
  1464. x = self.layer1(input)
  1465. # self.layer1.weight is required by self.layer2.forward
  1466. with deepspeed.zero.GatheredParameters(self.layer1.weight,
  1467. fwd_module=self):
  1468. y = self.layer2(x, self.layer1.weight)
  1469. return y
  1470. #. Pretrained model loading
  1471. .. code-block:: python
  1472. with deepspeed.zero.Init():
  1473. model = MyModel()
  1474. state_dict = torch.load(model_path, map_location="cpu")
  1475. def load(module: nn.Module, prefix=""):
  1476. # because zero3 puts placeholders in model params, this context
  1477. # manager gathers (unpartitions) the params of the current layer, then loads from
  1478. # the state dict and then re-partitions them again
  1479. with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
  1480. if deepspeed.comm.get_rank() == 0:
  1481. module._load_from_state_dict(state_dict, prefix)
  1482. for name, child in module._modules.items():
  1483. if child is not None:
  1484. load(child, prefix + name + ".")
  1485. load(model, prefix="")
  1486. If this approach is not used, then the full model will first be copied to each GPU. For models
  1487. bigger than the memory of a single GPU, this method is required.
  1488. """
  1489. self.enabled = enabled
  1490. if not enabled:
  1491. return
  1492. if isinstance(params, Iterable) and not isinstance(params, torch.Tensor):
  1493. # deal with generators like model.parameters()
  1494. # must convert to list to be able to iterate more than once if we get a generator
  1495. params = list(params)
  1496. else:
  1497. # single param
  1498. params = [params]
  1499. # enable if at least one is zero-param, otherwise a noop
  1500. if not any(is_zero_param(p) for p in params):
  1501. self.enabled = False
  1502. return
  1503. self.params = [p for p in params if hasattr(p, "ds_id")]
  1504. self.params = sorted(
  1505. set(self.params), key=lambda x: x.ds_id
  1506. ) # remove the duplicates to prevent racing condition, we must also make sure the order is the same on all ranks otherwise we'll get deadlocks
  1507. self.src_rank = None
  1508. if modifier_rank is not None:
  1509. if self.params[0].ds_process_group == dist.get_world_group():
  1510. self.src_rank = modifier_rank
  1511. else:
  1512. # A group was specified; convert DP rank to global rank
  1513. self.src_rank = dist.get_global_rank(self.params[0].ds_process_group, modifier_rank)
  1514. self.fwd_module = fwd_module
  1515. if self.fwd_module is not None:
  1516. # is a no-op if already registered
  1517. for p in self.params:
  1518. register_external_parameter(self.fwd_module, p)
  1519. def __enter__(self):
  1520. if not self.enabled:
  1521. return
  1522. self.params[0].all_gather(param_list=self.params)
  1523. def __exit__(self, *exc):
  1524. if not self.enabled:
  1525. return
  1526. if self.src_rank is None:
  1527. self.params[0].partition(param_list=self.params, has_been_updated=False)
  1528. return
  1529. handles = [dist.broadcast(p, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params]
  1530. for h in handles:
  1531. h.wait()
  1532. self.params[0].partition(param_list=self.params, has_been_updated=True)