module.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import os
  3. import glob
  4. import re as regex
  5. from functools import partial
  6. import torch
  7. import torch.nn as nn
  8. from deepspeed import comm as dist
  9. from deepspeed.utils import logger
  10. from .. import utils as ds_utils
  11. from ..activation_checkpointing import checkpointing
  12. from .topology import PipeDataParallelTopology, PipelineParallelGrid
  13. from deepspeed.runtime.state_dict_factory import SDLoaderFactory
  14. from deepspeed.accelerator import get_accelerator
  15. class PipelineError(Exception):
  16. """Errors related to the use of deepspeed.PipelineModule """
  17. class LayerSpec:
  18. """Building block for specifying pipeline-parallel modules.
  19. LayerSpec stores the type information and parameters for each stage in a
  20. PipelineModule. For example:
  21. .. code-block:: python
  22. nn.Sequence(
  23. torch.nn.Linear(self.in_dim, self.hidden_dim, bias=False),
  24. torch.nn.Linear(self.hidden_hidden, self.out_dim)
  25. )
  26. becomes
  27. .. code-block:: python
  28. layer_specs = [
  29. LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False),
  30. LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
  31. ]
  32. """
  33. def __init__(self, typename, *module_args, **module_kwargs):
  34. self.typename = typename
  35. self.module_args = module_args
  36. self.module_kwargs = module_kwargs
  37. if not issubclass(typename, nn.Module):
  38. raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
  39. if dist.is_initialized():
  40. self.global_rank = dist.get_rank()
  41. else:
  42. self.global_rank = -1
  43. def __repr__(self):
  44. return ds_utils.call_to_str(self.typename.__name__,
  45. self.module_args,
  46. self.module_kwargs)
  47. def build(self, log=False):
  48. """Build the stored specification."""
  49. if log:
  50. logger.info(f'RANK={self.global_rank} building {repr(self)}')
  51. return self.typename(*self.module_args, **self.module_kwargs)
  52. class TiedLayerSpec(LayerSpec):
  53. def __init__(self,
  54. key,
  55. typename,
  56. *module_args,
  57. forward_fn=None,
  58. tied_weight_attr='weight',
  59. **module_kwargs):
  60. super().__init__(typename, *module_args, **module_kwargs)
  61. self.key = key
  62. self.forward_fn = forward_fn
  63. self.tied_weight_attr = tied_weight_attr
  64. class PipelineModule(nn.Module):
  65. """Modules to be parallelized with pipeline parallelism.
  66. The key constraint that enables pipeline parallelism is the
  67. representation of the forward pass as a sequence of layers
  68. and the enforcement of a simple interface between them. The
  69. forward pass is implicitly defined by the module ``layers``. The key
  70. assumption is that the output of each layer can be directly fed as
  71. input to the next, like a ``torch.nn.Sequence``. The forward pass is
  72. implicitly:
  73. .. code-block:: python
  74. def forward(self, inputs):
  75. x = inputs
  76. for layer in self.layers:
  77. x = layer(x)
  78. return x
  79. .. note::
  80. Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.
  81. Args:
  82. layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.
  83. num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.
  84. topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.
  85. loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``
  86. seed_layers(bool, optional): Use a different seed for each layer. Defaults to False.
  87. seed_fn(type, optional): The custom seed generating function. Defaults to random seed generator.
  88. base_seed (int, optional): The starting seed. Defaults to 1234.
  89. partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
  90. activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
  91. activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
  92. checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
  93. """
  94. def __init__(self,
  95. layers,
  96. num_stages=None,
  97. topology=None,
  98. loss_fn=None,
  99. seed_layers=False,
  100. seed_fn=None,
  101. base_seed=1234,
  102. partition_method='parameters',
  103. activation_checkpoint_interval=0,
  104. activation_checkpoint_func=checkpointing.checkpoint,
  105. checkpointable_layers=None):
  106. super().__init__()
  107. if num_stages is None and topology is None:
  108. raise RuntimeError('must provide num_stages or topology')
  109. self.micro_offset = 0
  110. self.loss_fn = loss_fn
  111. self.checkpointable_layers = checkpointable_layers
  112. if checkpointable_layers is not None:
  113. assert isinstance(checkpointable_layers, list), "param `checkpointable_layers` must be type of list."
  114. self.seed_layers = seed_layers
  115. self.seed_fn = seed_fn
  116. self.base_seed = base_seed
  117. if dist.get_rank() == 0:
  118. try:
  119. seed_str = self.seed_fn.__name__
  120. except AttributeError:
  121. seed_str = None
  122. print(
  123. f'SEED_LAYERS={self.seed_layers} BASE_SEED={self.base_seed} SEED_FN={seed_str}'
  124. )
  125. # Setup world info
  126. self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
  127. self.global_rank = dist.get_rank(group=self.world_group)
  128. self.world_size = dist.get_world_size(group=self.world_group)
  129. self.local_rank = int(os.environ.get("LOCAL_RANK", None))
  130. assert self.local_rank != None
  131. if topology:
  132. self._topo = topology
  133. self.num_stages = self._topo.get_dim('pipe')
  134. else:
  135. self.num_stages = num_stages
  136. if topology is None:
  137. if self.world_size % self.num_stages != 0:
  138. raise RuntimeError(
  139. f'num_stages ({self.num_stages}) must divide distributed world size ({self.world_size})'
  140. )
  141. dp = self.world_size // num_stages
  142. topology = PipeDataParallelTopology(num_pp=num_stages, num_dp=dp)
  143. self._topo = topology
  144. # Construct communicators for pipeline topology
  145. self._grid = PipelineParallelGrid(process_group=self.world_group,
  146. topology=self._topo)
  147. self.stage_id = self._topo.get_coord(self.global_rank).pipe
  148. # Initialize partition information
  149. self._layer_specs = list(layers)
  150. self._num_layers = len(self._layer_specs)
  151. self._local_start = 0
  152. self._local_stop = None
  153. self._partition_layers(method=partition_method)
  154. self.forward_funcs = []
  155. self.fwd_map = {}
  156. self.tied_modules = nn.ModuleDict()
  157. self.tied_weight_attrs = {}
  158. # Offset the random seed by the stage ID.
  159. #newseed = get_accelerator().initial_seed() + self._grid.get_stage_id()
  160. #ds_utils.set_random_seed(newseed)
  161. #with torch.random.fork_rng(devices=[get_accelerator().current_device_name()]):
  162. self._build()
  163. self.to(get_accelerator().device_name(self.local_rank))
  164. self.tied_comms = self._index_tied_modules()
  165. self._synchronize_tied_weights()
  166. self.activation_checkpoint_interval = activation_checkpoint_interval
  167. self.activation_checkpoint_func = activation_checkpoint_func
  168. def _build(self):
  169. specs = self._layer_specs
  170. for local_idx, layer in enumerate(specs[self._local_start:self._local_stop]):
  171. layer_idx = local_idx + self._local_start
  172. if self.seed_layers:
  173. if self.seed_fn:
  174. self.seed_fn(self.base_seed + layer_idx)
  175. else:
  176. ds_utils.set_random_seed(self.base_seed + layer_idx)
  177. # Recursively build PipelineModule objects
  178. if isinstance(layer, PipelineModule):
  179. raise NotImplementedError('RECURSIVE BUILD NOT YET IMPLEMENTED')
  180. # LayerSpec objects contain an nn.Module that should be allocated now.
  181. elif isinstance(layer, nn.Module):
  182. name = str(layer_idx)
  183. self.forward_funcs.append(layer)
  184. self.fwd_map.update({name: len(self.forward_funcs) - 1})
  185. self.add_module(name, layer)
  186. # TiedLayerSpec objects contain an nn.Module that should be allocated now.
  187. elif isinstance(layer, TiedLayerSpec):
  188. # Build and register the module if we haven't seen it before.
  189. if layer.key not in self.tied_modules:
  190. self.tied_modules[layer.key] = layer.build()
  191. self.tied_weight_attrs[layer.key] = layer.tied_weight_attr
  192. if layer.forward_fn is None:
  193. # Just use forward()
  194. self.forward_funcs.append(self.tied_modules[layer.key])
  195. else:
  196. # User specified fn with args (module, input)
  197. self.forward_funcs.append(
  198. partial(layer.forward_fn,
  199. self.tied_modules[layer.key]))
  200. # LayerSpec objects contain an nn.Module that should be allocated now.
  201. elif isinstance(layer, LayerSpec):
  202. module = layer.build()
  203. name = str(layer_idx)
  204. self.forward_funcs.append(module)
  205. self.fwd_map.update({name: len(self.forward_funcs) - 1})
  206. self.add_module(name, module)
  207. # Last option: layer may be a functional (e.g., lambda). We do nothing in
  208. # that case and just use it in forward()
  209. else:
  210. self.forward_funcs.append(layer)
  211. # All pipeline parameters should be considered as model parallel in the context
  212. # of our FP16 optimizer
  213. for p in self.parameters():
  214. p.ds_pipe_replicated = False
  215. def _count_layer_params(self):
  216. """Count the trainable parameters in individual layers.
  217. This routine will only build one layer at a time.
  218. Returns:
  219. A list of the number of parameters in each layer.
  220. """
  221. param_counts = [0] * len(self._layer_specs)
  222. for idx, layer in enumerate(self._layer_specs):
  223. if isinstance(layer, LayerSpec):
  224. l = layer.build()
  225. params = filter(lambda p: p.requires_grad, l.parameters())
  226. param_counts[idx] = sum(p.numel() for p in params)
  227. elif isinstance(layer, nn.Module):
  228. params = filter(lambda p: p.requires_grad, layer.parameters())
  229. param_counts[idx] = sum(p.numel() for p in params)
  230. return param_counts
  231. def _find_layer_type(self, layername):
  232. idxs = []
  233. typeregex = regex.compile(layername, regex.IGNORECASE)
  234. for idx, layer in enumerate(self._layer_specs):
  235. name = None
  236. if isinstance(layer, LayerSpec):
  237. name = layer.typename.__name__
  238. elif isinstance(layer, nn.Module):
  239. name = layer.__class__.__name__
  240. else:
  241. try:
  242. name = layer.__name__
  243. except AttributeError:
  244. continue
  245. if typeregex.search(name):
  246. idxs.append(idx)
  247. if len(idxs) == 0:
  248. raise RuntimeError(
  249. f"Partitioning '{layername}' found no valid layers to partition.")
  250. return idxs
  251. def forward(self, forward_input):
  252. # We need to offset the seed by the microbatch ID. Save it in a local var to
  253. # ensure it is preserved in the closure. Otherwise checkpointed forward funcs
  254. # will see a different offset.
  255. self.micro_offset += 1
  256. def exec_range_func(start, end):
  257. ''' Helper function to be used with checkpoint()
  258. Adapted from torch.utils.checkpoint:checkpoint_sequential()
  259. '''
  260. local_micro_offset = self.micro_offset + 1
  261. def exec_func(*inputs):
  262. # Single tensor inputs need to be unwrapped
  263. if len(inputs) == 1:
  264. inputs = inputs[0]
  265. for idx, layer in enumerate(self.forward_funcs[start:end]):
  266. self.curr_layer = idx + self._local_start
  267. if self.seed_layers:
  268. new_seed = (self.base_seed *
  269. local_micro_offset) + self.curr_layer
  270. if self.seed_fn:
  271. self.seed_fn(new_seed)
  272. else:
  273. ds_utils.set_random_seed(new_seed)
  274. inputs = layer(inputs)
  275. return inputs
  276. return exec_func
  277. if self.activation_checkpoint_interval == 0:
  278. func = exec_range_func(0, len(self.forward_funcs))
  279. x = func(forward_input)
  280. else:
  281. num_layers = len(self.forward_funcs)
  282. x = forward_input
  283. for start_idx in range(0, num_layers, self.activation_checkpoint_interval):
  284. end_idx = min(start_idx + self.activation_checkpoint_interval,
  285. num_layers)
  286. funcs = self.forward_funcs[start_idx:end_idx]
  287. # Since we either pass tensors or tuples of tensors without unpacking, we
  288. # need to be careful not to double-wrap tensors with tuple.
  289. if not isinstance(x, tuple):
  290. x = (x, )
  291. if self._is_checkpointable(funcs):
  292. x = self.activation_checkpoint_func(
  293. exec_range_func(start_idx,
  294. end_idx),
  295. *x)
  296. else:
  297. x = exec_range_func(start_idx, end_idx)(*x)
  298. return x
  299. def _partition_layers(self, method='uniform'):
  300. num_stages = self._topo.get_dim('pipe')
  301. stage_id = self._topo.get_coord(self.global_rank).pipe
  302. if self.global_rank == 0:
  303. logger.info(f'Partitioning pipeline stages with method {method}')
  304. method = method.lower()
  305. # Each stage gets a simple uniform number of layers.
  306. if method == 'uniform':
  307. num_layers = len(self._layer_specs)
  308. self.parts = ds_utils.partition_uniform(num_items=num_layers,
  309. num_parts=num_stages)
  310. elif method == 'parameters':
  311. param_counts = self._count_layer_params()
  312. self.parts = ds_utils.partition_balanced(weights=param_counts,
  313. num_parts=num_stages)
  314. elif method.startswith('type:'):
  315. layertype = method.split(':')[1]
  316. binary_weights = [0] * len(self._layer_specs)
  317. for idx in self._find_layer_type(layertype):
  318. binary_weights[idx] = 1
  319. self.parts = ds_utils.partition_balanced(weights=binary_weights,
  320. num_parts=num_stages)
  321. elif method == 'profile':
  322. raise NotImplementedError(f'Partitioning method {method} not implemented.')
  323. else:
  324. raise NotImplementedError(f'Partitioning method {method} not implemented.')
  325. # Print some information on the partitioning.
  326. if self.global_rank == 0:
  327. for stage in range(num_stages):
  328. start = self.parts[stage]
  329. stop = self.parts[stage + 1]
  330. print(f'stage={stage} layers={stop - start}')
  331. for idx, layer in enumerate(self._layer_specs[start:stop]):
  332. name = str(layer)
  333. if isinstance(layer, LayerSpec):
  334. name = layer.typename.__name__
  335. if isinstance(layer, nn.Module):
  336. name = layer.__class__.__name__
  337. else:
  338. try:
  339. name = layer.__name__
  340. except AttributeError:
  341. pass
  342. print(f' {idx+start:2d}: {name}')
  343. if self.loss_fn:
  344. try:
  345. print(f' loss: {self.loss_fn.__name__}')
  346. except AttributeError:
  347. print(f' loss: {self.loss_fn.__class__.__name__}')
  348. self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1])
  349. def allreduce_tied_weight_gradients(self):
  350. '''All reduce the gradients of the tied weights between tied stages'''
  351. for key, comm in self.tied_comms.items():
  352. weight = getattr(self.tied_modules[key], comm['weight_attr'])
  353. dist.all_reduce(weight.grad, group=comm['group'])
  354. def get_tied_weights_and_groups(self):
  355. weight_group_list = []
  356. for key, comm in self.tied_comms.items():
  357. weight = getattr(self.tied_modules[key], comm['weight_attr'])
  358. weight_group_list.append((weight, comm['group']))
  359. return weight_group_list
  360. def _synchronize_tied_weights(self):
  361. for key, comm in self.tied_comms.items():
  362. dist.broadcast(
  363. getattr(comm['module'],
  364. comm['weight_attr']),
  365. src=min(comm['ranks']),
  366. group=comm['group'],
  367. )
  368. def _index_tied_modules(self):
  369. ''' Build communication structures for tied modules. '''
  370. tied_comms = {}
  371. if self._topo.get_dim('pipe') == 1:
  372. return tied_comms
  373. specs = self._layer_specs
  374. tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec))
  375. for key in tie_keys:
  376. # Find the layers that the tied module appears in
  377. tied_layers = []
  378. for idx, layer in enumerate(specs):
  379. if isinstance(layer, TiedLayerSpec) and layer.key == key:
  380. tied_layers.append(idx)
  381. # Find all stages with this tied module
  382. # TODO: Would be nice to remove the nested data/model parallelism loops and
  383. # TODO: instead generalize in some way, since we really just care about the
  384. # TODO: stage that owns the tied layer. Then loop over each (dp, mp, ...)
  385. # TODO: fiber to generate process groups.
  386. tied_stages = set(self.stage_owner(idx) for idx in tied_layers)
  387. for dp in range(self._grid.data_parallel_size):
  388. for mp in range(self._grid.get_slice_parallel_world_size()):
  389. tied_ranks = []
  390. for s in sorted(tied_stages):
  391. if self._grid.get_slice_parallel_world_size() > 1:
  392. tied_ranks.append(
  393. self._grid.stage_to_global(stage_id=s,
  394. data=dp,
  395. model=mp))
  396. else:
  397. tied_ranks.append(
  398. self._grid.stage_to_global(stage_id=s,
  399. data=dp))
  400. group = dist.new_group(ranks=tied_ranks)
  401. # Record this tied module if we own a local copy of it.
  402. if self.global_rank in tied_ranks:
  403. assert key in self.tied_modules
  404. if key in self.tied_modules:
  405. tied_comms[key] = {
  406. 'ranks': tied_ranks,
  407. 'group': group,
  408. 'weight_attr': self.tied_weight_attrs[key],
  409. 'module': self.tied_modules[key],
  410. }
  411. # Only count the tied module once in the eyes of the FP16 optimizer
  412. if self.global_rank != tied_ranks[0]:
  413. for p in self.tied_modules[key].parameters():
  414. p.ds_pipe_replicated = True
  415. '''
  416. if len(tied_comms) > 0:
  417. print(f'RANK={self.global_rank} tied_comms={tied_comms}')
  418. '''
  419. return tied_comms
  420. def partitions(self):
  421. return self.parts
  422. def stage_owner(self, layer_idx):
  423. assert 0 <= layer_idx < self._num_layers
  424. for stage in range(self._topo.get_dim('pipe')):
  425. if self.parts[stage] <= layer_idx < self.parts[stage + 1]:
  426. return stage
  427. raise RuntimeError(f'Layer {layer_idx} not owned? parts={self.parts}')
  428. def _set_bounds(self, start=None, stop=None):
  429. """Manually define the range of layers that will be built on this process.
  430. These boundaries are treated as list slices and so start is inclusive and stop is
  431. exclusive. The default of None for both results in all layers being built
  432. locally.
  433. """
  434. self._local_start = start
  435. self._local_stop = stop
  436. def set_checkpoint_interval(self, interval):
  437. assert interval >= 0
  438. self.checkpoint_interval = interval
  439. def topology(self):
  440. """ ProcessTopology object to query process mappings. """
  441. return self._topo
  442. def mpu(self):
  443. return self._grid
  444. def num_pipeline_stages(self):
  445. return self._topo.get_dim('pipe')
  446. def ckpt_prefix(self, checkpoints_path, tag):
  447. """Build a prefix for all checkpoint files written by this module. """
  448. # All checkpoint files start with this
  449. rank_name = 'module'
  450. # Data parallelism is omitted from the naming convention because we are agnostic
  451. # to this in the checkpoint.
  452. omit_dims = frozenset(['data'])
  453. axes = [a for a in self._grid._topo.get_axis_names() if a not in omit_dims]
  454. for dim in axes:
  455. rank = getattr(self._grid._topo.get_coord(rank=self.global_rank), dim)
  456. rank_name += f'-{dim}_{rank:02d}'
  457. ckpt_name = os.path.join(checkpoints_path, str(tag), rank_name)
  458. return ckpt_name
  459. def ckpt_layer_path(self, ckpt_dir, local_layer_idx):
  460. """Customize a prefix for a specific pipeline module layer. """
  461. idx = local_layer_idx + self._local_start
  462. layer_ckpt_path = os.path.join(ckpt_dir, f'layer_{idx:02d}')
  463. rank_repr = self._grid._topo.get_rank_repr(rank=self.global_rank)
  464. if rank_repr != '':
  465. layer_ckpt_path += f'-{rank_repr}'
  466. layer_ckpt_path += '-model_states.pt'
  467. return layer_ckpt_path
  468. def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx):
  469. """Get all ckpt file list for a specific pipeline module layer. """
  470. idx = local_layer_idx + self._local_start
  471. layer_ckpt_path = os.path.join(ckpt_dir, f'layer_{idx:02d}-')
  472. layer_ckpt_path += "*model_states.pt"
  473. ckpt_files = glob.glob(layer_ckpt_path)
  474. ckpt_files.sort()
  475. return ckpt_files
  476. def save_state_dict(self, save_dir, checkpoint_engine):
  477. # Processes having the same model parallel rank on different data parallel instances
  478. # have identical layer weights. We can distribute the task of saving the layer weights
  479. # among the data parallel ranks. For example, if a pipeline stage has 9 layers and
  480. # if there are 2 data parallel instances, rank 0 will save the first 5 layers and
  481. # rank 1 will save the last 4.
  482. dp_rank = self._grid.data_parallel_id
  483. dp_size = self._grid.data_parallel_size
  484. num_layers = len(self.forward_funcs)
  485. if self.checkpoint_parallel_write_pipeline:
  486. # spread layers evenly across data parallel ranks
  487. offsets = ds_utils.partition_uniform(num_layers, dp_size)
  488. start, end = offsets[dp_rank], offsets[dp_rank + 1]
  489. else:
  490. # data parallel rank 0 writes all layers
  491. if dp_rank != 0:
  492. return
  493. start, end = 0, num_layers
  494. layer_list = self.forward_funcs[start:end]
  495. checkpoint_engine.makedirs(save_dir, exist_ok=True)
  496. for idx, layer in enumerate(layer_list):
  497. model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
  498. if not hasattr(layer, 'state_dict'):
  499. continue
  500. # We pass cloned tensors to torch.save() to avoid checkpoint bloat which occurs because torch.save()
  501. # saves the underlying storage rather than the slice of the storage corresponding to individual tensors.
  502. # This is a problem in DeepSpeed because we often allocate tensors using slices of large flattened buffers.
  503. # Tensor cloning helps to avoid this problem because the storage of cloned tensors are closer to the true size.
  504. # It is expected that the garbage collector will reclaim the cloned tensor storage to avoid memory bloat.
  505. # See https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing
  506. orig_state_dict = layer.state_dict()
  507. final_state_dict = type(orig_state_dict)(
  508. {k: v.clone()
  509. for k,
  510. v in orig_state_dict.items()})
  511. checkpoint_engine.save(final_state_dict, model_ckpt_path)
  512. def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
  513. for idx, layer in enumerate(self.forward_funcs):
  514. # Functions, etc. will not have state_dicts
  515. if not hasattr(layer, 'load_state_dict'):
  516. continue
  517. # get all checkpoint files for the layer.
  518. model_ckpt_list = self.ckpt_layer_path_list(load_dir, idx)
  519. mp_rank = self._grid.get_slice_parallel_rank()
  520. mp_world_size = self._grid.get_slice_parallel_world_size()
  521. sd_loader = SDLoaderFactory.get_sd_loader(
  522. model_ckpt_list,
  523. version=2.0,
  524. checkpoint_engine=checkpoint_engine)
  525. load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True)
  526. layer.load_state_dict(checkpoint)
  527. # if self._grid.data_parallel_id == 0:
  528. # logger.info(
  529. # f'RANK={self.global_rank} Loaded layer={idx+self._local_start} file={load_path}'
  530. # )
  531. self._synchronize_tied_weights()
  532. def _is_checkpointable(self, funcs):
  533. # This is an unfortunate hack related to torch and deepspeed activation checkpoint implementations.
  534. # Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things.
  535. # I presume it's related to the discrete inputs that cannot require_grad? Need to revisit.
  536. if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
  537. return all('ParallelTransformerLayerPipe' in f.__class__.__name__
  538. for f in funcs)
  539. if self.checkpointable_layers is not None:
  540. return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)
  541. params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
  542. return any(len(list(p)) > 0 for p in params)