partition_parameters.py 100 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import math
  5. import os
  6. import types
  7. from typing import Callable, Iterable
  8. from enum import Enum
  9. import functools
  10. import itertools
  11. from typing import List
  12. from collections import defaultdict
  13. import logging
  14. import torch
  15. from torch import Tensor
  16. from deepspeed import comm as dist
  17. from torch.nn import Module
  18. from torch.nn import Parameter
  19. from .linear import zero3_linear_wrap
  20. from deepspeed.utils import groups
  21. import deepspeed
  22. from ..utils import see_memory_usage
  23. from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
  24. from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param
  25. from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
  26. from deepspeed.runtime.config_utils import get_config_default
  27. from deepspeed.utils import instrument_w_nvtx, logger
  28. from deepspeed.comm.comm import init_distributed
  29. from deepspeed.utils.debug import (debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name,
  30. debug_param2name_id, debug_param2name_id_shape_status)
  31. from deepspeed.accelerator import get_accelerator
  32. from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
  33. from deepspeed.inference.quantization.utils import _quantize_param, WEIGHT_QUANTIZATION_LAYERS, wrap_quantized_functional, wrap_load_from_state_dict
  34. partitioned_param_data_shape = [0]
  35. zero_init_context = 0
  36. top_level_context = None
  37. class NoGatherHandle:
  38. def __init__(self, param: Parameter) -> None:
  39. if param.ds_status != ZeroParamStatus.INFLIGHT:
  40. raise RuntimeError(f"expected param {param.ds_summary()} to be available")
  41. if hasattr(param.ds_tensor, "ds_quant_scale"):
  42. param.data = Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale).to(
  43. device=get_accelerator().current_device_name(), non_blocking=True).view(param.ds_shape)
  44. else:
  45. param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(),
  46. non_blocking=True).view(param.ds_shape)
  47. self.__param = param
  48. def wait(self) -> None:
  49. if not get_accelerator().is_synchronized_device():
  50. get_accelerator().current_stream().synchronize()
  51. self.__param.ds_status = ZeroParamStatus.AVAILABLE
  52. class NoGatherCoalescedHandle:
  53. def __init__(self, params: List[Parameter]) -> None:
  54. self.__params = params
  55. self.__complete = False
  56. for param in self.__params:
  57. if param.ds_status != ZeroParamStatus.INFLIGHT:
  58. raise RuntimeError(f"expected param {param.ds_summary()} to not be available")
  59. if hasattr(param.ds_tensor, "ds_quant_scale"):
  60. param.data = Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale).to(
  61. device=get_accelerator().current_device_name(), non_blocking=True).view(param.ds_shape)
  62. else:
  63. param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(),
  64. non_blocking=True).view(param.ds_shape)
  65. @instrument_w_nvtx
  66. def wait(self) -> None:
  67. if self.__complete:
  68. return
  69. if not get_accelerator().is_synchronized_device():
  70. get_accelerator().current_stream().synchronize()
  71. for param in self.__params:
  72. assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
  73. param.ds_status = ZeroParamStatus.AVAILABLE
  74. self.__complete = True
  75. def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):
  76. return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True)
  77. def print_rank_0(message, debug=False, force=False):
  78. rank = dist.get_rank()
  79. if rank == 0 and (debug or force):
  80. print(message)
  81. # other variations
  82. # - print for all ranks w/o interleaving
  83. # printflock(f"[{rank}] {message}")
  84. # - print to log file per rank
  85. # log_rank_file(rank, message)
  86. def debug_rank0(msg: str) -> None:
  87. if dist.get_rank() == 0:
  88. logger.debug(msg)
  89. def _init_external_params(module):
  90. if not hasattr(module, '_external_params'):
  91. module._external_params = {}
  92. def external_parameters(self):
  93. return self._external_params.items()
  94. def all_parameters(self):
  95. return itertools.chain(self.named_parameters(self, recurse=False), external_parameters(self))
  96. module.ds_external_parameters = types.MethodType(external_parameters, module)
  97. module.all_parameters = types.MethodType(all_parameters, module)
  98. def register_external_parameter(module, parameter):
  99. """Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in
  100. the forward and backward passes of ``module``.
  101. This is used when a parameter is accessed outside of its owning module's
  102. ``forward()``. DeepSpeed must know to collect it from its partitioned
  103. state and when to release the memory.
  104. .. note::
  105. This is only applicable to training with ZeRO stage 3.
  106. Args:
  107. module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass.
  108. parameter (``torch.nn.Parameter``): The parameter to register.
  109. Raises:
  110. RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
  111. Examples
  112. ========
  113. #. Register a weight that is used in another module's forward pass (line 6).
  114. Parameter ``layer1.weight`` is used by ``layer2`` (line 11).
  115. .. code-block:: python
  116. :linenos:
  117. :emphasize-lines: 6,11
  118. class ModuleZ3(torch.nn.Module):
  119. def __init__(self, *args):
  120. super().__init__(self, *args)
  121. self.layer1 = SomeLayer()
  122. self.layer2 = OtherLayer()
  123. deepspeed.zero.register_external_parameter(self, self.layer1.weight)
  124. def forward(self, input):
  125. x = self.layer1(input)
  126. # self.layer1.weight is required by self.layer2.forward
  127. y = self.layer2(x, self.layer1.weight)
  128. return y
  129. """
  130. if not isinstance(parameter, torch.nn.Parameter):
  131. raise RuntimeError('Parameter is not a torch.nn.Parameter')
  132. if not hasattr(module, '_external_params'):
  133. _init_external_params(module)
  134. key = id(parameter)
  135. module._external_params[key] = parameter
  136. def unregister_external_parameter(module, parameter):
  137. """Reverses the effects of :meth:`register_external_parameter`.
  138. Args:
  139. module (``torch.nn.Module``): The module to affect.
  140. parameter (``torch.nn.Parameter``): The parameter to unregister.
  141. Raises:
  142. RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
  143. RuntimeError: If ``parameter`` is not a registered external parameter of ``module``.
  144. """
  145. if not isinstance(parameter, torch.nn.Parameter):
  146. raise RuntimeError('Parameter is not a torch.nn.Parameter')
  147. if not hasattr(module, '_external_params') or id(parameter) not in module._external_params:
  148. raise RuntimeError('Parameter is not a registered external parameter of module.')
  149. key = id(parameter)
  150. del module._external_params[key]
  151. class ZeroParamType(Enum):
  152. # same as regular pytorch parameters
  153. NORMAL = 1
  154. # parameters are partitioned across data parallel process
  155. PARTITIONED = 2
  156. # the parameter is held with a unique process rank
  157. # and is not available on all other process
  158. REMOTE = 3
  159. class ZeroParamStatus(Enum):
  160. # parameters are fully present and ready for use on all processes
  161. AVAILABLE = 1
  162. # parameters are either partitioned or remote in some or all process
  163. NOT_AVAILABLE = 2
  164. # parameters are being gathered.
  165. INFLIGHT = 3
  166. _orig_torch_tensor = torch.tensor
  167. _orig_torch_empty = torch.empty
  168. _orig_torch_zeros = torch.zeros
  169. _orig_torch_ones = torch.ones
  170. _orig_torch_full = torch.full
  171. _orig_torch_arange = torch.arange
  172. _orig_torch_eye = torch.eye
  173. _orig_torch_randn = torch.randn
  174. def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable:
  175. def wrapped_fn(*args, **kwargs) -> Tensor:
  176. if kwargs.get("device", None) is None:
  177. kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
  178. tensor: Tensor = fn(*args, **kwargs)
  179. if tensor.is_floating_point():
  180. tensor.data = tensor.data.to(target_fp_dtype)
  181. return tensor
  182. return wrapped_fn
  183. def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable:
  184. def new_tensor(cls, *args, **kwargs) -> Tensor:
  185. device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
  186. if not args:
  187. args = (0, )
  188. tensor = _orig_torch_empty(0, device=device).new_empty(*args, **kwargs)
  189. if tensor.is_floating_point():
  190. tensor = tensor.to(dtype)
  191. return tensor
  192. return new_tensor
  193. # https://stackoverflow.com/a/63851681/9201239
  194. def get_all_subclasses(cls):
  195. subclass_list = []
  196. def recurse(cl):
  197. for subclass in cl.__subclasses__():
  198. subclass_list.append(subclass)
  199. recurse(subclass)
  200. recurse(cls)
  201. return set(subclass_list)
  202. @instrument_w_nvtx
  203. def free_param(param: Parameter) -> None:
  204. """Free underlying storage of a parameter."""
  205. assert not param.ds_active_sub_modules, param.ds_summary()
  206. if get_accelerator().on_accelerator(param.data):
  207. # need to make sure that we don't free the parameter while it is still
  208. # being used for computation
  209. if not get_accelerator().is_synchronized_device():
  210. param.data.record_stream(get_accelerator().current_stream())
  211. # param.data doesn't store anything meaningful in partitioned state
  212. param.data = torch.empty(0, dtype=param.dtype, device=param.device)
  213. param.ds_status = ZeroParamStatus.NOT_AVAILABLE
  214. reuse_buffers = False
  215. temp_contiguous_tensor = None
  216. empty_buffers = {}
  217. # Inserts _post_init_method at the end of init method
  218. # for all sub classes of torch.nn.Module
  219. class InsertPostInitMethodToModuleSubClasses(object):
  220. num_module_parameters = 0
  221. num_module_elements = 0
  222. def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtype=None):
  223. self.mem_efficient_linear = mem_efficient_linear
  224. self.enabled = enabled
  225. self._set_dtype(ds_config, dtype)
  226. assert self.dtype in [
  227. torch.half, torch.bfloat16, torch.float
  228. ], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]"
  229. self.wrapped_cls = set()
  230. self.skip_init_depth = 0
  231. self.quantized_initialization = None
  232. if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization:
  233. self.quantized_initialization = ds_config.weight_quantization_config.quantized_initialization
  234. def __enter__(self):
  235. if not self.enabled:
  236. return
  237. global zero_init_context
  238. if zero_init_context == 0:
  239. self.patch_init_and_builtins()
  240. global top_level_context
  241. top_level_context = self
  242. zero_init_context += 1
  243. def __exit__(self, exc_type, exc_value, traceback):
  244. if not self.enabled:
  245. return
  246. global zero_init_context
  247. zero_init_context -= 1
  248. # Exiting the top level context
  249. if zero_init_context == 0:
  250. self.unpatch_init_and_builtins()
  251. global top_level_context
  252. top_level_context = None
  253. if dist.get_rank() == 0:
  254. billion_elems = InsertPostInitMethodToModuleSubClasses.num_module_elements / 1e9
  255. num_params = InsertPostInitMethodToModuleSubClasses.num_module_parameters
  256. logger.info(
  257. f"finished initializing model - num_params = {num_params}, num_elems = {billion_elems:.2f}B")
  258. # Now that we cleaned up the metaclass injection, raise the exception.
  259. if exc_type is not None:
  260. return False
  261. # To be implemented by inheriting classes
  262. def _post_init_method(self, module):
  263. pass
  264. def _set_dtype(self, ds_config, dtype):
  265. if ds_config is not None and dtype is None:
  266. if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
  267. raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")
  268. if ds_config.bfloat16_enabled:
  269. self.dtype = torch.bfloat16
  270. elif ds_config.fp16_enabled:
  271. self.dtype = torch.half
  272. else:
  273. self.dtype = torch.float
  274. else:
  275. self.dtype = dtype or torch.float16 if get_accelerator().is_fp16_supported(
  276. ) else torch.bfloat16 if get_accelerator().is_bf16_supported else torch.float32
  277. def patch_init_and_builtins(self):
  278. def apply_with_gather(orig_module_apply_fn: Callable) -> Callable:
  279. """many models make use of child modules like Linear or Embedding which
  280. perform their own weight initialization in their __init__ methods,
  281. but will then have more weight initialization in a parent module's __init__
  282. method that modifies weights of child modules, which is typically done
  283. using the Module.apply method.
  284. since the Init context manager partitions child modules immediately after
  285. they are initialized, without modifying apply we would entirely skip
  286. any initialization done by parent modules.
  287. to get around this issue, we wrap the function passed to Module.apply
  288. so that the applied function is applied to child modules correctly.
  289. """
  290. def get_wrapped_fn_to_apply(fn_to_apply: Callable) -> Callable:
  291. if hasattr(fn_to_apply, "wrapped"):
  292. return fn_to_apply
  293. @functools.wraps(fn_to_apply)
  294. def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None:
  295. """gathers parameters before calling apply function. afterwards
  296. parameters are broadcasted to ensure consistency across all ranks
  297. then re-partitioned.
  298. takes the following steps:
  299. 1. allgathers parameters for the current module being worked on
  300. 2. calls the original function
  301. 3. broadcasts root rank's parameters to the other ranks
  302. 4. re-partitions the parameters
  303. """
  304. # TODO Delay error checking for dangling partitioned parameters to post module init
  305. # raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, "
  306. # f"were zero params, is it possible that the parameters were "
  307. # f"overwritten after they were initialized? "
  308. # f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ")
  309. params_to_apply_fn_to: Iterable[Parameter] = list(
  310. sorted([p for p in module_to_apply_fn_to.parameters(recurse=False) if is_zero_param(p)],
  311. key=lambda p: p.ds_id))
  312. for param in params_to_apply_fn_to:
  313. param.all_gather()
  314. fn_to_apply(module_to_apply_fn_to)
  315. for param in params_to_apply_fn_to:
  316. dist.broadcast(param.data, 0, group=param.ds_process_group)
  317. for param in params_to_apply_fn_to:
  318. param.partition(has_been_updated=True)
  319. wrapped_fn_to_apply.wrapped = True
  320. return wrapped_fn_to_apply
  321. @functools.wraps(orig_module_apply_fn)
  322. def wrapped_apply(module: Module, fn_to_apply: Callable) -> None:
  323. orig_module_apply_fn(module, get_wrapped_fn_to_apply(fn_to_apply))
  324. return wrapped_apply
  325. def hook_for_skip_init(module):
  326. # this function is intended for handling the logic of torch.nn.utils.skip_init
  327. # skip_init:module_cls(*args, **kwargs).to_empty(device=final_device), where kwargs['device']='meta'
  328. # the function call occurs between module_cls(*args, **kwargs) and to_empty(device=final_device).
  329. def partition_after_empty_init(f):
  330. @functools.wraps(f)
  331. def wrapper(module, *args, **kwargs):
  332. _module = f(module, *args, **kwargs)
  333. # here is the post-hook for module.apply(empty_like...)
  334. # after module.apply(empty_like...), the module has completed its empty init on real device
  335. # since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init
  336. self._post_init_method(_module)
  337. return _module
  338. return wrapper
  339. def post_wrapper_to_empty(f):
  340. # append some wrapper restoration after to_empty() call
  341. @functools.wraps(f)
  342. def wrapper(*args, **kwargs):
  343. res = f(*args, **kwargs)
  344. # restore _apply hook
  345. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  346. _disable_class_apply(subclass)
  347. # self restore
  348. module.to_empty = f
  349. return res
  350. return wrapper
  351. def _enable_class_apply(cls):
  352. cls._old_apply_of_skip_init_hook = cls._apply
  353. cls._apply = partition_after_empty_init(cls._apply)
  354. def _disable_class_apply(cls):
  355. cls._apply = cls._old_apply_of_skip_init_hook
  356. # add hooks for to_empty: apply_(empty_like)
  357. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  358. _enable_class_apply(subclass)
  359. # add a restore hook when exiting skip_init
  360. module.to_empty = post_wrapper_to_empty(module.to_empty)
  361. def partition_after(f):
  362. @functools.wraps(f)
  363. def wrapper(module, *args, **kwargs):
  364. # important logic: We want to run post_init only after child's __init__ is
  365. # completed, and do nothing after __init__ of any of its parents and grandparents in
  366. # the inheritance ancestry. This way the partitioning will need to happen only once
  367. # when the whole object is ready to be partitioned and not before. This is because
  368. # often the child module will need to tweak the weights - for example running a
  369. # custom weights init function. So if a parent created the weights param, the child
  370. # won't need to gather it in order to tweak it
  371. print_rank_0(f'Before initializing {module.__class__.__name__}', force=False)
  372. is_child_module = False
  373. if not hasattr(module, "_ds_child_entered"):
  374. # child's __init__ was called, since parents all see the same object they can now skip post_init
  375. is_child_module = True
  376. setattr(module, "_ds_child_entered", True)
  377. init_on_meta = 'device' in kwargs and kwargs['device'] == 'meta'
  378. if init_on_meta:
  379. self.skip_init_depth += 1
  380. f(module, *args, **kwargs)
  381. if init_on_meta and self.skip_init_depth == 1:
  382. # check and handle the logic of empty_init
  383. hook_for_skip_init(module)
  384. if is_child_module:
  385. # child's __init__ is done, now we can run a single post_init on the child object
  386. delattr(module, "_ds_child_entered")
  387. print_rank_0(f'Running post_init for {module.__class__.__name__}', force=False)
  388. if self.skip_init_depth == 0:
  389. self._post_init_method(module)
  390. print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False)
  391. if init_on_meta:
  392. self.skip_init_depth -= 1
  393. return wrapper
  394. def _enable_class(cls):
  395. cls._old_init = cls.__init__
  396. cls.__init__ = partition_after(cls.__init__)
  397. def _init_subclass(cls, **kwargs):
  398. cls._old_init = cls.__init__
  399. cls.__init__ = partition_after(cls.__init__)
  400. # Replace .__init__() for all existing subclasses of torch.nn.Module recursively
  401. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  402. _enable_class(subclass)
  403. # holding onto some methods so we can put them back the way they were in __exit__
  404. torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
  405. torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply
  406. torch.Tensor.__old_new__ = torch.Tensor.__new__
  407. # Replace .__init__() for future subclasses of torch.nn.Module
  408. torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
  409. if Init.override_module_apply:
  410. torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
  411. self._add_tensor_creation_wrappers()
  412. if self.mem_efficient_linear:
  413. print_rank_0(
  414. "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
  415. force=False)
  416. self.linear_bk = torch.nn.functional.linear
  417. torch.nn.functional.linear = zero3_linear_wrap
  418. if self.quantized_initialization:
  419. print_rank_0("nn.functional.linear has been overridden with quantized linear version.", force=False)
  420. torch.nn.functional.linear = wrap_quantized_functional(torch.nn.functional.linear)
  421. torch.nn.functional.embedding = wrap_quantized_functional(torch.nn.functional.embedding)
  422. for cls in WEIGHT_QUANTIZATION_LAYERS:
  423. cls._load_from_state_dict = wrap_load_from_state_dict(cls._load_from_state_dict)
  424. logger.info("Enable Zero3 engine with INT4 quantization.")
  425. self.patched = True
  426. def unpatch_init_and_builtins(self):
  427. if self.patched:
  428. def _disable_class(cls):
  429. cls.__init__ = cls._old_init
  430. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  431. _disable_class(subclass)
  432. # putting methods back the way we found them
  433. torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
  434. if Init.override_module_apply:
  435. torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
  436. self._remove_tensor_creation_wrappers()
  437. self.patched = False
  438. def _add_tensor_creation_wrappers(self):
  439. torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
  440. torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype)
  441. torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype)
  442. torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype)
  443. torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
  444. torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
  445. torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype)
  446. torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype)
  447. torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, self.dtype)
  448. def _remove_tensor_creation_wrappers(self):
  449. torch.Tensor.__new__ = torch.Tensor.__old_new__
  450. torch.tensor = _orig_torch_tensor
  451. torch.empty = _orig_torch_empty
  452. torch.zeros = _orig_torch_zeros
  453. torch.ones = _orig_torch_ones
  454. torch.full = _orig_torch_full
  455. torch.arange = _orig_torch_arange
  456. torch.eye = _orig_torch_eye
  457. torch.randn = _orig_torch_randn
  458. def shutdown_init_context():
  459. """
  460. This function is used to initialize deepspeed engine inside the context of Init.
  461. We need to remove the wrappers but keep the context.
  462. """
  463. if top_level_context:
  464. top_level_context.unpatch_init_and_builtins()
  465. def restore_init_context():
  466. """
  467. This function is used to restore the wrappers after deepspeed engine is initialized.
  468. """
  469. if top_level_context:
  470. top_level_context.patch_init_and_builtins()
  471. class AllGatherHandle:
  472. def __init__(self, handle, param: Parameter, quantization=None) -> None:
  473. if param.ds_status != ZeroParamStatus.INFLIGHT:
  474. raise RuntimeError(f"expected param {param.ds_summary()} to be available")
  475. self.__handle = handle
  476. self.__param = param
  477. self.__quantization = quantization
  478. def wait(self) -> None:
  479. instrument_w_nvtx(self.__handle.wait)()
  480. if self.__quantization:
  481. instrument_w_nvtx(self.__quantization.quant_handle.wait)()
  482. self.__param.data = self.__quantization.backend.dequantize(
  483. self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device)
  484. self.__param.ds_status = ZeroParamStatus.AVAILABLE
  485. class AllGatherCoalescedHandle:
  486. def __init__(
  487. self,
  488. allgather_handle,
  489. params: List[Parameter],
  490. partitions: List[Tensor],
  491. world_size: int,
  492. use_secondary_tensor=False,
  493. quantization=None,
  494. ) -> None:
  495. self.allgather_handle = allgather_handle
  496. self.params = params
  497. self.partitions = partitions
  498. self.world_size = world_size
  499. self.use_secondary_tensor = use_secondary_tensor
  500. self.complete = False
  501. self.quantization = quantization
  502. for param in self.params:
  503. if param.ds_status != ZeroParamStatus.INFLIGHT:
  504. raise RuntimeError(f"expected param {param.ds_summary()} to not be available")
  505. @instrument_w_nvtx
  506. def wait(self) -> None:
  507. if self.complete:
  508. return
  509. instrument_w_nvtx(self.allgather_handle.wait)()
  510. if self.quantization:
  511. instrument_w_nvtx(self.quantization.quant_handle.wait)()
  512. flat_tensor = self.quantization.backend.dequantize(
  513. self.quantization.quantized_param, self.quantization.scale_buffer).to(self.params[0].device)
  514. self.partitions: List[Parameter] = []
  515. for i in range(self.world_size):
  516. self.partitions.append(
  517. flat_tensor.narrow(0, self.quantization.partition_sz * i, self.quantization.partition_sz))
  518. # split the single tensor out into individual tensors
  519. param_offset = 0
  520. for param in self.params:
  521. assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
  522. partitions: List[Tensor] = []
  523. ds_tensor_numel = param.ds_tensor.ds_numel
  524. if self.use_secondary_tensor:
  525. ds_tensor_numel *= param.ds_secondary_tensor_num_of_groups
  526. for rank in range(self.world_size):
  527. param_start = rank * ds_tensor_numel
  528. if param_start < param.ds_numel:
  529. part_to_copy = self.partitions[rank].narrow(0, param_offset,
  530. min(param.ds_numel - param_start, ds_tensor_numel))
  531. partitions.append(part_to_copy)
  532. param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
  533. param.ds_status = ZeroParamStatus.AVAILABLE
  534. for part_to_copy in partitions:
  535. if not get_accelerator().is_synchronized_device():
  536. part_to_copy.record_stream(get_accelerator().current_stream())
  537. param_offset += ds_tensor_numel
  538. self.complete = True
  539. class MultipleAllGatherHandles:
  540. def __init__(self, handles: List[AllGatherCoalescedHandle]):
  541. self.handles = handles
  542. def wait(self) -> None:
  543. for handle in self.handles:
  544. handle.wait()
  545. class QuantizationInfo:
  546. # a placeholder object to store all quant related vars used in handles
  547. def __init__(self) -> None:
  548. self.quantized_param = None
  549. self.backend = None
  550. self.quant_handle = None
  551. self.scale_buffer = None
  552. class CUDAQuantizer:
  553. async_flag = True
  554. target_group_size = 8000 # the optimal size is 4k, so we set the target to be below 8k
  555. group_size_cache = dict()
  556. quantizer_cuda_module = None
  557. def __init__(self) -> None:
  558. if CUDAQuantizer.quantizer_cuda_module is None:
  559. CUDAQuantizer.quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load()
  560. def quantize(self, param, groups=None):
  561. if groups is None:
  562. try:
  563. groups = self.group_size_cache[param.numel()]
  564. except KeyError:
  565. groups = math.ceil(param.numel() / self.target_group_size)
  566. while groups < param.numel():
  567. if param.numel() % (8 * groups) == 0:
  568. break
  569. groups += 1
  570. while True:
  571. if param.numel() % (8 * groups * 2) == 0 and param.numel(
  572. ) / groups > self.target_group_size: #hard limit of 16k group_size
  573. groups *= 2
  574. else:
  575. break
  576. assert (
  577. param.numel() % (8 * groups) == 0
  578. ), f"Qantized weight requires the number of weights be a multiple of 8. Yet {param.numel()} cannot be divided by 8*{groups}"
  579. assert (param.numel() / groups < 16000), f"{param.numel()} / {groups} is larger than 16k"
  580. assert param.numel(
  581. ) > groups, f"Adaptive grouping algorithm cannot find a group size for input tensor of size {param.numel()}"
  582. self.group_size_cache[param.numel()] = groups
  583. return self.quantizer_cuda_module.quantize(param.to(get_accelerator().device_name()), groups, 8,
  584. self.quantizer_cuda_module.Symmetric)
  585. def dequantize(self, quantized_param, scale):
  586. return self.quantizer_cuda_module.dequantize(quantized_param, scale, scale.numel(), 8,
  587. self.quantizer_cuda_module.Symmetric)
  588. def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandle:
  589. for param in params:
  590. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  591. raise RuntimeError(f"expect param.ds_status == ZeroParamStatus.NOT_AVAILABLE, got{param.ds_summary()}")
  592. param.ds_status = ZeroParamStatus.INFLIGHT
  593. params = sorted(params, key=lambda p: p.ds_id)
  594. if len(params) == 1:
  595. param, = params
  596. return NoGatherHandle(param)
  597. return NoGatherCoalescedHandle(params)
  598. # Replaces all parameters in module with Scattered Parameters
  599. class Init(InsertPostInitMethodToModuleSubClasses):
  600. param_id = 0
  601. param_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "param_persistence_threshold")
  602. model_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "model_persistence_threshold")
  603. num_persisted_parameters = 0
  604. num_persisted_elements = 0
  605. apply_param_persistence = False
  606. override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply")
  607. def __init__(
  608. self,
  609. module=None,
  610. data_parallel_group=None,
  611. mem_efficient_linear=True,
  612. remote_device=None,
  613. pin_memory=False,
  614. config_dict_or_path=None,
  615. config=None,
  616. enabled=True,
  617. dtype=None,
  618. mpu=None,
  619. zero_param_parallel_group=None,
  620. zero_quantized_weights=False,
  621. zero_quantized_nontrainable_weights=False,
  622. sequence_data_parallel_group=None,
  623. param_swapper=None,
  624. ):
  625. """A context to enable massive model construction for training with
  626. ZeRO-3. Models are automatically partitioned (or, sharded) across the
  627. system and converted to half precision.
  628. Args:
  629. module (``torch.nn.Module``, optional): If provided, partition the model as
  630. if it was constructed in the context.
  631. data_parallel_group (``deepspeed.comm`` process group, optional):
  632. The group of processes to partition among. Defaults to all processes.
  633. mem_efficient_linear (bool, optional): Replace
  634. torch.nn.functional.linear with an implementation that allows
  635. DeepSpeed to partition parameters. Defaults to ``True``.
  636. remote_device (string, optional): The initial device to store model
  637. weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
  638. memory. The model may still be moved to GPU based on the
  639. offload settings for training. Defaults to param offload device if a config is
  640. defined, otherwise GPU.
  641. pin_memory (bool, optional): Potentially increase performance by
  642. using pinned memory for model weights. ``remote_device`` must be
  643. ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``.
  644. config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
  645. for swapping fp16 params to NVMe.
  646. config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
  647. enabled (bool, optional): If ``False``, this context has no
  648. effect. Defaults to ``True``.
  649. dtype (``dtype``, optional): Can be used to change the data type of the parameters.
  650. Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
  651. mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
  652. zero_param_parallel_group(``object``, optional): Parallel (comm) group for dual partitioning of ZeRO params.
  653. zero_quantized_weights (bool, optional): If ``True``, turn on quantized weights in all gather weights. Default is ``False``
  654. zero_quantized_nontrainable_weights (bool, optional): If ``True``, nontrainable weights will be stored in quantized format. Default is ``False``
  655. param_swapper (``deepspeed.runtime.swap_tensor.partitioned_param_swapper.AsyncPartitionedParameterSwapper``, optional): [Experimental] Use existing parameter swapper. Defaults to ``None``.
  656. This argument will be removed in the near future.
  657. This context accelerates model initialization and enables models that
  658. are too large to allocate in their entirety in CPU memory. It has the
  659. following effects:
  660. #. allocates tensors to either GPU or CPU memory or NVMe
  661. #. converts floating point tensors to half precision
  662. #. immediately partitions tensors among the group of data-parallel devices
  663. #. (*optional*) replaces ``torch.nn.functional.linear`` with a more
  664. memory-efficient implementation
  665. These modifications allow for models that exceed the size of local CPU/GPU
  666. memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
  667. or GPU memory or NVMe) across all nodes. Consider initializing a model with one
  668. trillion parameters, whose weights occupy two terabytes (TB) in half
  669. precision. The initial CPU allocation in full precision requires 4TB of
  670. memory *per process*, and so a system with 8 GPUs per node would need 32TB of
  671. CPU memory due to data-parallel redundancies. Instead, by immediately
  672. partitioning tensors we remove the redundancies. The result is that
  673. regardless of the number of GPUs, we still only require the original 4TB. This
  674. allows for a linear increase in model size with the aggregate system memory.
  675. For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
  676. parameter model with 4 nodes and 32 GPUs.
  677. Important: If the fp16 weights of the model can't fit onto a single GPU memory
  678. this feature must be used.
  679. .. note::
  680. Initializes ``deepspeed.comm`` if it has not already been done so.
  681. See :meth:`deepspeed.init_distributed` for more information.
  682. .. note::
  683. Only applicable to training with ZeRO-3.
  684. Examples
  685. --------
  686. #. Allocate a model and partition it among all processes:
  687. .. code-block:: python
  688. with deepspeed.zero.Init():
  689. model = MyLargeModel()
  690. #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:
  691. .. code-block:: python
  692. with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
  693. remote_device="cpu",
  694. pin_memory=True):
  695. model = MyLargeModel()
  696. #. Partition an already-allocated model in CPU memory:
  697. .. code-block:: python
  698. model = deepspeed.zero.Init(module=model)
  699. """
  700. if config is not None:
  701. config_dict_or_path = config
  702. logger.warning(
  703. f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.')
  704. _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path,
  705. mpu) if config_dict_or_path is not None else None
  706. if _ds_config is not None:
  707. if _ds_config.zero_config.memory_efficient_linear and _ds_config.compile_config.enabled:
  708. # memory_efficient_linear displays numerous errors when torch.compile is enabled.
  709. # Refer to https://github.com/pytorch/pytorch/issues/119059 for details.
  710. # Further investigation into performance is necessary, even after resolving this issue because
  711. # the `memory_efficient_linear` module may lead to more graph breaks compared to the original implementation.
  712. logger.warning(f'memory_efficient_linear is disabled when torch.compile is enabled.')
  713. mem_efficient_linear = False
  714. else:
  715. mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear
  716. super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype)
  717. if not dist.is_initialized():
  718. init_distributed()
  719. assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
  720. if data_parallel_group is None and sequence_data_parallel_group is None:
  721. self.ds_process_group = dist.get_world_group()
  722. elif sequence_data_parallel_group is not None:
  723. self.ds_process_group = sequence_data_parallel_group
  724. elif data_parallel_group is not None:
  725. self.ds_process_group = data_parallel_group
  726. else: # both given
  727. raise ValueError(
  728. "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments."
  729. )
  730. self.rank = dist.get_rank(group=self.ds_process_group)
  731. self.dp_world_size = dist.get_world_size(group=self.ds_process_group)
  732. self.zero_param_process_group = zero_param_parallel_group
  733. if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1 and self.zero_param_process_group is None:
  734. groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size)
  735. self.zero_param_process_group = groups._get_zero_param_intra_parallel_group()
  736. self.num_ranks_in_param_group = self.dp_world_size
  737. self.rank_in_group = self.rank
  738. self.num_param_groups = 1
  739. if self.zero_param_process_group is not None:
  740. self.num_ranks_in_param_group = groups._get_zero_param_intra_parallel_group_world_size()
  741. self.num_param_groups = int(self.dp_world_size / self.num_ranks_in_param_group)
  742. self.rank_in_group = groups._get_zero_param_intra_parallel_rank_in_mygroup()
  743. print_rank_0(f"hpZeRO group size: {self.num_ranks_in_param_group}", force=True)
  744. logger.debug(
  745. "hpZeRO partition parameter my rank in world {} my rank in group {} ranks in my param partition group: {} "
  746. .format(self.rank, self.rank_in_group, groups._get_zero_param_intra_parallel_group_ranks()))
  747. # Local device is the device where the parameters are consumed, must be default device.
  748. # It is the device where parameters are fully instantiated using allgather
  749. self.local_device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
  750. get_accelerator().set_device(self.local_device)
  751. self.quantized_weights = zero_quantized_weights
  752. if _ds_config is not None and _ds_config.zero_config.zero_quantized_weights and not self.quantized_weights:
  753. self.quantized_weights = _ds_config.zero_config.zero_quantized_weights
  754. self.quantized_nontrainable_weights = zero_quantized_nontrainable_weights
  755. if _ds_config is not None and _ds_config.zero_config.zero_quantized_nontrainable_weights and not self.quantized_nontrainable_weights:
  756. self.quantized_nontrainable_weights = _ds_config.zero_config.zero_quantized_nontrainable_weights
  757. self.module = module
  758. if (self.quantized_weights or self.quantized_nontrainable_weights):
  759. self.quantizer_module = CUDAQuantizer()
  760. print_rank_0(f'Using quantizer for weights: {self.quantizer_module.__class__.__name__}', force=True)
  761. if _ds_config is not None:
  762. Init.override_module_apply = _ds_config.zero_config.override_module_apply
  763. if _ds_config.zero_config.offload_param is not None:
  764. remote_device = _ds_config.zero_config.offload_param.device
  765. pin_memory = _ds_config.zero_config.offload_param.pin_memory
  766. self._validate_remote_device(remote_device, _ds_config)
  767. # Remote device is the device where parameter partitions are stored
  768. # It can be same as local_device or it could be CPU or NVMe.
  769. self.remote_device = self.local_device if remote_device in [None, OffloadDeviceEnum.none] else remote_device
  770. self.pin_memory = pin_memory if (self.remote_device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme
  771. ]) else False
  772. # Enable fp16 param swapping to NVMe
  773. if self.remote_device == OffloadDeviceEnum.nvme:
  774. self.param_swapper = param_swapper or AsyncPartitionedParameterSwapper(_ds_config, self.dtype)
  775. else:
  776. self.param_swapper = None
  777. # If we are provided an already-allocated module to prepare.
  778. if module is not None:
  779. assert isinstance(module, torch.nn.Module)
  780. self._convert_to_zero_parameters(module.parameters(recurse=True))
  781. self.use_all_gather_into_tensor = dist.has_all_gather_into_tensor()
  782. if not self.use_all_gather_into_tensor:
  783. logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}")
  784. def _update_persist_config(self, ds_config):
  785. Init.apply_param_persistence = True
  786. Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold
  787. Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.num_partitions
  788. def _zero_init_param(self, param):
  789. self._convert_to_deepspeed_param(param)
  790. if dist.get_world_group() == self.get_dp_process_group():
  791. dist.broadcast(param.data, 0, self.get_dp_process_group())
  792. else:
  793. dist.broadcast(param.data, dist.get_global_rank(self.get_dp_process_group(), 0),
  794. self.get_dp_process_group())
  795. param.partition()
  796. def _convert_to_zero_parameters(self, param_list):
  797. for param in param_list:
  798. if is_zero_param(param):
  799. continue
  800. param.data = param.data.to(self.local_device)
  801. self._zero_init_param(param)
  802. def _validate_remote_device(self, remote_device, ds_config):
  803. if ds_config is not None:
  804. if remote_device in [None, OffloadDeviceEnum.cpu]:
  805. if ds_config.zero_config.offload_param is not None:
  806. offload_param_device = ds_config.zero_config.offload_param.device
  807. assert offload_param_device != OffloadDeviceEnum.nvme, \
  808. f"'device' in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."
  809. if remote_device == OffloadDeviceEnum.nvme:
  810. assert ds_config.zero_config.offload_param is not None, \
  811. f'"offload_param" must be defined in DeepSpeed Config if remote device is {OffloadDeviceEnum.nvme}.'
  812. assert ds_config.zero_config.offload_param.nvme_path is not None, \
  813. f'"nvme_path" in DeepSpeed Config cannot be None if remote device is {OffloadDeviceEnum.nvme}'
  814. def _post_init_method(self, module):
  815. #see_memory_usage(f"Before converting params in {module.__class__.__name__}", force=False)
  816. print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
  817. see_memory_usage(f"Before converting and partitioning params in {module.__class__.__name__}", force=False)
  818. for name, param in module.named_parameters(recurse=False):
  819. print_rank_0(f'Analyzing param {name} in {module.__class__.__name__}', force=False)
  820. InsertPostInitMethodToModuleSubClasses.num_module_parameters += 1
  821. InsertPostInitMethodToModuleSubClasses.num_module_elements += param.numel()
  822. if not is_zero_param(param):
  823. if not get_accelerator().on_accelerator(param):
  824. param.data = param.data.to(self.local_device)
  825. if name == 'weight' and self.quantized_initialization and type(module) in WEIGHT_QUANTIZATION_LAYERS:
  826. _quantize_param(param, self.quantized_initialization)
  827. self._zero_init_param(param)
  828. print_rank_0(
  829. f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}")
  830. see_memory_usage(
  831. f"Param count {InsertPostInitMethodToModuleSubClasses.num_module_elements}. After converting and partitioning params in {module.__class__.__name__}",
  832. force=False)
  833. def _convert_to_deepspeed_param(self, param):
  834. # Partitioned, Normal, Remote
  835. param.ds_param_type = ZeroParamType.PARTITIONED
  836. # Replicated vs Partitioned vs Inflight
  837. param.ds_status = ZeroParamStatus.AVAILABLE
  838. # Stores the shape of the original tensor
  839. param.ds_shape = param.shape
  840. # Stores the number of elements in the original parameter without padding
  841. param.ds_numel = param.numel()
  842. # Stores the partitioned copy of the tensor
  843. param.ds_tensor = None
  844. # Keeps track of how many active sub-modules need this param at any given point in time
  845. param.ds_active_sub_modules = set()
  846. # If this flag is true, then the parameters are replicated throughput training
  847. # And only partitioned before the step
  848. if Init.apply_param_persistence and param.ds_numel <= Init.param_persistence_threshold and Init.num_persisted_elements + param.ds_numel <= Init.model_persistence_threshold:
  849. param.ds_persist = True
  850. Init.num_persisted_parameters += 1
  851. Init.num_persisted_elements += param.ds_numel
  852. else:
  853. param.ds_persist = False
  854. param.is_external_param = False
  855. # The group that the parameter is scattered across.
  856. param.ds_process_group = self.ds_process_group
  857. # Stores the secondary partitioned copy of the tensor
  858. param.ds_secondary_tensor = None
  859. #Process group for secondary partition all (group) gather
  860. param.ds_zero_param_process_group = self.zero_param_process_group
  861. param.ds_secondary_tensor_group_size = self.num_ranks_in_param_group
  862. param.ds_secondary_tensor_num_of_groups = self.num_param_groups
  863. # This is set to the Async Param swapper if remote device is nvme
  864. # else this is set to None
  865. param.nvme_swapper = self.param_swapper
  866. # DeepSpeed Param ID
  867. param.ds_id = Init.param_id
  868. Init.param_id += 1
  869. def all_gather(param_list=None, async_op=False, hierarchy=0):
  870. cls = param
  871. if param_list is None:
  872. param_list = [cls]
  873. return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
  874. def _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group):
  875. partition_sz = sum(p.ds_tensor.ds_numel for p in params)
  876. use_secondary_tensor = params[0].ds_secondary_tensor is not None
  877. if use_secondary_tensor:
  878. partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)
  879. flat_tensor = torch.empty(partition_sz * world_size,
  880. dtype=dtype,
  881. device=get_accelerator().current_device_name(),
  882. requires_grad=False)
  883. partitions: List[Parameter] = []
  884. for i in range(world_size):
  885. partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))
  886. if use_secondary_tensor:
  887. instrument_w_nvtx(
  888. torch.cat)([p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
  889. out=partitions[rank_in_group])
  890. else:
  891. instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
  892. out=partitions[rank_in_group])
  893. handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
  894. #Fix get_partition_dp_group(params[0]))
  895. return AllGatherCoalescedHandle(
  896. allgather_handle=handle,
  897. params=params,
  898. partitions=partitions,
  899. world_size=world_size,
  900. use_secondary_tensor=use_secondary_tensor,
  901. )
  902. @instrument_w_nvtx
  903. def all_gather_coalesced(params: Iterable[Parameter],
  904. safe_mode: bool = False,
  905. quantize: bool = False) -> AllGatherCoalescedHandle:
  906. # fetches from nvme if the partition is not available and in nvme
  907. self._ensure_availability_of_partitioned_params(params)
  908. if self.num_partitions == 1:
  909. return _no_gather_coalesced(params)
  910. for param in params:
  911. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  912. raise RuntimeError(param.ds_summary())
  913. param.ds_status = ZeroParamStatus.INFLIGHT
  914. #use appropriate all gather process group
  915. ds_process_group = self.ds_process_group
  916. rank_in_group = self.rank
  917. world_size = self.dp_world_size
  918. use_secondary_tensor = params[0].ds_secondary_tensor is not None
  919. if self.zero_param_process_group and use_secondary_tensor:
  920. ds_process_group = self.zero_param_process_group #intragroup
  921. rank_in_group = self.rank_in_group
  922. world_size = self.num_ranks_in_param_group
  923. #pprint(dir(ds_process_group))
  924. # ensure that each rank has params in same order. the allgather
  925. # is done by flattening the parameter list into a single tensor that
  926. # can be allgathered in a single call - this means that if each rank
  927. # gives a list of the same parameters in a different order we will
  928. # silently get incorrect parameter values, and have very difficult
  929. # to debug correctness issues.
  930. params = sorted(params, key=lambda p: p.ds_id)
  931. if logger.isEnabledFor(logging.DEBUG):
  932. debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")
  933. if safe_mode:
  934. # ensure that same list (with same ordering) of parameters are
  935. # being allgathered across all ranks, otherwise could mix
  936. # data between tensors.
  937. assert_ints_same_as_other_ranks([p.ds_id for p in params])
  938. # ensure that tensors from each rank agree on the same ds_numel
  939. # otherwise could mix data between tensors.
  940. assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params])
  941. if len(params) == 1:
  942. # have an opportunity to avoid some intermediate memory allocations
  943. param = params[0]
  944. buffer_size = math.ceil(param.ds_numel / world_size) * world_size
  945. if use_secondary_tensor:
  946. buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized
  947. param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor
  948. param_buffer = torch.empty(
  949. buffer_size,
  950. dtype=param_ds_tensor.dtype if not quantize else torch.int8,
  951. device=get_accelerator().current_device_name(),
  952. requires_grad=False,
  953. )
  954. if not quantize:
  955. handles = _dist_allgather_fn(
  956. param_ds_tensor.to(get_accelerator().current_device_name()),
  957. param_buffer,
  958. ds_process_group,
  959. )
  960. param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
  961. return AllGatherHandle(handles, param)
  962. else:
  963. if hasattr(param_ds_tensor, "ds_quant_scale"):
  964. scales = param_ds_tensor.ds_quant_scale
  965. quantized_param = param_ds_tensor.data
  966. else:
  967. quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor)
  968. handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device_name()),
  969. param_buffer, ds_process_group)
  970. quant_scale_buffer = torch.empty(
  971. scales.numel() * world_size,
  972. dtype=scales.dtype,
  973. device=get_accelerator().current_device_name(),
  974. requires_grad=False,
  975. )
  976. quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device_name()),
  977. quant_scale_buffer, ds_process_group)
  978. quant_info = QuantizationInfo()
  979. quant_info.quantized_param = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(
  980. param.device)
  981. quant_info.backend = self.quantizer_module
  982. quant_info.quant_handle = quant_handle
  983. quant_info.scale_buffer = quant_scale_buffer
  984. return AllGatherHandle(handle, param, quantization=quant_info)
  985. else:
  986. if not quantize:
  987. dtype_params = defaultdict(list)
  988. for p in params:
  989. dtype_params[p.ds_tensor.dtype].append(p)
  990. handles = []
  991. for dtype, params in dtype_params.items():
  992. handles.append(_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group))
  993. return MultipleAllGatherHandles(handles)
  994. else:
  995. partition_sz = sum(p.ds_tensor.ds_numel for p in params)
  996. if use_secondary_tensor:
  997. partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)
  998. flat_tensor = torch.empty(partition_sz * world_size,
  999. dtype=torch.int8,
  1000. device=get_accelerator().current_device_name(),
  1001. requires_grad=False)
  1002. if use_secondary_tensor:
  1003. if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
  1004. quantized_param = instrument_w_nvtx(torch.cat)([
  1005. p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params
  1006. ])
  1007. scales = instrument_w_nvtx(torch.cat)([
  1008. p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name())
  1009. for p in params
  1010. ])
  1011. else:
  1012. quantized_param, scales = self.quantizer_module.quantize(
  1013. instrument_w_nvtx(torch.cat)([
  1014. p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params
  1015. ]))
  1016. else:
  1017. if hasattr(params[0].ds_tensor, "ds_quant_scale"):
  1018. quantized_param = instrument_w_nvtx(torch.cat)(
  1019. [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params])
  1020. scales = instrument_w_nvtx(torch.cat)([
  1021. p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) for p in params
  1022. ])
  1023. else:
  1024. quantized_param, scales = self.quantizer_module.quantize(
  1025. instrument_w_nvtx(torch.cat)(
  1026. [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params]))
  1027. quant_scale_buffer = torch.empty(
  1028. scales.numel() * world_size,
  1029. dtype=torch.float32,
  1030. device=get_accelerator().current_device_name(),
  1031. requires_grad=False,
  1032. )
  1033. handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
  1034. quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
  1035. quant_info = QuantizationInfo()
  1036. quant_info.quantized_param = flat_tensor
  1037. quant_info.backend = self.quantizer_module
  1038. quant_info.quant_handle = quant_handle
  1039. quant_info.scale_buffer = quant_scale_buffer
  1040. quant_info.partition_sz = partition_sz
  1041. quant_info.world_size = world_size
  1042. return AllGatherCoalescedHandle(
  1043. allgather_handle=handle,
  1044. params=params,
  1045. partitions=None,
  1046. world_size=world_size,
  1047. use_secondary_tensor=use_secondary_tensor,
  1048. quantization=quant_info,
  1049. )
  1050. def partition(param_list=None, hierarchy=0, has_been_updated=False):
  1051. cls = param
  1052. print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
  1053. force=False)
  1054. if param_list is None:
  1055. param_list = [cls]
  1056. self._partition(param_list, has_been_updated=has_been_updated)
  1057. def reduce_gradients_at_owner(param_list=None, hierarchy=0):
  1058. cls = param
  1059. if param_list is None:
  1060. param_list = [cls]
  1061. print_rank_0(
  1062. f"{'--'*hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner"
  1063. )
  1064. self._reduce_scatter_gradients(param_list)
  1065. def partition_gradients(param_list=None, partition_buffers=None, hierarchy=0, accumulate=False):
  1066. cls = param
  1067. print_rank_0(
  1068. f"{'--'*hierarchy}----Partitioning param gradient with id {debug_param2name_id_shape_device(cls)}")
  1069. if param_list is None:
  1070. param_list = [cls]
  1071. if isinstance(partition_buffers, torch.Tensor):
  1072. partition_buffers = [partition_buffers]
  1073. self._partition_gradients(param_list, partition_buffers=partition_buffers, accumulate=accumulate)
  1074. def aligned_size():
  1075. return self._aligned_size(param)
  1076. def padding_size():
  1077. return self._padding_size(param)
  1078. def partition_numel():
  1079. return self._partition_numel(param)
  1080. def item_override():
  1081. param.all_gather()
  1082. return param._orig_item()
  1083. def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict:
  1084. return {
  1085. "id": debug_param2name_id(slf) if use_debug_name else slf.ds_id,
  1086. "status": slf.ds_status.name,
  1087. "numel": slf.numel(),
  1088. "ds_numel": slf.ds_numel,
  1089. "shape": tuple(slf.shape),
  1090. "ds_shape": tuple(slf.ds_shape),
  1091. "requires_grad": slf.requires_grad,
  1092. "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None,
  1093. "persist": slf.ds_persist,
  1094. "active_sub_modules": slf.ds_active_sub_modules,
  1095. "ds_tensor.shape": slf.ds_tensor.shape if slf.ds_tensor is not None else None
  1096. }
  1097. def convert_to_zero_parameters(param_list):
  1098. self._convert_to_zero_parameters(param_list)
  1099. def allgather_before(func: Callable) -> Callable:
  1100. def wrapped(*args, **kwargs):
  1101. param.all_gather()
  1102. return func(*args, **kwargs)
  1103. return wrapped
  1104. # Collectives for gathering and partitioning parameters
  1105. param.all_gather = all_gather
  1106. param.all_gather_coalesced = all_gather_coalesced
  1107. param.partition = partition
  1108. # Collective for averaging gradients
  1109. param.reduce_gradients_at_owner = reduce_gradients_at_owner
  1110. param.partition_gradients = partition_gradients
  1111. # Partitioning size utilities
  1112. param.aligned_size = aligned_size
  1113. param.padding_size = padding_size
  1114. param.partition_numel = partition_numel
  1115. param.ds_summary = types.MethodType(ds_summary, param)
  1116. param.item = allgather_before(param.item)
  1117. param.convert_to_zero_parameters = convert_to_zero_parameters
  1118. def _aligned_size(self, param):
  1119. return param.ds_numel + self._padding_size(param)
  1120. def _padding_size(self, param):
  1121. remainder = param.ds_numel % self.num_partitions
  1122. return (self.num_partitions - remainder) if remainder else 0
  1123. def _partition_numel(self, param):
  1124. return param.ds_tensor.ds_numel
  1125. def _ensure_availability_of_partitioned_params(self, params):
  1126. swap_in_list = []
  1127. swap_in_flight = []
  1128. for param in params:
  1129. if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
  1130. assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
  1131. swap_in_list.append(param)
  1132. if param.ds_tensor.status == PartitionedParamStatus.INFLIGHT:
  1133. assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
  1134. swap_in_flight.append(param)
  1135. if len(swap_in_list) > 0:
  1136. swap_in_list[0].nvme_swapper.swap_in(swap_in_list, async_op=False)
  1137. elif len(swap_in_flight) > 0:
  1138. swap_in_flight[0].nvme_swapper.synchronize_reads()
  1139. @instrument_w_nvtx
  1140. def _all_gather(self, param_list, async_op=False, hierarchy=None):
  1141. # fetches from nvme if the partition is not available and in nvme
  1142. self._ensure_availability_of_partitioned_params(param_list)
  1143. handles = []
  1144. all_gather_list = []
  1145. for param in param_list:
  1146. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  1147. if async_op:
  1148. handle = self._allgather_param(param, async_op=async_op, hierarchy=hierarchy)
  1149. param.ds_status = ZeroParamStatus.INFLIGHT # if async_op else ZeroParamStatus.AVAILABLE
  1150. handles.append(handle)
  1151. else:
  1152. all_gather_list.append(param)
  1153. # note: param_list may contain params that are already in flight / aviailable. So we need to use all_gather_list
  1154. if not async_op:
  1155. if len(all_gather_list) == 1:
  1156. ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
  1157. else:
  1158. all_gather_quantize_list = []
  1159. all_gather_nonquantize_list = []
  1160. for param in all_gather_list:
  1161. if hasattr(param.ds_tensor,
  1162. "ds_quant_scale") or (hasattr(param, "ds_secondary_tensor")
  1163. and hasattr(param.ds_secondary_tensor, "ds_quant_scale")):
  1164. all_gather_quantize_list.append(param)
  1165. else:
  1166. all_gather_nonquantize_list.append(param)
  1167. # _allgather_params_coalesced always return None
  1168. self._allgather_params_coalesced(all_gather_nonquantize_list, hierarchy, quantize=False)
  1169. self._allgather_params_coalesced(all_gather_quantize_list, hierarchy, quantize=True)
  1170. for param in all_gather_list:
  1171. param.ds_status = ZeroParamStatus.AVAILABLE
  1172. return None
  1173. return handles
  1174. def _partition(self, param_list, force=False, has_been_updated=False):
  1175. for param in param_list:
  1176. print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False)
  1177. if self.zero_param_process_group is not None:
  1178. self._partition_param_sec(param)
  1179. self._partition_param(param, has_been_updated=has_been_updated)
  1180. param.ds_status = ZeroParamStatus.NOT_AVAILABLE
  1181. # if param.ds_tensor is not None:
  1182. # assert id(param.data) == id(param.ds_tensor.data), \
  1183. # "After the parameters are initially partitioned, make sure we are not recreating the partition."
  1184. #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False)
  1185. @instrument_w_nvtx
  1186. def _partition_param(self, param, buffer=None, has_been_updated=False):
  1187. assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
  1188. global reuse_buffers
  1189. print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False)
  1190. if param.ds_status is ZeroParamStatus.AVAILABLE:
  1191. print_rank_0(f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}", force=False)
  1192. # if reuse_buffers and False:
  1193. # numel = buffer.numel()
  1194. # buffer = param.data.view(-1)
  1195. # print_rank_0(
  1196. # "Returning buffer for param {param.ds_id} with numel {param.ds_numel} to empty buffers",
  1197. # force=False)
  1198. # if numel in empty_buffers:
  1199. # empty_buffers[numel].append(buffer)
  1200. # if deepspeed.comm.get_rank():
  1201. # print(f"Releasing {param.data.numel()}")
  1202. if param.ds_tensor is not None and not has_been_updated: ##param already partitioned
  1203. #print_rank_0(f"Param {param.ds_id} pri {param.ds_tensor.size()} loc? {param.ds_tensor.final_location}", force=True)
  1204. #param.data = param.ds_tensor.data
  1205. see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
  1206. # param.data does not store anything meaningful in partitioned state
  1207. free_param(param)
  1208. see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)
  1209. if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
  1210. print_rank_0(f"Param {param.ds_id} partition released since it exists in nvme", force=False)
  1211. param.nvme_swapper.remove_partition_and_release_buffers([param])
  1212. print_rank_0(
  1213. f"after swap Param {param.ds_id} {param.ds_tensor.shape} partition released since it exists in nvme",
  1214. force=False)
  1215. return
  1216. tensor_size = self._aligned_size(param)
  1217. partition_size = tensor_size // self.num_partitions
  1218. if param.ds_tensor is None:
  1219. final_location = None
  1220. if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor(
  1221. numel=partition_size):
  1222. final_location = OffloadDeviceEnum.nvme
  1223. buffer = self.param_swapper.get_buffer(param, partition_size)
  1224. partitioned_tensor = torch.empty(0, dtype=param.dtype, device=buffer.device)
  1225. partitioned_tensor.data = buffer.data
  1226. print_rank_0(f"ID {param.ds_id} Initializing partition for the first time for nvme offload.")
  1227. else:
  1228. if param.ds_persist:
  1229. device = self.local_device
  1230. elif self.remote_device == OffloadDeviceEnum.nvme:
  1231. device = OffloadDeviceEnum.cpu
  1232. else:
  1233. device = self.remote_device
  1234. partitioned_tensor = torch.empty(partition_size, dtype=param.dtype, device=device)
  1235. # quantize the tensor if it's not trainable
  1236. if not param.requires_grad and self.quantized_nontrainable_weights:
  1237. partitioned_tensor, partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize(
  1238. partitioned_tensor)
  1239. if device == OffloadDeviceEnum.cpu and self.pin_memory:
  1240. partitioned_tensor = get_accelerator().pin_memory(partitioned_tensor)
  1241. partitioned_tensor.requires_grad = False
  1242. param.ds_tensor = partitioned_tensor
  1243. param.ds_tensor.ds_numel = partition_size
  1244. param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
  1245. param.ds_tensor.final_location = final_location
  1246. start = partition_size * self.get_partition_rank()
  1247. end = start + partition_size
  1248. one_dim_param = param.contiguous().view(-1)
  1249. if start < param.ds_numel and end <= param.ds_numel:
  1250. src_tensor = one_dim_param.narrow(0, start, partition_size)
  1251. with torch.no_grad():
  1252. # make sure param.ds_tensor requires_grad always be false,
  1253. # otherwise, torch tracer will complain.
  1254. param.ds_tensor.copy_(src_tensor)
  1255. #partitioned_tensor = src_tensor.clone().detach().to(self.remote_device)
  1256. else:
  1257. # partitioned_tensor = torch.zeros(partition_size,
  1258. # dtype=param.dtype,
  1259. # device=self.remote_device )
  1260. if start < param.ds_numel:
  1261. elems_to_copy = param.ds_numel - start
  1262. with torch.no_grad():
  1263. # make sure param.ds_tensor requires_grad always be false,
  1264. # otherwise, torch tracer will complain.
  1265. param.ds_tensor.narrow(0, 0,
  1266. elems_to_copy).copy_(one_dim_param.narrow(0, start, elems_to_copy))
  1267. #print(f"Remote device {self.remote_device}")
  1268. #param.ds_tensor = partitioned_tensor
  1269. #param.data = param.ds_tensor.data
  1270. # param.data does not store anything meaningful in partitioned state
  1271. see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
  1272. free_param(param)
  1273. see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)
  1274. if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
  1275. self.param_swapper.swap_out_and_release([param])
  1276. print_rank_0(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.")
  1277. see_memory_usage(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.", force=False)
  1278. print_rank_0(f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}")
  1279. @instrument_w_nvtx
  1280. def _partition_param_sec(self, param, buffer=None, has_been_updated=False):
  1281. assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
  1282. global reuse_buffers
  1283. ##support for NVME secondary param offload
  1284. #print_rank_0(f"SEC Param id {param.ds_id} status is {param.ds_status}", force=True)
  1285. if param.ds_status is ZeroParamStatus.AVAILABLE:
  1286. #check padding
  1287. tensor_size = self._aligned_size(param)
  1288. partition_size = tensor_size // self.dp_world_size
  1289. secondary_partition_size = int(tensor_size // self.num_ranks_in_param_group)
  1290. if param.ds_secondary_tensor is None:
  1291. final_location = None
  1292. secondary_partitioned_tensor = torch.empty(secondary_partition_size,
  1293. dtype=param.dtype,
  1294. device=self.remote_device)
  1295. if self.pin_memory:
  1296. secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory()
  1297. # quantize the tensor if it's not trainable
  1298. if not param.requires_grad and self.quantized_nontrainable_weights:
  1299. secondary_partitioned_tensor, secondary_partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize(
  1300. secondary_partitioned_tensor)
  1301. secondary_partitioned_tensor.requires_grad = False
  1302. param.ds_secondary_tensor = secondary_partitioned_tensor
  1303. param.ds_secondary_tensor.ds_numel = secondary_partition_size
  1304. param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE
  1305. param.ds_secondary_tensor.final_location = final_location
  1306. #use rank in group for secondary tensor
  1307. secondary_start = secondary_partition_size * self.rank_in_group
  1308. secondary_end = secondary_start + secondary_partition_size
  1309. one_dim_param = param.contiguous().view(-1)
  1310. # ds_numel is unpadded, so the last chunk of the secondary tensor might not be secondary_partition_size
  1311. sec_numel = param.ds_numel - secondary_start if secondary_end > param.ds_numel else secondary_partition_size
  1312. # copy from full tensor to secondary tensor
  1313. param.ds_secondary_tensor.narrow(0, 0,
  1314. sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel))
  1315. # TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done
  1316. get_accelerator().current_stream().synchronize()
  1317. print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}",
  1318. force=False)
  1319. def _param_status(self, param):
  1320. if param.ds_tensor is not None:
  1321. print_rank_0(
  1322. f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned numel {param.ds_tensor.numel()}, data numel {param.data.numel()}"
  1323. )
  1324. else:
  1325. print_rank_0(
  1326. f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned ds_tensor {param.ds_tensor}, data numel {param.data.numel()}"
  1327. )
  1328. def _allgather_param(self, param, async_op=False, hierarchy=0):
  1329. partition_size = param.ds_tensor.ds_numel
  1330. tensor_size = partition_size * self.num_partitions
  1331. aligned_param_size = self._aligned_size(param)
  1332. assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}'
  1333. print_rank_0(
  1334. f"{'--'* hierarchy}---- Before allocating allgather param {debug_param2name_id_shape_status(param)} partition size={partition_size}"
  1335. )
  1336. see_memory_usage(
  1337. f'Before allocate allgather param {debug_param2name_id_shape_status(param)} partition_size={partition_size} ',
  1338. force=False)
  1339. flat_tensor = torch.zeros(aligned_param_size, dtype=param.dtype, device=param.device).view(-1)
  1340. see_memory_usage(
  1341. f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
  1342. force=False)
  1343. get_accelerator().synchronize()
  1344. print_rank_0(
  1345. f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
  1346. )
  1347. # if not flat_tensor.numel() > 100000:
  1348. # replicated_tensor = flat_tensor.narrow(0,
  1349. # 0,
  1350. # param.ds_numel).view(param.ds_shape)
  1351. # param.data = replicated_tensor.data
  1352. # return None
  1353. if self.use_all_gather_into_tensor:
  1354. handle = dist.all_gather_into_tensor(flat_tensor,
  1355. param.ds_tensor.to(get_accelerator().device_name()),
  1356. group=self.get_partition_dp_group(param),
  1357. async_op=async_op)
  1358. else:
  1359. partitions = []
  1360. for i in range(self.num_partitions):
  1361. partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size))
  1362. if i == dist.get_rank(group=self.get_partition_dp_group(param)):
  1363. partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
  1364. handle = dist.all_gather(partitions,
  1365. partitions[self.get_partition_rank()],
  1366. group=self.get_partition_dp_group(param),
  1367. async_op=async_op)
  1368. replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
  1369. param.data = replicated_tensor.data
  1370. return handle
  1371. def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False):
  1372. """ blocking call
  1373. avoid explicit memory copy in _allgather_params
  1374. """
  1375. if len(param_list) == 0:
  1376. return
  1377. if self.num_partitions == 1:
  1378. handle = _no_gather_coalesced(param_list)
  1379. handle.wait()
  1380. return None
  1381. # collect local tensors and partition sizes
  1382. partition_sizes = []
  1383. local_tensors = []
  1384. if quantize:
  1385. quantize_scale_sizes = []
  1386. quantize_scale_tensors = []
  1387. for param in param_list:
  1388. partition_sizes.append(param.ds_tensor.ds_numel)
  1389. local_tensors.append(param.ds_tensor.to(get_accelerator().device_name()))
  1390. if quantize:
  1391. quantize_scale_sizes.append(param.ds_tensor.ds_quant_scale.numel())
  1392. quantize_scale_tensors.append(param.ds_tensor.ds_quant_scale.to(get_accelerator().device_name()))
  1393. # allocate memory for allgather params
  1394. allgather_params = []
  1395. if quantize:
  1396. allgather_quantize_scale = []
  1397. for psize in partition_sizes:
  1398. tensor_size = psize * self.num_partitions
  1399. flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype,
  1400. device=self.local_device).view(-1)
  1401. flat_tensor.requires_grad = False
  1402. allgather_params.append(flat_tensor)
  1403. if quantize:
  1404. for psize in quantize_scale_sizes:
  1405. tensor_size = psize * self.num_partitions
  1406. flat_tensor = torch.empty(tensor_size,
  1407. dtype=param_list[0].ds_tensor.ds_quant_scale.dtype,
  1408. device=self.local_device).view(-1)
  1409. flat_tensor.requires_grad = False
  1410. allgather_quantize_scale.append(flat_tensor)
  1411. # launch
  1412. launch_handles = []
  1413. launch_quantize_handles = []
  1414. for param_idx, param in enumerate(param_list):
  1415. input_tensor = local_tensors[param_idx].view(-1)
  1416. if self.use_all_gather_into_tensor:
  1417. # try the _all_gather_base from Pytorch master
  1418. h = dist.all_gather_into_tensor(allgather_params[param_idx],
  1419. input_tensor,
  1420. group=self.get_partition_dp_group(param),
  1421. async_op=True)
  1422. if quantize:
  1423. quantize_handle = dist.all_gather_into_tensor(allgather_quantize_scale[param_idx],
  1424. quantize_scale_tensors[param_idx],
  1425. group=self.get_partition_dp_group(param),
  1426. async_op=True)
  1427. launch_quantize_handles.append(quantize_handle)
  1428. else:
  1429. output_list = []
  1430. for i in range(self.num_partitions):
  1431. psize = partition_sizes[param_idx]
  1432. partition = allgather_params[param_idx].narrow(0, i * psize, psize)
  1433. output_list.append(partition)
  1434. if not get_accelerator().on_accelerator(partition):
  1435. logger.warning(
  1436. f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}')
  1437. # back to old all_gather function
  1438. h = dist.all_gather(output_list, input_tensor, group=self.get_partition_dp_group(param), async_op=True)
  1439. if quantize:
  1440. output_scale_list = []
  1441. for i in range(self.num_partitions):
  1442. psize = quantize_scale_sizes[param_idx]
  1443. partition = allgather_quantize_scale[param_idx].narrow(0, i * psize, psize)
  1444. output_scale_list.append(partition)
  1445. quant_handle = dist.all_gather(output_scale_list,
  1446. quantize_scale_tensors[param_idx],
  1447. group=self.get_partition_dp_group(param),
  1448. async_op=True)
  1449. launch_quantize_handles.append(quant_handle)
  1450. launch_handles.append(h)
  1451. # Wait ensures the operation is enqueued, but not necessarily complete.
  1452. launch_handles[-1].wait()
  1453. if quantize:
  1454. for quant_handle in launch_quantize_handles:
  1455. quant_handle.wait()
  1456. # assign to param.data (not copy)
  1457. for i, param in enumerate(param_list):
  1458. gathered_tensor = allgather_params[i]
  1459. if quantize:
  1460. gathered_tensor = self.quantizer_module.dequantize(gathered_tensor, allgather_quantize_scale[i])
  1461. param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data
  1462. # guarantee the communication to be completed
  1463. get_accelerator().synchronize()
  1464. return None
  1465. def _allgather_params(self, param_list, hierarchy=0):
  1466. if len(param_list) == 0:
  1467. return
  1468. partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
  1469. tensor_size = partition_size * self.num_partitions
  1470. flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device)
  1471. flat_tensor.requires_grad = False
  1472. partitions = []
  1473. for i in range(self.num_partitions):
  1474. start = partition_size * i
  1475. partitions.append(flat_tensor.narrow(0, start, partition_size))
  1476. if i == self.get_partition_rank():
  1477. offset = 0
  1478. for param in param_list:
  1479. param_numel = param.ds_tensor.ds_numel
  1480. partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data)
  1481. offset += param_numel
  1482. if hasattr(param_list[0], 'ds_quant_scale'):
  1483. scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list])
  1484. scale_tensor_size = scale_size * self.world_size
  1485. flat_scale_tensor = torch.empty(scale_tensor_size,
  1486. dtype=param_list[0].ds_tensor.ds_quant_scale.dtype,
  1487. device=self.local_device)
  1488. flat_scale_tensor.requires_grad = False
  1489. scale_partitions = []
  1490. for i in range(self.world_size):
  1491. start = scale_tensor_size * i
  1492. scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size))
  1493. if i == self.rank:
  1494. offset = 0
  1495. for param in param_list:
  1496. param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel
  1497. scale_partitions[i].narrow(0, offset,
  1498. param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data)
  1499. offset += param_scale_numel
  1500. dist.all_gather_into_tensor(flat_tensor,
  1501. partitions[self.get_partition_rank()],
  1502. group=self.get_partition_dp_group(param),
  1503. async_op=False)
  1504. if hasattr(param_list[0], 'ds_quant_scale'):
  1505. dist.all_gather(flat_scale_tensor,
  1506. param_list[0].ds_quant_scale,
  1507. group=self.get_partition_dp_group(param),
  1508. async_op=False)
  1509. param_offset = 0
  1510. for param in param_list:
  1511. param_partition_size = param.ds_tensor.ds_numel
  1512. param_size = param.ds_numel
  1513. replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device)
  1514. for i in range(self.num_partitions):
  1515. start = i * partition_size
  1516. param_start = i * param_partition_size
  1517. if param_start < param_size:
  1518. numel_to_copy = min(param_size - param_start, param_partition_size)
  1519. part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
  1520. replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy)
  1521. #param_offset += param.data.numel()
  1522. param_offset += param.ds_tensor.ds_numel
  1523. if hasattr(param_list[0], 'ds_quant_scale'):
  1524. replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor)
  1525. param.data = replicated_tensor.data
  1526. return None
  1527. def _reduce_scatter_gradients(self, param_list):
  1528. #print_rank_0([param.grad for param in param_list])
  1529. #assert any([param.grad is None for param in param_list]), "None gradients cannot be reduce scattered"
  1530. handles_and_reduced_partitions = []
  1531. for param in param_list:
  1532. assert param.grad.numel(
  1533. ) == param.ds_numel, f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter gradients whose size is not same as the params"
  1534. handles_and_reduced_partitions.append(self._reduce_scatter_gradient(param))
  1535. for param, (handle, reduced_partition) in zip(param_list, handles_and_reduced_partitions):
  1536. if handle is not None:
  1537. handle.wait()
  1538. # some ranks may have partitions that are padded to go beyond the grad size.
  1539. # For these ranks the output of reduce scatter is a separate buffer and needs
  1540. # to be copied in
  1541. partition_size = param.ds_tensor.ds_numel
  1542. start = self.get_partition_rank() * partition_size
  1543. end = start + partition_size
  1544. #print_rank_0("REduce scatter was executed for param {param.ds_id}")
  1545. if start < param.ds_numel < end:
  1546. elements = param.ds_numel - start
  1547. param.grad.view(-1).narrow(0, start, elements).copy_(reduced_partition.narrow(0, 0, elements))
  1548. def _reduce_scatter_gradient(self, param):
  1549. partition_size = param.ds_tensor.ds_numel
  1550. #output = torch.empty(partition_size, dtype=param.dtype, device=param.device)
  1551. total_size = partition_size * self.num_partitions
  1552. input_list = []
  1553. for i in range(self.num_partitions):
  1554. start = i * partition_size
  1555. end = start + partition_size
  1556. #print("before reduce scatter gradients")
  1557. if start < param.ds_numel and end <= param.ds_numel:
  1558. input = param.grad.view(-1).narrow(0, start, partition_size)
  1559. else:
  1560. input = torch.zeros(partition_size, dtype=param.dtype, device=param.device)
  1561. if start < param.ds_numel:
  1562. elements = param.ds_numel - start
  1563. input.narrow(0, 0, elements).copy_(param.grad.view(-1).narrow(0, start, elements))
  1564. #print("after reduce scatter gradients")
  1565. input_list.append(input)
  1566. rank = dist.get_rank(group=self.get_partition_dp_group(param))
  1567. handle = dist.reduce_scatter(input_list[rank],
  1568. input_list,
  1569. group=self.get_partition_dp_group(param),
  1570. async_op=True)
  1571. return handle, input_list[rank]
  1572. def _partition_gradients(self, param_list, partition_buffers=None, accumulate=False):
  1573. if partition_buffers is None:
  1574. partition_buffers = [None] * len(param_list)
  1575. for param, partition_buffer in zip(param_list, partition_buffers):
  1576. self._partition_gradient(param, partition_buffer=partition_buffer, accumulate=accumulate)
  1577. def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
  1578. #import pdb;pdb.set_trace()
  1579. # param.grad=None
  1580. # param.grad.test()
  1581. print_rank_0(
  1582. f"Partitioning param {param.ds_id} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.ds_numel}"
  1583. )
  1584. see_memory_usage("Before partitioning gradients", force=False)
  1585. partition_size = param.ds_tensor.ds_numel
  1586. if partition_buffer is None:
  1587. assert not accumulate, "No buffer to accumulate to"
  1588. partition_buffer = torch.zeros(partition_size, dtype=param.dtype, device=param.device)
  1589. else:
  1590. assert partition_buffer.numel(
  1591. ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
  1592. rank = dist.get_rank(group=self.get_partition_dp_group(param))
  1593. start = partition_size * rank
  1594. end = start + partition_size
  1595. dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size)
  1596. #print("before partition gradients")
  1597. if start < param.ds_numel:
  1598. elements = min(param.ds_numel - start, partition_size)
  1599. dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
  1600. src_tensor = param.grad.view(-1).narrow(0, start, elements)
  1601. # just copy the grad partition to the buffer
  1602. if not accumulate:
  1603. dest_tensor.copy_(src_tensor)
  1604. # if source and destination are on same device,
  1605. # add to the provided buffer
  1606. elif src_tensor.device == dest_tensor.device:
  1607. dest_tensor.add_(src_tensor)
  1608. # if source and destination are on different device, copy first to src
  1609. # then add and move back to the destination. This seems to run faster
  1610. # when src is gpu and dest is cpu
  1611. # adding directly to cpu is very slow
  1612. else:
  1613. acc_tensor = torch.empty(src_tensor.numel(), dtype=param.dtype, device=param.device)
  1614. acc_tensor.copy_(dest_tensor)
  1615. acc_tensor.add_(src_tensor)
  1616. dest_tensor.copy_(acc_tensor)
  1617. # partition_buffer.view(-1).narrow(
  1618. # 0,
  1619. # 0,
  1620. # elements).copy_(param.grad.view(-1).narrow(0,
  1621. # start,
  1622. # elements))
  1623. #print("after partition gradients")
  1624. param.grad.data = dest_tensor_full_buffer.data
  1625. see_memory_usage("After partitioning gradients", force=False)
  1626. def get_partition_dp_group(self, param):
  1627. return param.ds_process_group
  1628. def get_partition_rank(self):
  1629. """subclass can overload to specify different relative rank in
  1630. parameter partition group"""
  1631. return self.rank
  1632. @property
  1633. def num_partitions(self):
  1634. return self.dp_world_size
  1635. def get_dp_process_group(self):
  1636. """ Return the communication group with all data-parallel ranks """
  1637. return self.ds_process_group
  1638. class GatheredParameters:
  1639. def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True):
  1640. """A context that collects parameters that were partitioned via a
  1641. :class:`deepspeed.zero.Init` context. The parameters are partitioned
  1642. again upon exit.
  1643. Args:
  1644. params (``torch.nn.Parameter``): A single parameter, or an iterable of parameters (list, tuple, generator) of parameters to collect.
  1645. It's assumed that all parameters are zero params.
  1646. modifier_rank (int, optional): If specified, this rank's parameter will be
  1647. broadcasted on exit from the context. This argument is required if ``params`` are
  1648. modified, so that all processes have a consistent view of the data. Defaults
  1649. to ``None``.
  1650. fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be
  1651. registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
  1652. enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.
  1653. Important: Make sure to use ``modifier_rank`` that is not ``None`` (e.g., ``modifier_rank=0``)
  1654. if you need the GPU memory allocated by gather to be released upon exit from the context manager.
  1655. Important: if ``params`` isn't an iterable of parameters or a single parameter it'll be silently ignored!
  1656. Examples
  1657. ========
  1658. #. Allocate a partitioned module, initialize its weight on rank 0, and update all
  1659. processes.
  1660. .. code-block:: python
  1661. with deepspeed.zero.Init():
  1662. linear = torch.nn.Linear(1000,1000)
  1663. with deepspeed.zero.GatheredParameters(linear.weight,
  1664. modifier_rank=0):
  1665. if deepspeed.comm.get_rank() == 0:
  1666. linear.weight.zero_()
  1667. with deepspeed.zero.GatheredParameters(linear.weight,
  1668. modifier_rank=0):
  1669. if deepspeed.comm.get_rank() == 0:
  1670. linear.weight.zero_()
  1671. #. Collect a partitioned weight to pass to another module during
  1672. training. The parameter will be registered as an external parameter
  1673. and made available during the backward pass.
  1674. .. code-block:: python
  1675. :emphasize-lines: 6
  1676. def forward(self, input):
  1677. x = self.layer1(input)
  1678. # self.layer1.weight is required by self.layer2.forward
  1679. with deepspeed.zero.GatheredParameters(self.layer1.weight,
  1680. fwd_module=self):
  1681. y = self.layer2(x, self.layer1.weight)
  1682. return y
  1683. #. Pretrained model loading
  1684. .. code-block:: python
  1685. with deepspeed.zero.Init():
  1686. model = MyModel()
  1687. state_dict = torch.load(model_path, map_location="cpu")
  1688. def load(module: nn.Module, prefix=""):
  1689. # because zero3 puts placeholders in model params, this context
  1690. # manager gathers (unpartitions) the params of the current layer, then loads from
  1691. # the state dict and then re-partitions them again
  1692. with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
  1693. if deepspeed.comm.get_rank() == 0:
  1694. module._load_from_state_dict(state_dict, prefix)
  1695. for name, child in module._modules.items():
  1696. if child is not None:
  1697. load(child, prefix + name + ".")
  1698. load(model, prefix="")
  1699. If this approach is not used, then the full model will first be copied to each GPU. For models
  1700. bigger than the memory of a single GPU, this method is required.
  1701. """
  1702. self.enabled = enabled
  1703. if not enabled:
  1704. return
  1705. if isinstance(params, Iterable) and not isinstance(params, torch.Tensor):
  1706. # deal with generators like model.parameters()
  1707. # must convert to list to be able to iterate more than once if we get a generator
  1708. params = list(params)
  1709. else:
  1710. # single param
  1711. params = [params]
  1712. # enable if at least one is zero-param, otherwise a noop
  1713. if not any(is_zero_param(p) for p in params):
  1714. self.enabled = False
  1715. return
  1716. self.params = [p for p in params if hasattr(p, "ds_id")]
  1717. self.params = sorted(
  1718. set(self.params), key=lambda x: x.ds_id
  1719. ) # remove the duplicates to prevent racing condition, we must also make sure the order is the same on all ranks otherwise we'll get deadlocks
  1720. self.src_rank = None
  1721. if modifier_rank is not None:
  1722. if self.params[0].ds_process_group == dist.get_world_group():
  1723. self.src_rank = modifier_rank
  1724. else:
  1725. # A group was specified; convert DP rank to global rank
  1726. self.src_rank = dist.get_global_rank(self.params[0].ds_process_group, modifier_rank)
  1727. self.fwd_module = fwd_module
  1728. if self.fwd_module is not None:
  1729. # is a no-op if already registered
  1730. for p in self.params:
  1731. register_external_parameter(self.fwd_module, p)
  1732. def __enter__(self):
  1733. if not self.enabled:
  1734. return
  1735. self.params[0].all_gather(param_list=self.params)
  1736. def __exit__(self, *exc):
  1737. if not self.enabled:
  1738. return
  1739. if self.src_rank is None:
  1740. self.params[0].partition(param_list=self.params, has_been_updated=False)
  1741. return
  1742. handles = [dist.broadcast(p.data, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params]
  1743. for h in handles:
  1744. h.wait()
  1745. self.params[0].partition(param_list=self.params, has_been_updated=True)