module.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import glob
  6. import re as regex
  7. from functools import partial
  8. import torch
  9. import torch.nn as nn
  10. from deepspeed import comm as dist
  11. from deepspeed.utils import logger
  12. from .. import utils as ds_utils
  13. from ..activation_checkpointing import checkpointing
  14. from .topology import PipeDataParallelTopology, PipelineParallelGrid
  15. from deepspeed.runtime.state_dict_factory import SDLoaderFactory
  16. from deepspeed.accelerator import get_accelerator
  17. from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
  18. class PipelineError(Exception):
  19. """Errors related to the use of deepspeed.PipelineModule """
  20. class LayerSpec:
  21. """Building block for specifying pipeline-parallel modules.
  22. LayerSpec stores the type information and parameters for each stage in a
  23. PipelineModule. For example:
  24. .. code-block:: python
  25. nn.Sequence(
  26. torch.nn.Linear(self.in_dim, self.hidden_dim, bias=False),
  27. torch.nn.Linear(self.hidden_hidden, self.out_dim)
  28. )
  29. becomes
  30. .. code-block:: python
  31. layer_specs = [
  32. LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False),
  33. LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
  34. ]
  35. """
  36. def __init__(self, typename, *module_args, **module_kwargs):
  37. self.typename = typename
  38. self.module_args = module_args
  39. self.module_kwargs = module_kwargs
  40. if not issubclass(typename, nn.Module):
  41. raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
  42. if dist.is_initialized():
  43. self.global_rank = dist.get_rank()
  44. else:
  45. self.global_rank = -1
  46. def __repr__(self):
  47. return ds_utils.call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
  48. def build(self, log=False):
  49. """Build the stored specification."""
  50. if log:
  51. logger.info(f'RANK={self.global_rank} building {repr(self)}')
  52. return self.typename(*self.module_args, **self.module_kwargs)
  53. class TiedLayerSpec(LayerSpec):
  54. def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr='weight', **module_kwargs):
  55. super().__init__(typename, *module_args, **module_kwargs)
  56. self.key = key
  57. self.forward_fn = forward_fn
  58. self.tied_weight_attr = tied_weight_attr
  59. class PipelineModule(nn.Module):
  60. """Modules to be parallelized with pipeline parallelism.
  61. The key constraint that enables pipeline parallelism is the
  62. representation of the forward pass as a sequence of layers
  63. and the enforcement of a simple interface between them. The
  64. forward pass is implicitly defined by the module ``layers``. The key
  65. assumption is that the output of each layer can be directly fed as
  66. input to the next, like a ``torch.nn.Sequence``. The forward pass is
  67. implicitly:
  68. .. code-block:: python
  69. def forward(self, inputs):
  70. x = inputs
  71. for layer in self.layers:
  72. x = layer(x)
  73. return x
  74. .. note::
  75. Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.
  76. Args:
  77. layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.
  78. num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.
  79. topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.
  80. loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``
  81. seed_layers(bool, optional): Use a different seed for each layer. Defaults to False.
  82. seed_fn(type, optional): The custom seed generating function. Defaults to random seed generator.
  83. base_seed (int, optional): The starting seed. Defaults to 1234.
  84. partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
  85. activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
  86. activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
  87. checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
  88. """
  89. def __init__(self,
  90. layers,
  91. num_stages=None,
  92. topology=None,
  93. loss_fn=None,
  94. seed_layers=False,
  95. seed_fn=None,
  96. base_seed=1234,
  97. partition_method='parameters',
  98. activation_checkpoint_interval=0,
  99. activation_checkpoint_func=checkpointing.checkpoint,
  100. checkpointable_layers=None):
  101. super().__init__()
  102. if num_stages is None and topology is None:
  103. raise RuntimeError('must provide num_stages or topology')
  104. self.micro_offset = 0
  105. self.loss_fn = loss_fn
  106. self.checkpointable_layers = checkpointable_layers
  107. if checkpointable_layers is not None:
  108. assert isinstance(checkpointable_layers, list), "param `checkpointable_layers` must be type of list."
  109. self.seed_layers = seed_layers
  110. self.seed_fn = seed_fn
  111. self.base_seed = base_seed
  112. if dist.get_rank() == 0:
  113. try:
  114. seed_str = self.seed_fn.__name__
  115. except AttributeError:
  116. seed_str = None
  117. print(f'SEED_LAYERS={self.seed_layers} BASE_SEED={self.base_seed} SEED_FN={seed_str}')
  118. # Setup world info
  119. self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
  120. self.global_rank = dist.get_rank(group=self.world_group)
  121. self.world_size = dist.get_world_size(group=self.world_group)
  122. self.local_rank = int(os.environ.get("LOCAL_RANK", None))
  123. assert self.local_rank is not None
  124. if topology:
  125. self._topo = topology
  126. self.num_stages = self._topo.get_dim('pipe')
  127. else:
  128. self.num_stages = num_stages
  129. if topology is None:
  130. if self.world_size % self.num_stages != 0:
  131. raise RuntimeError(
  132. f'num_stages ({self.num_stages}) must divide distributed world size ({self.world_size})')
  133. dp = self.world_size // num_stages
  134. topology = PipeDataParallelTopology(num_pp=num_stages, num_dp=dp)
  135. self._topo = topology
  136. # Construct communicators for pipeline topology
  137. self._grid = PipelineParallelGrid(process_group=self.world_group, topology=self._topo)
  138. self.stage_id = self._topo.get_coord(self.global_rank).pipe
  139. # Initialize partition information
  140. self._layer_specs = list(layers)
  141. self._num_layers = len(self._layer_specs)
  142. self._local_start = 0
  143. self._local_stop = None
  144. self._partition_layers(method=partition_method)
  145. self.forward_funcs = []
  146. self.fwd_map = {}
  147. self.tied_modules = nn.ModuleDict()
  148. self.tied_weight_attrs = {}
  149. # Offset the random seed by the stage ID.
  150. #newseed = get_accelerator().initial_seed() + self._grid.get_stage_id()
  151. #ds_utils.set_random_seed(newseed)
  152. #with torch.random.fork_rng(devices=[get_accelerator().current_device_name()]):
  153. self._build()
  154. self.to(get_accelerator().device_name(self.local_rank))
  155. self.tied_comms = self._index_tied_modules()
  156. self._synchronize_tied_weights()
  157. self.activation_checkpoint_interval = activation_checkpoint_interval
  158. self.activation_checkpoint_func = activation_checkpoint_func
  159. def _build(self):
  160. specs = self._layer_specs
  161. for local_idx, layer in enumerate(specs[self._local_start:self._local_stop]):
  162. layer_idx = local_idx + self._local_start
  163. if self.seed_layers:
  164. if self.seed_fn:
  165. self.seed_fn(self.base_seed + layer_idx)
  166. else:
  167. ds_utils.set_random_seed(self.base_seed + layer_idx)
  168. # Recursively build PipelineModule objects
  169. if isinstance(layer, PipelineModule):
  170. raise NotImplementedError('RECURSIVE BUILD NOT YET IMPLEMENTED')
  171. # LayerSpec objects contain an nn.Module that should be allocated now.
  172. elif isinstance(layer, nn.Module):
  173. name = str(layer_idx)
  174. self.forward_funcs.append(layer)
  175. self.fwd_map.update({name: len(self.forward_funcs) - 1})
  176. self.add_module(name, layer)
  177. # TiedLayerSpec objects contain an nn.Module that should be allocated now.
  178. elif isinstance(layer, TiedLayerSpec):
  179. # Build and register the module if we haven't seen it before.
  180. if layer.key not in self.tied_modules:
  181. self.tied_modules[layer.key] = layer.build()
  182. self.tied_weight_attrs[layer.key] = layer.tied_weight_attr
  183. if layer.forward_fn is None:
  184. # Just use forward()
  185. self.forward_funcs.append(self.tied_modules[layer.key])
  186. else:
  187. # User specified fn with args (module, input)
  188. self.forward_funcs.append(partial(layer.forward_fn, self.tied_modules[layer.key]))
  189. # LayerSpec objects contain an nn.Module that should be allocated now.
  190. elif isinstance(layer, LayerSpec):
  191. module = layer.build()
  192. name = str(layer_idx)
  193. self.forward_funcs.append(module)
  194. self.fwd_map.update({name: len(self.forward_funcs) - 1})
  195. self.add_module(name, module)
  196. # Last option: layer may be a functional (e.g., lambda). We do nothing in
  197. # that case and just use it in forward()
  198. else:
  199. self.forward_funcs.append(layer)
  200. # All pipeline parameters should be considered as model parallel in the context
  201. # of our FP16 optimizer
  202. for p in self.parameters():
  203. p.ds_pipe_replicated = False
  204. def _get_frozen_parameter_names(self, layer):
  205. """ Get names of frozen parameters in the layer.
  206. Returns:
  207. A list of frozen parameter names
  208. """
  209. if isinstance(layer, LayerSpec):
  210. l = layer.build()
  211. return [n for n, p in l.named_parameters() if not p.requires_grad]
  212. elif isinstance(layer, nn.Module):
  213. return [n for n, p in layer.named_parameters() if not p.requires_grad]
  214. return []
  215. def _count_layer_params(self):
  216. """Count the trainable parameters in individual layers.
  217. This routine will only build one layer at a time.
  218. Returns:
  219. A list of the number of parameters in each layer.
  220. """
  221. param_counts = [0] * len(self._layer_specs)
  222. for idx, layer in enumerate(self._layer_specs):
  223. if isinstance(layer, LayerSpec):
  224. l = layer.build()
  225. params = filter(lambda p: p.requires_grad, l.parameters())
  226. param_counts[idx] = sum(p.numel() for p in params)
  227. elif isinstance(layer, nn.Module):
  228. params = filter(lambda p: p.requires_grad, layer.parameters())
  229. param_counts[idx] = sum(p.numel() for p in params)
  230. return param_counts
  231. def _find_layer_type(self, layername):
  232. idxs = []
  233. typeregex = regex.compile(layername, regex.IGNORECASE)
  234. for idx, layer in enumerate(self._layer_specs):
  235. name = None
  236. if isinstance(layer, LayerSpec):
  237. name = layer.typename.__name__
  238. elif isinstance(layer, nn.Module):
  239. name = layer.__class__.__name__
  240. else:
  241. try:
  242. name = layer.__name__
  243. except AttributeError:
  244. continue
  245. if typeregex.search(name):
  246. idxs.append(idx)
  247. if len(idxs) == 0:
  248. raise RuntimeError(f"Partitioning '{layername}' found no valid layers to partition.")
  249. return idxs
  250. def forward(self, forward_input):
  251. # We need to offset the seed by the microbatch ID. Save it in a local var to
  252. # ensure it is preserved in the closure. Otherwise checkpointed forward funcs
  253. # will see a different offset.
  254. self.micro_offset += 1
  255. def exec_range_func(start, end):
  256. ''' Helper function to be used with checkpoint()
  257. Adapted from torch.utils.checkpoint:checkpoint_sequential()
  258. '''
  259. local_micro_offset = self.micro_offset + 1
  260. def exec_func(*inputs):
  261. # Single tensor inputs need to be unwrapped
  262. if len(inputs) == 1:
  263. inputs = inputs[0]
  264. for idx, layer in enumerate(self.forward_funcs[start:end]):
  265. self.curr_layer = idx + self._local_start
  266. if self.seed_layers:
  267. new_seed = (self.base_seed * local_micro_offset) + self.curr_layer
  268. if self.seed_fn:
  269. self.seed_fn(new_seed)
  270. else:
  271. ds_utils.set_random_seed(new_seed)
  272. inputs = layer(inputs)
  273. return inputs
  274. return exec_func
  275. if self.activation_checkpoint_interval == 0:
  276. func = exec_range_func(0, len(self.forward_funcs))
  277. x = func(forward_input)
  278. else:
  279. num_layers = len(self.forward_funcs)
  280. x = forward_input
  281. for start_idx in range(0, num_layers, self.activation_checkpoint_interval):
  282. end_idx = min(start_idx + self.activation_checkpoint_interval, num_layers)
  283. funcs = self.forward_funcs[start_idx:end_idx]
  284. # Since we either pass tensors or tuples of tensors without unpacking, we
  285. # need to be careful not to double-wrap tensors with tuple.
  286. if not isinstance(x, tuple):
  287. x = (x, )
  288. if self._is_checkpointable(funcs):
  289. x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x)
  290. else:
  291. x = exec_range_func(start_idx, end_idx)(*x)
  292. return x
  293. def _partition_layers(self, method='uniform'):
  294. num_stages = self._topo.get_dim('pipe')
  295. stage_id = self._topo.get_coord(self.global_rank).pipe
  296. if self.global_rank == 0:
  297. logger.info(f'Partitioning pipeline stages with method {method}')
  298. method = method.lower()
  299. # Each stage gets a simple uniform number of layers.
  300. if method == 'uniform':
  301. num_layers = len(self._layer_specs)
  302. self.parts = ds_utils.partition_uniform(num_items=num_layers, num_parts=num_stages)
  303. elif method == 'parameters':
  304. param_counts = self._count_layer_params()
  305. self.parts = ds_utils.partition_balanced(weights=param_counts, num_parts=num_stages)
  306. elif method.startswith('type:'):
  307. layertype = method.split(':')[1]
  308. binary_weights = [0] * len(self._layer_specs)
  309. for idx in self._find_layer_type(layertype):
  310. binary_weights[idx] = 1
  311. self.parts = ds_utils.partition_balanced(weights=binary_weights, num_parts=num_stages)
  312. elif method == 'profile':
  313. raise NotImplementedError(f'Partitioning method {method} not implemented.')
  314. else:
  315. raise NotImplementedError(f'Partitioning method {method} not implemented.')
  316. # Print some information on the partitioning.
  317. if self.global_rank == 0:
  318. for stage in range(num_stages):
  319. start = self.parts[stage]
  320. stop = self.parts[stage + 1]
  321. print(f'stage={stage} layers={stop - start}')
  322. for idx, layer in enumerate(self._layer_specs[start:stop]):
  323. name = str(layer)
  324. if isinstance(layer, LayerSpec):
  325. name = layer.typename.__name__
  326. if isinstance(layer, nn.Module):
  327. name = layer.__class__.__name__
  328. else:
  329. try:
  330. name = layer.__name__
  331. except AttributeError:
  332. pass
  333. print(f' {idx+start:2d}: {name}')
  334. if self.loss_fn:
  335. try:
  336. print(f' loss: {self.loss_fn.__name__}')
  337. except AttributeError:
  338. print(f' loss: {self.loss_fn.__class__.__name__}')
  339. self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1])
  340. def allreduce_tied_weight_gradients(self):
  341. '''All reduce the gradients of the tied weights between tied stages'''
  342. for key, comm in self.tied_comms.items():
  343. weight = getattr(self.tied_modules[key], comm['weight_attr'])
  344. dist.all_reduce(weight.grad, group=comm['group'])
  345. def get_tied_weights_and_groups(self):
  346. weight_group_list = []
  347. for key, comm in self.tied_comms.items():
  348. weight = getattr(self.tied_modules[key], comm['weight_attr'])
  349. weight_group_list.append((weight, comm['group']))
  350. return weight_group_list
  351. def _synchronize_tied_weights(self):
  352. for key, comm in self.tied_comms.items():
  353. dist.broadcast(
  354. getattr(comm['module'], comm['weight_attr']),
  355. src=min(comm['ranks']),
  356. group=comm['group'],
  357. )
  358. def _index_tied_modules(self):
  359. ''' Build communication structures for tied modules. '''
  360. tied_comms = {}
  361. if self._topo.get_dim('pipe') == 1:
  362. return tied_comms
  363. specs = self._layer_specs
  364. tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec))
  365. for key in tie_keys:
  366. # Find the layers that the tied module appears in
  367. tied_layers = []
  368. for idx, layer in enumerate(specs):
  369. if isinstance(layer, TiedLayerSpec) and layer.key == key:
  370. tied_layers.append(idx)
  371. # Find all stages with this tied module
  372. # TODO: Would be nice to remove the nested data/model parallelism loops and
  373. # TODO: instead generalize in some way, since we really just care about the
  374. # TODO: stage that owns the tied layer. Then loop over each (dp, mp, ...)
  375. # TODO: fiber to generate process groups.
  376. tied_stages = set(self.stage_owner(idx) for idx in tied_layers)
  377. for dp in range(self._grid.data_parallel_size):
  378. for mp in range(self._grid.get_slice_parallel_world_size()):
  379. tied_ranks = []
  380. for s in sorted(tied_stages):
  381. if self._grid.get_slice_parallel_world_size() > 1:
  382. tied_ranks.append(self._grid.stage_to_global(stage_id=s, data=dp, model=mp))
  383. else:
  384. tied_ranks.append(self._grid.stage_to_global(stage_id=s, data=dp))
  385. group = dist.new_group(ranks=tied_ranks)
  386. # Record this tied module if we own a local copy of it.
  387. if self.global_rank in tied_ranks:
  388. assert key in self.tied_modules
  389. if key in self.tied_modules:
  390. tied_comms[key] = {
  391. 'ranks': tied_ranks,
  392. 'group': group,
  393. 'weight_attr': self.tied_weight_attrs[key],
  394. 'module': self.tied_modules[key],
  395. }
  396. # Only count the tied module once in the eyes of the FP16 optimizer
  397. if self.global_rank != tied_ranks[0]:
  398. for p in self.tied_modules[key].parameters():
  399. p.ds_pipe_replicated = True
  400. '''
  401. if len(tied_comms) > 0:
  402. print(f'RANK={self.global_rank} tied_comms={tied_comms}')
  403. '''
  404. return tied_comms
  405. def partitions(self):
  406. return self.parts
  407. def stage_owner(self, layer_idx):
  408. assert 0 <= layer_idx < self._num_layers
  409. for stage in range(self._topo.get_dim('pipe')):
  410. if self.parts[stage] <= layer_idx < self.parts[stage + 1]:
  411. return stage
  412. raise RuntimeError(f'Layer {layer_idx} not owned? parts={self.parts}')
  413. def _set_bounds(self, start=None, stop=None):
  414. """Manually define the range of layers that will be built on this process.
  415. These boundaries are treated as list slices and so start is inclusive and stop is
  416. exclusive. The default of None for both results in all layers being built
  417. locally.
  418. """
  419. self._local_start = start
  420. self._local_stop = stop
  421. def set_checkpoint_interval(self, interval):
  422. assert interval >= 0
  423. self.checkpoint_interval = interval
  424. def topology(self):
  425. """ ProcessTopology object to query process mappings. """
  426. return self._topo
  427. def mpu(self):
  428. return self._grid
  429. def num_pipeline_stages(self):
  430. return self._topo.get_dim('pipe')
  431. def ckpt_prefix(self, checkpoints_path, tag):
  432. """Build a prefix for all checkpoint files written by this module. """
  433. # All checkpoint files start with this
  434. rank_name = 'module'
  435. # Data parallelism is omitted from the naming convention because we are agnostic
  436. # to this in the checkpoint.
  437. omit_dims = frozenset(['data'])
  438. axes = [a for a in self._grid._topo.get_axis_names() if a not in omit_dims]
  439. for dim in axes:
  440. rank = getattr(self._grid._topo.get_coord(rank=self.global_rank), dim)
  441. rank_name += f'-{dim}_{rank:02d}'
  442. ckpt_name = os.path.join(checkpoints_path, str(tag), rank_name)
  443. return ckpt_name
  444. def ckpt_layer_path(self, ckpt_dir, local_layer_idx):
  445. """Customize a prefix for a specific pipeline module layer. """
  446. idx = local_layer_idx + self._local_start
  447. layer_ckpt_path = os.path.join(ckpt_dir, f'layer_{idx:02d}')
  448. rank_repr = self._grid._topo.get_rank_repr(rank=self.global_rank)
  449. if rank_repr != '':
  450. layer_ckpt_path += f'-{rank_repr}'
  451. layer_ckpt_path += '-model_states.pt'
  452. return layer_ckpt_path
  453. def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx):
  454. """Get all ckpt file list for a specific pipeline module layer. """
  455. idx = local_layer_idx + self._local_start
  456. layer_ckpt_path = os.path.join(ckpt_dir, f'layer_{idx:02d}-')
  457. layer_ckpt_path += "*model_states.pt"
  458. ckpt_files = glob.glob(layer_ckpt_path)
  459. ckpt_files.sort()
  460. return ckpt_files
  461. def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=False):
  462. # Processes having the same model parallel rank on different data parallel instances
  463. # have identical layer weights. We can distribute the task of saving the layer weights
  464. # among the data parallel ranks. For example, if a pipeline stage has 9 layers and
  465. # if there are 2 data parallel instances, rank 0 will save the first 5 layers and
  466. # rank 1 will save the last 4.
  467. dp_rank = self._grid.data_parallel_id
  468. dp_size = self._grid.data_parallel_size
  469. num_layers = len(self.forward_funcs)
  470. if self.checkpoint_parallel_write_pipeline:
  471. # spread layers evenly across data parallel ranks
  472. offsets = ds_utils.partition_uniform(num_layers, dp_size)
  473. start, end = offsets[dp_rank], offsets[dp_rank + 1]
  474. else:
  475. # data parallel rank 0 writes all layers
  476. if dp_rank != 0:
  477. return
  478. start, end = 0, num_layers
  479. layer_list = self.forward_funcs[start:end]
  480. checkpoint_engine.makedirs(save_dir, exist_ok=True)
  481. for idx, layer in enumerate(layer_list):
  482. model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
  483. if not hasattr(layer, 'state_dict'):
  484. continue
  485. orig_state_dict = layer.state_dict()
  486. if exclude_frozen_params:
  487. for n in self._get_frozen_parameter_names(layer):
  488. del orig_state_dict[n]
  489. final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
  490. checkpoint_engine.save(final_state_dict, model_ckpt_path)
  491. def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
  492. for idx, layer in enumerate(self.forward_funcs):
  493. # Functions, etc. will not have state_dicts
  494. if not hasattr(layer, 'load_state_dict'):
  495. continue
  496. # get all checkpoint files for the layer.
  497. model_ckpt_list = self.ckpt_layer_path_list(load_dir, idx)
  498. mp_rank = self._grid.get_slice_parallel_rank()
  499. mp_world_size = self._grid.get_slice_parallel_world_size()
  500. sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list,
  501. version=2.0,
  502. checkpoint_engine=checkpoint_engine)
  503. load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True)
  504. layer.load_state_dict(checkpoint, strict=strict)
  505. # if self._grid.data_parallel_id == 0:
  506. # logger.info(
  507. # f'RANK={self.global_rank} Loaded layer={idx+self._local_start} file={load_path}'
  508. # )
  509. self._synchronize_tied_weights()
  510. def _is_checkpointable(self, funcs):
  511. # This is an unfortunate hack related to torch and deepspeed activation checkpoint implementations.
  512. # Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things.
  513. # I presume it's related to the discrete inputs that cannot require_grad? Need to revisit.
  514. if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
  515. return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)
  516. if self.checkpointable_layers is not None:
  517. return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)
  518. params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
  519. return any(len(list(p)) > 0 for p in params)