engine.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from types import MethodType
  5. import torch
  6. from deepspeed import comm as dist
  7. from deepspeed.utils import logger
  8. from deepspeed.utils.timer import ThroughputTimer
  9. from deepspeed.accelerator import get_accelerator
  10. from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
  11. from deepspeed.utils.timer import FORWARD_MICRO_TIMER, FORWARD_GLOBAL_TIMER, BACKWARD_MICRO_TIMER, \
  12. BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_INNER_GLOBAL_TIMER, \
  13. BACKWARD_REDUCE_MICRO_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \
  14. STEP_MICRO_TIMER, STEP_GLOBAL_TIMER
  15. from ..utils import PartitionedTensor
  16. from ..dataloader import RepeatingLoader
  17. from ..zero.config import ZeroStageEnum
  18. from ..activation_checkpointing import checkpointing as ds_checkpointing
  19. from .module import PipelineModule, PipelineError
  20. from . import p2p
  21. from . import schedule
  22. TARGET_ID = -2
  23. LOG_STAGE = -2
  24. DATA_PARALLEL_ID = -2
  25. BATCH_INPUT_TIMER = 'batch_input'
  26. TRAIN_BATCH_TIMER = 'train_batch'
  27. PIPE_SEND_OUTPUT_TIMER = 'pipe_send_output'
  28. PIPE_SEND_GRAD_TIMER = 'pipe_send_grad'
  29. PIPE_RECV_INPUT_TIMER = 'pipe_recv_input'
  30. PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad'
  31. def is_even(number):
  32. return number % 2 == 0
  33. mem_alloced = 0
  34. mem_cached = 0
  35. def _tensor_bytes(tensor):
  36. return tensor.numel() * tensor.element_size()
  37. class PipelineEngine(DeepSpeedEngine):
  38. """ A training engine hybrid pipeline, data, and model parallel training.
  39. This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`
  40. is provided.
  41. """
  42. ID_TO_DTYPE = [
  43. torch.float32, torch.float64, torch.complex64, torch.complex128, torch.float16, torch.bfloat16, torch.uint8,
  44. torch.int8, torch.int16, torch.int32, torch.int64, torch.bool
  45. ]
  46. DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}
  47. def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
  48. super().__init__(*super_args, **super_kwargs)
  49. assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
  50. assert self.zero_optimization_stage(
  51. ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
  52. # We schedule the all-reduces, so disable it in super().backward()
  53. self.enable_backward_allreduce = False
  54. self.has_bool_tensors = has_bool_tensors
  55. self.eval_return_logits = False
  56. self.outputs = None
  57. # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
  58. self.pipeline_enable_backward_allreduce = True
  59. if self.elasticity_enabled():
  60. if not self.is_elastic_model_parallel_supported():
  61. assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
  62. " with pipeline parallelism."
  63. # pipeline step for logging
  64. self.log_batch_step_id = -1
  65. self.micro_batch_size = self.train_micro_batch_size_per_gpu()
  66. self.micro_batches = self.gradient_accumulation_steps()
  67. # Set Grid and Communication Groups
  68. self.grid = self.module._grid
  69. if self.grid.get_global_rank() == 0:
  70. logger.info(f'CONFIG: micro_batches={self.micro_batches} '
  71. f'micro_batch_size={self.micro_batch_size}')
  72. self.global_rank = self.grid.get_global_rank()
  73. assert self.dp_world_size == self.grid.data_parallel_size
  74. assert self.train_batch_size() == \
  75. self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size
  76. # Set Stage Inf
  77. self.num_stages = self.grid.pipe_parallel_size
  78. self.stage_id = self.grid.get_stage_id()
  79. self.prev_stage = self.stage_id - 1
  80. self.next_stage = self.stage_id + 1
  81. self.data_iterator = None
  82. self.batch_fn = None
  83. self._force_grad_boundary = False
  84. self.batch_timer = ThroughputTimer(batch_size=self.train_batch_size(),
  85. logging_fn=self.tput_log,
  86. monitor_memory=False,
  87. steps_per_output=self.steps_per_print())
  88. # PipelineEngine needs to handle data loading specially due to only the first
  89. # and last stages loading inputs/labels. We construct a sampler that uses
  90. if self.training_data:
  91. self._build_data_iter(self.training_data)
  92. self.is_pipe_parallel = self.grid.pipe_parallel_size > 1
  93. self.is_data_parallel = self.grid.data_parallel_size > 1
  94. self.is_model_parallel = self.grid.model_parallel_size > 1
  95. # Partition input/output buffers
  96. # XXX temporarily disable while I revert some partition hacks.
  97. self.is_pipe_partitioned = self.is_model_parallel
  98. self.is_grad_partitioned = self.is_model_parallel
  99. model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())
  100. num_params = sum([p.numel() for p in model_parameters])
  101. unique_params = num_params
  102. # Subtract tied parameters if we don't own them
  103. if self.module.tied_comms:
  104. tied_params = 0
  105. for key, d in self.module.tied_comms.items():
  106. if self.global_rank != min(d['ranks']):
  107. tied_params += sum(p.numel() for p in d['module'].parameters())
  108. unique_params -= tied_params
  109. params_tensor = torch.LongTensor(data=[num_params, unique_params]).to(self.device)
  110. dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group())
  111. params_tensor = params_tensor.tolist()
  112. total_params = params_tensor[0]
  113. unique_params = params_tensor[1]
  114. if self.grid.data_parallel_id == 0:
  115. logger.info(f'RANK={self.global_rank} '
  116. f'STAGE={self.stage_id} '
  117. f'LAYERS={self.module._local_stop - self.module._local_start} '
  118. f'[{self.module._local_start}, {self.module._local_stop}) '
  119. f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) '
  120. f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) '
  121. f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)')
  122. #initialize peer-2-peer communication and allreduce groups
  123. if self.is_pipe_parallel:
  124. p2p.init_process_groups(self.grid)
  125. # Pipeline buffers
  126. self.num_pipe_buffers = 0
  127. self.pipe_buffers = {
  128. 'inputs': [], # batch input and received activations
  129. 'labels': [], # labels from batch input
  130. 'outputs': [], # activations
  131. 'output_tensors': [], # tensor object to preserve backward graph
  132. }
  133. self.pipe_recv_buf = None
  134. self.grad_layer = None
  135. self.meta_buffer = None
  136. self.first_output_send = True
  137. self.first_gradient_send = True
  138. #stores the loss for the current micro batch being processed
  139. self.loss = torch.tensor(0.0).to(self.device)
  140. #stores the loss for the entire batch
  141. self.total_loss = None
  142. self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
  143. self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
  144. if self._config.pipeline['activation_checkpoint_interval'] > 0:
  145. self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval']
  146. self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline
  147. if self.is_last_stage():
  148. self.loss_model = self.module.loss_fn
  149. self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe'
  150. # Initialize pipeline communicators. Just send a 0.
  151. if is_even(self.stage_id):
  152. if not self.is_last_stage():
  153. p2p.send(self.loss, self.next_stage)
  154. if not self.is_first_stage():
  155. p2p.recv(self.loss, self.prev_stage)
  156. else:
  157. if not self.is_first_stage():
  158. p2p.recv(self.loss, self.prev_stage)
  159. if not self.is_last_stage():
  160. p2p.send(self.loss, self.next_stage)
  161. # XXX look into timer reporting timing
  162. # Initialize some timers because of early weirdness.
  163. if self.wall_clock_breakdown():
  164. self.timers(FORWARD_MICRO_TIMER).start()
  165. self.timers(FORWARD_MICRO_TIMER).stop()
  166. self.timers(BACKWARD_MICRO_TIMER).start()
  167. self.timers(BACKWARD_MICRO_TIMER).stop()
  168. self.timers(BACKWARD_INNER_MICRO_TIMER).start()
  169. self.timers(BACKWARD_INNER_MICRO_TIMER).stop()
  170. self.timers(BACKWARD_REDUCE_MICRO_TIMER).start()
  171. self.timers(BACKWARD_REDUCE_MICRO_TIMER).stop()
  172. self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).start()
  173. self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).stop()
  174. self.timers(STEP_MICRO_TIMER).start()
  175. self.timers(STEP_MICRO_TIMER).stop()
  176. def set_has_attention_mask(self, value):
  177. assert isinstance(value, bool)
  178. self.has_attention_mask = value
  179. def _build_data_iter(self, dataset):
  180. sampler = torch.utils.data.distributed.DistributedSampler(dataset,
  181. num_replicas=self.dp_world_size,
  182. rank=self.mpu.get_data_parallel_rank(),
  183. shuffle=False)
  184. # Build a loader and make it repeating.
  185. pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler)
  186. pipe_dataloader = RepeatingLoader(pipe_dataloader)
  187. self.set_dataloader(pipe_dataloader)
  188. def _exec_reduce_tied_grads(self):
  189. # We need to run this first to write to self.averaged_gradients;
  190. # since this class turns `enable_backward_allreduce` off,
  191. # `self.overlapping_partition_gradients_reduce_epilogue()` defined in the DeepSpeedEngine
  192. # never actually runs. I suspect this is because of efficiency problems; get_flat_partition in
  193. # stage2.py might do something expensive; someone will have to look into that later. But
  194. # in the meantime, this fixes ZeRO2 + Pipelining enough to run a demo. Further profiling
  195. # needed to decide if it actually breaks everything.
  196. # (see https://github.com/EleutherAI/gpt-neox/issues/62#issuecomment-761471944)
  197. if self.zero_optimization_partition_gradients():
  198. self.optimizer.overlapping_partition_gradients_reduce_epilogue()
  199. weight_group_list = self.module.get_tied_weights_and_groups()
  200. for weight, group in weight_group_list:
  201. grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad
  202. dist.all_reduce(grad, group=group)
  203. def _exec_reduce_grads(self):
  204. self._force_grad_boundary = True
  205. if self.pipeline_enable_backward_allreduce:
  206. if self.bfloat16_enabled():
  207. # PP+BF16 work for ZeRO Stage 1
  208. self._bf16_reduce_grads()
  209. else:
  210. self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
  211. self._force_grad_boundary = False
  212. def _bf16_reduce_grads(self):
  213. # Make our own list of gradients from the optimizer's FP32 grads
  214. grads = []
  215. self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(),
  216. elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
  217. def _reserve_pipe_buffers(self, num_buffers):
  218. """Ensure that each pipeline buffer has at least ``num_buffers`` slots.
  219. This method only reserves slots and does not allocate tensors.
  220. Args:
  221. num_buffers (int): The number of buffers to reserve.
  222. """
  223. if self.num_pipe_buffers >= num_buffers:
  224. return
  225. num_added = num_buffers - self.num_pipe_buffers
  226. for key in self.pipe_buffers:
  227. self.pipe_buffers[key].extend([None] * num_added)
  228. self.num_pipe_buffers = num_buffers
  229. def reset_activation_shape(self):
  230. """Reset the buffers when the shape of activation and gradient change.
  231. For example, for curriculum learning that changes the seqlen of each
  232. sample, we need to call this whenever the seqlen is going to change.
  233. """
  234. self.first_output_send = True
  235. self.pipe_recv_buf = None
  236. self.grad_layer = None
  237. self.meta_buffer = None
  238. def train_batch(self, data_iter=None):
  239. """Progress the pipeline to train the next batch of data. The engine will ingest
  240. ``self.train_batch_size()`` total samples collectively across all workers.
  241. An iterator that over training data should be provided as an argument
  242. unless ``deepspeed.initialize()`` was provided a training set. In that event,
  243. the training data will automatically be read.
  244. .. warning::
  245. A total of ``self.gradient_accumulation_steps()`` entries will be pulled
  246. from ``data_iter`` by each pipeline. There must be sufficient
  247. data left in ``data_iter`` or else a ``StopIteration`` will halt training.
  248. DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
  249. that wraps data loaders to automatically restart upon a ``StopIteration``.
  250. Args:
  251. data_iter (Iterator, optional): Iterator of training data.
  252. Returns:
  253. The arithmetic mean of the losses computed this batch.
  254. """
  255. if not torch._C.is_grad_enabled():
  256. raise RuntimeError(f'train_batch() requires gradients enabled. Use eval_batch() instead.')
  257. # Curriculum learning could change activation shape
  258. if self.curriculum_enabled_legacy():
  259. new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
  260. self.global_steps + 1)
  261. if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
  262. self.reset_activation_shape()
  263. self.curriculum_scheduler_legacy.first_step = False
  264. elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
  265. self.global_steps):
  266. self.reset_activation_shape()
  267. if data_iter:
  268. self.set_dataiterator(data_iter)
  269. self.module.train()
  270. self.total_loss = None
  271. self._compute_loss = True
  272. # Do the work
  273. self.timers(TRAIN_BATCH_TIMER).start()
  274. sched = schedule.TrainSchedule(micro_batches=self.micro_batches,
  275. stages=self.num_stages,
  276. stage_id=self.stage_id)
  277. self._exec_schedule(sched)
  278. self.agg_train_loss = self._aggregate_total_loss()
  279. self.timers(TRAIN_BATCH_TIMER).stop()
  280. if self.global_steps % self.steps_per_print() == 0:
  281. if self.global_rank == 0:
  282. elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0
  283. iter_time = elapsed / self.steps_per_print()
  284. tput = self.train_batch_size() / iter_time
  285. print(f'steps: {self.global_steps} '
  286. f'loss: {self.agg_train_loss:0.4f} '
  287. f'iter time (s): {iter_time:0.3f} '
  288. f'samples/sec: {tput:0.3f}')
  289. # Monitoring
  290. if self.global_rank == 0 and self.monitor.enabled:
  291. self.summary_events = [(f'Train/Samples/train_loss', self.agg_train_loss.mean().item(),
  292. self.global_samples)]
  293. self.monitor.write_events(self.summary_events)
  294. if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0:
  295. self.timers.log([
  296. PIPE_SEND_OUTPUT_TIMER,
  297. PIPE_SEND_GRAD_TIMER,
  298. PIPE_RECV_INPUT_TIMER,
  299. PIPE_RECV_GRAD_TIMER,
  300. ])
  301. # TODO: should return precisely what loss returned and allow others to be queried?
  302. return self.agg_train_loss
  303. def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg'):
  304. """Evaluate the pipeline on a batch of data from ``data_iter``. The
  305. engine will evaluate ``self.train_batch_size()`` total samples
  306. collectively across all workers.
  307. This method is equivalent to:
  308. .. code-block:: python
  309. module.eval()
  310. with torch.no_grad():
  311. output = module(batch)
  312. .. warning::
  313. A total of ``self.gradient_accumulation_steps()`` entries will be pulled
  314. from ``data_iter`` by each pipeline. There must be sufficient
  315. data left in ``data_iter`` or else a ``StopIteration`` will halt training.
  316. DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
  317. that wraps data loaders to automatically restart upon a ``StopIteration``.
  318. Args:
  319. data_iter (Iterator): Iterator of data to evaluate.
  320. Returns:
  321. The arithmetic mean of the losses computed this batch.
  322. """
  323. self.eval_return_logits = return_logits
  324. self.module.eval()
  325. # Curriculum learning could change activation shape
  326. if self.curriculum_enabled_legacy():
  327. new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
  328. self.global_steps + 1)
  329. if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
  330. self.reset_activation_shape()
  331. self.curriculum_scheduler_legacy.first_step = False
  332. elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
  333. self.global_steps):
  334. self.reset_activation_shape()
  335. eval_output = None
  336. self._compute_loss = compute_loss
  337. # Use the provided data iterator
  338. train_iterator = self.data_iterator
  339. self.set_dataiterator(data_iter)
  340. # Do the work
  341. sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
  342. stages=self.num_stages,
  343. stage_id=self.stage_id)
  344. # prevent dead-lock with multiple evals sequence
  345. dist.barrier()
  346. with torch.no_grad():
  347. self._exec_schedule(sched)
  348. if self.is_last_stage():
  349. eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output)
  350. if compute_loss:
  351. eval_output = self._bcast_pipe_scalar(eval_output)
  352. if self.global_rank == 0 and self.monitor.enabled:
  353. self.summary_events = [(f'Train/Samples/eval_loss', eval_output.mean().item(), self.global_samples)]
  354. self.monitor.write_events(self.summary_events)
  355. # Restore the training iterator
  356. self.set_dataiterator(train_iterator)
  357. # Reset any buffers that may have been populated during the forward passes.
  358. #ds_checkpointing.reset()
  359. self.eval_return_logits = False
  360. if return_logits:
  361. outputs = self.outputs
  362. self.outputs = None
  363. return eval_output, outputs
  364. return eval_output
  365. def set_train_batch_size(self, train_batch_size):
  366. """Adjust the global batch size by increasing or decreasing the number of
  367. micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
  368. (i.e., ``train_micro_batch_size_per_gpu``) is not changed.
  369. Args:
  370. train_batch_size (int): The new global batch size for training.
  371. Raises:
  372. ValueError: if ``train_batch_size`` is not divisible by the
  373. configured micro-batch size and data parallelism.
  374. """
  375. super().set_train_batch_size(train_batch_size)
  376. self.micro_batches = self.gradient_accumulation_steps()
  377. def is_first_stage(self):
  378. """True if this process is in the first stage in the pipeline."""
  379. return self.stage_id == 0
  380. def is_last_stage(self):
  381. """True if this process is in the last stage in the pipeline."""
  382. return self.stage_id == self.num_stages - 1
  383. def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True):
  384. if reduce is None:
  385. return outputs
  386. if reduce.lower() == 'avg':
  387. # first sum over all microbatches
  388. if torch.is_tensor(outputs[0]):
  389. reduced = sum(outputs)
  390. else:
  391. assert isinstance(outputs, (list, tuple))
  392. reduced = [torch.zeros_like(o) for o in outputs[0]]
  393. for idx, out in outputs:
  394. reduced[idx] += out
  395. # Average over the microbatches
  396. reduced = self._scale_loss_by_gas(reduced)
  397. # Average over DP groups
  398. if reduce_dp and self.is_data_parallel:
  399. if torch.is_tensor(reduced):
  400. dist.all_reduce(reduced, group=self.mpu.get_data_parallel_group())
  401. reduced /= self.dp_world_size
  402. else:
  403. for idx in range(len(reduced)):
  404. dist.all_reduce(reduced[idx], group=self.mpu.get_data_parallel_group())
  405. reduced[idx] /= self.dp_world_size
  406. return reduced
  407. else:
  408. raise NotImplementedError(f'reduction type {reduce} not supported.')
  409. def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32):
  410. # Default to last stage (e.g., for broadcasting loss)
  411. if src_rank is None:
  412. src_rank = self.grid.stage_to_global(self.num_stages - 1)
  413. assert src_rank in self.grid.pp_group
  414. if self.global_rank == src_rank:
  415. result = data.clone().detach().type(dtype).to(self.device)
  416. else:
  417. result = torch.Tensor([0.]).type(dtype).to(self.device)
  418. dist.broadcast(tensor=result, src=src_rank, group=self.mpu.get_pipe_parallel_group())
  419. return result
  420. def _aggregate_total_loss(self):
  421. # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
  422. if self.is_last_stage():
  423. loss = self._scale_loss_by_gas(self.total_loss)
  424. self.dp_group_loss = loss.clone().detach()
  425. ## Average loss across all data-parallel groups
  426. agg_loss = self.dp_group_loss.clone().detach()
  427. #print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)
  428. if self.is_data_parallel:
  429. dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
  430. agg_loss /= self.dp_world_size
  431. assert self.global_rank in self.grid.pp_group
  432. losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
  433. if self.is_pipe_parallel:
  434. dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
  435. else:
  436. # Get loss from last stage
  437. src_rank = self.grid.stage_to_global(self.num_stages - 1)
  438. assert src_rank in self.grid.pp_group
  439. losses = torch.Tensor([0., 0.]).to(self.device)
  440. dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group())
  441. self.dp_group_loss = losses[0].clone().detach()
  442. agg_loss = losses[1].clone().detach()
  443. return agg_loss
  444. def set_dataloader(self, loader):
  445. """"""
  446. if self.is_first_stage() or self.is_last_stage():
  447. self.training_dataloader = loader
  448. self.data_iterator = iter(self.training_dataloader)
  449. def set_dataiterator(self, iterator):
  450. """ Store an iterator to sample for training data. """
  451. if self.is_first_stage() or self.is_last_stage():
  452. self.training_dataloader = None
  453. self.data_iterator = iterator
  454. def set_batch_fn(self, fn):
  455. """Execute a post-processing function on input data.
  456. Args:
  457. fn (function): The function to run.
  458. """
  459. self.batch_fn = fn
  460. def is_gradient_accumulation_boundary(self):
  461. """True if the engine is executing a gradient reduction or optimizer step instruction.
  462. This is overridden from :class:`DeepSpeedEngine` to force reductions
  463. and steps when the pipeline engine is instructed to do so.
  464. Returns:
  465. bool: whether reductions and optimizer steps should occur.
  466. """
  467. return self._force_grad_boundary
  468. def log_for_device(self, *msg):
  469. if LOG_STAGE == self.stage_id or LOG_STAGE == -1:
  470. if DATA_PARALLEL_ID == self.grid.data_parallel_id or DATA_PARALLEL_ID == -1:
  471. print(
  472. f'RANK={dist.get_rank()} '
  473. f'PIPE-ID={self.stage_id} '
  474. f'DATA-ID={self.grid.data_parallel_id} '
  475. f'MBATCH-ID={self.microbatch_id} '
  476. f'STEP-ID={self.log_batch_step_id} '
  477. '::',
  478. *msg,
  479. flush=True)
  480. def tput_log(self, *msg):
  481. if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0:
  482. print(*msg)
  483. def _next_batch(self):
  484. # If using 3D parallelism, only some first-stage ranks may do IO
  485. batch = None
  486. if self.data_iterator is not None:
  487. batch = next(self.data_iterator)
  488. # Any post-processing, like broadcasting across a slice-parallel group.
  489. if self.batch_fn:
  490. batch = self.batch_fn(batch)
  491. return batch
  492. def _exec_forward_pass(self, buffer_id):
  493. self.tput_timer.start()
  494. self.mem_status('BEFORE FWD', reset_max=True)
  495. if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
  496. inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
  497. else:
  498. inputs = self.pipe_buffers['inputs'][buffer_id].clone()
  499. # collect the partitioned input from the previous stage
  500. if self.is_pipe_partitioned and not self.is_first_stage():
  501. part_input = PartitionedTensor.from_meta(meta=inputs[0],
  502. local_part=inputs[1],
  503. group=self.grid.get_slice_parallel_group())
  504. inputs = (part_input.full(), *inputs[2:])
  505. inputs[0].requires_grad = True
  506. # skip mask
  507. #inputs[1].requires_grad = True
  508. part_input = None
  509. inputs = inputs[0] if len(inputs) == 1 else inputs
  510. self.pipe_buffers['inputs'][buffer_id] = inputs
  511. # Zero out the gradients each time we use the tensor because only the data in
  512. # tensor changes across batches
  513. self._zero_grads(inputs)
  514. outputs = super().forward(inputs)
  515. # Reset activation checkpointing buffers.
  516. # Need to call this between evaluation iterations
  517. if not self.module.training:
  518. ds_checkpointing.reset()
  519. # Partition the outputs if we are not the last stage
  520. if self.is_pipe_partitioned and not self.is_last_stage():
  521. if isinstance(outputs, tuple):
  522. first_output = outputs[0]
  523. # TODO: Improve pipe partitioning to pass multiple tensors that require grads
  524. assert all([torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:]])
  525. outputs_tail = outputs[1:]
  526. elif torch.is_tensor(outputs):
  527. first_output = outputs
  528. outputs_tail = []
  529. else:
  530. raise ValueError("expecting a tensor or a tuple of tensors")
  531. part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group())
  532. # Clear the large output data, but save the computation graph
  533. first_output.data = torch.zeros(1)
  534. self.pipe_buffers['output_tensors'][buffer_id] = first_output
  535. # Inject the partitioned tensor into the output before sending
  536. outputs = (part.to_meta(), part.data(), *outputs_tail)
  537. part = None
  538. self.pipe_buffers['outputs'][buffer_id] = outputs
  539. # Optionally compute loss on the last device
  540. if self.is_last_stage():
  541. if self._compute_loss and self.module.loss_fn is not None:
  542. labels = self.pipe_buffers['labels'][buffer_id]
  543. self.loss = self.module.loss_fn(outputs, labels)
  544. else:
  545. # Some models just return loss from forward()
  546. self.loss = outputs
  547. if self.eval_return_logits:
  548. self.outputs = outputs
  549. if isinstance(self.loss, torch.Tensor):
  550. self.fwd_outputs.append(self.loss.detach())
  551. if self.total_loss is None:
  552. self.total_loss = torch.zeros_like(self.loss)
  553. self.total_loss += self.loss.detach()
  554. else:
  555. self.fwd_outputs.append([l.detach() for l in self.loss])
  556. if self.total_loss is None:
  557. self.total_loss = [torch.zeros_like(l) for l in self.loss]
  558. for idx, l in enumerate(self.loss):
  559. self.total_loss[idx] += l.detach()
  560. def _exec_backward_pass(self, buffer_id):
  561. assert self.optimizer is not None, "must provide optimizer during " \
  562. "init in order to use backward"
  563. self.mem_status('BEFORE BWD', reset_max=True)
  564. # The last stage just runs backward on the loss using DeepSpeed's typical
  565. # mechanisms.
  566. if self.is_last_stage():
  567. super().backward(self.loss)
  568. self.mem_status('AFTER BWD')
  569. return
  570. outputs = self.pipe_buffers['outputs'][buffer_id]
  571. if self.wall_clock_breakdown():
  572. self.timers(BACKWARD_MICRO_TIMER).start()
  573. self.timers(BACKWARD_GLOBAL_TIMER).start()
  574. self.timers(BACKWARD_INNER_MICRO_TIMER).start()
  575. self.timers(BACKWARD_INNER_GLOBAL_TIMER).start()
  576. # Reconstruct if we previously partitioned the output. We must be
  577. # careful to also restore the computational graph of the tensors we partitioned.
  578. if self.is_pipe_partitioned:
  579. if self.is_grad_partitioned:
  580. part_output = PartitionedTensor.from_meta(meta=outputs[0],
  581. local_part=outputs[1],
  582. group=self.grid.get_slice_parallel_group())
  583. self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
  584. outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:])
  585. else:
  586. # Already restored from partition
  587. self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0]
  588. outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:])
  589. grad_tensors = self.grad_layer
  590. if self.is_grad_partitioned:
  591. #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
  592. part_grad = PartitionedTensor.from_meta(meta=self.grad_layer[0],
  593. local_part=self.grad_layer[1],
  594. group=self.grid.get_slice_parallel_group())
  595. grad_tensors = (part_grad.full(), *grad_tensors[2:])
  596. part_grad = None
  597. #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
  598. if self.bfloat16_enabled() and not self.is_last_stage():
  599. # manually call because we don't call optimizer.backward()
  600. self.optimizer.clear_lp_grads()
  601. # This handles either a single tensor or tuple of tensors.
  602. if isinstance(outputs, tuple):
  603. out_tensors = [t for t in outputs if t.is_floating_point()]
  604. assert len(out_tensors) == len(grad_tensors)
  605. torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors)
  606. else:
  607. torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
  608. if self.bfloat16_enabled() and not self.is_last_stage():
  609. # manually call because we don't call optimizer.backward()
  610. self.optimizer.update_hp_grads(clear_lp_grads=False)
  611. # Free up the memory from the output of forward()
  612. self.pipe_buffers['output_tensors'][buffer_id] = None
  613. self.pipe_buffers['outputs'][buffer_id] = None
  614. grad_tensors = None
  615. if self.wall_clock_breakdown():
  616. self.timers(BACKWARD_INNER_MICRO_TIMER).stop()
  617. self.timers(BACKWARD_INNER_GLOBAL_TIMER).stop()
  618. self.timers(BACKWARD_MICRO_TIMER).stop()
  619. self.timers(BACKWARD_GLOBAL_TIMER).stop()
  620. self.mem_status('AFTER BWD')
  621. def _exec_load_micro_batch(self, buffer_id):
  622. if self.wall_clock_breakdown():
  623. self.timers(BATCH_INPUT_TIMER).start()
  624. batch = self._next_batch()
  625. if self.is_first_stage():
  626. loaded = None
  627. if torch.is_tensor(batch[0]):
  628. loaded = batch[0].clone().to(self.device).detach()
  629. loaded.requires_grad = loaded.is_floating_point()
  630. else:
  631. assert isinstance(batch[0], (tuple, list))
  632. # Assume list or tuple
  633. loaded = []
  634. for x in batch[0]:
  635. assert torch.is_tensor(x)
  636. mine = x.clone().detach().to(self.device)
  637. mine.requires_grad = mine.is_floating_point()
  638. loaded.append(mine)
  639. loaded = tuple(loaded)
  640. self.pipe_buffers['inputs'][buffer_id] = loaded
  641. if self.is_last_stage():
  642. loaded = batch[1]
  643. if torch.is_tensor(batch[1]):
  644. loaded = batch[1].to(self.device)
  645. # XXX: torch 1.6.0 DataLoader will auto convert tuple to list
  646. elif isinstance(batch[1], (tuple, list)):
  647. loaded = []
  648. for x in batch[1]:
  649. assert torch.is_tensor(x)
  650. x = x.to(self.device).detach()
  651. loaded.append(x)
  652. loaded = tuple(loaded)
  653. self.pipe_buffers['labels'][buffer_id] = loaded
  654. if self.wall_clock_breakdown():
  655. self.timers(BATCH_INPUT_TIMER).stop()
  656. def _send_tensor_meta(self, buffer, recv_stage):
  657. """ Communicate metadata about upcoming p2p transfers.
  658. Metadata is communicated in this order:
  659. * type (0: tensor, 1: list)
  660. * num_tensors if type=list
  661. foreach tensor in buffer:
  662. * ndims
  663. * shape
  664. """
  665. send_bytes = 0
  666. if isinstance(buffer, torch.Tensor):
  667. type_tensor = torch.LongTensor(data=[0]).to(self.device)
  668. p2p.send(type_tensor, recv_stage)
  669. send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
  670. send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
  671. p2p.send(send_ndims, recv_stage)
  672. p2p.send(send_shape, recv_stage)
  673. send_bytes += _tensor_bytes(buffer)
  674. elif isinstance(buffer, list):
  675. assert (False)
  676. type_tensor = torch.LongTensor(data=[1]).to(self.device)
  677. p2p.send(type_tensor, recv_stage)
  678. count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
  679. p2p.send(count_tensor, recv_stage)
  680. for tensor in buffer:
  681. assert isinstance(tensor, torch.Tensor)
  682. send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
  683. send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
  684. p2p.send(send_ndims, recv_stage)
  685. p2p.send(send_shape, recv_stage)
  686. send_bytes += _tensor_bytes(tensor)
  687. elif isinstance(buffer, tuple):
  688. type_tensor = torch.LongTensor(data=[2]).to(self.device)
  689. p2p.send(type_tensor, recv_stage)
  690. count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
  691. p2p.send(count_tensor, recv_stage)
  692. for idx, tensor in enumerate(buffer):
  693. assert isinstance(tensor, torch.Tensor)
  694. send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
  695. send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
  696. send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device)
  697. p2p.send(send_dtype, recv_stage)
  698. p2p.send(send_ndims, recv_stage)
  699. p2p.send(send_shape, recv_stage)
  700. # Useful for performance debugging.
  701. '''
  702. new_bytes = _tensor_bytes(tensor)
  703. send_bytes += _tensor_bytes(tensor)
  704. # Useful for performance debugging.
  705. if self.grid.data_parallel_id == 0:
  706. print(
  707. f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
  708. )
  709. '''
  710. else:
  711. raise NotImplementedError(f'Could not send meta type {type(buffer)}')
  712. # Useful for performance debugging.
  713. '''
  714. if self.grid.data_parallel_id == 0:
  715. print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB')
  716. '''
  717. def _recv_tensor_meta(self, send_stage):
  718. """Receive metadata about upcoming p2p transfers and return allocated buffers.
  719. Metadata is communicated in this order:
  720. * type (0: tensor, 1: list)
  721. * num_tensors if type=list
  722. foreach tensor in buffer:
  723. * ndims
  724. * shape
  725. Returns:
  726. Allocated buffer for receiving from send_stage.
  727. """
  728. type_tensor = torch.LongTensor(data=[0]).to(self.device)
  729. p2p.recv(type_tensor, send_stage)
  730. recv_type = type_tensor.item()
  731. # A single tensor will be sent.
  732. if recv_type == 0:
  733. recv_ndims = torch.LongTensor(data=[0]).to(self.device)
  734. p2p.recv(recv_ndims, send_stage)
  735. recv_ndims = recv_ndims.item()
  736. recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
  737. p2p.recv(recv_shape, send_stage)
  738. recv_shape = recv_shape.tolist()
  739. return self._allocate_buffer(recv_shape, num_buffers=1)[0]
  740. # List or tuple of tensors
  741. elif recv_type == 1 or recv_type == 2:
  742. count_tensor = torch.LongTensor(data=[0]).to(self.device)
  743. p2p.recv(count_tensor, send_stage)
  744. num_tensors = count_tensor.item()
  745. recv_shapes_and_dtypes = []
  746. for idx in range(num_tensors):
  747. recv_dtype = torch.LongTensor(data=[0]).to(self.device)
  748. p2p.recv(recv_dtype, send_stage)
  749. recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
  750. recv_ndims = torch.LongTensor(data=[0]).to(self.device)
  751. p2p.recv(recv_ndims, send_stage)
  752. recv_ndims = recv_ndims.item()
  753. recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
  754. p2p.recv(recv_shape, send_stage)
  755. recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))
  756. buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
  757. # Convert to tuples if requested.
  758. if recv_type == 2:
  759. buffers = tuple(buffers)
  760. return buffers
  761. else:
  762. raise NotImplementedError(f'Could not receive type {type(recv_type)}')
  763. def _exec_send_activations(self, buffer_id):
  764. if self.wall_clock_breakdown():
  765. self.timers(PIPE_SEND_OUTPUT_TIMER).start()
  766. outputs = self.pipe_buffers['outputs'][buffer_id]
  767. # NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
  768. # We could do char, but with half() we can eventually flatten with other fp16
  769. # messages (TODO)
  770. if self.has_attention_mask or self.has_bool_tensors:
  771. outputs = list(outputs)
  772. outputs[-1] = outputs[-1].half()
  773. outputs = tuple(outputs)
  774. if self.first_output_send:
  775. self.first_output_send = False
  776. self._send_tensor_meta(outputs, self.next_stage)
  777. if isinstance(outputs, torch.Tensor):
  778. p2p.send(outputs, self.next_stage)
  779. elif isinstance(outputs, tuple):
  780. for idx, buffer in enumerate(outputs):
  781. p2p.send(buffer, self.next_stage)
  782. else:
  783. raise NotImplementedError('Could not send output of type '
  784. f'{type(outputs)}')
  785. # Restore the boolean tensor
  786. if self.has_attention_mask or self.has_bool_tensors:
  787. outputs = list(outputs)
  788. outputs[-1] = outputs[-1].bool()
  789. outputs = tuple(outputs)
  790. if self.wall_clock_breakdown():
  791. self.timers(PIPE_SEND_OUTPUT_TIMER).stop()
  792. def _exec_send_grads(self, buffer_id):
  793. if self.wall_clock_breakdown():
  794. self.timers(PIPE_SEND_GRAD_TIMER).start()
  795. inputs = self.pipe_buffers['inputs'][buffer_id]
  796. # Partition the gradient
  797. if self.is_grad_partitioned:
  798. if isinstance(inputs, tuple):
  799. first_input = inputs[0]
  800. assert all([torch.is_tensor(elt) for elt in inputs[1:]])
  801. inputs_grad_tail = [elt.grad for elt in inputs[1:] if elt.grad is not None]
  802. elif torch.is_tensor(inputs):
  803. first_input = inputs
  804. inputs_grad_tail = []
  805. else:
  806. raise ValueError("expecting a tensor or a tuple of tensors")
  807. assert torch.is_tensor(first_input)
  808. part = PartitionedTensor(tensor=first_input.grad, group=self.grid.get_slice_parallel_group())
  809. inputs = (part.to_meta(), part.data(), *inputs_grad_tail)
  810. # XXX Terrible hack
  811. # Drop the attention mask from the input buffer here. It does not have
  812. # a grad that needs to be communicated. We free the buffer immediately
  813. # after, so no need to restore it. The receiver also has a hack that skips
  814. # the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
  815. if self.has_attention_mask or self.has_bool_tensors:
  816. inputs = list(inputs)
  817. inputs.pop()
  818. inputs = tuple(inputs)
  819. if isinstance(inputs, torch.Tensor):
  820. assert inputs.grad is not None
  821. p2p.send(inputs.grad, self.prev_stage)
  822. else:
  823. # XXX terrible hacky branch
  824. if self.is_grad_partitioned:
  825. # First two sends are partitioned gradient
  826. p2p.send(inputs[0], self.prev_stage)
  827. p2p.send(inputs[1], self.prev_stage)
  828. else:
  829. for idx, buffer in enumerate(inputs):
  830. # Skip tensors that will not produce a grad
  831. if not buffer.is_floating_point():
  832. assert buffer.grad is None
  833. continue
  834. assert buffer.grad is not None
  835. p2p.send(buffer.grad, self.prev_stage)
  836. # We can free up the input buffer now
  837. self.pipe_buffers['inputs'][buffer_id] = None
  838. if self.wall_clock_breakdown():
  839. self.timers(PIPE_SEND_GRAD_TIMER).stop()
  840. def _exec_recv_activations(self, buffer_id):
  841. if self.wall_clock_breakdown():
  842. self.timers(PIPE_RECV_INPUT_TIMER).start()
  843. recvd = None
  844. # Allocate the buffer if necessary
  845. if self.pipe_recv_buf is None:
  846. self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)
  847. if isinstance(self.pipe_recv_buf, torch.Tensor):
  848. p2p.recv(self.pipe_recv_buf, self.prev_stage)
  849. recvd = self.pipe_recv_buf.clone().detach()
  850. recvd.requires_grad = recvd.is_floating_point()
  851. else:
  852. assert isinstance(self.pipe_recv_buf, tuple)
  853. recvd = [None] * len(self.pipe_recv_buf)
  854. for idx, buffer in enumerate(self.pipe_recv_buf):
  855. assert torch.is_tensor(buffer)
  856. # XXX hardcode meta type
  857. if self.is_pipe_partitioned and idx == 0 and buffer.dtype != torch.long:
  858. if self.meta_buffer is None:
  859. self.meta_buffer = torch.zeros(buffer.size(), dtype=torch.long, device=self.device)
  860. buffer = self.meta_buffer
  861. p2p.recv(buffer, self.prev_stage)
  862. recvd[idx] = buffer.clone().detach()
  863. # NCCL does not like to send torch.BoolTensor types, so un-cast the
  864. # attention mask
  865. if self.has_attention_mask or self.has_bool_tensors:
  866. recvd[-1] = recvd[-1].bool()
  867. recvd = tuple(recvd)
  868. for buffer in recvd:
  869. buffer.requires_grad = buffer.is_floating_point()
  870. self.pipe_buffers['inputs'][buffer_id] = recvd
  871. if self.wall_clock_breakdown():
  872. self.timers(PIPE_RECV_INPUT_TIMER).stop()
  873. def _exec_recv_grads(self, buffer_id):
  874. if self.wall_clock_breakdown():
  875. self.timers(PIPE_RECV_GRAD_TIMER).start()
  876. outputs = self.pipe_buffers['outputs'][buffer_id]
  877. # XXX these shapes are hardcoded for Megatron
  878. # Restore partitioned output if it was partitioned and we are sending full gradients
  879. if self.is_pipe_partitioned and not self.is_grad_partitioned:
  880. part_output = PartitionedTensor.from_meta(meta=outputs[0],
  881. local_part=outputs[1],
  882. group=self.grid.get_slice_parallel_group())
  883. outputs[0].data = part_output.full()
  884. outputs = (outputs[0], *outputs[2:])
  885. # save for backward
  886. self.pipe_buffers['outputs'][buffer_id] = outputs
  887. # Allocate gradient if necessary
  888. if self.grad_layer is None:
  889. if isinstance(outputs, torch.Tensor):
  890. s = list(outputs.size())
  891. self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0]
  892. else:
  893. # XXX This is a HACK
  894. # When we exchange activations/gradients, the two pipe stages
  895. # need to issue the send/recv with the same buffer sizes or
  896. # else there is a deadlock. The is_floating_point() filter is
  897. # used to avoid sending gradients for tensors that do not
  898. # produce gradients. When TP>1, we partition the first
  899. # activations/gradients across TP ranks to save communication
  900. # volume and memory. That partitioned tensor is represented as
  901. # two tensors: a 1/TPth chunk of the original data and also a
  902. # small LongTensor storing the metadata used to reconstruct on
  903. # the other side. When combined, the floating point filter also
  904. # filtered out the metadata tensor. This quick (hacky) fix just
  905. # branches on is_grad_partitioned so we don't filter out the
  906. # metadata tensor.
  907. if self.is_grad_partitioned:
  908. sizes_and_dtypes = [(list(t.size()), t.dtype)
  909. for t in outputs[:2]] + [(list(t.size()), t.dtype)
  910. for t in outputs[2:] if t.is_floating_point()]
  911. else:
  912. sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()]
  913. self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0]
  914. if isinstance(self.grad_layer, torch.Tensor):
  915. p2p.recv(self.grad_layer, self.next_stage)
  916. else:
  917. assert isinstance(outputs, tuple)
  918. for idx, buffer in enumerate(self.grad_layer):
  919. # XXX GPT-2 hack
  920. if self.is_grad_partitioned and idx == 0 and buffer.dtype != torch.long:
  921. buffer.data = torch.zeros(buffer.size(), dtype=torch.long, device=self.device)
  922. p2p.recv(buffer, self.next_stage)
  923. if self.wall_clock_breakdown():
  924. self.timers(PIPE_RECV_GRAD_TIMER).stop()
  925. def _exec_optimizer_step(self, lr_kwargs=None):
  926. if self.wall_clock_breakdown():
  927. self.timers(STEP_MICRO_TIMER).start()
  928. self.timers(STEP_GLOBAL_TIMER).start()
  929. self.mem_status('BEFORE STEP', reset_max=True)
  930. self._force_grad_boundary = True
  931. self._take_model_step(lr_kwargs)
  932. self._force_grad_boundary = False
  933. self.mem_status('AFTER STEP')
  934. if self.global_rank == 0 and self.monitor.enabled:
  935. self.summary_events = [(f'Train/Samples/lr', self.get_lr()[0], self.global_samples)]
  936. if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
  937. self.summary_events.append(
  938. (f'Train/Samples/loss_scale', self.optimizer.cur_scale, self.global_samples))
  939. self.monitor.write_events(self.summary_events)
  940. if self.wall_clock_breakdown():
  941. self.timers(STEP_MICRO_TIMER).stop()
  942. self.timers(STEP_GLOBAL_TIMER).stop()
  943. if self.global_steps % self.steps_per_print() == 0:
  944. self.timers.log([
  945. BATCH_INPUT_TIMER,
  946. FORWARD_MICRO_TIMER,
  947. BACKWARD_MICRO_TIMER,
  948. BACKWARD_INNER_MICRO_TIMER,
  949. BACKWARD_REDUCE_MICRO_TIMER,
  950. STEP_MICRO_TIMER,
  951. ])
  952. if self.global_steps % self.steps_per_print() == 0:
  953. self.timers.log([
  954. FORWARD_GLOBAL_TIMER,
  955. BACKWARD_GLOBAL_TIMER,
  956. BACKWARD_INNER_GLOBAL_TIMER,
  957. BACKWARD_REDUCE_GLOBAL_TIMER,
  958. STEP_GLOBAL_TIMER,
  959. ])
  960. def _zero_grads(self, inputs):
  961. if isinstance(inputs, torch.Tensor):
  962. if inputs.grad is not None:
  963. inputs.grad.data.zero_()
  964. else:
  965. for t in inputs:
  966. if t.grad is not None:
  967. t.grad.data.zero_()
  968. def _allocate_zeros(self, shape, **kwargs):
  969. """ Allocate a tensor of zeros on the engine's device.
  970. Arguments:
  971. shape: the shape of the tensor to allocate
  972. kwargs: passed to torch.zeros()
  973. Returns:
  974. A tensor from torch.zeros() allocated on self.device.
  975. """
  976. if "dtype" not in kwargs:
  977. if self.fp16_enabled():
  978. kwargs["dtype"] = torch.half
  979. if self.bfloat16_enabled():
  980. kwargs["dtype"] = torch.bfloat16
  981. return torch.zeros(shape, device=self.device, **kwargs)
  982. def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
  983. buffers = []
  984. if num_buffers == -1:
  985. num_buffers = self.num_pipe_buffers
  986. for count in range(num_buffers):
  987. buffers.append(self._allocate_zeros(shape, **kwargs))
  988. return buffers
  989. def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1):
  990. buffers = []
  991. if num_buffers == -1:
  992. num_buffers = self.num_pipe_buffers
  993. for count in range(num_buffers):
  994. buffer = []
  995. for shape, dtype in shapes_and_dtypes:
  996. buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad))
  997. buffers.append(buffer)
  998. return buffers
  999. def forward(self, *args, **kwargs):
  1000. """Disabled for pipeline parallel training. See ``train_batch()``. """
  1001. raise PipelineError("Only train_batch() is accessible in pipeline mode.")
  1002. def backward(self, *args, **kwargs):
  1003. """Disabled for pipeline parallel training. See ``train_batch()``. """
  1004. raise PipelineError("Only train_batch() is accessible in pipeline mode.")
  1005. def step(self, *args, **kwargs):
  1006. """Disabled for pipeline parallel training. See ``train_batch()``. """
  1007. raise PipelineError("Only train_batch() is accessible in pipeline mode.")
  1008. def mem_status(self, msg, print_rank=-1, reset_max=False):
  1009. return
  1010. global mem_alloced, mem_cached
  1011. if not self.global_steps == 0 or not self.global_steps == 9:
  1012. #return
  1013. pass
  1014. if self.mpu.get_data_parallel_rank() != 0:
  1015. return
  1016. if self.global_rank != 0:
  1017. return
  1018. rank = self.global_rank
  1019. if print_rank != -1 and rank != print_rank:
  1020. return
  1021. get_accelerator().synchronize()
  1022. if reset_max:
  1023. get_accelerator().reset_max_memory_cached()
  1024. get_accelerator().reset_max_memory_allocated()
  1025. new_alloced = get_accelerator().memory_allocated()
  1026. new_cached = get_accelerator().memory_cached()
  1027. delta_alloced = new_alloced - mem_alloced
  1028. delta_cached = new_cached - mem_cached
  1029. mem_cached = new_cached
  1030. mem_alloced = new_alloced
  1031. max_alloced = get_accelerator().max_memory_allocated()
  1032. max_cached = get_accelerator().max_memory_cached()
  1033. # convert to GB for printing
  1034. new_alloced /= 1024**3
  1035. new_cached /= 1024**3
  1036. delta_alloced /= 1024**3
  1037. delta_cached /= 1024**3
  1038. max_alloced /= 1024**3
  1039. max_cached /= 1024**3
  1040. print(
  1041. f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS', msg,
  1042. f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
  1043. f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')
  1044. def module_state_dict(self, exclude_frozen_parameters=False):
  1045. """Override hack to save a pipe model and return the directory path of the save.
  1046. This method should only be called by DeepSpeed's ``save_checkpoint()``. The
  1047. recommended way of saving a ``PipelineModule`` outside of ``save_checkpoint()``
  1048. is ``save_state_dict()``.
  1049. Returns:
  1050. None
  1051. """
  1052. assert isinstance(self.module, PipelineModule)
  1053. assert self._curr_ckpt_path is not None, \
  1054. "PipelineEngine expects module_state_dict() to be called from save_checkpoint()"
  1055. self.module.save_state_dict(self._curr_ckpt_path,
  1056. checkpoint_engine=self.checkpoint_engine,
  1057. exclude_frozen_params=exclude_frozen_parameters)
  1058. return None
  1059. def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False):
  1060. """Override hack to instead use a directory path.
  1061. This is important because pipeline models checkpoint by layer instead of rank.
  1062. If ``state_dict`` is not ``None`` or a ``str``, we revert to ``super()`` expecting a ``dict``.
  1063. Args:
  1064. state_dict (str, None): unused
  1065. strict (bool, optional): Strict state loading. Defaults to True.
  1066. """
  1067. assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism"
  1068. state_dict = checkpoint['module']
  1069. if (state_dict is not None) and (not isinstance(state_dict, str)):
  1070. super().load_module_state_dict(state_dict, strict)
  1071. return
  1072. self.module.load_state_dir(load_dir=self._curr_ckpt_path,
  1073. strict=strict,
  1074. checkpoint_engine=self.checkpoint_engine)
  1075. # A map of PipeInstruction types to methods. Each method will be executed with the
  1076. # kwargs provided to the PipeInstruction from the scheduler.
  1077. _INSTRUCTION_MAP = {
  1078. schedule.OptimizerStep: _exec_optimizer_step,
  1079. schedule.ReduceGrads: _exec_reduce_grads,
  1080. schedule.ReduceTiedGrads: _exec_reduce_tied_grads,
  1081. schedule.LoadMicroBatch: _exec_load_micro_batch,
  1082. schedule.ForwardPass: _exec_forward_pass,
  1083. schedule.BackwardPass: _exec_backward_pass,
  1084. schedule.SendActivation: _exec_send_activations,
  1085. schedule.RecvActivation: _exec_recv_activations,
  1086. schedule.SendGrad: _exec_send_grads,
  1087. schedule.RecvGrad: _exec_recv_grads,
  1088. }
  1089. def _exec_schedule(self, pipe_schedule):
  1090. # Reserve and reset buffers.
  1091. self._reserve_pipe_buffers(pipe_schedule.num_pipe_buffers())
  1092. self.fwd_outputs = []
  1093. # For each step in the schedule
  1094. for step_cmds in pipe_schedule:
  1095. # For each instruction in the step
  1096. for cmd in step_cmds:
  1097. if type(cmd) not in self._INSTRUCTION_MAP:
  1098. raise RuntimeError(f'{self.__class__.__name__} does not understand instruction {repr(cmd)}')
  1099. # Equivalent to: self._exec_forward_pass(buffer_id=0)
  1100. self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
  1101. self._exec_instr(**cmd.kwargs)