partition_parameters.py 69 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658
  1. """
  2. "Copyright 2020 The Microsoft DeepSpeed Team.
  3. Licensed under the MIT license.
  4. """
  5. import math
  6. import os
  7. import types
  8. from typing import Callable, Iterable
  9. from enum import Enum
  10. import functools
  11. import itertools
  12. from typing import List
  13. import torch
  14. from torch import Tensor
  15. from deepspeed import comm as dist
  16. from torch.nn import Module
  17. from torch.nn import Parameter
  18. from .linear import zero3_linear_wrap
  19. import deepspeed
  20. from ..utils import get_only_unique_item, see_memory_usage
  21. from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks
  22. from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
  23. from deepspeed.utils import instrument_w_nvtx, logger
  24. from deepspeed.comm.comm import init_distributed
  25. from deepspeed.utils.debug import (debug_param2name_id_shape,
  26. debug_param2name_id_shape_device,
  27. debug_module2name,
  28. debug_param2name_id,
  29. debug_param2name_id_shape_status)
  30. from deepspeed.accelerator import get_accelerator
  31. from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
  32. param_count = 0
  33. partitioned_param_data_shape = [0]
  34. zero_init_enabled = False
  35. def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):
  36. return instrument_w_nvtx(dist.allgather_fn)(output_tensor,
  37. input_tensor,
  38. group=group,
  39. async_op=True)
  40. def print_rank_0(message, debug=False, force=False):
  41. rank = dist.get_rank()
  42. if rank == 0 and (debug or force):
  43. print(message)
  44. # other variations
  45. # - print for all ranks w/o interleaving
  46. # printflock(f"[{rank}] {message}")
  47. # - print to log file per rank
  48. # log_rank_file(rank, message)
  49. def debug_rank0(msg: str) -> None:
  50. if dist.get_rank() == 0:
  51. logger.debug(msg)
  52. def is_zero_param(parameter):
  53. if not torch.is_tensor(parameter):
  54. return False
  55. return hasattr(parameter, 'ds_id')
  56. def _init_external_params(module):
  57. if not hasattr(module, '_external_params'):
  58. module._external_params = {}
  59. def external_parameters(self):
  60. return self._external_params.items()
  61. def all_parameters(self):
  62. return itertools.chain(self.named_parameters(self,
  63. recurse=False),
  64. external_parameters(self))
  65. module.ds_external_parameters = types.MethodType(external_parameters, module)
  66. module.all_parameters = types.MethodType(all_parameters, module)
  67. def register_external_parameter(module, parameter):
  68. """Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in
  69. the forward and backward passes of ``module``.
  70. This is used when a parameter is accessed outside of its owning module's
  71. ``forward()``. DeepSpeed must know to collect it from its partitioned
  72. state and when to release the memory.
  73. .. note::
  74. This is only applicable to training with ZeRO stage 3.
  75. Args:
  76. module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass.
  77. parameter (``torch.nn.Parameter``): The parameter to register.
  78. Raises:
  79. RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
  80. Examples
  81. ========
  82. #. Register a weight that is used in another module's forward pass (line 6).
  83. Parameter ``layer1.weight`` is used by ``layer2`` (line 11).
  84. .. code-block:: python
  85. :linenos:
  86. :emphasize-lines: 6,11
  87. class ModuleZ3(torch.nn.Module):
  88. def __init__(self, *args):
  89. super().__init__(self, *args)
  90. self.layer1 = SomeLayer()
  91. self.layer2 = OtherLayer()
  92. deepspeed.zero.register_external_parameter(self, self.layer1.weight)
  93. def forward(self, input):
  94. x = self.layer1(input)
  95. # self.layer1.weight is required by self.layer2.forward
  96. y = self.layer2(x, self.layer1.weight)
  97. return y
  98. """
  99. if not isinstance(parameter, torch.nn.Parameter):
  100. raise RuntimeError('Parameter is not a torch.nn.Parameter')
  101. if not hasattr(module, '_external_params'):
  102. _init_external_params(module)
  103. key = id(parameter)
  104. module._external_params[key] = parameter
  105. def unregister_external_parameter(module, parameter):
  106. """Reverses the effects of :meth:`register_external_parameter`.
  107. Args:
  108. module (``torch.nn.Module``): The module to affect.
  109. parameter (``torch.nn.Parameter``): The parameter to unregister.
  110. Raises:
  111. RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
  112. RuntimeError: If ``parameter`` is not a registered external parameter of ``module``.
  113. """
  114. if not isinstance(parameter, torch.nn.Parameter):
  115. raise RuntimeError('Parameter is not a torch.nn.Parameter')
  116. if not hasattr(module,
  117. '_external_params') or id(parameter) not in module._external_params:
  118. raise RuntimeError('Parameter is not a registered external parameter of module.')
  119. key = id(parameter)
  120. del module._external_params[key]
  121. class ZeroParamType(Enum):
  122. # same as regular pytorch parameters
  123. NORMAL = 1
  124. # parameters are partitioned across data parallel process
  125. PARTITIONED = 2
  126. # the parameter is held with a unique process rank
  127. # and is not available on all other process
  128. REMOTE = 3
  129. class ZeroParamStatus(Enum):
  130. # parameters are fully present and ready for use on all processes
  131. AVAILABLE = 1
  132. # parameters are either partitioned or remote in some or all process
  133. NOT_AVAILABLE = 2
  134. # parameters are being gathered.
  135. INFLIGHT = 3
  136. _orig_torch_empty = torch.empty
  137. _orig_torch_zeros = torch.zeros
  138. _orig_torch_ones = torch.ones
  139. _orig_torch_full = torch.full
  140. def zero_wrapper_for_fp_tensor_constructor(fn: Callable,
  141. target_fp_dtype: torch.dtype) -> Callable:
  142. def wrapped_fn(*args, **kwargs) -> Tensor:
  143. if kwargs.get("device", None) is None:
  144. kwargs['device'] = torch.device(get_accelerator().device_name(
  145. os.environ["LOCAL_RANK"]))
  146. tensor: Tensor = fn(*args, **kwargs)
  147. if tensor.is_floating_point():
  148. tensor = tensor.to(target_fp_dtype)
  149. return tensor
  150. return wrapped_fn
  151. def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable:
  152. def new_tensor(cls, *args) -> Tensor:
  153. device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
  154. tensor = _orig_torch_empty(0, device=device).new_empty(*args)
  155. if tensor.is_floating_point():
  156. tensor = tensor.to(dtype)
  157. return tensor
  158. return new_tensor
  159. # https://stackoverflow.com/a/63851681/9201239
  160. def get_all_subclasses(cls):
  161. subclass_list = []
  162. def recurse(cl):
  163. for subclass in cl.__subclasses__():
  164. subclass_list.append(subclass)
  165. recurse(subclass)
  166. recurse(cls)
  167. return set(subclass_list)
  168. @instrument_w_nvtx
  169. def free_param(param: Parameter) -> None:
  170. """Free underlying storage of a parameter."""
  171. assert not param.ds_active_sub_modules, param.ds_summary()
  172. if get_accelerator().on_accelerator(param.data):
  173. # need to make sure that we don't free the parameter while it is still
  174. # being used for computation
  175. param.data.record_stream(get_accelerator().current_stream())
  176. # param.data doesn't store anything meaningful in partitioned state
  177. param.data = torch.empty(0, dtype=param.dtype, device=param.device)
  178. param.ds_status = ZeroParamStatus.NOT_AVAILABLE
  179. reuse_buffers = False
  180. temp_contiguous_tensor = None
  181. empty_buffers = {}
  182. # Inserts _post_init_method at the end of init method
  183. # for all sub classes of torch.nn.Module
  184. class InsertPostInitMethodToModuleSubClasses(object):
  185. def __init__(self,
  186. enabled=True,
  187. mem_efficient_linear=True,
  188. ds_config=None,
  189. dtype=None):
  190. self.mem_efficient_linear = mem_efficient_linear
  191. self.enabled = enabled
  192. self._set_dtype(ds_config, dtype)
  193. assert self.dtype in [torch.half, torch.bfloat16, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]"
  194. def __enter__(self):
  195. global zero_init_enabled
  196. if not self.enabled:
  197. return
  198. zero_init_enabled = True
  199. def apply_with_gather(orig_module_apply_fn: Callable) -> Callable:
  200. """many models make use of child modules like Linear or Embedding which
  201. perform their own weight initialization in their __init__ methods,
  202. but will then have more weight initialization in a parent module's __init__
  203. method that modifies weights of child modules, which is typically done
  204. using the Module.apply method.
  205. since the Init context manager partitions child modules immediately after
  206. they are initialized, without modifying apply we would entirely skip
  207. any initialization done by parent modules.
  208. to get around this issue, we wrap the function passed to Module.apply
  209. so that the applied function is applied to child modules correctly.
  210. """
  211. def get_wrapped_fn_to_apply(fn_to_apply: Callable) -> Callable:
  212. if hasattr(fn_to_apply, "wrapped"):
  213. return fn_to_apply
  214. @functools.wraps(fn_to_apply)
  215. def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None:
  216. """gathers parameters before calling apply function. afterwards
  217. parameters are broadcasted to ensure consistency across all ranks
  218. then re-partitioned.
  219. takes the following steps:
  220. 1. allgathers parameters for the current module being worked on
  221. 2. calls the original function
  222. 3. broadcasts root rank's parameters to the other ranks
  223. 4. re-partitions the parameters
  224. """
  225. if not all(
  226. is_zero_param(p)
  227. for p in module_to_apply_fn_to.parameters(recurse=False)):
  228. raise RuntimeError(
  229. f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, "
  230. f"were zero params, is it possible that the parameters were "
  231. f"overwritten after they were initialized? "
  232. f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} "
  233. )
  234. params_to_apply_fn_to: Iterable[Parameter] = list(
  235. sorted(module_to_apply_fn_to.parameters(recurse=False),
  236. key=lambda p: p.ds_id))
  237. for param in params_to_apply_fn_to:
  238. param.all_gather()
  239. fn_to_apply(module_to_apply_fn_to)
  240. for param in params_to_apply_fn_to:
  241. dist.broadcast(param.data, 0, group=param.ds_process_group)
  242. for param in params_to_apply_fn_to:
  243. param.partition(has_been_updated=True)
  244. wrapped_fn_to_apply.wrapped = True
  245. return wrapped_fn_to_apply
  246. @functools.wraps(orig_module_apply_fn)
  247. def wrapped_apply(module: Module, fn_to_apply: Callable) -> None:
  248. orig_module_apply_fn(module, get_wrapped_fn_to_apply(fn_to_apply))
  249. return wrapped_apply
  250. def partition_after(f):
  251. @functools.wraps(f)
  252. def wrapper(module, *args, **kwargs):
  253. # important logic: We want to run post_init only after child's __init__ is
  254. # completed, and do nothing after __init__ of any of its parents and grandparents in
  255. # the inheritance ancestry. This way the partitioning will need to happen only once
  256. # when the whole object is ready to be partitioned and not before. This is because
  257. # often the child module will need to tweak the weights - for example running a
  258. # custom weights init function. So if a parent created the weights param, the child
  259. # won't need to gather it in order to tweak it
  260. print_rank_0(f'Before initializing {module.__class__.__name__}',
  261. force=False)
  262. is_child_module = False
  263. if not hasattr(module, "_ds_child_entered"):
  264. # child's __init__ was called, since parents all see the same object they can now skip post_init
  265. is_child_module = True
  266. setattr(module, "_ds_child_entered", True)
  267. f(module, *args, **kwargs)
  268. if is_child_module:
  269. # child's __init__ is done, now we can run a single post_init on the child object
  270. delattr(module, "_ds_child_entered")
  271. print_rank_0(f'Running post_init for {module.__class__.__name__}',
  272. force=False)
  273. self._post_init_method(module)
  274. print_rank_0(
  275. f'After initializing followed by post init for {module.__class__.__name__}',
  276. force=False)
  277. return wrapper
  278. def _enable_class(cls):
  279. cls._old_init = cls.__init__
  280. cls.__init__ = partition_after(cls.__init__)
  281. def _init_subclass(cls, **kwargs):
  282. cls.__init__ = partition_after(cls.__init__)
  283. # Replace .__init__() for all existing subclasses of torch.nn.Module recursively
  284. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  285. # print(f"subclass={subclass.__module__}.{subclass.__qualname__}")
  286. _enable_class(subclass)
  287. # holding onto some methods so we can put them back the way they were in __exit__
  288. torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
  289. torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply
  290. torch.Tensor.__old_new__ = torch.Tensor.__new__
  291. # Replace .__init__() for future subclasses of torch.nn.Module
  292. torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
  293. torch.nn.modules.module.Module.apply = apply_with_gather(
  294. torch.nn.modules.module.Module._old_apply)
  295. torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
  296. torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty,
  297. self.dtype)
  298. torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros,
  299. self.dtype)
  300. torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
  301. torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
  302. if self.mem_efficient_linear:
  303. print_rank_0(
  304. "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
  305. force=False)
  306. self.linear_bk = torch.nn.functional.linear
  307. torch.nn.functional.linear = zero3_linear_wrap
  308. def __exit__(self, exc_type, exc_value, traceback):
  309. if not self.enabled:
  310. return
  311. shutdown_init_context()
  312. if dist.get_rank() == 0:
  313. logger.info("finished initializing model with %.2fB parameters",
  314. param_count / 1e9)
  315. # Now that we cleaned up the metaclass injection, raise the exception.
  316. if exc_type is not None:
  317. return False
  318. # To be implemented by inheriting classes
  319. def _post_init_method(self, module):
  320. pass
  321. def _set_dtype(self, ds_config, dtype):
  322. if ds_config is not None and dtype is None:
  323. if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
  324. raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")
  325. if ds_config.bfloat16_enabled:
  326. self.dtype = torch.bfloat16
  327. elif ds_config.fp16_enabled:
  328. self.dtype = torch.half
  329. else:
  330. self.dtype = torch.float
  331. else:
  332. self.dtype = dtype or torch.half
  333. def shutdown_init_context():
  334. global zero_init_enabled
  335. if not zero_init_enabled:
  336. return
  337. def _disable_class(cls):
  338. cls.__init__ = cls._old_init
  339. # Replace .__init__() for all existing subclasses of torch.nn.Module
  340. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  341. _disable_class(subclass)
  342. # putting methods back the way we found them
  343. torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
  344. torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
  345. torch.Tensor.__new__ = torch.Tensor.__old_new__
  346. torch.empty = _orig_torch_empty
  347. torch.zeros = _orig_torch_zeros
  348. torch.ones = _orig_torch_ones
  349. torch.full = _orig_torch_full
  350. # un doing it here will undo it during training
  351. # if self.mem_efficient_linear:
  352. # torch.nn.functional.linear = self.linear_bk
  353. # if self.mem_efficient_linear:
  354. # torch.nn.functional.linear = self.linear_bk
  355. zero_init_enabled = False
  356. class AllGatherHandle:
  357. def __init__(self, handle, param: Parameter) -> None:
  358. if param.ds_status != ZeroParamStatus.INFLIGHT:
  359. raise RuntimeError(f"expected param {param.ds_summary()} to be available")
  360. self.__handle = handle
  361. self.__param = param
  362. def wait(self) -> None:
  363. instrument_w_nvtx(self.__handle.wait)()
  364. self.__param.ds_status = ZeroParamStatus.AVAILABLE
  365. class AllGatherCoalescedHandle:
  366. def __init__(
  367. self,
  368. allgather_handle,
  369. params: List[Parameter],
  370. partitions: List[Tensor],
  371. world_size: int,
  372. ) -> None:
  373. self.__allgather_handle = allgather_handle
  374. self.__params = params
  375. self.__partitions = partitions
  376. self.__world_size = world_size
  377. self.__complete = False
  378. for param in self.__params:
  379. if param.ds_status != ZeroParamStatus.INFLIGHT:
  380. raise RuntimeError(
  381. f"expected param {param.ds_summary()} to not be available")
  382. @instrument_w_nvtx
  383. def wait(self) -> None:
  384. if self.__complete:
  385. return
  386. instrument_w_nvtx(self.__allgather_handle.wait)()
  387. # split the single tensor out into individual tensors
  388. param_offset = 0
  389. for param in self.__params:
  390. assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
  391. partitions: List[Tensor] = []
  392. for rank in range(self.__world_size):
  393. param_start = rank * param.ds_tensor.ds_numel
  394. if param_start < param.ds_numel:
  395. part_to_copy = self.__partitions[rank].narrow(
  396. 0,
  397. param_offset,
  398. min(param.ds_numel - param_start,
  399. param.ds_tensor.ds_numel))
  400. partitions.append(part_to_copy)
  401. param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
  402. param.ds_status = ZeroParamStatus.AVAILABLE
  403. for part_to_copy in partitions:
  404. part_to_copy.record_stream(get_accelerator().current_stream())
  405. param_offset += param.ds_tensor.ds_numel
  406. self.__complete = True
  407. # Replaces all parameters in module with Scattered Parameters
  408. class Init(InsertPostInitMethodToModuleSubClasses):
  409. param_id = 0
  410. def __init__(self,
  411. module=None,
  412. data_parallel_group=None,
  413. mem_efficient_linear=True,
  414. remote_device=None,
  415. pin_memory=False,
  416. config_dict_or_path=None,
  417. config=None,
  418. enabled=True,
  419. dtype=None,
  420. mpu=None):
  421. """A context to enable massive model construction for training with
  422. ZeRO-3. Models are automatically partitioned (or, sharded) across the
  423. system and converted to half precision.
  424. Args:
  425. module (``torch.nn.Module``, optional): If provided, partition the model as
  426. if it was constructed in the context.
  427. data_parallel_group (``deepspeed.comm`` process group, optional):
  428. The group of processes to partition among. Defaults to all processes.
  429. mem_efficient_linear (bool, optional): Replace
  430. torch.nn.functional.linear with an implementation that allows
  431. DeepSpeed to partition parameters. Defaults to ``True``.
  432. remote_device (string, optional): The initial device to store model
  433. weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
  434. memory. The model may still be moved to GPU based on the
  435. offload settings for training. Defaults to param offload device if a config is
  436. defined, otherwise GPU.
  437. pin_memory (bool, optional): Potentially increase performance by
  438. using pinned memory for model weights. ``remote_device`` must be
  439. ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``.
  440. config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
  441. for swapping fp16 params to NVMe.
  442. config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
  443. enabled (bool, optional): If ``False``, this context has no
  444. effect. Defaults to ``True``.
  445. dtype (``dtype``, optional): Can be used to change the data type of the parameters.
  446. Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
  447. mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
  448. This context accelerates model initialization and enables models that
  449. are too large to allocate in their entirety in CPU memory. It has the
  450. following effects:
  451. #. allocates tensors to either GPU or CPU memory or NVMe
  452. #. converts floating point tensors to half precision
  453. #. immediately partitions tensors among the group of data-parallel devices
  454. #. (*optional*) replaces ``torch.nn.functional.linear`` with a more
  455. memory-efficient implementation
  456. These modifications allow for models that exceed the size of local CPU/GPU
  457. memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
  458. or GPU memory or NVMe) across all nodes. Consider initializing a model with one
  459. trillion parameters, whose weights occupy two terabytes (TB) in half
  460. precision. The initial CPU allocation in full precision requires 4TB of
  461. memory *per process*, and so a system with 8 GPUs per node would need 32TB of
  462. CPU memory due to data-parallel redundancies. Instead, by immediately
  463. partitioning tensors we remove the redundancies. The result is that
  464. regardless of the number of GPUs, we still only require the original 4TB. This
  465. allows for a linear increase in model size with the aggregate system memory.
  466. For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
  467. parameter model with 4 nodes and 32 GPUs.
  468. Important: If the fp16 weights of the model can't fit onto a single GPU memory
  469. this feature must be used.
  470. .. note::
  471. Initializes ``deepspeed.comm`` if it has not already been done so.
  472. See :meth:`deepspeed.init_distributed` for more information.
  473. .. note::
  474. Can also be used as a decorator:
  475. .. code-block:: python
  476. @deepspeed.zero.Init()
  477. def get_model():
  478. return MyLargeModel()
  479. .. note::
  480. Only applicable to training with ZeRO-3.
  481. Examples
  482. --------
  483. #. Allocate a model and partition it among all processes:
  484. .. code-block:: python
  485. with deepspeed.zero.Init():
  486. model = MyLargeModel()
  487. #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:
  488. .. code-block:: python
  489. with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
  490. remote_device="cpu",
  491. pin_memory=True):
  492. model = MyLargeModel()
  493. #. Partition an already-allocated model in CPU memory:
  494. .. code-block:: python
  495. model = deepspeed.zero.Init(module=model)
  496. """
  497. if config is not None:
  498. config_dict_or_path = config
  499. logger.warning(
  500. f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.'
  501. )
  502. _ds_config = deepspeed.runtime.config.DeepSpeedConfig(
  503. config_dict_or_path,
  504. mpu) if config_dict_or_path is not None else None
  505. super().__init__(enabled=enabled,
  506. mem_efficient_linear=mem_efficient_linear,
  507. ds_config=_ds_config,
  508. dtype=dtype)
  509. if not dist.is_initialized():
  510. init_distributed()
  511. assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
  512. if data_parallel_group is None:
  513. self.ds_process_group = dist.get_world_group()
  514. else:
  515. self.ds_process_group = data_parallel_group
  516. self.rank = dist.get_rank(group=self.ds_process_group)
  517. self.world_size = dist.get_world_size(group=self.ds_process_group)
  518. # Local device is the device where the parameters are consumed, must be default device.
  519. # It is the device where parameters are fully instantiated using allgather
  520. self.local_device = torch.device(get_accelerator().device_name(
  521. os.environ["LOCAL_RANK"]))
  522. get_accelerator().set_device(self.local_device)
  523. if _ds_config is not None and _ds_config.zero_config.offload_param is not None:
  524. remote_device = _ds_config.zero_config.offload_param.device
  525. pin_memory = _ds_config.zero_config.offload_param.pin_memory
  526. self._validate_remote_device(remote_device, _ds_config)
  527. # Remote device is the device where parameter partitions are stored
  528. # It can be same as local_device or it could be CPU or NVMe.
  529. self.remote_device = self.local_device if remote_device in [
  530. None,
  531. OffloadDeviceEnum.none
  532. ] else remote_device
  533. self.pin_memory = pin_memory if (
  534. self.remote_device in [OffloadDeviceEnum.cpu,
  535. OffloadDeviceEnum.nvme]) else False
  536. # Enable fp16 param swapping to NVMe
  537. if self.remote_device == OffloadDeviceEnum.nvme:
  538. self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config, self.dtype)
  539. else:
  540. self.param_swapper = None
  541. # If we are provided an already-allocated module to prepare.
  542. if module is not None:
  543. assert isinstance(module, torch.nn.Module)
  544. self._convert_to_zero_parameters(module.parameters(recurse=True))
  545. self.use_all_gather_base = False
  546. if dist.has_allgather_base():
  547. self.use_all_gather_base = True
  548. else:
  549. logger.info(
  550. f"_all_gather_base API is not available in torch {torch.__version__}")
  551. def _convert_to_zero_parameters(self, param_list):
  552. for param in param_list:
  553. if is_zero_param(param):
  554. continue
  555. self._convert_to_deepspeed_param(param)
  556. param.partition()
  557. def _validate_remote_device(self, remote_device, ds_config):
  558. if ds_config is not None:
  559. if remote_device in [None, OffloadDeviceEnum.cpu]:
  560. if ds_config.zero_config.offload_param is not None:
  561. offload_param_device = ds_config.zero_config.offload_param.device
  562. assert offload_param_device != OffloadDeviceEnum.nvme, \
  563. f"'device' in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."
  564. if remote_device == OffloadDeviceEnum.nvme:
  565. assert ds_config.zero_config.offload_param is not None, \
  566. f'"offload_param" must be defined in DeepSpeed Config if remote device is {OffloadDeviceEnum.nvme}.'
  567. assert ds_config.zero_config.offload_param.nvme_path is not None, \
  568. f'"nvme_path" in DeepSpeed Config cannot be None if remote device is {OffloadDeviceEnum.nvme}'
  569. def _post_init_method(self, module):
  570. #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
  571. print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
  572. see_memory_usage(
  573. f"Before converting and partitioning parmas in {module.__class__.__name__}",
  574. force=False)
  575. global param_count
  576. for name, param in module.named_parameters(recurse=False):
  577. param_count += param.numel()
  578. if not is_zero_param(param):
  579. self._convert_to_deepspeed_param(param)
  580. print_rank_0(
  581. f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}"
  582. )
  583. if get_accelerator().on_accelerator(param):
  584. dist.broadcast(param, 0, self.ds_process_group)
  585. else:
  586. if dist.get_rank() == 0:
  587. logger.warn(f"param `{name}` in {module.__class__.__name__} "
  588. f"not on GPU so was not broadcasted from rank 0")
  589. param.partition()
  590. see_memory_usage(
  591. f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}",
  592. force=False)
  593. def _convert_to_deepspeed_param(self, param):
  594. # Partitioned, Normal, Remote
  595. param.ds_param_type = ZeroParamType.PARTITIONED
  596. # Replicated vs Partitioned vs Inflight
  597. param.ds_status = ZeroParamStatus.AVAILABLE
  598. # Stores the shape of the original tensor
  599. param.ds_shape = param.shape
  600. # Stores the number of elements in the original parameter without padding
  601. param.ds_numel = param.numel()
  602. # Stores the partitioned copy of the tensor
  603. param.ds_tensor = None
  604. # Keeps track of how many active sub-modules need this param at any given point in time
  605. param.ds_active_sub_modules = set()
  606. # If this flag is true, then the parameters are replicated throughput training
  607. # And only partitioned before the step
  608. param.ds_persist = False
  609. param.is_external_param = False
  610. # The group that the parameter is scattered across.
  611. param.ds_process_group = self.ds_process_group
  612. # This is set to the Async Param swapper if remote device is nvme
  613. # else this is set to None
  614. param.nvme_swapper = self.param_swapper
  615. # DeepSpeed Param ID
  616. param.ds_id = Init.param_id
  617. Init.param_id += 1
  618. def all_gather(param_list=None, async_op=False, hierarchy=0):
  619. cls = param
  620. if param_list is None:
  621. param_list = [cls]
  622. return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
  623. @instrument_w_nvtx
  624. def all_gather_coalesced(params: Iterable[Parameter],
  625. safe_mode: bool = False) -> AllGatherCoalescedHandle:
  626. # fetches from nvme if the partition is not available and in nvme
  627. self._ensure_availability_of_partitioned_params(params)
  628. for param in params:
  629. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  630. raise RuntimeError(param.ds_summary())
  631. param.ds_status = ZeroParamStatus.INFLIGHT
  632. # ensure that each rank has params in same order. the allgather
  633. # is done by flattening the parameter list into a single tensor that
  634. # can be allgathered in a single call - this means that if each rank
  635. # gives a list of the same parameters in a different order we will
  636. # silently get incorrect parameter values, and have very difficult
  637. # to debug correctness issues.
  638. params = sorted(params, key=lambda p: p.ds_id)
  639. debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")
  640. if safe_mode:
  641. # ensure that same list (with same ordering) of parameters are
  642. # being allgathered across all ranks, otherwise could mix
  643. # data between tensors.
  644. assert_ints_same_as_other_ranks([p.ds_id for p in params])
  645. # ensure that tensors from each rank agree on the same ds_numel
  646. # otherwise could mix data between tensors.
  647. assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params])
  648. if len(params) == 1:
  649. # have an opportunity to avoid some intermediate memory allocations
  650. param, = params
  651. param_buffer = torch.empty(
  652. math.ceil(param.ds_numel / self.world_size) * self.world_size,
  653. dtype=param.dtype,
  654. device=get_accelerator().current_device_name(),
  655. requires_grad=False,
  656. )
  657. handle = _dist_allgather_fn(
  658. param.ds_tensor.to(get_accelerator().current_device_name()),
  659. param_buffer,
  660. self.ds_process_group)
  661. param.data = param_buffer.narrow(0,
  662. 0,
  663. param.ds_numel).view(param.ds_shape).to(
  664. param.device)
  665. return AllGatherHandle(handle, param)
  666. else:
  667. partition_sz = sum(p.ds_tensor.ds_numel for p in params)
  668. flat_tensor = torch.empty(partition_sz * self.world_size,
  669. dtype=get_only_unique_item(p.dtype
  670. for p in params),
  671. device=get_accelerator().current_device_name(),
  672. requires_grad=False)
  673. partitions: List[Parameter] = []
  674. for i in range(self.world_size):
  675. partitions.append(
  676. flat_tensor.narrow(0,
  677. partition_sz * i,
  678. partition_sz))
  679. instrument_w_nvtx(torch.cat)([
  680. p.ds_tensor.to(get_accelerator().current_device_name())
  681. for p in params
  682. ],
  683. out=partitions[self.rank])
  684. handle = _dist_allgather_fn(partitions[self.rank],
  685. flat_tensor,
  686. self.ds_process_group)
  687. return AllGatherCoalescedHandle(
  688. allgather_handle=handle,
  689. params=params,
  690. partitions=partitions,
  691. world_size=self.world_size,
  692. )
  693. def partition(param_list=None, hierarchy=0, has_been_updated=False):
  694. cls = param
  695. print_rank_0(
  696. f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}"
  697. )
  698. if param_list is None:
  699. param_list = [cls]
  700. self._partition(param_list, has_been_updated=has_been_updated)
  701. def reduce_gradients_at_owner(param_list=None, hierarchy=0):
  702. cls = param
  703. if param_list is None:
  704. param_list = [cls]
  705. print_rank_0(
  706. f"{'--'*hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner"
  707. )
  708. self._reduce_scatter_gradients(param_list)
  709. def partition_gradients(param_list=None,
  710. partition_buffers=None,
  711. hierarchy=0,
  712. accumulate=False):
  713. cls = param
  714. print_rank_0(
  715. f"{'--'*hierarchy}----Partitioning param gradient with id {debug_param2name_id_shape_device(cls)}"
  716. )
  717. if param_list is None:
  718. param_list = [cls]
  719. if isinstance(partition_buffers, torch.Tensor):
  720. partition_buffers = [partition_buffers]
  721. self._partition_gradients(param_list,
  722. partition_buffers=partition_buffers,
  723. accumulate=accumulate)
  724. def aligned_size():
  725. return self._aligned_size(param)
  726. def padding_size():
  727. return self._padding_size(param)
  728. def partition_numel():
  729. return self._partition_numel(param)
  730. def item_override():
  731. param.all_gather()
  732. return param._orig_item()
  733. def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict:
  734. return {
  735. "id": debug_param2name_id(slf) if use_debug_name else slf.ds_id,
  736. "status": slf.ds_status.name,
  737. "numel": slf.numel(),
  738. "ds_numel": slf.ds_numel,
  739. "shape": tuple(slf.shape),
  740. "ds_shape": tuple(slf.ds_shape),
  741. "requires_grad": slf.requires_grad,
  742. "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None,
  743. "persist": slf.ds_persist,
  744. "active_sub_modules": slf.ds_active_sub_modules,
  745. }
  746. def convert_to_zero_parameters(param_list):
  747. self._convert_to_zero_parameters(param_list)
  748. def allgather_before(func: Callable) -> Callable:
  749. def wrapped(*args, **kwargs):
  750. param.all_gather()
  751. return func(*args, **kwargs)
  752. return wrapped
  753. # Collectives for gathering and partitioning parameters
  754. param.all_gather = all_gather
  755. param.all_gather_coalesced = all_gather_coalesced
  756. param.partition = partition
  757. # Collective for averaging gradients
  758. param.reduce_gradients_at_owner = reduce_gradients_at_owner
  759. param.partition_gradients = partition_gradients
  760. # Partitioning size utilities
  761. param.aligned_size = aligned_size
  762. param.padding_size = padding_size
  763. param.partition_numel = partition_numel
  764. param.ds_summary = types.MethodType(ds_summary, param)
  765. param.item = allgather_before(param.item)
  766. param.convert_to_zero_parameters = convert_to_zero_parameters
  767. def _aligned_size(self, param):
  768. return param.ds_numel + self._padding_size(param)
  769. def _padding_size(self, param):
  770. remainder = param.ds_numel % self.world_size
  771. return (self.world_size - remainder) if remainder else 0
  772. def _partition_numel(self, param):
  773. return param.ds_tensor.ds_numel
  774. def _ensure_availability_of_partitioned_params(self, params):
  775. swap_in_list = []
  776. swap_in_flight = []
  777. for param in params:
  778. if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
  779. assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
  780. swap_in_list.append(param)
  781. if param.ds_tensor.status == PartitionedParamStatus.INFLIGHT:
  782. assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
  783. swap_in_flight.append(param)
  784. if len(swap_in_list) > 0:
  785. swap_in_list[0].nvme_swapper.swap_in(swap_in_list, async_op=False)
  786. elif len(swap_in_flight) > 0:
  787. swap_in_flight[0].nvme_swapper.synchronize_reads()
  788. @instrument_w_nvtx
  789. def _all_gather(self, param_list, async_op=False, hierarchy=None):
  790. # fetches from nvme if the partition is not available and in nvme
  791. self._ensure_availability_of_partitioned_params(param_list)
  792. handles = []
  793. all_gather_list = []
  794. for param in param_list:
  795. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  796. if async_op:
  797. handle = self._allgather_param(param,
  798. async_op=async_op,
  799. hierarchy=hierarchy)
  800. param.ds_status = ZeroParamStatus.INFLIGHT # if async_op else ZeroParamStatus.AVAILABLE
  801. handles.append(handle)
  802. else:
  803. all_gather_list.append(param)
  804. if not async_op:
  805. if len(param_list) == 1:
  806. ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
  807. else:
  808. ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy)
  809. for param in all_gather_list:
  810. param.ds_status = ZeroParamStatus.AVAILABLE
  811. return ret_value
  812. return handles
  813. def _partition(self, param_list, force=False, has_been_updated=False):
  814. for param in param_list:
  815. #print_rank_0(f"Before Partitioning Param {param.ds_id}")
  816. # self._param_status(param)
  817. self._partition_param(param, has_been_updated=has_been_updated)
  818. param.ds_status = ZeroParamStatus.NOT_AVAILABLE
  819. # if param.ds_tensor is not None:
  820. # assert id(param.data) == id(param.ds_tensor.data), \
  821. # "After the parameters are initially partitioned, make sure we are not recreating the partition."
  822. #print_rank_0(f"After Partitioning Param {param.ds_id}")
  823. # self._param_status(param)
  824. @instrument_w_nvtx
  825. def _partition_param(self, param, buffer=None, has_been_updated=False):
  826. assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
  827. global reuse_buffers
  828. #print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}")
  829. if param.ds_status is ZeroParamStatus.AVAILABLE:
  830. print_rank_0(
  831. f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}",
  832. force=False)
  833. # if reuse_buffers and False:
  834. # numel = buffer.numel()
  835. # buffer = param.data.view(-1)
  836. # print_rank_0(
  837. # "Returning buffer for param {param.ds_id} with numel {param.ds_numel} to empty buffers",
  838. # force=False)
  839. # if numel in empty_buffers:
  840. # empty_buffers[numel].append(buffer)
  841. # if deepspeed.comm.get_rank():
  842. # print(f"Releasing {param.data.numel()}")
  843. if param.ds_tensor is not None and not has_been_updated:
  844. #param.data = param.ds_tensor.data
  845. see_memory_usage(
  846. f'Before partitioning param {param.ds_id} {param.shape}',
  847. force=False)
  848. # param.data does not store anything meaningful in partitioned state
  849. free_param(param)
  850. see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
  851. force=False)
  852. if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
  853. print_rank_0(
  854. f"Param {param.ds_id} partition released since it exists in nvme",
  855. force=False)
  856. param.nvme_swapper.remove_partition_and_release_buffers([param])
  857. return
  858. tensor_size = self._aligned_size(param)
  859. partition_size = tensor_size // self.world_size
  860. if param.ds_tensor is None:
  861. final_location = None
  862. if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor(
  863. numel=partition_size):
  864. final_location = OffloadDeviceEnum.nvme
  865. buffer = self.param_swapper.get_buffer(param, partition_size)
  866. partitioned_tensor = torch.empty(0,
  867. dtype=param.dtype,
  868. device=buffer.device)
  869. partitioned_tensor.data = buffer.data
  870. print_rank_0(
  871. f"ID {param.ds_id} Initializing partition for the first time for nvme offload."
  872. )
  873. else:
  874. partitioned_tensor = torch.empty(
  875. partition_size,
  876. dtype=param.dtype,
  877. device=OffloadDeviceEnum.cpu if self.remote_device
  878. == OffloadDeviceEnum.nvme else self.remote_device)
  879. if self.pin_memory:
  880. partitioned_tensor = get_accelerator().pin_memory(
  881. partitioned_tensor)
  882. partitioned_tensor.requires_grad = False
  883. param.ds_tensor = partitioned_tensor
  884. param.ds_tensor.ds_numel = partition_size
  885. param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
  886. param.ds_tensor.final_location = final_location
  887. start = partition_size * self.rank
  888. end = start + partition_size
  889. one_dim_param = param.contiguous().view(-1)
  890. if start < param.ds_numel and end <= param.ds_numel:
  891. src_tensor = one_dim_param.narrow(0, start, partition_size)
  892. param.ds_tensor.copy_(src_tensor)
  893. #partitioned_tensor = src_tensor.clone().detach().to(self.remote_device)
  894. else:
  895. # partitioned_tensor = torch.zeros(partition_size,
  896. # dtype=param.dtype,
  897. # device=self.remote_device )
  898. if start < param.ds_numel:
  899. elements_to_copy = param.ds_numel - start
  900. param.ds_tensor.narrow(0,
  901. 0,
  902. elements_to_copy).copy_(
  903. one_dim_param.narrow(
  904. 0,
  905. start,
  906. elements_to_copy))
  907. #print(f"Remote device {self.remote_device}")
  908. #param.ds_tensor = partitioned_tensor
  909. #param.data = param.ds_tensor.data
  910. # param.data does not store anything meaningful in partitioned state
  911. see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}',
  912. force=False)
  913. free_param(param)
  914. see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
  915. force=False)
  916. if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
  917. self.param_swapper.swap_out_and_release([param])
  918. print_rank_0(
  919. f"ID {param.ds_id} Offloaded to nvme offload and buffers released.")
  920. see_memory_usage(
  921. f"ID {param.ds_id} Offloaded to nvme offload and buffers released.",
  922. force=False)
  923. print_rank_0(
  924. f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}"
  925. )
  926. def _param_status(self, param):
  927. if param.ds_tensor is not None:
  928. print_rank_0(
  929. f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned numel {param.ds_tensor.numel()}, data numel {param.data.numel()}"
  930. )
  931. else:
  932. print_rank_0(
  933. f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned ds_tensor {param.ds_tensor}, data numel {param.data.numel()}"
  934. )
  935. def _allgather_param(self, param, async_op=False, hierarchy=0):
  936. partition_size = param.ds_tensor.ds_numel
  937. tensor_size = partition_size * self.world_size
  938. aligned_param_size = self._aligned_size(param)
  939. assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}'
  940. print_rank_0(
  941. f"{'--'* hierarchy}---- Before allocating allgather param {debug_param2name_id_shape_status(param)} partition size={partition_size}"
  942. )
  943. see_memory_usage(
  944. f'Before allocate allgather param {debug_param2name_id_shape_status(param)} partition_size={partition_size} ',
  945. force=False)
  946. flat_tensor = torch.zeros(aligned_param_size,
  947. dtype=param.dtype,
  948. device=param.device).view(-1)
  949. see_memory_usage(
  950. f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
  951. force=False)
  952. get_accelerator().synchronize()
  953. print_rank_0(
  954. f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
  955. )
  956. # if not flat_tensor.numel() > 100000:
  957. # replicated_tensor = flat_tensor.narrow(0,
  958. # 0,
  959. # param.ds_numel).view(param.ds_shape)
  960. # param.data = replicated_tensor.data
  961. # return None
  962. if self.use_all_gather_base:
  963. # try the _all_gather_base on PyTorch master branch
  964. handle = dist.all_gather_base(flat_tensor,
  965. param.ds_tensor.to(
  966. get_accelerator().device_name()),
  967. group=self.ds_process_group,
  968. async_op=async_op)
  969. else:
  970. partitions = []
  971. for i in range(self.world_size):
  972. partitions.append(
  973. flat_tensor.narrow(0,
  974. partition_size * i,
  975. partition_size))
  976. if i == dist.get_rank(group=self.ds_process_group):
  977. partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
  978. handle = dist.all_gather(partitions,
  979. partitions[self.rank],
  980. group=self.ds_process_group,
  981. async_op=async_op)
  982. replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
  983. param.data = replicated_tensor.data
  984. return handle
  985. def _allgather_params_coalesced(self, param_list, hierarchy=0):
  986. """ blocking call
  987. avoid explicit memory copy in _allgather_params
  988. """
  989. if len(param_list) == 0:
  990. return
  991. # collect local tensors and partition sizes
  992. partition_sizes = []
  993. local_tensors = []
  994. for param in param_list:
  995. partition_sizes.append(param.ds_tensor.ds_numel)
  996. local_tensors.append(param.ds_tensor.to(get_accelerator().device_name()))
  997. # allocate memory for allgather params
  998. allgather_params = []
  999. for psize in partition_sizes:
  1000. tensor_size = psize * self.world_size
  1001. flat_tensor = torch.empty(tensor_size,
  1002. dtype=param_list[0].dtype,
  1003. device=self.local_device).view(-1)
  1004. flat_tensor.requires_grad = False
  1005. allgather_params.append(flat_tensor)
  1006. # launch
  1007. launch_handles = []
  1008. # backend = get_backend(self.ds_process_group)
  1009. # with _batch_p2p_manager(backend):
  1010. for param_idx, param in enumerate(param_list):
  1011. input_tensor = local_tensors[param_idx].view(-1)
  1012. if self.use_all_gather_base:
  1013. # try the _all_gather_base from Pytorch master
  1014. h = dist.all_gather_base(allgather_params[param_idx],
  1015. input_tensor,
  1016. group=self.ds_process_group,
  1017. async_op=True)
  1018. else:
  1019. output_list = []
  1020. for i in range(self.world_size):
  1021. psize = partition_sizes[param_idx]
  1022. partition = allgather_params[param_idx].narrow(0, i * psize, psize)
  1023. output_list.append(partition)
  1024. if not get_accelerator().on_accelerator(partition):
  1025. logger.warning(
  1026. f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}'
  1027. )
  1028. # back to old all_gather function signature
  1029. h = dist.all_gather(output_list,
  1030. input_tensor,
  1031. group=self.ds_process_group,
  1032. async_op=True)
  1033. launch_handles.append(h)
  1034. # Wait ensures the operation is enqueued, but not necessarily complete.
  1035. launch_handles[-1].wait()
  1036. # assign to param.data (not copy)
  1037. for i, param in enumerate(param_list):
  1038. gathered_tensor = allgather_params[i]
  1039. param.data = gathered_tensor.narrow(0,
  1040. 0,
  1041. param.ds_numel).view(param.ds_shape).data
  1042. # guarantee the communication to be completed
  1043. get_accelerator().synchronize()
  1044. return None
  1045. def _allgather_params(self, param_list, hierarchy=0):
  1046. if len(param_list) == 0:
  1047. return
  1048. partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
  1049. tensor_size = partition_size * self.world_size
  1050. flat_tensor = torch.empty(tensor_size,
  1051. dtype=param_list[0].dtype,
  1052. device=self.local_device)
  1053. flat_tensor.requires_grad = False
  1054. partitions = []
  1055. for i in range(self.world_size):
  1056. start = partition_size * i
  1057. partitions.append(flat_tensor.narrow(0, start, partition_size))
  1058. if i == self.rank:
  1059. offset = 0
  1060. for param in param_list:
  1061. param_numel = param.ds_tensor.ds_numel
  1062. partitions[i].narrow(0,
  1063. offset,
  1064. param_numel).copy_(param.ds_tensor.data)
  1065. offset += param_numel
  1066. dist.all_gather(partitions,
  1067. partitions[self.rank],
  1068. group=self.ds_process_group,
  1069. async_op=False)
  1070. param_offset = 0
  1071. for param in param_list:
  1072. param_partition_size = param.ds_tensor.ds_numel
  1073. param_size = param.ds_numel
  1074. replicated_tensor = torch.empty(param.ds_shape,
  1075. dtype=param.dtype,
  1076. device=self.local_device)
  1077. for i in range(self.world_size):
  1078. start = i * partition_size
  1079. param_start = i * param_partition_size
  1080. if param_start < param_size:
  1081. numel_to_copy = min(param_size - param_start, param_partition_size)
  1082. part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
  1083. replicated_tensor.view(-1).narrow(0,
  1084. param_start,
  1085. numel_to_copy).copy_(part_to_copy)
  1086. #param_offset += param.data.numel()
  1087. param_offset += param.ds_tensor.ds_numel
  1088. param.data = replicated_tensor.data
  1089. return None
  1090. def _reduce_scatter_gradients(self, param_list):
  1091. #print_rank_0([param.grad for param in param_list])
  1092. #assert any([param.grad is None for param in param_list]), "None gradients cannot be reduce scattered"
  1093. handles_and_reduced_partitions = []
  1094. for param in param_list:
  1095. assert param.grad.numel(
  1096. ) == param.ds_numel, f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter gradients whose size is not same as the params"
  1097. handles_and_reduced_partitions.append(self._reduce_scatter_gradient(param))
  1098. for param, (handle, reduced_partition) in zip(param_list, handles_and_reduced_partitions):
  1099. if handle is not None:
  1100. handle.wait()
  1101. # some ranks may have partitions that are padded to go beyond the grad size.
  1102. # For these ranks the output of reduce scatter is a separate buffer and needs
  1103. # to be copied in
  1104. partition_size = param.ds_tensor.ds_numel
  1105. start = self.rank * partition_size
  1106. end = start + partition_size
  1107. #print_rank_0("REduce scatter was executed for praam {param.ds_id}")
  1108. if start < param.ds_numel and end > param.ds_numel:
  1109. elements = param.ds_numel - start
  1110. param.grad.view(-1).narrow(0,
  1111. start,
  1112. elements).copy_(
  1113. reduced_partition.narrow(0,
  1114. 0,
  1115. elements))
  1116. def _reduce_scatter_gradient(self, param):
  1117. partition_size = param.ds_tensor.ds_numel
  1118. #output = torch.empty(partition_size, dtype=param.dtype, device=param.device)
  1119. total_size = partition_size * self.world_size
  1120. input_list = []
  1121. for i in range(self.world_size):
  1122. start = i * partition_size
  1123. end = start + partition_size
  1124. #print("before reduce scatter gradients")
  1125. if start < param.ds_numel and end <= param.ds_numel:
  1126. input = param.grad.view(-1).narrow(0, start, partition_size)
  1127. else:
  1128. input = torch.zeros(partition_size,
  1129. dtype=param.dtype,
  1130. device=param.device)
  1131. if start < param.ds_numel:
  1132. elements = param.ds_numel - start
  1133. input.narrow(0,
  1134. 0,
  1135. elements).copy_(
  1136. param.grad.view(-1).narrow(0,
  1137. start,
  1138. elements))
  1139. #print("after reduce scatter gradients")
  1140. input_list.append(input)
  1141. rank = dist.get_rank(group=self.ds_process_group)
  1142. handle = dist.reduce_scatter(input_list[rank],
  1143. input_list,
  1144. group=self.ds_process_group,
  1145. async_op=True)
  1146. return handle, input_list[rank]
  1147. def _partition_gradients(self, param_list, partition_buffers=None, accumulate=False):
  1148. if partition_buffers is None:
  1149. partition_buffers = [None] * len(param_list)
  1150. for param, partition_buffer in zip(param_list, partition_buffers):
  1151. self._partition_gradient(param,
  1152. partition_buffer=partition_buffer,
  1153. accumulate=accumulate)
  1154. def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
  1155. #import pdb;pdb.set_trace()
  1156. # param.grad=None
  1157. # param.grad.test()
  1158. print_rank_0(
  1159. f"Partitioning param {param.ds_id} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.ds_numel}"
  1160. )
  1161. see_memory_usage("Before partitioning gradients", force=False)
  1162. partition_size = param.ds_tensor.ds_numel
  1163. if partition_buffer is None:
  1164. assert not accumulate, "No buffer to accumulate to"
  1165. partition_buffer = torch.zeros(partition_size,
  1166. dtype=param.dtype,
  1167. device=param.device)
  1168. else:
  1169. assert partition_buffer.numel(
  1170. ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
  1171. rank = dist.get_rank(group=self.ds_process_group)
  1172. start = partition_size * rank
  1173. end = start + partition_size
  1174. dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size)
  1175. #print("before partition gradients")
  1176. if start < param.ds_numel:
  1177. elements = min(param.ds_numel - start, partition_size)
  1178. dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
  1179. src_tensor = param.grad.view(-1).narrow(0, start, elements)
  1180. # just copy the grad partition to the buffer
  1181. if not accumulate:
  1182. dest_tensor.copy_(src_tensor)
  1183. # if source and destination are on same device,
  1184. # add to the provided buffer
  1185. elif src_tensor.device == dest_tensor.device:
  1186. dest_tensor.add_(src_tensor)
  1187. # if source and destination are on different device, copy first to src
  1188. # then add and move back to the destination. This seems to run faster
  1189. # when src is gpu and dest is cpu
  1190. # adding directly to cpu is very slow
  1191. else:
  1192. acc_tensor = torch.empty(src_tensor.numel(),
  1193. dtype=param.dtype,
  1194. device=param.device)
  1195. acc_tensor.copy_(dest_tensor)
  1196. acc_tensor.add_(src_tensor)
  1197. dest_tensor.copy_(acc_tensor)
  1198. # partition_buffer.view(-1).narrow(
  1199. # 0,
  1200. # 0,
  1201. # elements).copy_(param.grad.view(-1).narrow(0,
  1202. # start,
  1203. # elements))
  1204. #print("after partition gradients")
  1205. param.grad.data = dest_tensor_full_buffer.data
  1206. see_memory_usage("After partitioning gradients", force=False)
  1207. class GatheredParameters:
  1208. def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True):
  1209. """A context that collects parameters that were partitioned via a
  1210. :class:`deepspeed.zero.Init` context. The parameters are partitioned
  1211. again upon exit.
  1212. Args:
  1213. params (``torch.nn.Parameter``): A single parameter, or an iterable of parameters (list, tuple, generator) of parameters to collect.
  1214. It's assumed that all parameters are zero params.
  1215. modifier_rank (int, optional): If specified, this rank's parameter will be
  1216. broadcasted on exit from the context. This argument is required if ``params`` are
  1217. modified, so that all processes have a consistent view of the data. Defaults
  1218. to ``None``.
  1219. fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be
  1220. registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
  1221. enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.
  1222. Important: Make sure to use ``modifier_rank`` that is not ``None`` (e.g., ``modifier_rank=0``)
  1223. if you need the GPU memory allocated by gather to be released upon exit from the context manager.
  1224. Important: if ``params`` isn't an iterable of parameters or a single parameter it'll be silently ignored!
  1225. Examples
  1226. ========
  1227. #. Allocate a partitioned module, initialize its weight on rank 0, and update all
  1228. processes.
  1229. .. code-block:: python
  1230. with deepspeed.zero.Init():
  1231. linear = torch.nn.Linear(1000,1000)
  1232. with deepspeed.zero.GatheredParameters(linear.weight,
  1233. modifier_rank=0):
  1234. if deepspeed.comm.get_rank() == 0:
  1235. linear.weight.zero_()
  1236. with deepspeed.zero.GatheredParameters(linear.weight,
  1237. modifier_rank=0):
  1238. if deepspeed.comm.get_rank() == 0:
  1239. linear.weight.zero_()
  1240. #. Collect a partitioned weight to pass to another module during
  1241. training. The parameter will be registered as an external parameter
  1242. and made available during the backward pass.
  1243. .. code-block:: python
  1244. :emphasize-lines: 6
  1245. def forward(self, input):
  1246. x = self.layer1(input)
  1247. # self.layer1.weight is required by self.layer2.forward
  1248. with deepspeed.zero.GatheredParameters(self.layer1.weight,
  1249. fwd_module=self):
  1250. y = self.layer2(x, self.layer1.weight)
  1251. return y
  1252. #. Pretrained model loading
  1253. .. code-block:: python
  1254. with deepspeed.zero.Init():
  1255. model = MyModel()
  1256. state_dict = torch.load(model_path, map_location="cpu")
  1257. def load(module: nn.Module, prefix=""):
  1258. # because zero3 puts placeholders in model params, this context
  1259. # manager gathers (unpartitions) the params of the current layer, then loads from
  1260. # the state dict and then re-partitions them again
  1261. with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
  1262. if deepspeed.comm.get_rank() == 0:
  1263. module._load_from_state_dict(state_dict, prefix)
  1264. for name, child in module._modules.items():
  1265. if child is not None:
  1266. load(child, prefix + name + ".")
  1267. load(model, prefix="")
  1268. If this approach is not used, then the full model will first be copied to each GPU. For models
  1269. bigger than the memory of a single GPU, this method is required.
  1270. """
  1271. self.enabled = enabled
  1272. if not enabled:
  1273. return
  1274. if isinstance(params, Iterable) and not isinstance(params, torch.Tensor):
  1275. # deal with generators like model.parameters()
  1276. # must convert to list to be able to iterate more than once if we get a generator
  1277. params = list(params)
  1278. else:
  1279. # single param
  1280. params = [params]
  1281. # enable if at least one is zero-param, otherwise a noop
  1282. if not any(is_zero_param(p) for p in params):
  1283. self.enabled = False
  1284. return
  1285. self.params = [p for p in params if hasattr(p, "ds_id")]
  1286. self.src_rank = None
  1287. if modifier_rank is not None:
  1288. if self.params[0].ds_process_group == dist.get_world_group():
  1289. self.src_rank = modifier_rank
  1290. else:
  1291. # A group was specified; convert DP rank to global rank
  1292. self.src_rank = dist.get_global_rank(self.params[0].ds_process_group,
  1293. modifier_rank)
  1294. self.fwd_module = fwd_module
  1295. if self.fwd_module is not None:
  1296. # is a no-op if already registered
  1297. for p in self.params:
  1298. register_external_parameter(self.fwd_module, p)
  1299. def __enter__(self):
  1300. if not self.enabled:
  1301. return
  1302. self.params[0].all_gather(param_list=self.params)
  1303. def __exit__(self, *exc):
  1304. if not self.enabled:
  1305. return
  1306. if self.src_rank is None:
  1307. self.params[0].partition(param_list=self.params, has_been_updated=False)
  1308. return
  1309. handles = [
  1310. dist.broadcast(p,
  1311. self.src_rank,
  1312. group=p.ds_process_group,
  1313. async_op=True) for p in self.params
  1314. ]
  1315. for h in handles:
  1316. h.wait()
  1317. self.params[0].partition(param_list=self.params, has_been_updated=True)