engine.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376
  1. # Copyright 2019 The Microsoft DeepSpeed Team
  2. from types import MethodType
  3. import torch
  4. from deepspeed import comm as dist
  5. from deepspeed.utils import logger
  6. from deepspeed.utils.timer import ThroughputTimer
  7. from deepspeed.accelerator import get_accelerator
  8. from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
  9. from ..utils import PartitionedTensor
  10. from ..dataloader import RepeatingLoader
  11. from .module import PipelineModule, PipelineError
  12. from . import p2p
  13. from . import schedule
  14. TARGET_ID = -2
  15. LOG_STAGE = -2
  16. DATA_PARALLEL_ID = -2
  17. def is_even(number):
  18. return number % 2 == 0
  19. mem_alloced = 0
  20. mem_cached = 0
  21. def _tensor_bytes(tensor):
  22. return tensor.numel() * tensor.element_size()
  23. class PipelineEngine(DeepSpeedEngine):
  24. """ A training engine hybrid pipeline, data, and model parallel training.
  25. This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`
  26. is provided.
  27. """
  28. ID_TO_DTYPE = [
  29. torch.float32,
  30. torch.float64,
  31. torch.complex64,
  32. torch.complex128,
  33. torch.float16,
  34. torch.bfloat16,
  35. torch.uint8,
  36. torch.int8,
  37. torch.int16,
  38. torch.int32,
  39. torch.int64,
  40. torch.bool
  41. ]
  42. DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}
  43. def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
  44. super().__init__(*super_args, **super_kwargs)
  45. assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
  46. assert self.zero_optimization_stage() < 2, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
  47. # We schedule the all-reduces, so disable it in super().backward()
  48. self.enable_backward_allreduce = False
  49. self.has_bool_tensors = has_bool_tensors
  50. self.eval_return_logits = False
  51. self.outputs = None
  52. # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
  53. self.pipeline_enable_backward_allreduce = True
  54. if self.elasticity_enabled():
  55. if not self.is_elastic_model_parallel_supported():
  56. assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
  57. " with pipeline parallelism."
  58. # pipeline step for logging
  59. self.log_batch_step_id = -1
  60. self.micro_batch_size = self.train_micro_batch_size_per_gpu()
  61. self.micro_batches = self.gradient_accumulation_steps()
  62. # Set Grid and Communication Groups
  63. self.grid = self.module._grid
  64. if self.grid.get_global_rank() == 0:
  65. logger.info(f'CONFIG: micro_batches={self.micro_batches} '
  66. f'micro_batch_size={self.micro_batch_size}')
  67. self.global_rank = self.grid.get_global_rank()
  68. assert self.dp_world_size == self.grid.data_parallel_size
  69. assert self.train_batch_size() == \
  70. self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size
  71. # Set Stage Inf
  72. self.num_stages = self.grid.pipe_parallel_size
  73. self.stage_id = self.grid.get_stage_id()
  74. self.prev_stage = self.stage_id - 1
  75. self.next_stage = self.stage_id + 1
  76. self.data_iterator = None
  77. self.batch_fn = None
  78. self._force_grad_boundary = False
  79. self.batch_timer = ThroughputTimer(batch_size=self.train_batch_size(),
  80. logging_fn=self.tput_log,
  81. monitor_memory=False,
  82. steps_per_output=self.steps_per_print())
  83. # PipelineEngine needs to handle data loading specially due to only the first
  84. # and last stages loading inputs/labels. We construct a sampler that uses
  85. if self.training_data:
  86. self._build_data_iter(self.training_data)
  87. self.is_pipe_parallel = self.grid.pipe_parallel_size > 1
  88. self.is_data_parallel = self.grid.data_parallel_size > 1
  89. self.is_model_parallel = self.grid.model_parallel_size > 1
  90. # Partition input/output buffers
  91. # XXX temporarily disable while I revert some partition hacks.
  92. self.is_pipe_partitioned = self.is_model_parallel
  93. self.is_grad_partitioned = self.is_model_parallel
  94. model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())
  95. num_params = sum([p.numel() for p in model_parameters])
  96. unique_params = num_params
  97. # Subtract tied parameters if we don't own them
  98. if self.module.tied_comms:
  99. tied_params = 0
  100. for key, d in self.module.tied_comms.items():
  101. if self.global_rank != min(d['ranks']):
  102. tied_params += sum(p.numel() for p in d['module'].parameters())
  103. unique_params -= tied_params
  104. params_tensor = torch.LongTensor(data=[num_params,
  105. unique_params]).to(self.device)
  106. dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group())
  107. params_tensor = params_tensor.tolist()
  108. total_params = params_tensor[0]
  109. unique_params = params_tensor[1]
  110. if self.grid.data_parallel_id == 0:
  111. logger.info(f'RANK={self.global_rank} '
  112. f'STAGE={self.stage_id} '
  113. f'LAYERS={self.module._local_stop - self.module._local_start} '
  114. f'[{self.module._local_start}, {self.module._local_stop}) '
  115. f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) '
  116. f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) '
  117. f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)')
  118. #initialize peer-2-peer communication and allreduce groups
  119. if self.is_pipe_parallel:
  120. p2p.init_process_groups(self.grid)
  121. # Pipeline buffers
  122. self.num_pipe_buffers = 0
  123. self.pipe_buffers = {
  124. 'inputs' : [], # batch input and received activations
  125. 'labels' : [], # labels from batch input
  126. 'outputs' : [], # activations
  127. 'output_tensors' : [], # tensor object to preserve backward graph
  128. }
  129. self.pipe_recv_buf = None
  130. self.grad_layer = None
  131. self.meta_buffer = None
  132. self.first_output_send = True
  133. self.first_gradient_send = True
  134. #stores the loss for the current micro batch being processed
  135. self.loss = torch.tensor(0.0).to(self.device)
  136. #stores the loss for the entire batch
  137. self.total_loss = None
  138. self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
  139. self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
  140. if self._config.pipeline['activation_checkpoint_interval'] > 0:
  141. self.module.activation_checkpoint_interval = self._config.pipeline[
  142. 'activation_checkpoint_interval']
  143. self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline
  144. if self.is_last_stage():
  145. self.loss_model = self.module.loss_fn
  146. self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe'
  147. # Initialize pipeline communicators. Just send a 0.
  148. if is_even(self.stage_id):
  149. if not self.is_last_stage():
  150. p2p.send(self.loss, self.next_stage)
  151. if not self.is_first_stage():
  152. p2p.recv(self.loss, self.prev_stage)
  153. else:
  154. if not self.is_first_stage():
  155. p2p.recv(self.loss, self.prev_stage)
  156. if not self.is_last_stage():
  157. p2p.send(self.loss, self.next_stage)
  158. # XXX look into timer reporting timing
  159. # Initialize some timers because of early weirdness.
  160. if self.wall_clock_breakdown():
  161. self.timers('forward_microstep').start()
  162. self.timers('forward_microstep').stop()
  163. self.timers('backward_microstep').start()
  164. self.timers('backward_microstep').stop()
  165. self.timers('backward_inner_microstep').start()
  166. self.timers('backward_inner_microstep').stop()
  167. self.timers('backward_allreduce_microstep').start()
  168. self.timers('backward_allreduce_microstep').stop()
  169. self.timers('backward_allreduce').start()
  170. self.timers('backward_allreduce').stop()
  171. self.timers('step_microstep').start()
  172. self.timers('step_microstep').stop()
  173. def set_has_attention_mask(self, value):
  174. assert isinstance(value, bool)
  175. self.has_attention_mask = value
  176. def _build_data_iter(self, dataset):
  177. sampler = torch.utils.data.distributed.DistributedSampler(
  178. dataset,
  179. num_replicas=self.dp_world_size,
  180. rank=self.mpu.get_data_parallel_rank(),
  181. shuffle=False)
  182. # Build a loader and make it repeating.
  183. pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler)
  184. pipe_dataloader = RepeatingLoader(pipe_dataloader)
  185. self.set_dataloader(pipe_dataloader)
  186. def _exec_reduce_tied_grads(self):
  187. # We need to run this first to write to self.averaged_gradients;
  188. # since this class turns `enable_backward_allreduce` off,
  189. # `self.overlapping_partition_gradients_reduce_epilogue()` defined in the DeepSpeedEngine
  190. # never actually runs. I suspect this is because of efficiency problems; get_flat_partition in
  191. # stage2.py might do something expensive; someone will have to look into that later. But
  192. # in the meantime, this fixes ZeRO2 + Pipelining enough to run a demo. Further profiling
  193. # needed to decide if it actually breaks everything.
  194. # (see https://github.com/EleutherAI/gpt-neox/issues/62#issuecomment-761471944)
  195. if self.zero_optimization_partition_gradients():
  196. self.optimizer.overlapping_partition_gradients_reduce_epilogue()
  197. weight_group_list = self.module.get_tied_weights_and_groups()
  198. for weight, group in weight_group_list:
  199. grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad
  200. dist.all_reduce(grad, group=group)
  201. def _exec_reduce_grads(self):
  202. self._force_grad_boundary = True
  203. if self.pipeline_enable_backward_allreduce:
  204. if self.bfloat16_enabled():
  205. if self.zero_optimization_stage() == 0:
  206. self._bf16_reduce_grads()
  207. else:
  208. assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported"
  209. raise NotImplementedError()
  210. else:
  211. self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
  212. self._force_grad_boundary = False
  213. def _bf16_reduce_grads(self):
  214. # Make our own list of gradients from the optimizer's FP32 grads
  215. grads = []
  216. self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(),
  217. elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
  218. def _reserve_pipe_buffers(self, num_buffers):
  219. """Ensure that each pipeline buffer has at least ``num_buffers`` slots.
  220. This method only reserves slots and does not allocate tensors.
  221. Args:
  222. num_buffers (int): The number of buffers to reserve.
  223. """
  224. if self.num_pipe_buffers >= num_buffers:
  225. return
  226. num_added = num_buffers - self.num_pipe_buffers
  227. for key in self.pipe_buffers:
  228. self.pipe_buffers[key].extend([None] * num_added)
  229. self.num_pipe_buffers = num_buffers
  230. def reset_activation_shape(self):
  231. """Reset the buffers when the shape of activation and gradient change.
  232. For example, for curriculum learning that changes the seqlen of each
  233. sample, we need to call this whenever the seqlen is going to change.
  234. """
  235. self.first_output_send = True
  236. self.pipe_recv_buf = None
  237. self.grad_layer = None
  238. self.meta_buffer = None
  239. def train_batch(self, data_iter=None):
  240. """Progress the pipeline to train the next batch of data. The engine will ingest
  241. ``self.train_batch_size()`` total samples collectively across all workers.
  242. An iterator that over training data should be provided as an argument
  243. unless ``deepspeed.initialize()`` was provided a training set. In that event,
  244. the training data will automatically be read.
  245. .. warning::
  246. A total of ``self.gradient_accumulation_steps()`` entries will be pulled
  247. from ``data_iter`` by each pipeline. There must be sufficient
  248. data left in ``data_iter`` or else a ``StopIteration`` will halt training.
  249. DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
  250. that wraps data loaders to automatically restart upon a ``StopIteration``.
  251. Args:
  252. data_iter (Iterator, optional): Iterator of training data.
  253. Returns:
  254. The arithmetic mean of the losses computed this batch.
  255. """
  256. if not torch._C.is_grad_enabled():
  257. raise RuntimeError(
  258. f'train_batch() requires gradients enabled. Use eval_batch() instead.')
  259. # Curriculum learning could change activation shape
  260. if self.curriculum_enabled_legacy():
  261. new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
  262. self.global_steps + 1)
  263. if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
  264. self.reset_activation_shape()
  265. self.curriculum_scheduler_legacy.first_step = False
  266. elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
  267. self.global_steps):
  268. self.reset_activation_shape()
  269. if data_iter:
  270. self.set_dataiterator(data_iter)
  271. self.module.train()
  272. self.total_loss = None
  273. self._compute_loss = True
  274. # Do the work
  275. self.timers('train_batch').start()
  276. sched = schedule.TrainSchedule(micro_batches=self.micro_batches,
  277. stages=self.num_stages,
  278. stage_id=self.stage_id)
  279. self._exec_schedule(sched)
  280. self.agg_train_loss = self._aggregate_total_loss()
  281. self.timers('train_batch').stop()
  282. if self.global_steps % self.steps_per_print() == 0:
  283. if self.global_rank == 0:
  284. elapsed = self.timers('train_batch').elapsed(reset=True) / 1000.0
  285. iter_time = elapsed / self.steps_per_print()
  286. tput = self.train_batch_size() / iter_time
  287. print(f'steps: {self.global_steps} '
  288. f'loss: {self.agg_train_loss:0.4f} '
  289. f'iter time (s): {iter_time:0.3f} '
  290. f'samples/sec: {tput:0.3f}')
  291. # Monitoring
  292. if self.global_rank == 0 and self.monitor.enabled:
  293. self.summary_events = [(f'Train/Samples/train_loss',
  294. self.agg_train_loss.mean().item(),
  295. self.global_samples)]
  296. self.monitor.write_events(self.summary_events)
  297. if self.wall_clock_breakdown(
  298. ) and self.global_steps % self.steps_per_print() == 0:
  299. self.timers.log([
  300. 'pipe_send_output',
  301. 'pipe_send_grad',
  302. 'pipe_recv_input',
  303. 'pipe_recv_grad'
  304. ])
  305. # TODO: should return precisely what loss returned and allow others to be queried?
  306. return self.agg_train_loss
  307. def eval_batch(self,
  308. data_iter,
  309. return_logits=False,
  310. compute_loss=True,
  311. reduce_output='avg'):
  312. """Evaluate the pipeline on a batch of data from ``data_iter``. The
  313. engine will evaluate ``self.train_batch_size()`` total samples
  314. collectively across all workers.
  315. This method is equivalent to:
  316. .. code-block:: python
  317. module.eval()
  318. with torch.no_grad():
  319. output = module(batch)
  320. .. warning::
  321. A total of ``self.gradient_accumulation_steps()`` entries will be pulled
  322. from ``data_iter`` by each pipeline. There must be sufficient
  323. data left in ``data_iter`` or else a ``StopIteration`` will halt training.
  324. DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
  325. that wraps data loaders to automatically restart upon a ``StopIteration``.
  326. Args:
  327. data_iter (Iterator): Iterator of data to evaluate.
  328. Returns:
  329. The arithmetic mean of the losses computed this batch.
  330. """
  331. self.eval_return_logits = return_logits
  332. self.module.eval()
  333. # Curriculum learning could change activation shape
  334. if self.curriculum_enabled_legacy():
  335. new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
  336. self.global_steps + 1)
  337. if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
  338. self.reset_activation_shape()
  339. self.curriculum_scheduler_legacy.first_step = False
  340. elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
  341. self.global_steps):
  342. self.reset_activation_shape()
  343. eval_output = None
  344. self._compute_loss = compute_loss
  345. # Use the provided data iterator
  346. train_iterator = self.data_iterator
  347. self.set_dataiterator(data_iter)
  348. # Do the work
  349. sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
  350. stages=self.num_stages,
  351. stage_id=self.stage_id)
  352. # prevent dead-lock with multiple evals sequence
  353. dist.barrier()
  354. with torch.no_grad():
  355. self._exec_schedule(sched)
  356. if self.is_last_stage():
  357. eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output)
  358. if compute_loss:
  359. eval_output = self._bcast_pipe_scalar(eval_output)
  360. if self.global_rank == 0 and self.monitor.enabled:
  361. self.summary_events = [(f'Train/Samples/eval_loss',
  362. eval_output.mean().item(),
  363. self.global_samples)]
  364. self.monitor.write_events(self.summary_events)
  365. # Restore the training iterator
  366. self.set_dataiterator(train_iterator)
  367. # Reset any buffers that may have been populated during the forward passes.
  368. #ds_checkpointing.reset()
  369. self.eval_return_logits = False
  370. if return_logits:
  371. outputs = self.outputs
  372. self.outputs = None
  373. return eval_output, outputs
  374. return eval_output
  375. def set_train_batch_size(self, train_batch_size):
  376. """Adjust the global batch size by increasing or decreasing the number of
  377. micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
  378. (i.e., ``train_micro_batch_size_per_gpu``) is not changed.
  379. Args:
  380. train_batch_size (int): The new global batch size for training.
  381. Raises:
  382. ValueError: if ``train_batch_size`` is not divisible by the
  383. configured micro-batch size and data parallelism.
  384. """
  385. super().set_train_batch_size(train_batch_size)
  386. self.micro_batches = self.gradient_accumulation_steps()
  387. def is_first_stage(self):
  388. """True if this process is in the first stage in the pipeline."""
  389. return self.stage_id == 0
  390. def is_last_stage(self):
  391. """True if this process is in the last stage in the pipeline."""
  392. return self.stage_id == self.num_stages - 1
  393. def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True):
  394. if reduce is None:
  395. return outputs
  396. if reduce.lower() == 'avg':
  397. # first sum over all microbatches
  398. if torch.is_tensor(outputs[0]):
  399. reduced = sum(outputs)
  400. else:
  401. assert isinstance(outputs, (list, tuple))
  402. reduced = [torch.zeros_like(o) for o in outputs[0]]
  403. for idx, out in outputs:
  404. reduced[idx] += out
  405. # Average over the microbatches
  406. reduced = self._scale_loss_by_gas(reduced)
  407. # Average over DP groups
  408. if reduce_dp and self.is_data_parallel:
  409. if torch.is_tensor(reduced):
  410. dist.all_reduce(reduced, group=self.mpu.get_data_parallel_group())
  411. reduced /= self.dp_world_size
  412. else:
  413. for idx in range(len(reduced)):
  414. dist.all_reduce(reduced[idx],
  415. group=self.mpu.get_data_parallel_group())
  416. reduced[idx] /= self.dp_world_size
  417. return reduced
  418. else:
  419. raise NotImplementedError(f'reduction type {reduce} not supported.')
  420. def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32):
  421. # Default to last stage (e.g., for broadcasting loss)
  422. if src_rank is None:
  423. src_rank = self.grid.stage_to_global(self.num_stages - 1)
  424. assert src_rank in self.grid.pp_group
  425. if self.global_rank == src_rank:
  426. result = data.clone().detach()
  427. else:
  428. result = torch.Tensor([0.]).type(dtype).to(self.device)
  429. dist.broadcast(tensor=result,
  430. src=src_rank,
  431. group=self.mpu.get_pipe_parallel_group())
  432. return result
  433. def _aggregate_total_loss(self):
  434. # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
  435. if self.is_last_stage():
  436. loss = self._scale_loss_by_gas(self.total_loss)
  437. self.dp_group_loss = loss.clone().detach()
  438. ## Average loss across all data-parallel groups
  439. agg_loss = self.dp_group_loss.clone().detach()
  440. #print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)
  441. if self.is_data_parallel:
  442. dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
  443. agg_loss /= self.dp_world_size
  444. assert self.global_rank in self.grid.pp_group
  445. losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
  446. dist.broadcast(tensor=losses,
  447. src=self.global_rank,
  448. group=self.mpu.get_pipe_parallel_group())
  449. else:
  450. # Get loss from last stage
  451. src_rank = self.grid.stage_to_global(self.num_stages - 1)
  452. assert src_rank in self.grid.pp_group
  453. losses = torch.Tensor([0., 0.]).to(self.device)
  454. dist.broadcast(tensor=losses,
  455. src=src_rank,
  456. group=self.grid.get_pipe_parallel_group())
  457. self.dp_group_loss = losses[0].clone().detach()
  458. agg_loss = losses[1].clone().detach()
  459. return agg_loss
  460. def set_dataloader(self, loader):
  461. """"""
  462. if self.is_first_stage() or self.is_last_stage():
  463. self.training_dataloader = loader
  464. self.data_iterator = iter(self.training_dataloader)
  465. def set_dataiterator(self, iterator):
  466. """ Store an iterator to sample for training data. """
  467. if self.is_first_stage() or self.is_last_stage():
  468. self.training_dataloader = None
  469. self.data_iterator = iterator
  470. def set_batch_fn(self, fn):
  471. """Execute a post-processing function on input data.
  472. Args:
  473. fn (function): The function to run.
  474. """
  475. self.batch_fn = fn
  476. def is_gradient_accumulation_boundary(self):
  477. """True if the engine is executing a gradient reduction or optimizer step instruction.
  478. This is overridden from :class:`DeepSpeedEngine` to force reductions
  479. and steps when the pipeline engine is instructed to do so.
  480. Returns:
  481. bool: whether reductions and optimizer steps should occur.
  482. """
  483. return self._force_grad_boundary
  484. def log_for_device(self, *msg):
  485. if LOG_STAGE == self.stage_id or LOG_STAGE == -1:
  486. if DATA_PARALLEL_ID == self.grid.data_parallel_id or DATA_PARALLEL_ID == -1:
  487. print(
  488. f'RANK={dist.get_rank()} '
  489. f'PIPE-ID={self.stage_id} '
  490. f'DATA-ID={self.grid.data_parallel_id} '
  491. f'MBATCH-ID={self.microbatch_id} '
  492. f'STEP-ID={self.log_batch_step_id} '
  493. '::',
  494. *msg,
  495. flush=True)
  496. def tput_log(self, *msg):
  497. if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0:
  498. print(*msg)
  499. def _next_batch(self):
  500. # If using 3D parallelism, only some first-stage ranks may do IO
  501. batch = None
  502. if self.data_iterator is not None:
  503. batch = next(self.data_iterator)
  504. # Any post-processing, like broadcasting across a slice-parallel group.
  505. if self.batch_fn:
  506. batch = self.batch_fn(batch)
  507. return batch
  508. def _exec_forward_pass(self, buffer_id):
  509. self.tput_timer.start()
  510. self.mem_status('BEFORE FWD', reset_max=True)
  511. if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
  512. inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
  513. else:
  514. inputs = self.pipe_buffers['inputs'][buffer_id].clone()
  515. # collect the partitioned input from the previous stage
  516. if self.is_pipe_partitioned and not self.is_first_stage():
  517. part_input = PartitionedTensor.from_meta(
  518. meta=inputs[0],
  519. local_part=inputs[1],
  520. group=self.grid.get_slice_parallel_group())
  521. inputs = (part_input.full(), *inputs[2:])
  522. inputs[0].requires_grad = True
  523. # skip mask
  524. #inputs[1].requires_grad = True
  525. part_input = None
  526. inputs = inputs[0] if len(inputs) == 1 else inputs
  527. self.pipe_buffers['inputs'][buffer_id] = inputs
  528. # Zero out the gradients each time we use the tensor because only the data in
  529. # tensor changes across batches
  530. self._zero_grads(inputs)
  531. outputs = super().forward(inputs)
  532. # Partition the outputs if we are not the last stage
  533. if self.is_pipe_partitioned and not self.is_last_stage():
  534. if isinstance(outputs, tuple):
  535. first_output = outputs[0]
  536. # TODO: Improve pipe partitioning to pass multiple tensors that require grads
  537. assert all([
  538. torch.is_tensor(elt) and elt.requires_grad is False
  539. for elt in outputs[1:]
  540. ])
  541. outputs_tail = outputs[1:]
  542. elif torch.is_tensor(outputs):
  543. first_output = outputs
  544. outputs_tail = []
  545. else:
  546. raise ValueError("expecting a tensor or a tuple of tensors")
  547. part = PartitionedTensor(tensor=first_output,
  548. group=self.grid.get_slice_parallel_group())
  549. # Clear the large output data, but save the computation graph
  550. first_output.data = torch.zeros(1)
  551. self.pipe_buffers['output_tensors'][buffer_id] = first_output
  552. # Inject the partitioned tensor into the output before sending
  553. outputs = (part.to_meta(), part.data(), *outputs_tail)
  554. part = None
  555. self.pipe_buffers['outputs'][buffer_id] = outputs
  556. # Optionally compute loss on the last device
  557. if self.is_last_stage():
  558. if self._compute_loss and self.module.loss_fn is not None:
  559. labels = self.pipe_buffers['labels'][buffer_id]
  560. self.loss = self.module.loss_fn(outputs, labels)
  561. else:
  562. # Some models just return loss from forward()
  563. self.loss = outputs
  564. if self.eval_return_logits:
  565. self.outputs = outputs
  566. if isinstance(self.loss, torch.Tensor):
  567. self.fwd_outputs.append(self.loss.detach())
  568. if self.total_loss is None:
  569. self.total_loss = torch.zeros_like(self.loss)
  570. self.total_loss += self.loss.detach()
  571. else:
  572. self.fwd_outputs.append([l.detach() for l in self.loss])
  573. if self.total_loss is None:
  574. self.total_loss = [torch.zeros_like(l) for l in self.loss]
  575. for idx, l in enumerate(self.loss):
  576. self.total_loss[idx] += l.detach()
  577. def _exec_backward_pass(self, buffer_id):
  578. assert self.optimizer is not None, "must provide optimizer during " \
  579. "init in order to use backward"
  580. self.mem_status('BEFORE BWD', reset_max=True)
  581. # The last stage just runs backward on the loss using DeepSpeed's typical
  582. # mechanisms.
  583. if self.is_last_stage():
  584. super().backward(self.loss)
  585. self.mem_status('AFTER BWD')
  586. return
  587. outputs = self.pipe_buffers['outputs'][buffer_id]
  588. if self.wall_clock_breakdown():
  589. self.timers('backward_microstep').start()
  590. self.timers('backward').start()
  591. self.timers('backward_inner_microstep').start()
  592. self.timers('backward_inner').start()
  593. # Reconstruct if we previously partitioned the output. We must be
  594. # careful to also restore the computational graph of the tensors we partitioned.
  595. if self.is_pipe_partitioned:
  596. if self.is_grad_partitioned:
  597. part_output = PartitionedTensor.from_meta(
  598. meta=outputs[0],
  599. local_part=outputs[1],
  600. group=self.grid.get_slice_parallel_group())
  601. self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
  602. outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:])
  603. else:
  604. # Already restored from partition
  605. self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0]
  606. outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:])
  607. grad_tensors = self.grad_layer
  608. if self.is_grad_partitioned:
  609. #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
  610. part_grad = PartitionedTensor.from_meta(
  611. meta=self.grad_layer[0],
  612. local_part=self.grad_layer[1],
  613. group=self.grid.get_slice_parallel_group())
  614. grad_tensors = (part_grad.full(), *grad_tensors[2:])
  615. part_grad = None
  616. #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
  617. if self.bfloat16_enabled() and not self.is_last_stage():
  618. # manually call because we don't call optimizer.backward()
  619. self.optimizer.clear_lp_grads()
  620. # This handles either a single tensor or tuple of tensors.
  621. if isinstance(outputs, tuple):
  622. out_tensors = [t for t in outputs if t.is_floating_point()]
  623. assert len(out_tensors) == len(grad_tensors)
  624. torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors)
  625. else:
  626. torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
  627. if self.bfloat16_enabled() and not self.is_last_stage():
  628. # manually call because we don't call optimizer.backward()
  629. self.optimizer.update_hp_grads(clear_lp_grads=False)
  630. # Free up the memory from the output of forward()
  631. self.pipe_buffers['output_tensors'][buffer_id] = None
  632. self.pipe_buffers['outputs'][buffer_id] = None
  633. grad_tensors = None
  634. if self.wall_clock_breakdown():
  635. self.timers('backward_inner').stop()
  636. self.timers('backward_inner_microstep').stop()
  637. self.timers('backward').stop()
  638. self.timers('backward_microstep').stop()
  639. self.mem_status('AFTER BWD')
  640. def _exec_load_micro_batch(self, buffer_id):
  641. if self.wall_clock_breakdown():
  642. self.timers('batch_input').start()
  643. batch = self._next_batch()
  644. if self.is_first_stage():
  645. loaded = None
  646. if torch.is_tensor(batch[0]):
  647. loaded = batch[0].clone().to(self.device).detach()
  648. loaded.requires_grad = loaded.is_floating_point()
  649. else:
  650. assert isinstance(batch[0], tuple)
  651. # Assume list or tuple
  652. loaded = []
  653. for x in batch[0]:
  654. assert torch.is_tensor(x)
  655. mine = x.clone().detach().to(self.device)
  656. mine.requires_grad = mine.is_floating_point()
  657. loaded.append(mine)
  658. loaded = tuple(loaded)
  659. self.pipe_buffers['inputs'][buffer_id] = loaded
  660. if self.is_last_stage():
  661. loaded = batch[1]
  662. if torch.is_tensor(batch[1]):
  663. loaded = batch[1].to(self.device)
  664. elif isinstance(batch[1], tuple):
  665. loaded = []
  666. for x in batch[1]:
  667. assert torch.is_tensor(x)
  668. x = x.to(self.device).detach()
  669. loaded.append(x)
  670. loaded = tuple(loaded)
  671. self.pipe_buffers['labels'][buffer_id] = loaded
  672. if self.wall_clock_breakdown():
  673. self.timers('batch_input').stop()
  674. def _send_tensor_meta(self, buffer, recv_stage):
  675. """ Communicate metadata about upcoming p2p transfers.
  676. Metadata is communicated in this order:
  677. * type (0: tensor, 1: list)
  678. * num_tensors if type=list
  679. foreach tensor in buffer:
  680. * ndims
  681. * shape
  682. """
  683. send_bytes = 0
  684. if isinstance(buffer, torch.Tensor):
  685. type_tensor = torch.LongTensor(data=[0]).to(self.device)
  686. p2p.send(type_tensor, recv_stage)
  687. send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
  688. send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
  689. p2p.send(send_ndims, recv_stage)
  690. p2p.send(send_shape, recv_stage)
  691. send_bytes += _tensor_bytes(buffer)
  692. elif isinstance(buffer, list):
  693. assert (False)
  694. type_tensor = torch.LongTensor(data=[1]).to(self.device)
  695. p2p.send(type_tensor, recv_stage)
  696. count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
  697. p2p.send(count_tensor, recv_stage)
  698. for tensor in buffer:
  699. assert isinstance(tensor, torch.Tensor)
  700. send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
  701. send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
  702. p2p.send(send_ndims, recv_stage)
  703. p2p.send(send_shape, recv_stage)
  704. send_bytes += _tensor_bytes(tensor)
  705. elif isinstance(buffer, tuple):
  706. type_tensor = torch.LongTensor(data=[2]).to(self.device)
  707. p2p.send(type_tensor, recv_stage)
  708. count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
  709. p2p.send(count_tensor, recv_stage)
  710. for idx, tensor in enumerate(buffer):
  711. assert isinstance(tensor, torch.Tensor)
  712. send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
  713. send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
  714. send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(
  715. self.device)
  716. p2p.send(send_dtype, recv_stage)
  717. p2p.send(send_ndims, recv_stage)
  718. p2p.send(send_shape, recv_stage)
  719. # Useful for performance debugging.
  720. '''
  721. new_bytes = _tensor_bytes(tensor)
  722. send_bytes += _tensor_bytes(tensor)
  723. # Useful for performance debugging.
  724. if self.grid.data_parallel_id == 0:
  725. print(
  726. f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
  727. )
  728. '''
  729. else:
  730. raise NotImplementedError(f'Could not send meta type {type(buffer)}')
  731. # Useful for performance debugging.
  732. '''
  733. if self.grid.data_parallel_id == 0:
  734. print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB')
  735. '''
  736. def _recv_tensor_meta(self, send_stage):
  737. """Receive metadata about upcoming p2p transfers and return allocated buffers.
  738. Metadata is communicated in this order:
  739. * type (0: tensor, 1: list)
  740. * num_tensors if type=list
  741. foreach tensor in buffer:
  742. * ndims
  743. * shape
  744. Returns:
  745. Allocated buffer for receiving from send_stage.
  746. """
  747. type_tensor = torch.LongTensor(data=[0]).to(self.device)
  748. p2p.recv(type_tensor, send_stage)
  749. recv_type = type_tensor.item()
  750. # A single tensor will be sent.
  751. if recv_type == 0:
  752. recv_ndims = torch.LongTensor(data=[0]).to(self.device)
  753. p2p.recv(recv_ndims, send_stage)
  754. recv_ndims = recv_ndims.item()
  755. recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
  756. p2p.recv(recv_shape, send_stage)
  757. recv_shape = recv_shape.tolist()
  758. return self._allocate_buffer(recv_shape, num_buffers=1)[0]
  759. # List or tuple of tensors
  760. elif recv_type == 1 or recv_type == 2:
  761. count_tensor = torch.LongTensor(data=[0]).to(self.device)
  762. p2p.recv(count_tensor, send_stage)
  763. num_tensors = count_tensor.item()
  764. recv_shapes_and_dtypes = []
  765. for idx in range(num_tensors):
  766. recv_dtype = torch.LongTensor(data=[0]).to(self.device)
  767. p2p.recv(recv_dtype, send_stage)
  768. recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
  769. recv_ndims = torch.LongTensor(data=[0]).to(self.device)
  770. p2p.recv(recv_ndims, send_stage)
  771. recv_ndims = recv_ndims.item()
  772. recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
  773. p2p.recv(recv_shape, send_stage)
  774. recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))
  775. buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
  776. # Convert to tuples if requested.
  777. if recv_type == 2:
  778. buffers = tuple(buffers)
  779. return buffers
  780. else:
  781. raise NotImplementedError(f'Could not receive type {type(recv_type)}')
  782. def _exec_send_activations(self, buffer_id):
  783. if self.wall_clock_breakdown():
  784. self.timers('pipe_send_output').start()
  785. outputs = self.pipe_buffers['outputs'][buffer_id]
  786. # NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
  787. # We could do char, but with half() we can eventually flatten with other fp16
  788. # messages (TODO)
  789. if self.has_attention_mask or self.has_bool_tensors:
  790. outputs = list(outputs)
  791. outputs[-1] = outputs[-1].half()
  792. outputs = tuple(outputs)
  793. if self.first_output_send:
  794. self.first_output_send = False
  795. self._send_tensor_meta(outputs, self.next_stage)
  796. if isinstance(outputs, torch.Tensor):
  797. p2p.send(outputs, self.next_stage)
  798. elif isinstance(outputs, tuple):
  799. for idx, buffer in enumerate(outputs):
  800. p2p.send(buffer, self.next_stage)
  801. else:
  802. raise NotImplementedError('Could not send output of type '
  803. f'{type(outputs)}')
  804. # Restore the boolean tensor
  805. if self.has_attention_mask or self.has_bool_tensors:
  806. outputs = list(outputs)
  807. outputs[-1] = outputs[-1].bool()
  808. outputs = tuple(outputs)
  809. if self.wall_clock_breakdown():
  810. self.timers('pipe_send_output').stop()
  811. def _exec_send_grads(self, buffer_id):
  812. if self.wall_clock_breakdown():
  813. self.timers('pipe_send_grad').start()
  814. inputs = self.pipe_buffers['inputs'][buffer_id]
  815. # Partition the gradient
  816. if self.is_grad_partitioned:
  817. if isinstance(inputs, tuple):
  818. first_input = inputs[0]
  819. assert all([torch.is_tensor(elt) for elt in inputs[1:]])
  820. inputs_grad_tail = [
  821. elt.grad for elt in inputs[1:] if elt.grad is not None
  822. ]
  823. elif torch.is_tensor(inputs):
  824. first_input = inputs
  825. inputs_grad_tail = []
  826. else:
  827. raise ValueError("expecting a tensor or a tuple of tensors")
  828. assert torch.is_tensor(first_input)
  829. part = PartitionedTensor(tensor=first_input.grad,
  830. group=self.grid.get_slice_parallel_group())
  831. inputs = (part.to_meta(), part.data(), *inputs_grad_tail)
  832. # XXX Terrible hack
  833. # Drop the attention mask from the input buffer here. It does not have
  834. # a grad that needs to be communicated. We free the buffer immediately
  835. # after, so no need to restore it. The receiver also has a hack that skips
  836. # the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
  837. if self.has_attention_mask or self.has_bool_tensors:
  838. inputs = list(inputs)
  839. inputs.pop()
  840. inputs = tuple(inputs)
  841. if isinstance(inputs, torch.Tensor):
  842. assert inputs.grad is not None
  843. p2p.send(inputs.grad, self.prev_stage)
  844. else:
  845. # XXX terrible hacky branch
  846. if self.is_grad_partitioned:
  847. # First two sends are partitioned gradient
  848. p2p.send(inputs[0], self.prev_stage)
  849. p2p.send(inputs[1], self.prev_stage)
  850. else:
  851. for idx, buffer in enumerate(inputs):
  852. # Skip tensors that will not produce a grad
  853. if not buffer.is_floating_point():
  854. assert buffer.grad is None
  855. continue
  856. assert buffer.grad is not None
  857. p2p.send(buffer.grad, self.prev_stage)
  858. # We can free up the input buffer now
  859. self.pipe_buffers['inputs'][buffer_id] = None
  860. if self.wall_clock_breakdown():
  861. self.timers('pipe_send_grad').stop()
  862. def _exec_recv_activations(self, buffer_id):
  863. if self.wall_clock_breakdown():
  864. self.timers('pipe_recv_input').start()
  865. recvd = None
  866. # Allocate the buffer if necessary
  867. if self.pipe_recv_buf is None:
  868. self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)
  869. if isinstance(self.pipe_recv_buf, torch.Tensor):
  870. p2p.recv(self.pipe_recv_buf, self.prev_stage)
  871. recvd = self.pipe_recv_buf.clone().detach()
  872. recvd.requires_grad = recvd.is_floating_point()
  873. else:
  874. assert isinstance(self.pipe_recv_buf, tuple)
  875. recvd = [None] * len(self.pipe_recv_buf)
  876. for idx, buffer in enumerate(self.pipe_recv_buf):
  877. assert torch.is_tensor(buffer)
  878. # XXX hardcode meta type
  879. if self.is_pipe_partitioned and idx == 0 and buffer.dtype != torch.long:
  880. if self.meta_buffer is None:
  881. self.meta_buffer = torch.zeros(buffer.size(),
  882. dtype=torch.long,
  883. device=self.device)
  884. buffer = self.meta_buffer
  885. p2p.recv(buffer, self.prev_stage)
  886. recvd[idx] = buffer.clone().detach()
  887. # NCCL does not like to send torch.BoolTensor types, so un-cast the
  888. # attention mask
  889. if self.has_attention_mask or self.has_bool_tensors:
  890. recvd[-1] = recvd[-1].bool()
  891. recvd = tuple(recvd)
  892. for buffer in recvd:
  893. buffer.requires_grad = buffer.is_floating_point()
  894. self.pipe_buffers['inputs'][buffer_id] = recvd
  895. if self.wall_clock_breakdown():
  896. self.timers('pipe_recv_input').stop()
  897. def _exec_recv_grads(self, buffer_id):
  898. if self.wall_clock_breakdown():
  899. self.timers('pipe_recv_grad').start()
  900. outputs = self.pipe_buffers['outputs'][buffer_id]
  901. # XXX these shapes are hardcoded for Megatron
  902. # Restore partitioned output if it was partitioned and we are sending full gradients
  903. if self.is_pipe_partitioned and not self.is_grad_partitioned:
  904. part_output = PartitionedTensor.from_meta(
  905. meta=outputs[0],
  906. local_part=outputs[1],
  907. group=self.grid.get_slice_parallel_group())
  908. outputs[0].data = part_output.full()
  909. outputs = (outputs[0], *outputs[2:])
  910. # save for backward
  911. self.pipe_buffers['outputs'][buffer_id] = outputs
  912. # Allocate gradient if necessary
  913. if self.grad_layer is None:
  914. if isinstance(outputs, torch.Tensor):
  915. s = list(outputs.size())
  916. self.grad_layer = self._allocate_buffer(s,
  917. dtype=outputs.dtype,
  918. num_buffers=1)[0]
  919. else:
  920. # XXX This is a HACK
  921. # When we exchange activations/gradients, the two pipe stages
  922. # need to issue the send/recv with the same buffer sizes or
  923. # else there is a deadlock. The is_floating_point() filter is
  924. # used to avoid sending gradients for tensors that do not
  925. # produce gradients. When TP>1, we partition the first
  926. # activations/gradients across TP ranks to save communication
  927. # volume and memory. That partitioned tensor is represented as
  928. # two tensors: a 1/TPth chunk of the original data and also a
  929. # small LongTensor storing the metadata used to reconstruct on
  930. # the other side. When combined, the floating point filter also
  931. # filtered out the metadata tensor. This quick (hacky) fix just
  932. # branches on is_grad_partitioned so we don't filter out the
  933. # metadata tensor.
  934. if self.is_grad_partitioned:
  935. sizes_and_dtypes = [
  936. (list(t.size()),
  937. t.dtype) for t in outputs[:2]
  938. ] + [(list(t.size()),
  939. t.dtype) for t in outputs[2:] if t.is_floating_point()]
  940. else:
  941. sizes_and_dtypes = [(list(t.size()),
  942. t.dtype) for t in outputs
  943. if t.is_floating_point()]
  944. self.grad_layer = self._allocate_buffers(sizes_and_dtypes,
  945. num_buffers=1)[0]
  946. if isinstance(self.grad_layer, torch.Tensor):
  947. p2p.recv(self.grad_layer, self.next_stage)
  948. else:
  949. assert isinstance(outputs, tuple)
  950. for idx, buffer in enumerate(self.grad_layer):
  951. # XXX GPT-2 hack
  952. if self.is_grad_partitioned and idx == 0 and buffer.dtype != torch.long:
  953. buffer.data = torch.zeros(buffer.size(),
  954. dtype=torch.long,
  955. device=self.device)
  956. p2p.recv(buffer, self.next_stage)
  957. if self.wall_clock_breakdown():
  958. self.timers('pipe_recv_grad').stop()
  959. def _exec_optimizer_step(self, lr_kwargs=None):
  960. if self.wall_clock_breakdown():
  961. self.timers('step_microstep').start()
  962. self.timers('step').start()
  963. self.mem_status('BEFORE STEP', reset_max=True)
  964. self._force_grad_boundary = True
  965. self._take_model_step(lr_kwargs)
  966. self._force_grad_boundary = False
  967. self.mem_status('AFTER STEP')
  968. if self.global_rank == 0 and self.monitor.enabled:
  969. self.summary_events = [(f'Train/Samples/lr',
  970. self.get_lr()[0],
  971. self.global_samples)]
  972. if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
  973. self.summary_events.append((f'Train/Samples/loss_scale',
  974. self.optimizer.cur_scale,
  975. self.global_samples))
  976. self.monitor.write_events(self.summary_events)
  977. if self.wall_clock_breakdown():
  978. self.timers('step_microstep').stop()
  979. self.timers('step').stop()
  980. if self.global_steps % self.steps_per_print() == 0:
  981. self.timers.log([
  982. 'batch_input',
  983. 'forward_microstep',
  984. 'backward_microstep',
  985. 'backward_inner_microstep',
  986. 'backward_allreduce_microstep',
  987. 'backward_tied_allreduce_microstep',
  988. 'step_microstep'
  989. ])
  990. if self.global_steps % self.steps_per_print() == 0:
  991. self.timers.log([
  992. 'forward',
  993. 'backward',
  994. 'backward_inner',
  995. 'backward_allreduce',
  996. 'step'
  997. ])
  998. def _zero_grads(self, inputs):
  999. if isinstance(inputs, torch.Tensor):
  1000. if inputs.grad is not None:
  1001. inputs.grad.data.zero_()
  1002. else:
  1003. for t in inputs:
  1004. if t.grad is not None:
  1005. t.grad.data.zero_()
  1006. def _allocate_zeros(self, shape, **kwargs):
  1007. """ Allocate a tensor of zeros on the engine's device.
  1008. Arguments:
  1009. shape: the shape of the tensor to allocate
  1010. kwargs: passed to torch.zeros()
  1011. Returns:
  1012. A tensor from torch.zeros() allocated on self.device.
  1013. """
  1014. if "dtype" not in kwargs:
  1015. if self.fp16_enabled():
  1016. kwargs["dtype"] = torch.half
  1017. if self.bfloat16_enabled():
  1018. kwargs["dtype"] = torch.bfloat16
  1019. return torch.zeros(shape, device=self.device, **kwargs)
  1020. def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
  1021. buffers = []
  1022. if num_buffers == -1:
  1023. num_buffers = self.num_pipe_buffers
  1024. for count in range(num_buffers):
  1025. buffers.append(self._allocate_zeros(shape, **kwargs))
  1026. return buffers
  1027. def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1):
  1028. buffers = []
  1029. if num_buffers == -1:
  1030. num_buffers = self.num_pipe_buffers
  1031. for count in range(num_buffers):
  1032. buffer = []
  1033. for shape, dtype in shapes_and_dtypes:
  1034. buffer.append(
  1035. self._allocate_zeros(shape,
  1036. dtype=dtype,
  1037. requires_grad=requires_grad))
  1038. buffers.append(buffer)
  1039. return buffers
  1040. def forward(self, *args, **kwargs):
  1041. """Disabled for pipeline parallel training. See ``train_batch()``. """
  1042. raise PipelineError("Only train_batch() is accessible in pipeline mode.")
  1043. def backward(self, *args, **kwargs):
  1044. """Disabled for pipeline parallel training. See ``train_batch()``. """
  1045. raise PipelineError("Only train_batch() is accessible in pipeline mode.")
  1046. def step(self, *args, **kwargs):
  1047. """Disabled for pipeline parallel training. See ``train_batch()``. """
  1048. raise PipelineError("Only train_batch() is accessible in pipeline mode.")
  1049. def mem_status(self, msg, print_rank=-1, reset_max=False):
  1050. return
  1051. global mem_alloced, mem_cached
  1052. if not self.global_steps == 0 or not self.global_steps == 9:
  1053. #return
  1054. pass
  1055. if self.mpu.get_data_parallel_rank() != 0:
  1056. return
  1057. if self.global_rank != 0:
  1058. return
  1059. rank = self.global_rank
  1060. if print_rank != -1 and rank != print_rank:
  1061. return
  1062. get_accelerator().synchronize()
  1063. if reset_max:
  1064. get_accelerator().reset_max_memory_cached()
  1065. get_accelerator().reset_max_memory_allocated()
  1066. new_alloced = get_accelerator().memory_allocated()
  1067. new_cached = get_accelerator().memory_cached()
  1068. delta_alloced = new_alloced - mem_alloced
  1069. delta_cached = new_cached - mem_cached
  1070. mem_cached = new_cached
  1071. mem_alloced = new_alloced
  1072. max_alloced = get_accelerator().max_memory_allocated()
  1073. max_cached = get_accelerator().max_memory_cached()
  1074. # convert to GB for printing
  1075. new_alloced /= 1024**3
  1076. new_cached /= 1024**3
  1077. delta_alloced /= 1024**3
  1078. delta_cached /= 1024**3
  1079. max_alloced /= 1024**3
  1080. max_cached /= 1024**3
  1081. print(
  1082. f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS',
  1083. msg,
  1084. f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
  1085. f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)'
  1086. )
  1087. def module_state_dict(self):
  1088. """Override hack to save a pipe model and return the directory path of the save.
  1089. This method should only be called by DeepSpeed's ``save_checkpoint()``. The
  1090. recommended way of saving a ``PipelineModule`` outside of ``save_checkpoint()``
  1091. is ``save_state_dict()``.
  1092. Returns:
  1093. None
  1094. """
  1095. assert isinstance(self.module, PipelineModule)
  1096. assert self._curr_ckpt_path is not None, \
  1097. "PipelineEngine expects module_state_dict() to be called from save_checkpoint()"
  1098. self.module.save_state_dict(self._curr_ckpt_path,
  1099. checkpoint_engine=self.checkpoint_engine)
  1100. return None
  1101. def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None):
  1102. """Override hack to instead use a directory path.
  1103. This is important because pipeline models checkpoint by layer instead of rank.
  1104. If ``state_dict`` is not ``None`` or a ``str``, we revert to ``super()`` expecting a ``dict``.
  1105. Args:
  1106. state_dict (str, None): unused
  1107. strict (bool, optional): Strict state loading. Defaults to True.
  1108. """
  1109. assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism"
  1110. if (state_dict is not None) and (not isinstance(state_dict, str)):
  1111. super().load_module_state_dict(state_dict, strict)
  1112. return
  1113. self.module.load_state_dir(load_dir=self._curr_ckpt_path,
  1114. strict=strict,
  1115. checkpoint_engine=self.checkpoint_engine)
  1116. # A map of PipeInstruction types to methods. Each method will be executed with the
  1117. # kwargs provided to the PipeInstruction from the scheduler.
  1118. _INSTRUCTION_MAP = {
  1119. schedule.OptimizerStep: _exec_optimizer_step,
  1120. schedule.ReduceGrads: _exec_reduce_grads,
  1121. schedule.ReduceTiedGrads: _exec_reduce_tied_grads,
  1122. schedule.LoadMicroBatch: _exec_load_micro_batch,
  1123. schedule.ForwardPass: _exec_forward_pass,
  1124. schedule.BackwardPass: _exec_backward_pass,
  1125. schedule.SendActivation: _exec_send_activations,
  1126. schedule.RecvActivation: _exec_recv_activations,
  1127. schedule.SendGrad: _exec_send_grads,
  1128. schedule.RecvGrad: _exec_recv_grads,
  1129. }
  1130. def _exec_schedule(self, pipe_schedule):
  1131. # Reserve and reset buffers.
  1132. self._reserve_pipe_buffers(pipe_schedule.num_pipe_buffers())
  1133. self.fwd_outputs = []
  1134. # For each step in the schedule
  1135. for step_cmds in pipe_schedule:
  1136. # For each instruction in the step
  1137. for cmd in step_cmds:
  1138. if type(cmd) not in self._INSTRUCTION_MAP:
  1139. raise RuntimeError(
  1140. f'{self.__class__.__name__} does not understand instruction {repr(cmd)}'
  1141. )
  1142. # Equivalent to: self._exec_forward_pass(buffer_id=0)
  1143. self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
  1144. self._exec_instr(**cmd.kwargs)