partition_parameters.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327
  1. """
  2. "Copyright 2020 The Microsoft DeepSpeed Team.
  3. Licensed under the MIT license.
  4. """
  5. import os
  6. import time
  7. import types
  8. from enum import Enum
  9. import functools
  10. import itertools
  11. import torch
  12. from torch.distributed.distributed_c10d import _get_global_rank, group
  13. import torch.distributed as dist
  14. from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3
  15. from .offload_constants import *
  16. from ..utils import see_memory_usage
  17. from deepspeed.utils import log_dist, init_distributed, logger
  18. from deepspeed.utils.debug import debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name, debug_param2name, debug_param2name_id_shape_status, printflock, log_rank_file
  19. from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
  20. from ..config import DeepSpeedConfig
  21. param_count = 0
  22. partitioned_param_data_shape = [1]
  23. def print_rank_0(message, debug=False, force=False):
  24. rank = torch.distributed.get_rank()
  25. if rank == 0 and (debug or force):
  26. print(message)
  27. # other variations
  28. # - print for all ranks w/o interleaving
  29. # printflock(f"[{rank}] {message}")
  30. # - print to log file per rank
  31. # log_rank_file(rank, message)
  32. def is_zero_param(parameter):
  33. if not torch.is_tensor(parameter):
  34. return False
  35. return hasattr(parameter, 'ds_id')
  36. def _init_external_params(module):
  37. if not hasattr(module, '_external_params'):
  38. module._external_params = {}
  39. def external_parameters(self):
  40. return self._external_params.items()
  41. def all_parameters(self):
  42. return itertools.chain(self.named_parameters(self,
  43. recurse=False),
  44. external_parameters(self))
  45. module.ds_external_parameters = types.MethodType(external_parameters, module)
  46. module.all_parameters = types.MethodType(all_parameters, module)
  47. def register_external_parameter(module, parameter):
  48. """Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in
  49. the forward and backward passes of ``module``.
  50. This is used when a parameter is accessed outside of its owning module's
  51. ``forward()``. DeepSpeed must know to collect it from its partitioned
  52. state and when to release the memory.
  53. .. note::
  54. This is only applicable to training with ZeRO stage 3.
  55. Args:
  56. module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass.
  57. parameter (``torch.nn.Parameter``): The parameter to register.
  58. Raises:
  59. RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
  60. Examples
  61. ========
  62. #. Register a weight that is used in another module's forward pass (line 6).
  63. Parameter ``layer1.weight`` is used by ``layer2`` (line 11).
  64. .. code-block:: python
  65. :linenos:
  66. :emphasize-lines: 6,11
  67. class ModuleZ3(torch.nn.Module):
  68. def __init__(self, *args):
  69. super().__init__(self, *args)
  70. self.layer1 = SomeLayer()
  71. self.layer2 = OtherLayer()
  72. deepspeed.zero.register_external_parameter(self, self.layer1.weight)
  73. def forward(self, input):
  74. x = self.layer1(input)
  75. # self.layer1.weight is required by self.layer2.forward
  76. y = self.layer2(x, self.layer1.weight)
  77. return y
  78. """
  79. if not isinstance(parameter, torch.nn.Parameter):
  80. raise RuntimeError('Parameter is not a torch.nn.Parameter')
  81. if not hasattr(module, '_external_params'):
  82. _init_external_params(module)
  83. key = id(parameter)
  84. module._external_params[key] = parameter
  85. def unregister_external_parameter(module, parameter):
  86. """Reverses the effects of :meth:`register_external_parameter`.
  87. Args:
  88. module (``torch.nn.Module``): The module to affect.
  89. parameter (``torch.nn.Parameter``): The parameter to unregister.
  90. Raises:
  91. RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
  92. RuntimeError: If ``parameter`` is not a registered external parameter of ``module``.
  93. """
  94. if not isinstance(parameter, torch.nn.Parameter):
  95. raise RuntimeError('Parameter is not a torch.nn.Parameter')
  96. if not hasattr(module,
  97. '_external_params') or id(parameter) not in module._external_params:
  98. raise RuntimeError('Parameter is not a registered external parameter of module.')
  99. key = id(parameter)
  100. del module._external_params[key]
  101. class ZeroParamType(Enum):
  102. # same as regular pytorch parameters
  103. NORMAL = 1
  104. # parameters are partitioned across data parallel process
  105. PARTITIONED = 2
  106. # the parameter is held with a unique process rank
  107. # and is not available on all other process
  108. REMOTE = 3
  109. class ZeroParamStatus(Enum):
  110. # parameters are fully present and ready for use on all processes
  111. AVAILABLE = 1
  112. # parameters are either partitioned or remote in some or all process
  113. NOT_AVAILABLE = 2
  114. # parameters are being gathered.
  115. INFLIGHT = 3
  116. _orig_torch_empty = torch.empty
  117. def empty_cuda_tensor_half(*size, **kwargs):
  118. if not 'device' in kwargs.keys():
  119. kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
  120. tensor = _orig_torch_empty(*size, **kwargs)
  121. if tensor.is_floating_point():
  122. return tensor.half()
  123. else:
  124. return tensor
  125. def new_cuda_tensor_half(cls, *args):
  126. device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
  127. tensor = torch.ones((1, 1), device=device).new_empty(*args).half()
  128. if tensor.is_floating_point():
  129. return tensor.half()
  130. else:
  131. return tensor
  132. def empty_cuda_tensor(*size, **kwargs):
  133. if not 'device' in kwargs.keys():
  134. kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
  135. tensor = _orig_torch_empty(*size, **kwargs)
  136. return tensor
  137. def new_cuda_tensor(cls, *args):
  138. device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
  139. tensor = torch.ones((1, 1), device=device).new_empty(*args)
  140. return tensor
  141. # https://stackoverflow.com/a/63851681/9201239
  142. def get_all_subclasses(cls):
  143. subclass_list = []
  144. def recurse(cl):
  145. for subclass in cl.__subclasses__():
  146. subclass_list.append(subclass)
  147. recurse(subclass)
  148. recurse(cls)
  149. return set(subclass_list)
  150. reuse_buffers = False
  151. temp_contiguous_tensor = None
  152. empty_buffers = {}
  153. # Inserts _post_init_method at the end of init method
  154. # for all sub classes of torch.nn.Module
  155. class InsertPostInitMethodToModuleSubClasses(object):
  156. def __init__(self,
  157. enabled=True,
  158. mem_efficient_linear=True,
  159. ds_config=None,
  160. dtype=None):
  161. self.mem_efficient_linear = mem_efficient_linear
  162. self.enabled = enabled
  163. self._set_dtype(ds_config, dtype)
  164. assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]"
  165. def __enter__(self):
  166. if not self.enabled:
  167. return
  168. def partition_after(f):
  169. @functools.wraps(f)
  170. def wrapper(module, *args, **kwargs):
  171. # important logic: We want to run post_init only after child's __init__ is
  172. # completed, and do nothing after __init__ of any of its parents and grandparents in
  173. # the inheritance ancestry. This way the partitioning will need to happen only once
  174. # when the whole object is ready to be partitioned and not before. This is because
  175. # often the child module will need to tweak the weights - for example running a
  176. # custom weights init function. So if a parent created the weights param, the child
  177. # won't need to gather it in order to tweak it
  178. print_rank_0(f'Before initializing {module.__class__.__name__}',
  179. force=False)
  180. is_child_module = False
  181. if not hasattr(module, "_ds_child_entered"):
  182. # child's __init__ was called, since parents all see the same object they can now skip post_init
  183. is_child_module = True
  184. setattr(module, "_ds_child_entered", True)
  185. f(module, *args, **kwargs)
  186. if is_child_module:
  187. # child's __init__ is done, now we can run a single post_init on the child object
  188. delattr(module, "_ds_child_entered")
  189. print_rank_0(f'Running post_init for {module.__class__.__name__}',
  190. force=False)
  191. self._post_init_method(module)
  192. print_rank_0(
  193. f'After initializing followed by post init for {module.__class__.__name__}',
  194. force=False)
  195. return wrapper
  196. def _enable_class(cls):
  197. cls._old_init = cls.__init__
  198. cls.__init__ = partition_after(cls.__init__)
  199. def _init_subclass(cls, **kwargs):
  200. cls.__init__ = partition_after(cls.__init__)
  201. # Replace .__init__() for all existing subclasses of torch.nn.Module recursively
  202. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  203. # print(f"subclass={subclass.__module__}.{subclass.__qualname__}")
  204. _enable_class(subclass)
  205. # holding on to the current __init__subclass__ for exit
  206. torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
  207. torch.Tensor.__old_new__ = torch.Tensor.__new__
  208. # Replace .__init__() for future subclasses of torch.nn.Module
  209. torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
  210. if self.dtype == torch.half:
  211. torch.Tensor.__new__ = new_cuda_tensor_half
  212. torch.empty = empty_cuda_tensor_half
  213. else:
  214. torch.Tensor.__new__ = new_cuda_tensor
  215. torch.empty = empty_cuda_tensor
  216. if self.mem_efficient_linear:
  217. print_rank_0(
  218. "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
  219. force=False)
  220. self.linear_bk = torch.nn.functional.linear
  221. torch.nn.functional.linear = LinearFunctionForZeroStage3.apply
  222. def __exit__(self, exc_type, exc_value, traceback):
  223. if not self.enabled:
  224. return
  225. def _disable_class(cls):
  226. cls.__init__ = cls._old_init
  227. # Replace .__init__() for all existing subclasses of torch.nn.Module
  228. for subclass in get_all_subclasses(torch.nn.modules.module.Module):
  229. _disable_class(subclass)
  230. # Replace .__init__() for future subclasses of torch.nn.Module
  231. torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
  232. torch.Tensor.__new__ = torch.Tensor.__old_new__
  233. torch.empty = _orig_torch_empty
  234. # un doing it here will undo it during training
  235. # if self.mem_efficient_linear:
  236. # torch.nn.functional.linear = self.linear_bk
  237. # if self.mem_efficient_linear:
  238. # torch.nn.functional.linear = self.linear_bk
  239. # Now that we cleaned up the metaclass injection, raise the exception.
  240. if exc_type is not None:
  241. return False
  242. # To be implemented by inheriting classes
  243. def _post_init_method(self, module):
  244. pass
  245. def _set_dtype(self, ds_config, dtype):
  246. if ds_config is not None and dtype is None:
  247. self.dtype = torch.half if ds_config.fp16_enabled else torch.float
  248. elif dtype is None:
  249. self.dtype = torch.half
  250. else:
  251. self.dtype = dtype
  252. # Replaces all parameters in module with Scattered Parameters
  253. class Init(InsertPostInitMethodToModuleSubClasses):
  254. param_id = 0
  255. def __init__(self,
  256. module=None,
  257. data_parallel_group=None,
  258. mem_efficient_linear=True,
  259. remote_device=None,
  260. pin_memory=False,
  261. config_dict_or_path=None,
  262. config=None,
  263. enabled=True,
  264. dtype=None,
  265. mpu=None):
  266. """A context to enable massive model construction for training with
  267. ZeRO-3. Models are automatically partitioned (or, sharded) across the
  268. system and converted to half precision.
  269. Args:
  270. module (``torch.nn.Module``, optional): If provided, partition the model as
  271. if it was constructed in the context.
  272. data_parallel_group (``torch.distributed`` process group, optional):
  273. The group of processes to partition among. Defaults to all processes.
  274. mem_efficient_linear (bool, optional): Replace
  275. torch.nn.functional.linear with an implementation that allows
  276. DeepSpeed to partition parameters. Defaults to ``True``.
  277. remote_device (string, optional): The initial device to store model
  278. weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
  279. memory. The model may still be moved to GPU based on the
  280. offload settings for training. Defaults to param offload device if a config is
  281. defined, otherwise GPU.
  282. pin_memory (bool, optional): Potentially increase performance by
  283. using pinned memory for model weights. ``remote_device`` must be
  284. ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``.
  285. config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
  286. for swapping fp16 params to NVMe.
  287. config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
  288. enabled (bool, optional): If ``False``, this context has no
  289. effect. Defaults to ``True``.
  290. dtype (``dtype``, optional): Can be used to change the data type of the parameters.
  291. Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
  292. mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
  293. This context accelerates model initialization and enables models that
  294. are too large to allocate in their entirety in CPU memory. It has the
  295. following effects:
  296. #. allocates tensors to either GPU or CPU memory or NVMe
  297. #. converts floating point tensors to half precision
  298. #. immediately partitions tensors among the group of data-parallel devices
  299. #. (*optional*) replaces ``torch.nn.functional.linear`` with a more
  300. memory-efficient implementation
  301. These modifications allow for models that exceed the size of local CPU/GPU
  302. memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
  303. or GPU memory or NVMe) across all nodes. Consider initializing a model with one
  304. trillion parameters, whose weights occupy two terabytes (TB) in half
  305. precision. The initial CPU allocation in full precision requires 4TB of
  306. memory *per process*, and so a system with 8 GPUs per node would need 32TB of
  307. CPU memory due to data-parallel redundancies. Instead, by immediately
  308. partitioning tensors we remove the redundancies. The result is that
  309. regardless of the number of GPUs, we still only require the original 4TB. This
  310. allows for a linear increase in model size with the aggregate system memory.
  311. For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
  312. parameter model with 4 nodes and 32 GPUs.
  313. Important: If the fp16 weights of the model can't fit onto a single GPU memory
  314. this feature must be used.
  315. .. note::
  316. Initializes ``torch.distributed`` if it has not already been done so.
  317. See :meth:`deepseed.init_distributed` for more information.
  318. .. note::
  319. Can also be used as a decorator:
  320. .. code-block:: python
  321. @deepspeed.zero.Init()
  322. def get_model():
  323. return MyLargeModel()
  324. .. note::
  325. Only applicable to training with ZeRO-3.
  326. Examples
  327. --------
  328. #. Allocate a model and partition it among all processes:
  329. .. code-block:: python
  330. with deepspeed.zero.Init():
  331. model = MyLargeModel()
  332. #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:
  333. .. code-block:: python
  334. with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
  335. remote_device="cpu",
  336. pin_memory=True):
  337. model = MyLargeModel()
  338. #. Partition an already-allocated model in CPU memory:
  339. .. code-block:: python
  340. model = deepspeed.zero.Init(module=model)
  341. """
  342. if config is not None:
  343. config_dict_or_path = config
  344. logger.warning(
  345. f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.'
  346. )
  347. _ds_config = DeepSpeedConfig(config_dict_or_path,
  348. mpu) if config_dict_or_path is not None else None
  349. super().__init__(enabled=enabled,
  350. mem_efficient_linear=mem_efficient_linear,
  351. ds_config=_ds_config,
  352. dtype=dtype)
  353. if not torch.distributed.is_initialized():
  354. init_distributed()
  355. assert torch.distributed.is_initialized(), "Parameters cannot be scattered without initializing torch.distributed"
  356. if data_parallel_group is None:
  357. self.ds_process_group = torch.distributed.group.WORLD
  358. else:
  359. self.ds_process_group = data_parallel_group
  360. self.rank = torch.distributed.get_rank(group=self.ds_process_group)
  361. self.world_size = torch.distributed.get_world_size(group=self.ds_process_group)
  362. # Local device is the device where the parameters are consumed, must be default device.
  363. # It is the device where parameters are fully instantiated using allgather
  364. self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
  365. torch.cuda.set_device(self.local_device)
  366. if _ds_config is not None and _ds_config.zero_config.offload_param is not None:
  367. remote_device = _ds_config.zero_config.offload_param[OFFLOAD_PARAM_DEVICE]
  368. pin_memory = _ds_config.zero_config.offload_param[OFFLOAD_PARAM_PIN_MEMORY]
  369. self._validate_remote_device(remote_device, _ds_config)
  370. # Remote device is the device where parameter partiitons are stored
  371. # It can be same as local_device or it could be CPU or NVMe.
  372. self.remote_device = self.local_device if remote_device is None else remote_device
  373. self.pin_memory = pin_memory if (self.remote_device
  374. == OFFLOAD_CPU_DEVICE) else False
  375. # Enable fp16 param swapping to NVMe
  376. if self.remote_device == OFFLOAD_NVME_DEVICE:
  377. self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config)
  378. else:
  379. self.param_swapper = None
  380. # If we are provided an already-allocated module to prepare.
  381. if module is not None:
  382. assert isinstance(module, torch.nn.Module)
  383. self._convert_to_zero_parameters(module.parameters(recurse=True))
  384. self.use_all_gather_base = False
  385. try:
  386. from torch.distributed.distributed_c10d import _all_gather_base as all_gather
  387. self.use_all_gather_base = True
  388. except:
  389. logger.info(
  390. f"_all_gather_base API is not available in torch {torch.__version__}")
  391. def _convert_to_zero_parameters(self, param_list):
  392. for param in param_list:
  393. if is_zero_param(param):
  394. continue
  395. self._convert_to_deepspeed_param(param)
  396. param.partition()
  397. def _validate_remote_device(self, remote_device, ds_config):
  398. if ds_config is not None:
  399. if remote_device in [None, OFFLOAD_CPU_DEVICE]:
  400. if ds_config.zero_config.offload_param is not None:
  401. offload_param_device = ds_config.zero_config.offload_param[
  402. OFFLOAD_PARAM_DEVICE]
  403. assert offload_param_device != OFFLOAD_NVME_DEVICE, \
  404. f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."
  405. if remote_device == OFFLOAD_NVME_DEVICE:
  406. assert ds_config.zero_config.offload_param is not None, \
  407. f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.'
  408. assert ds_config.zero_config.offload_param[OFFLOAD_PARAM_NVME_PATH] is not None, \
  409. f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}'
  410. def _post_init_method(self, module):
  411. #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
  412. print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
  413. see_memory_usage(
  414. f"Before converting and partitioning parmas in {module.__class__.__name__}",
  415. force=False)
  416. global param_count
  417. for param in module.parameters(recurse=False):
  418. param_count += param.numel()
  419. if not is_zero_param(param):
  420. self._convert_to_deepspeed_param(param)
  421. print_rank_0(
  422. f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}"
  423. )
  424. param.partition()
  425. see_memory_usage(
  426. f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}",
  427. force=False)
  428. def _convert_to_deepspeed_param(self, param):
  429. # Partitioned, Normal, Remote
  430. param.ds_param_type = ZeroParamType.PARTITIONED
  431. # Replicated vs Partitioned vs Inflight
  432. param.ds_status = ZeroParamStatus.AVAILABLE
  433. # Stores the shape of the original tensor
  434. param.ds_shape = param.shape
  435. # Stores the number of elements in the original parameter without padding
  436. param.ds_numel = param.numel()
  437. # Stores the partitioned copy of the tensor
  438. param.ds_tensor = None
  439. # Keeps track of how many active sub-modules need this param at any given point in time
  440. param.ds_active_sub_modules = 0
  441. # If this flag is true, then the parameters are replicated throughput training
  442. # And only partitioned before the step
  443. param.ds_persist = False
  444. # The group that the parameter is scattered across.
  445. param.ds_process_group = self.ds_process_group
  446. # This is set to the Async Param swapper if remote device is nvme
  447. # else this is set to None
  448. param.nvme_swapper = self.param_swapper
  449. # DeepSped Param ID
  450. param.ds_id = Init.param_id
  451. Init.param_id += 1
  452. def all_gather(param_list=None, async_op=False, hierarchy=0):
  453. cls = param
  454. if param_list is None:
  455. param_list = [cls]
  456. return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
  457. def partition(param_list=None, hierarchy=0, has_been_updated=False):
  458. cls = param
  459. print_rank_0(
  460. f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}"
  461. )
  462. if param_list is None:
  463. param_list = [cls]
  464. self._partition(param_list, has_been_updated=has_been_updated)
  465. def reduce_gradients_at_owner(param_list=None, hierarchy=0):
  466. cls = param
  467. if param_list is None:
  468. param_list = [cls]
  469. print_rank_0(
  470. f"{'--'*hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner"
  471. )
  472. self._reduce_scatter_gradients(param_list)
  473. def partition_gradients(param_list=None,
  474. partition_buffers=None,
  475. hierarchy=0,
  476. accumulate=False):
  477. cls = param
  478. print_rank_0(
  479. f"{'--'*hierarchy}----Partitioning param gradient with id {debug_param2name_id_shape_device(cls)}"
  480. )
  481. if param_list is None:
  482. param_list = [cls]
  483. if isinstance(partition_buffers, torch.Tensor):
  484. partition_buffers = [partition_buffers]
  485. self._partition_gradients(param_list,
  486. partition_buffers=partition_buffers,
  487. accumulate=accumulate)
  488. def aligned_size():
  489. return self._aligned_size(param)
  490. def padding_size():
  491. return self._padding_size(param)
  492. def partitioned_size():
  493. return self._partitioned_size(param)
  494. def convert_to_zero_parameters(param_list):
  495. self._convert_to_zero_parameters(param_list)
  496. # Collectives for gathering and partitioning parameters
  497. param.all_gather = all_gather
  498. param.partition = partition
  499. # Collective for averaging gradients
  500. param.reduce_gradients_at_owner = reduce_gradients_at_owner
  501. param.partition_gradients = partition_gradients
  502. # Partitioning size utilities
  503. param.aligned_size = aligned_size
  504. param.padding_size = padding_size
  505. param.partitioned_size = partitioned_size
  506. param.convert_to_zero_parameters = convert_to_zero_parameters
  507. def _aligned_size(self, param):
  508. return param.ds_numel + self._padding_size(param)
  509. def _padding_size(self, param):
  510. remainder = param.ds_numel % self.world_size
  511. return (self.world_size - remainder) if remainder else 0
  512. def _partitioned_size(self, param):
  513. return param.ds_tensor.ds_numel
  514. def _ensure_availability_of_partitioned_params(self, params):
  515. swap_in_list = []
  516. swap_in_flight = []
  517. for param in params:
  518. if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
  519. assert param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
  520. swap_in_list.append(param)
  521. if param.ds_tensor.status == PartitionedParamStatus.INFLIGHT:
  522. assert param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
  523. swap_in_flight.append(param)
  524. if len(swap_in_list) > 0:
  525. swap_in_list[0].nvme_swapper.swap_in(swap_in_list, async_op=False)
  526. elif len(swap_in_flight) > 0:
  527. swap_in_flight[0].nvme_swapper.synchronize_reads()
  528. def _all_gather(self, param_list, async_op=False, hierarchy=None):
  529. # fetches from nvme if the partition is not available and in nvme
  530. self._ensure_availability_of_partitioned_params(param_list)
  531. handles = []
  532. all_gather_list = []
  533. for param in param_list:
  534. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  535. if async_op:
  536. handle = self._allgather_param(param,
  537. async_op=async_op,
  538. hierarchy=hierarchy)
  539. param.ds_status = ZeroParamStatus.INFLIGHT # if async_op else ZeroParamStatus.AVAILABLE
  540. handles.append(handle)
  541. else:
  542. all_gather_list.append(param)
  543. if not async_op:
  544. # ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
  545. ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy)
  546. for param in all_gather_list:
  547. param.ds_status = ZeroParamStatus.AVAILABLE
  548. return ret_value
  549. return handles
  550. def _partition(self, param_list, force=False, has_been_updated=False):
  551. for param in param_list:
  552. #print_rank_0(f"Before Partitioning Param {param.ds_id}")
  553. # self._param_status(param)
  554. self._partition_param(param, has_been_updated=has_been_updated)
  555. param.ds_status = ZeroParamStatus.NOT_AVAILABLE
  556. # if param.ds_tensor is not None:
  557. # assert id(param.data) == id(param.ds_tensor.data), \
  558. # "After the parameters are initially partitioned, make sure we are not recreating the partition."
  559. #print_rank_0(f"After Partitioning Param {param.ds_id}")
  560. # self._param_status(param)
  561. def _partition_param(self, param, buffer=None, has_been_updated=False):
  562. assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
  563. global reuse_buffers
  564. #print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}")
  565. if param.ds_status is ZeroParamStatus.AVAILABLE:
  566. print_rank_0(
  567. f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}",
  568. force=False)
  569. # if reuse_buffers and False:
  570. # numel = buffer.numel()
  571. # buffer = param.data.view(-1)
  572. # print_rank_0(
  573. # "Returning buffer for param {param.ds_id} with numel {param.ds_numel} to empty buffers",
  574. # force=False)
  575. # if numel in empty_buffers:
  576. # empty_buffers[numel].append(buffer)
  577. # if torch.distributed.get_rank():
  578. # print(f"Releasing {param.data.numel()}")
  579. if param.ds_tensor is not None and not has_been_updated:
  580. #param.data = param.ds_tensor.data
  581. see_memory_usage(
  582. f'Before partitioning param {param.ds_id} {param.shape}',
  583. force=False)
  584. # param.data does not store anything meaningful in partitioned state
  585. param.data = torch.empty(1, dtype=self.dtype, device=param.device)
  586. see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
  587. force=False)
  588. if param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE:
  589. print_rank_0(
  590. f"Param {param.ds_id} partition released since it exists in nvme",
  591. force=False)
  592. param.nvme_swapper.remove_partition_and_release_buffers([param])
  593. return
  594. tensor_size = self._aligned_size(param)
  595. partition_size = tensor_size // self.world_size
  596. if param.ds_tensor is None:
  597. final_location = None
  598. if self.remote_device == OFFLOAD_NVME_DEVICE and self.param_swapper.swappable_tensor(
  599. numel=partition_size):
  600. final_location = OFFLOAD_NVME_DEVICE
  601. buffer = self.param_swapper.get_buffer(param, partition_size)
  602. partitioned_tensor = torch.empty(1,
  603. dtype=param.dtype,
  604. device=buffer.device)
  605. partitioned_tensor.data = buffer.data
  606. print_rank_0(
  607. f"ID {param.ds_id} Initializing partition for the first time for nvme offload."
  608. )
  609. else:
  610. partitioned_tensor = torch.empty(
  611. partition_size,
  612. dtype=param.dtype,
  613. device=OFFLOAD_CPU_DEVICE if self.remote_device
  614. == OFFLOAD_NVME_DEVICE else self.remote_device)
  615. if self.pin_memory:
  616. partitioned_tensor = partitioned_tensor.pin_memory()
  617. partitioned_tensor.requires_grad = False
  618. param.ds_tensor = partitioned_tensor
  619. param.ds_tensor.ds_numel = partition_size
  620. param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
  621. param.ds_tensor.final_location = final_location
  622. start = partition_size * self.rank
  623. end = start + partition_size
  624. one_dim_param = param.contiguous().view(-1)
  625. if start < param.ds_numel and end <= param.ds_numel:
  626. src_tensor = one_dim_param.narrow(0, start, partition_size)
  627. param.ds_tensor.copy_(src_tensor)
  628. #partitioned_tensor = src_tensor.clone().detach().to(self.remote_device)
  629. else:
  630. # partitioned_tensor = torch.zeros(partition_size,
  631. # dtype=param.dtype,
  632. # device=self.remote_device )
  633. if start < param.ds_numel:
  634. elements_to_copy = param.ds_numel - start
  635. param.ds_tensor.narrow(0,
  636. 0,
  637. elements_to_copy).copy_(
  638. one_dim_param.narrow(
  639. 0,
  640. start,
  641. elements_to_copy))
  642. #print(f"Remote device {self.remote_device}")
  643. #param.ds_tensor = partitioned_tensor
  644. #param.data = param.ds_tensor.data
  645. # param.data does not store anything meaningful in partitioned state
  646. see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}',
  647. force=False)
  648. param.data = torch.ones(1, dtype=self.dtype).to(param.device)
  649. see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
  650. force=False)
  651. if param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE:
  652. self.param_swapper.swap_out_and_release([param])
  653. print_rank_0(
  654. f"ID {param.ds_id} Offloaded to nvme offload and buffers released.")
  655. see_memory_usage(
  656. f"ID {param.ds_id} Offloaded to nvme offload and buffers released.",
  657. force=False)
  658. print_rank_0(
  659. f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}"
  660. )
  661. def _param_status(self, param):
  662. if param.ds_tensor is not None:
  663. print_rank_0(
  664. f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned numel {param.ds_tensor.numel()}, data numel {param.data.numel()}"
  665. )
  666. else:
  667. print_rank_0(
  668. f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned ds_tensor {param.ds_tensor}, data numel {param.data.numel()}"
  669. )
  670. def _allgather_param(self, param, async_op=False, hierarchy=0):
  671. partition_size = param.ds_tensor.ds_numel
  672. tensor_size = partition_size * self.world_size
  673. aligned_param_size = self._aligned_size(param)
  674. assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}'
  675. print_rank_0(
  676. f"{'--'* hierarchy}---- Before allocating allgather param {debug_param2name_id_shape_status(param)} partition size={partition_size}"
  677. )
  678. see_memory_usage(
  679. f'Before allocate allgather param {debug_param2name_id_shape_status(param)} partition_size={partition_size} ',
  680. force=False)
  681. flat_tensor = torch.zeros(aligned_param_size,
  682. dtype=param.dtype,
  683. device=param.device).view(-1)
  684. see_memory_usage(
  685. f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
  686. force=False)
  687. torch.cuda.synchronize()
  688. print_rank_0(
  689. f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
  690. )
  691. # if not flat_tensor.numel() > 100000:
  692. # replicated_tensor = flat_tensor.narrow(0,
  693. # 0,
  694. # param.ds_numel).view(param.ds_shape)
  695. # param.data = replicated_tensor.data
  696. # return None
  697. if self.use_all_gather_base:
  698. # try the _all_gather_base on PyTorch master branch
  699. handle = dist._all_gather_base(flat_tensor,
  700. param.ds_tensor.cuda(),
  701. group=self.ds_process_group,
  702. async_op=async_op)
  703. else:
  704. partitions = []
  705. for i in range(self.world_size):
  706. partitions.append(
  707. flat_tensor.narrow(0,
  708. partition_size * i,
  709. partition_size))
  710. if i == dist.get_rank(group=self.ds_process_group):
  711. partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
  712. handle = dist.all_gather(partitions,
  713. partitions[self.rank],
  714. group=self.ds_process_group,
  715. async_op=async_op)
  716. replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
  717. param.data = replicated_tensor.data
  718. return handle
  719. def _allgather_params_coalesced(self, param_list, hierarchy=0):
  720. """ blocking call
  721. avoid explicit memory copy in _allgather_params
  722. """
  723. if len(param_list) == 0:
  724. return
  725. # collect local tensors and partition sizes
  726. partition_sizes = []
  727. local_tensors = []
  728. for param in param_list:
  729. partition_sizes.append(param.ds_tensor.ds_numel)
  730. local_tensors.append(param.ds_tensor.cuda())
  731. # allocate memory for allgather params
  732. allgather_params = []
  733. for psize in partition_sizes:
  734. tensor_size = psize * self.world_size
  735. flat_tensor = torch.empty(tensor_size,
  736. dtype=param_list[0].dtype,
  737. device=self.local_device).view(-1)
  738. flat_tensor.requires_grad = False
  739. allgather_params.append(flat_tensor)
  740. # launch
  741. launch_handles = []
  742. # backend = get_backend(self.ds_process_group)
  743. # with _batch_p2p_manager(backend):
  744. for param_idx, param in enumerate(param_list):
  745. input_tensor = local_tensors[param_idx].view(-1)
  746. if self.use_all_gather_base:
  747. # try the _all_gather_base from Pytorch master
  748. h = dist._all_gather_base(allgather_params[param_idx],
  749. input_tensor,
  750. group=self.ds_process_group,
  751. async_op=True)
  752. else:
  753. output_list = []
  754. for i in range(self.world_size):
  755. psize = partition_sizes[param_idx]
  756. partition = allgather_params[param_idx].narrow(0, i * psize, psize)
  757. output_list.append(partition)
  758. if not partition.is_cuda:
  759. logger.warning(
  760. f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}'
  761. )
  762. # back to old all_gather function signature
  763. h = dist.all_gather(output_list,
  764. input_tensor,
  765. group=self.ds_process_group,
  766. async_op=True)
  767. launch_handles.append(h)
  768. # Wait ensures the operation is enqueued, but not necessarily complete.
  769. launch_handles[-1].wait()
  770. # assign to param.data (not copy)
  771. for i, param in enumerate(param_list):
  772. gathered_tensor = allgather_params[i]
  773. param.data = gathered_tensor.narrow(0,
  774. 0,
  775. param.ds_numel).view(param.ds_shape).data
  776. # guarantee the communication to be completed
  777. torch.cuda.synchronize()
  778. return None
  779. def _allgather_params(self, param_list, hierarchy=0):
  780. if len(param_list) == 0:
  781. return
  782. partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
  783. tensor_size = partition_size * self.world_size
  784. flat_tensor = torch.empty(tensor_size,
  785. dtype=param_list[0].dtype,
  786. device=self.local_device)
  787. flat_tensor.requres_grad = False
  788. partitions = []
  789. for i in range(self.world_size):
  790. start = partition_size * i
  791. partitions.append(flat_tensor.narrow(0, start, partition_size))
  792. if i == self.rank:
  793. offset = 0
  794. for param in param_list:
  795. param_numel = param.ds_tensor.ds_numel
  796. partitions[i].narrow(0,
  797. offset,
  798. param_numel).copy_(param.ds_tensor.data)
  799. offset += param_numel
  800. torch.distributed.all_gather(partitions,
  801. partitions[self.rank],
  802. group=self.ds_process_group,
  803. async_op=False)
  804. param_offset = 0
  805. for param in param_list:
  806. param_partition_size = param.ds_tensor.ds_numel
  807. param_size = param.ds_numel
  808. replicated_tensor = torch.empty(param.ds_shape,
  809. dtype=param.dtype,
  810. device=self.local_device)
  811. for i in range(self.world_size):
  812. start = i * partition_size
  813. param_start = i * param_partition_size
  814. if param_start < param_size:
  815. numel_to_copy = min(param_size - param_start, param_partition_size)
  816. part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
  817. replicated_tensor.view(-1).narrow(0,
  818. param_start,
  819. numel_to_copy).copy_(part_to_copy)
  820. #param_offset += param.data.numel()
  821. param_offset += param.ds_tensor.ds_numel
  822. param.data = replicated_tensor.data
  823. return None
  824. def _reduce_scatter_gradients(self, param_list):
  825. #print_rank_0([param.grad for param in param_list])
  826. #assert any([param.grad is None for param in param_list]), "None gradients cannot be reduce scattered"
  827. handles_and_reduced_partitions = []
  828. for param in param_list:
  829. assert param.grad.numel(
  830. ) == param.ds_numel, f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter gradients whose size is not same as the params"
  831. handles_and_reduced_partitions.append(self._reduce_scatter_gradient(param))
  832. for param, (handle, reduced_partition) in zip(param_list, handles_and_reduced_partitions):
  833. if handle is not None:
  834. handle.wait()
  835. # some ranks may have partitions that are padded to go beyond the grad size.
  836. # For these ranks the output of reduce scatter is a separate buffer and needs
  837. # to be copied in
  838. partition_size = param.ds_tensor.ds_numel
  839. start = self.rank * partition_size
  840. end = start + partition_size
  841. #print_rank_0("REduce scatter was executed for praam {param.ds_id}")
  842. if start < param.ds_numel and end > param.ds_numel:
  843. elements = param.ds_numel - start
  844. param.grad.view(-1).narrow(0,
  845. start,
  846. elements).copy_(
  847. reduced_partition.narrow(0,
  848. 0,
  849. elements))
  850. def _reduce_scatter_gradient(self, param):
  851. partition_size = param.ds_tensor.ds_numel
  852. #output = torch.empty(partition_size, dtype=param.dtype, device=param.device)
  853. total_size = partition_size * self.world_size
  854. input_list = []
  855. for i in range(self.world_size):
  856. start = i * partition_size
  857. end = start + partition_size
  858. #print("before reduce scatter gradients")
  859. if start < param.ds_numel and end <= param.ds_numel:
  860. input = param.grad.view(-1).narrow(0, start, partition_size)
  861. else:
  862. input = torch.zeros(partition_size,
  863. dtype=param.dtype,
  864. device=param.device)
  865. if start < param.ds_numel:
  866. elements = param.ds_numel - start
  867. input.narrow(0,
  868. 0,
  869. elements).copy_(
  870. param.grad.view(-1).narrow(0,
  871. start,
  872. elements))
  873. #print("after reduce scatter gradients")
  874. input_list.append(input)
  875. rank = torch.distributed.get_rank(group=self.ds_process_group)
  876. handle = torch.distributed.reduce_scatter(input_list[rank],
  877. input_list,
  878. group=self.ds_process_group,
  879. async_op=True)
  880. return handle, input_list[rank]
  881. def _partition_gradients(self, param_list, partition_buffers=None, accumulate=False):
  882. if partition_buffers is None:
  883. partition_buffers = [None] * len(param_list)
  884. for param, partition_buffer in zip(param_list, partition_buffers):
  885. self._partition_gradient(param,
  886. partition_buffer=partition_buffer,
  887. accumulate=accumulate)
  888. def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
  889. #import pdb;pdb.set_trace()
  890. # param.grad=None
  891. # param.grad.test()
  892. print_rank_0(
  893. f"Partitioning param {param.ds_id} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.ds_numel}"
  894. )
  895. see_memory_usage("Before partitioning gradients", force=False)
  896. partition_size = param.ds_tensor.ds_numel
  897. if partition_buffer is None:
  898. assert not accumulate, "No buffer to accumulate to"
  899. partition_buffer = torch.zeros(partition_size,
  900. dtype=param.dtype,
  901. device=param.device)
  902. else:
  903. assert partition_buffer.numel(
  904. ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
  905. rank = torch.distributed.get_rank(group=self.ds_process_group)
  906. start = partition_size * rank
  907. end = start + partition_size
  908. dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size)
  909. #print("before partition gradients")
  910. if start < param.ds_numel:
  911. elements = min(param.ds_numel - start, partition_size)
  912. dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
  913. src_tensor = param.grad.view(-1).narrow(0, start, elements)
  914. # just copy the grad partition to the buffer
  915. if not accumulate:
  916. dest_tensor.copy_(src_tensor)
  917. # if source and destination are on same device,
  918. # add to the provided buffer
  919. elif src_tensor.device == dest_tensor.device:
  920. dest_tensor.add_(src_tensor)
  921. # if source and destination are on different device, copy first to src
  922. # then add and move back to the destination. This seems to run faster
  923. # when src is gpu and dest is cpu
  924. # adding directly to cpu is very slow
  925. else:
  926. acc_tensor = torch.empty(src_tensor.numel(),
  927. dtype=param.dtype,
  928. device=param.device)
  929. acc_tensor.copy_(dest_tensor)
  930. acc_tensor.add_(src_tensor)
  931. dest_tensor.copy_(acc_tensor)
  932. # partition_buffer.view(-1).narrow(
  933. # 0,
  934. # 0,
  935. # elements).copy_(param.grad.view(-1).narrow(0,
  936. # start,
  937. # elements))
  938. #print("after partition gradients")
  939. param.grad.data = dest_tensor_full_buffer.data
  940. see_memory_usage("After partitioning gradients", force=False)
  941. class GatheredParameters:
  942. def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True):
  943. """A context that collects parameters that were partitioned via a
  944. :class:`deepspeed.zero.Init` context. The parameters are partitioned
  945. again upon exit.
  946. Args:
  947. params (``torch.nn.Parameter``): A single parameter or a list of parameters to collect.
  948. It's assumed that all parameters are zero params.
  949. modifier_rank (int, optional): If specified, this rank's parameter will be
  950. broadcasted on exit from the context. This argument is required if ``params`` are
  951. modified, so that all processes have a consistent view of the data. Defaults
  952. to ``None``.
  953. fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be
  954. registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
  955. enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.
  956. Important: Make sure to use ``modifier_rank`` that is not ``None`` (e.g. ``modifier_rank=0``)
  957. if you need the GPU memory allocated by gather to be released upon exit from the context manager.
  958. Examples
  959. ========
  960. #. Allocate a partitioned module, initialize its weight on rank 0, and update all
  961. processes.
  962. .. code-block:: python
  963. with deepspeed.zero.Init():
  964. linear = torch.nn.Linear(1000,1000)
  965. with deepspeed.zero.GatheredParameters(linear.weight,
  966. modifier_rank=0):
  967. if torch.distributed.get_rank() == 0:
  968. linear.weight.zero_()
  969. with deepspeed.zero.GatheredParameters(linear.weight,
  970. modifier_rank=0):
  971. if torch.distributed.get_rank() == 0:
  972. linear.weight.zero_()
  973. #. Collect a partitioned weight to pass to another module during
  974. training. The parameter will be registered as an external parameter
  975. and made available during the backward pass.
  976. .. code-block:: python
  977. :emphasize-lines: 6
  978. def forward(self, input):
  979. x = self.layer1(input)
  980. # self.layer1.weight is required by self.layer2.forward
  981. with deepspeed.zero.GatheredParameters(self.layer1.weight,
  982. fwd_module=self):
  983. y = self.layer2(x, self.layer1.weight)
  984. return y
  985. #. Pretrained model loading
  986. .. code-block:: python
  987. with deepspeed.zero.Init():
  988. model = MyModel()
  989. state_dict = torch.load(model_path, map_location="cpu")
  990. def load(module: nn.Module, prefix=""):
  991. # because zero3 puts placeholders in model params, this context
  992. # manager gathers (unpartitions) the params of the current layer, then loads from
  993. # the state dict and then re-partitions them again
  994. with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
  995. if torch.distributed.get_rank() == 0:
  996. module._load_from_state_dict(state_dict, prefix)
  997. for name, child in module._modules.items():
  998. if child is not None:
  999. load(child, prefix + name + ".")
  1000. load(model, prefix="")
  1001. If this approach is not used, then the full model will first get copied to each GPU. For models
  1002. bigger than the memory of a single gpu this method is required.
  1003. """
  1004. self.enabled = enabled
  1005. if not enabled:
  1006. return
  1007. if not isinstance(params, list):
  1008. params = [params]
  1009. # enable if at least one is zero-param, otherwise a noop
  1010. if not any(is_zero_param(p) for p in params):
  1011. self.enabled = False
  1012. return
  1013. self.params = [p for p in params if hasattr(p, "ds_id")]
  1014. self.src_rank = None
  1015. if modifier_rank is not None:
  1016. if self.params[0].ds_process_group == torch.distributed.group.WORLD:
  1017. self.src_rank = modifier_rank
  1018. else:
  1019. # A group was specified; convert DP rank to global rank
  1020. self.src_rank = _get_global_rank(self.params[0].ds_process_group,
  1021. modifier_rank)
  1022. self.fwd_module = fwd_module
  1023. if self.fwd_module is not None:
  1024. # is a no-op if already registered
  1025. for p in self.params:
  1026. register_external_parameter(self.fwd_module, p)
  1027. def __enter__(self):
  1028. if not self.enabled:
  1029. return
  1030. self.params[0].all_gather(param_list=self.params)
  1031. def __exit__(self, *exc):
  1032. if not self.enabled:
  1033. return
  1034. if self.src_rank is None:
  1035. return
  1036. handles = [
  1037. torch.distributed.broadcast(p,
  1038. self.src_rank,
  1039. group=p.ds_process_group,
  1040. async_op=True) for p in self.params
  1041. ]
  1042. for h in handles:
  1043. h.wait()
  1044. self.params[0].partition(param_list=self.params, has_been_updated=True)