engine.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from types import MethodType
  5. import torch
  6. from deepspeed import comm as dist
  7. from deepspeed.utils import logger
  8. from deepspeed.utils.timer import ThroughputTimer
  9. from deepspeed.accelerator import get_accelerator
  10. from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
  11. from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
  12. from deepspeed.utils.timer import FORWARD_MICRO_TIMER, FORWARD_GLOBAL_TIMER, BACKWARD_MICRO_TIMER, \
  13. BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_INNER_GLOBAL_TIMER, \
  14. BACKWARD_REDUCE_MICRO_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \
  15. STEP_MICRO_TIMER, STEP_GLOBAL_TIMER
  16. from ..utils import PartitionedTensor
  17. from ..dataloader import RepeatingLoader
  18. from ..zero.config import ZeroStageEnum
  19. from ..activation_checkpointing import checkpointing as ds_checkpointing
  20. from .module import PipelineModule, PipelineError
  21. from . import p2p
  22. from . import schedule
  23. TARGET_ID = -2
  24. LOG_STAGE = -2
  25. DATA_PARALLEL_ID = -2
  26. BATCH_INPUT_TIMER = 'batch_input'
  27. TRAIN_BATCH_TIMER = 'train_batch'
  28. PIPE_SEND_OUTPUT_TIMER = 'pipe_send_output'
  29. PIPE_SEND_GRAD_TIMER = 'pipe_send_grad'
  30. PIPE_RECV_INPUT_TIMER = 'pipe_recv_input'
  31. PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad'
  32. def is_even(number):
  33. return number % 2 == 0
  34. mem_alloced = 0
  35. mem_cached = 0
  36. def _tensor_bytes(tensor):
  37. return tensor.numel() * tensor.element_size()
  38. class PipelineEngine(DeepSpeedEngine):
  39. """ A training engine hybrid pipeline, data, and model parallel training.
  40. This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`
  41. is provided.
  42. """
  43. ID_TO_DTYPE = [
  44. torch.float32, torch.float64, torch.complex64, torch.complex128, torch.float16, torch.bfloat16, torch.uint8,
  45. torch.int8, torch.int16, torch.int32, torch.int64, torch.bool
  46. ]
  47. DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}
  48. def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
  49. super().__init__(*super_args, **super_kwargs)
  50. assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
  51. assert self.zero_optimization_stage(
  52. ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
  53. # We schedule the all-reduces, so disable it in super().backward()
  54. self.enable_backward_allreduce = False
  55. self.has_bool_tensors = has_bool_tensors
  56. self.eval_return_logits = False
  57. self.outputs = None
  58. # BF16 Optimizer is hardcoded for fp32 gradient accumulation
  59. self.using_bf16_optimizer = type(self.optimizer) == BF16_Optimizer
  60. # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
  61. self.pipeline_enable_backward_allreduce = True
  62. if self.elasticity_enabled():
  63. if not self.is_elastic_model_parallel_supported():
  64. assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
  65. " with pipeline parallelism."
  66. # pipeline step for logging
  67. self.log_batch_step_id = -1
  68. self.micro_batch_size = self.train_micro_batch_size_per_gpu()
  69. self.micro_batches = self.gradient_accumulation_steps()
  70. # Set Grid and Communication Groups
  71. self.grid = self.module._grid
  72. if self.grid.get_global_rank() == 0:
  73. logger.info(f'CONFIG: micro_batches={self.micro_batches} '
  74. f'micro_batch_size={self.micro_batch_size}')
  75. self.global_rank = self.grid.get_global_rank()
  76. assert self.dp_world_size == self.grid.data_parallel_size
  77. assert self.train_batch_size() == \
  78. self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size
  79. # Set Stage Inf
  80. self.num_stages = self.grid.pipe_parallel_size
  81. self.stage_id = self.grid.get_stage_id()
  82. self.prev_stage = self.stage_id - 1
  83. self.next_stage = self.stage_id + 1
  84. self.data_iterator = None
  85. self.batch_fn = None
  86. self._force_grad_boundary = False
  87. self.batch_timer = ThroughputTimer(batch_size=self.train_batch_size(),
  88. logging_fn=self.tput_log,
  89. monitor_memory=False,
  90. steps_per_output=self.steps_per_print())
  91. # PipelineEngine needs to handle data loading specially due to only the first
  92. # and last stages loading inputs/labels. We construct a sampler that uses
  93. if self.training_data:
  94. self._build_data_iter(self.training_data)
  95. self.is_pipe_parallel = self.grid.pipe_parallel_size > 1
  96. self.is_data_parallel = self.grid.data_parallel_size > 1
  97. self.is_model_parallel = self.grid.model_parallel_size > 1
  98. # Partition input/output buffers
  99. # XXX temporarily disable while I revert some partition hacks.
  100. assert isinstance(self._config.pipeline['pipe_partitioned'], bool)
  101. assert isinstance(self._config.pipeline['grad_partitioned'], bool)
  102. self.is_pipe_partitioned = self.is_model_parallel and self._config.pipeline['pipe_partitioned']
  103. self.is_grad_partitioned = self.is_model_parallel and self._config.pipeline['grad_partitioned']
  104. logger.info(f'is_pipe_partitioned= {self.is_pipe_partitioned} '
  105. f'is_grad_partitioned= {self.is_grad_partitioned}')
  106. model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())
  107. num_params = sum([p.numel() for p in model_parameters])
  108. unique_params = num_params
  109. # Subtract tied parameters if we don't own them
  110. if self.module.tied_comms:
  111. tied_params = 0
  112. for key, d in self.module.tied_comms.items():
  113. if self.global_rank != min(d['ranks']):
  114. tied_params += sum(p.numel() for p in d['module'].parameters())
  115. unique_params -= tied_params
  116. params_tensor = torch.LongTensor(data=[num_params, unique_params]).to(self.device)
  117. dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group())
  118. params_tensor = params_tensor.tolist()
  119. total_params = params_tensor[0]
  120. unique_params = params_tensor[1]
  121. if self.grid.data_parallel_id == 0:
  122. logger.info(f'RANK={self.global_rank} '
  123. f'STAGE={self.stage_id} '
  124. f'LAYERS={self.module._local_stop - self.module._local_start} '
  125. f'[{self.module._local_start}, {self.module._local_stop}) '
  126. f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) '
  127. f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) '
  128. f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)')
  129. #initialize peer-2-peer communication and allreduce groups
  130. if self.is_pipe_parallel:
  131. p2p.init_process_groups(self.grid)
  132. # Pipeline buffers
  133. self.num_pipe_buffers = 0
  134. self.pipe_buffers = {
  135. 'inputs': [], # batch input and received activations
  136. 'labels': [], # labels from batch input
  137. 'outputs': [], # activations
  138. 'output_tensors': [], # tensor object to preserve backward graph
  139. }
  140. self.pipe_recv_buf = None
  141. self.grad_layer = None
  142. self.meta_buffer = None
  143. self.first_output_send = True
  144. self.first_gradient_send = True
  145. self.pipe_partition_input_meta_cache = None
  146. self.pipe_partition_output_meta_cache = None
  147. self.pipe_partition_grad_meta_cache = None
  148. self.grad_partition_grad_layer_meta_cache = None
  149. #stores the loss for the current micro batch being processed
  150. self.loss = torch.tensor(0.0).to(self.device)
  151. #stores the loss for the entire batch
  152. self.total_loss = None
  153. self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
  154. self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
  155. if self._config.pipeline['activation_checkpoint_interval'] > 0:
  156. self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval']
  157. # set use_reentrant default to True.
  158. if self._config.pipeline.get('use_reentrant') is None:
  159. self._config.pipeline['use_reentrant'] = True
  160. if self._config.pipeline['use_reentrant'] is False:
  161. # set activation_checkpoint_func to non_reentrant_checkpoint func.
  162. self.module.activation_checkpoint_func = ds_checkpointing.non_reentrant_checkpoint
  163. if self.grid.get_global_rank() == 0:
  164. logger.info(f'CONFIG: activation_checkpoint_func=non_reentrant_checkpoint')
  165. self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline
  166. if self.is_last_stage():
  167. self.loss_model = self.module.loss_fn
  168. self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe'
  169. # Initialize pipeline communicators. Just send a 0.
  170. if is_even(self.stage_id):
  171. if not self.is_last_stage():
  172. p2p.send(self.loss, self.next_stage)
  173. if not self.is_first_stage():
  174. p2p.recv(self.loss, self.prev_stage)
  175. else:
  176. if not self.is_first_stage():
  177. p2p.recv(self.loss, self.prev_stage)
  178. if not self.is_last_stage():
  179. p2p.send(self.loss, self.next_stage)
  180. # XXX look into timer reporting timing
  181. # Initialize some timers because of early weirdness.
  182. if self.wall_clock_breakdown():
  183. self.timers(FORWARD_MICRO_TIMER).start()
  184. self.timers(FORWARD_MICRO_TIMER).stop()
  185. self.timers(BACKWARD_MICRO_TIMER).start()
  186. self.timers(BACKWARD_MICRO_TIMER).stop()
  187. self.timers(BACKWARD_INNER_MICRO_TIMER).start()
  188. self.timers(BACKWARD_INNER_MICRO_TIMER).stop()
  189. self.timers(BACKWARD_REDUCE_MICRO_TIMER).start()
  190. self.timers(BACKWARD_REDUCE_MICRO_TIMER).stop()
  191. self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).start()
  192. self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).stop()
  193. self.timers(STEP_MICRO_TIMER).start()
  194. self.timers(STEP_MICRO_TIMER).stop()
  195. def set_has_attention_mask(self, value):
  196. assert isinstance(value, bool)
  197. self.has_attention_mask = value
  198. def _build_data_iter(self, dataset):
  199. sampler = torch.utils.data.distributed.DistributedSampler(dataset,
  200. num_replicas=self.dp_world_size,
  201. rank=self.mpu.get_data_parallel_rank(),
  202. shuffle=False)
  203. # Build a loader and make it repeating.
  204. pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler)
  205. pipe_dataloader = RepeatingLoader(pipe_dataloader)
  206. self.set_dataloader(pipe_dataloader)
  207. def _exec_reduce_tied_grads(self):
  208. # We need to run this first to write to self.averaged_gradients;
  209. # since this class turns `enable_backward_allreduce` off,
  210. # `self.overlapping_partition_gradients_reduce_epilogue()` defined in the DeepSpeedEngine
  211. # never actually runs. I suspect this is because of efficiency problems; get_flat_partition in
  212. # stage2.py might do something expensive; someone will have to look into that later. But
  213. # in the meantime, this fixes ZeRO2 + Pipelining enough to run a demo. Further profiling
  214. # needed to decide if it actually breaks everything.
  215. # (see https://github.com/EleutherAI/gpt-neox/issues/62#issuecomment-761471944)
  216. if self.zero_optimization_partition_gradients():
  217. self.optimizer.overlapping_partition_gradients_reduce_epilogue()
  218. weight_group_list = self.module.get_tied_weights_and_groups()
  219. for weight, group in weight_group_list:
  220. grad = weight._hp_grad if self.using_bf16_optimizer else weight.grad
  221. dist.all_reduce(grad, group=group)
  222. def _exec_reduce_grads(self):
  223. self._force_grad_boundary = True
  224. if self.pipeline_enable_backward_allreduce:
  225. if self.using_bf16_optimizer:
  226. # PP+BF16 work for ZeRO Stage 1
  227. self._bf16_reduce_grads()
  228. else:
  229. self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
  230. self._force_grad_boundary = False
  231. def _bf16_reduce_grads(self):
  232. # Make our own list of gradients from the optimizer's FP32 grads
  233. grads = []
  234. self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(),
  235. elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
  236. def _reserve_pipe_buffers(self, num_buffers):
  237. """Ensure that each pipeline buffer has at least ``num_buffers`` slots.
  238. This method only reserves slots and does not allocate tensors.
  239. Args:
  240. num_buffers (int): The number of buffers to reserve.
  241. """
  242. if self.num_pipe_buffers >= num_buffers:
  243. return
  244. num_added = num_buffers - self.num_pipe_buffers
  245. for key in self.pipe_buffers:
  246. self.pipe_buffers[key].extend([None] * num_added)
  247. self.num_pipe_buffers = num_buffers
  248. def reset_activation_shape(self):
  249. """Reset the buffers when the shape of activation and gradient change.
  250. For example, for curriculum learning that changes the seqlen of each
  251. sample, we need to call this whenever the seqlen is going to change.
  252. """
  253. self.first_output_send = True
  254. self.pipe_recv_buf = None
  255. self.grad_layer = None
  256. self.meta_buffer = None
  257. self.pipe_partition_input_meta_cache = None
  258. self.pipe_partition_output_meta_cache = None
  259. self.pipe_partition_grad_meta_cache = None
  260. self.grad_partition_grad_layer_meta_cache = None
  261. def train_batch(self, data_iter=None):
  262. """Progress the pipeline to train the next batch of data. The engine will ingest
  263. ``self.train_batch_size()`` total samples collectively across all workers.
  264. An iterator that over training data should be provided as an argument
  265. unless ``deepspeed.initialize()`` was provided a training set. In that event,
  266. the training data will automatically be read.
  267. .. warning::
  268. A total of ``self.gradient_accumulation_steps()`` entries will be pulled
  269. from ``data_iter`` by each pipeline. There must be sufficient
  270. data left in ``data_iter`` or else a ``StopIteration`` will halt training.
  271. DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
  272. that wraps data loaders to automatically restart upon a ``StopIteration``.
  273. Args:
  274. data_iter (Iterator, optional): Iterator of training data.
  275. Returns:
  276. The arithmetic mean of the losses computed this batch.
  277. """
  278. if not torch._C.is_grad_enabled():
  279. raise RuntimeError(f'train_batch() requires gradients enabled. Use eval_batch() instead.')
  280. # Curriculum learning could change activation shape
  281. if self.curriculum_enabled_legacy():
  282. new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
  283. self.global_steps + 1)
  284. if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
  285. self.reset_activation_shape()
  286. self.curriculum_scheduler_legacy.first_step = False
  287. elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
  288. self.global_steps):
  289. self.reset_activation_shape()
  290. if data_iter is not None:
  291. self.set_dataiterator(data_iter)
  292. self.module.train()
  293. self.total_loss = None
  294. self._compute_loss = True
  295. # Do the work
  296. self.timers(TRAIN_BATCH_TIMER).start()
  297. sched = schedule.TrainSchedule(micro_batches=self.micro_batches,
  298. stages=self.num_stages,
  299. stage_id=self.stage_id)
  300. self._exec_schedule(sched)
  301. self.agg_train_loss = self._aggregate_total_loss()
  302. self.timers(TRAIN_BATCH_TIMER).stop()
  303. if self.global_steps % self.steps_per_print() == 0:
  304. if self.global_rank == 0:
  305. elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0
  306. iter_time = elapsed / self.steps_per_print()
  307. tput = self.train_batch_size() / iter_time
  308. print(f'steps: {self.global_steps} '
  309. f'loss: {self.agg_train_loss:0.4f} '
  310. f'iter time (s): {iter_time:0.3f} '
  311. f'samples/sec: {tput:0.3f}')
  312. else:
  313. self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True)
  314. # Monitoring
  315. if self.global_rank == 0 and self.monitor.enabled:
  316. self.summary_events = [(f'Train/Samples/train_loss', self.agg_train_loss.mean().item(),
  317. self.global_samples)]
  318. self.monitor.write_events(self.summary_events)
  319. if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0:
  320. self.timers.log([
  321. PIPE_SEND_OUTPUT_TIMER,
  322. PIPE_SEND_GRAD_TIMER,
  323. PIPE_RECV_INPUT_TIMER,
  324. PIPE_RECV_GRAD_TIMER,
  325. ])
  326. # TODO: should return precisely what loss returned and allow others to be queried?
  327. return self.agg_train_loss
  328. def eval_batch(self,
  329. data_iter,
  330. return_logits=False,
  331. compute_loss=True,
  332. reduce_output='avg',
  333. bcast_loss=True,
  334. num_micro_batches=None):
  335. """Evaluate the pipeline on a batch of data from ``data_iter``. The
  336. engine will evaluate ``self.train_batch_size()`` total samples
  337. collectively across all workers.
  338. This method is equivalent to:
  339. .. code-block:: python
  340. module.eval()
  341. with torch.no_grad():
  342. output = module(batch)
  343. .. warning::
  344. A total of ``self.gradient_accumulation_steps()`` entries will be pulled
  345. from ``data_iter`` by each pipeline. There must be sufficient
  346. data left in ``data_iter`` or else a ``StopIteration`` will halt training.
  347. DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
  348. that wraps data loaders to automatically restart upon a ``StopIteration``.
  349. Args:
  350. data_iter (Iterator): Iterator of data to evaluate.
  351. Returns:
  352. The arithmetic mean of the losses computed this batch.
  353. """
  354. self.eval_return_logits = return_logits
  355. self.module.eval()
  356. # Curriculum learning could change activation shape
  357. if self.curriculum_enabled_legacy():
  358. new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
  359. self.global_steps + 1)
  360. if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
  361. self.reset_activation_shape()
  362. self.curriculum_scheduler_legacy.first_step = False
  363. elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
  364. self.global_steps):
  365. self.reset_activation_shape()
  366. eval_output = None
  367. self._compute_loss = compute_loss
  368. # Use the provided data iterator
  369. train_iterator = self.data_iterator
  370. self.set_dataiterator(data_iter)
  371. # set the number micro batches in case the user chose value than training
  372. micro_batches = self.micro_batches if num_micro_batches is None else num_micro_batches
  373. # Do the work
  374. sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
  375. stages=self.num_stages,
  376. stage_id=self.stage_id)
  377. # prevent dead-lock with multiple evals sequence
  378. dist.barrier()
  379. with torch.no_grad():
  380. self._exec_schedule(sched)
  381. if self.is_last_stage():
  382. eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output, micro_batches=micro_batches)
  383. if compute_loss and (bcast_loss or self.monitor.enabled):
  384. eval_output = self._bcast_pipe_scalar(eval_output)
  385. if self.global_rank == 0 and self.monitor.enabled:
  386. self.summary_events = [(f'Train/Samples/eval_loss', eval_output.mean().item(), self.global_samples)]
  387. self.monitor.write_events(self.summary_events)
  388. # Restore the training iterator
  389. self.set_dataiterator(train_iterator)
  390. # Reset any buffers that may have been populated during the forward passes.
  391. #ds_checkpointing.reset()
  392. self.eval_return_logits = False
  393. if return_logits:
  394. outputs = self.outputs
  395. self.outputs = None
  396. return eval_output, outputs
  397. return eval_output
  398. def set_train_batch_size(self, train_batch_size):
  399. """Adjust the global batch size by increasing or decreasing the number of
  400. micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
  401. (i.e., ``train_micro_batch_size_per_gpu``) is not changed.
  402. Args:
  403. train_batch_size (int): The new global batch size for training.
  404. Raises:
  405. ValueError: if ``train_batch_size`` is not divisible by the
  406. configured micro-batch size and data parallelism.
  407. """
  408. super().set_train_batch_size(train_batch_size)
  409. self.micro_batches = self.gradient_accumulation_steps()
  410. def is_first_stage(self):
  411. """True if this process is in the first stage in the pipeline."""
  412. return self.stage_id == 0
  413. def is_last_stage(self):
  414. """True if this process is in the last stage in the pipeline."""
  415. return self.stage_id == self.num_stages - 1
  416. def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True, micro_batches=None):
  417. if reduce is None:
  418. return outputs
  419. if reduce.lower() == 'avg':
  420. # first sum over all microbatches
  421. if torch.is_tensor(outputs[0]):
  422. reduced = sum(outputs)
  423. else:
  424. assert isinstance(outputs, (list, tuple))
  425. reduced = [torch.zeros_like(o) for o in outputs[0]]
  426. for idx, out in outputs:
  427. reduced[idx] += out
  428. # Average over the microbatches
  429. reduced = self._scale_loss_by_gas(reduced, eval_micro_batches=micro_batches)
  430. # Average over DP groups
  431. if reduce_dp and self.is_data_parallel:
  432. if torch.is_tensor(reduced):
  433. dist.all_reduce(reduced, group=self.mpu.get_data_parallel_group())
  434. reduced /= self.dp_world_size
  435. else:
  436. for idx in range(len(reduced)):
  437. dist.all_reduce(reduced[idx], group=self.mpu.get_data_parallel_group())
  438. reduced[idx] /= self.dp_world_size
  439. return reduced
  440. else:
  441. raise NotImplementedError(f'reduction type {reduce} not supported.')
  442. def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32):
  443. # Default to last stage (e.g., for broadcasting loss)
  444. if src_rank is None:
  445. src_rank = self.grid.stage_to_global(self.num_stages - 1)
  446. assert src_rank in self.grid.pp_group
  447. if self.global_rank == src_rank:
  448. result = data.clone().detach().type(dtype).to(self.device)
  449. else:
  450. result = torch.Tensor([0.]).type(dtype).to(self.device)
  451. dist.broadcast(tensor=result, src=src_rank, group=self.mpu.get_pipe_parallel_group())
  452. return result
  453. def _aggregate_total_loss(self):
  454. # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
  455. if self.is_last_stage():
  456. loss = self._scale_loss_by_gas(self.total_loss)
  457. self.dp_group_loss = loss.clone().detach()
  458. ## Average loss across all data-parallel groups
  459. agg_loss = self.dp_group_loss.clone().detach()
  460. #print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)
  461. if self.is_data_parallel:
  462. dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
  463. agg_loss /= self.dp_world_size
  464. assert self.global_rank in self.grid.pp_group
  465. losses = torch.stack([self.dp_group_loss, agg_loss]).float()
  466. if self.is_pipe_parallel:
  467. dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
  468. else:
  469. # Get loss from last stage
  470. src_rank = self.grid.stage_to_global(self.num_stages - 1)
  471. assert src_rank in self.grid.pp_group
  472. losses = torch.Tensor([0., 0.]).to(self.device)
  473. dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group())
  474. self.dp_group_loss = losses[0].clone().detach()
  475. agg_loss = losses[1].clone().detach()
  476. return agg_loss
  477. def set_dataloader(self, loader):
  478. """"""
  479. if self.is_first_stage() or self.is_last_stage():
  480. self.training_dataloader = loader
  481. self.data_iterator = iter(self.training_dataloader)
  482. def set_dataiterator(self, iterator):
  483. """ Store an iterator to sample for training data. """
  484. if self.is_first_stage() or self.is_last_stage():
  485. self.training_dataloader = None
  486. self.data_iterator = iterator
  487. def set_batch_fn(self, fn):
  488. """Execute a post-processing function on input data.
  489. Args:
  490. fn (function): The function to run.
  491. """
  492. self.batch_fn = fn
  493. def is_gradient_accumulation_boundary(self):
  494. """True if the engine is executing a gradient reduction or optimizer step instruction.
  495. This is overridden from :class:`DeepSpeedEngine` to force reductions
  496. and steps when the pipeline engine is instructed to do so.
  497. Returns:
  498. bool: whether reductions and optimizer steps should occur.
  499. """
  500. return self._force_grad_boundary
  501. def log_for_device(self, *msg):
  502. if LOG_STAGE == self.stage_id or LOG_STAGE == -1:
  503. if DATA_PARALLEL_ID == self.grid.data_parallel_id or DATA_PARALLEL_ID == -1:
  504. print(
  505. f'RANK={dist.get_rank()} '
  506. f'PIPE-ID={self.stage_id} '
  507. f'DATA-ID={self.grid.data_parallel_id} '
  508. f'MBATCH-ID={self.microbatch_id} '
  509. f'STEP-ID={self.log_batch_step_id} '
  510. '::',
  511. *msg,
  512. flush=True)
  513. def tput_log(self, *msg):
  514. if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0:
  515. print(*msg)
  516. def _next_batch(self):
  517. # If using 3D parallelism, only some first-stage ranks may do IO
  518. batch = None
  519. if self.data_iterator is not None:
  520. batch = next(self.data_iterator)
  521. # Any post-processing, like broadcasting across a slice-parallel group.
  522. if self.batch_fn:
  523. batch = self.batch_fn(batch)
  524. return batch
  525. def _exec_forward_pass(self, buffer_id):
  526. self.tput_timer.start()
  527. self.mem_status('BEFORE FWD', reset_max=True)
  528. if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
  529. inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
  530. else:
  531. inputs = self.pipe_buffers['inputs'][buffer_id].clone()
  532. # collect the partitioned input from the previous stage
  533. if self.is_pipe_partitioned and not self.is_first_stage():
  534. if self.pipe_partition_input_meta_cache is None:
  535. self.pipe_partition_input_meta_cache = inputs[0].to('cpu')
  536. part_input = PartitionedTensor.from_meta(meta=self.pipe_partition_input_meta_cache,
  537. local_part=inputs[1],
  538. group=self.grid.get_slice_parallel_group())
  539. inputs = (part_input.full(), *inputs[2:])
  540. inputs[0].requires_grad = True
  541. # skip mask
  542. #inputs[1].requires_grad = True
  543. part_input = None
  544. inputs = inputs[0] if len(inputs) == 1 else inputs
  545. self.pipe_buffers['inputs'][buffer_id] = inputs
  546. # inputs has no gradient because it is from a cloned tensor
  547. outputs = super().forward(inputs)
  548. # Reset activation checkpointing buffers.
  549. # Need to call this between evaluation iterations
  550. if not self.module.training:
  551. ds_checkpointing.reset()
  552. # Partition the outputs if we are not the last stage
  553. if self.is_pipe_partitioned and not self.is_last_stage():
  554. if isinstance(outputs, tuple):
  555. first_output = outputs[0]
  556. # TODO: Improve pipe partitioning to pass multiple tensors that require grads
  557. assert all([torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:]])
  558. outputs_tail = outputs[1:]
  559. elif torch.is_tensor(outputs):
  560. first_output = outputs
  561. outputs_tail = []
  562. else:
  563. raise ValueError("expecting a tensor or a tuple of tensors")
  564. part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group())
  565. # Clear the large output data, but save the computation graph
  566. first_output.data = torch.zeros(1)
  567. self.pipe_buffers['output_tensors'][buffer_id] = first_output
  568. # Inject the partitioned tensor into the output before sending
  569. outputs = (part.to_meta(), part.data(), *outputs_tail)
  570. part = None
  571. self.pipe_buffers['outputs'][buffer_id] = outputs
  572. # Optionally compute loss on the last device
  573. if self.is_last_stage():
  574. if self._compute_loss and self.module.loss_fn is not None:
  575. labels = self.pipe_buffers['labels'][buffer_id]
  576. self.loss = self.module.loss_fn(outputs, labels)
  577. else:
  578. # Some models just return loss from forward()
  579. self.loss = outputs
  580. if self.eval_return_logits:
  581. self.outputs = outputs
  582. if isinstance(self.loss, torch.Tensor):
  583. self.fwd_outputs.append(self.loss.detach())
  584. if self.total_loss is None:
  585. self.total_loss = torch.zeros_like(self.loss)
  586. self.total_loss += self.loss.detach()
  587. else:
  588. self.fwd_outputs.append([l.detach() for l in self.loss])
  589. if self.total_loss is None:
  590. self.total_loss = [torch.zeros_like(l) for l in self.loss]
  591. for idx, l in enumerate(self.loss):
  592. self.total_loss[idx] += l.detach()
  593. def _exec_backward_pass(self, buffer_id):
  594. assert self.optimizer is not None, "must provide optimizer during " \
  595. "init in order to use backward"
  596. self.mem_status('BEFORE BWD', reset_max=True)
  597. # The last stage just runs backward on the loss using DeepSpeed's typical
  598. # mechanisms.
  599. if self.is_last_stage():
  600. super().backward(self.loss)
  601. self.mem_status('AFTER BWD')
  602. return
  603. outputs = self.pipe_buffers['outputs'][buffer_id]
  604. if self.wall_clock_breakdown():
  605. self.timers(BACKWARD_MICRO_TIMER).start()
  606. self.timers(BACKWARD_GLOBAL_TIMER).start()
  607. self.timers(BACKWARD_INNER_MICRO_TIMER).start()
  608. self.timers(BACKWARD_INNER_GLOBAL_TIMER).start()
  609. # Reconstruct if we previously partitioned the output. We must be
  610. # careful to also restore the computational graph of the tensors we partitioned.
  611. if self.is_pipe_partitioned:
  612. if self.is_grad_partitioned:
  613. if self.pipe_partition_output_meta_cache is None:
  614. self.pipe_partition_output_meta_cache = outputs[0].to('cpu')
  615. part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_output_meta_cache,
  616. local_part=outputs[1],
  617. group=self.grid.get_slice_parallel_group())
  618. self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
  619. outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:])
  620. else:
  621. # Already restored from partition
  622. self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0]
  623. outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:])
  624. grad_tensors = self.grad_layer
  625. if self.is_grad_partitioned:
  626. #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
  627. if self.grad_partition_grad_layer_meta_cache is None:
  628. self.grad_partition_grad_layer_meta_cache = self.grad_layer[0].to('cpu')
  629. part_grad = PartitionedTensor.from_meta(meta=self.grad_partition_grad_layer_meta_cache,
  630. local_part=self.grad_layer[1],
  631. group=self.grid.get_slice_parallel_group())
  632. grad_tensors = (part_grad.full(), *grad_tensors[2:])
  633. part_grad = None
  634. #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
  635. if self.using_bf16_optimizer and not self.is_last_stage():
  636. # manually call because we don't call optimizer.backward()
  637. self.optimizer.clear_lp_grads()
  638. # This handles either a single tensor or tuple of tensors.
  639. if isinstance(outputs, tuple):
  640. out_tensors = [t for t in outputs if t.is_floating_point()]
  641. assert len(out_tensors) == len(grad_tensors)
  642. torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors)
  643. else:
  644. torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
  645. if self.using_bf16_optimizer and not self.is_last_stage():
  646. # manually call because we don't call optimizer.backward()
  647. self.optimizer.update_hp_grads(clear_lp_grads=False)
  648. # Free up the memory from the output of forward()
  649. self.pipe_buffers['output_tensors'][buffer_id] = None
  650. self.pipe_buffers['outputs'][buffer_id] = None
  651. grad_tensors = None
  652. if self.wall_clock_breakdown():
  653. self.timers(BACKWARD_INNER_MICRO_TIMER).stop()
  654. self.timers(BACKWARD_INNER_GLOBAL_TIMER).stop()
  655. self.timers(BACKWARD_MICRO_TIMER).stop()
  656. self.timers(BACKWARD_GLOBAL_TIMER).stop()
  657. self.mem_status('AFTER BWD')
  658. def _exec_load_micro_batch(self, buffer_id):
  659. if self.wall_clock_breakdown():
  660. self.timers(BATCH_INPUT_TIMER).start()
  661. batch = self._next_batch()
  662. if self.is_first_stage():
  663. loaded = None
  664. if torch.is_tensor(batch[0]):
  665. loaded = batch[0].clone().to(self.device).detach()
  666. if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[
  667. 'use_reentrant']:
  668. loaded.requires_grad = loaded.is_floating_point()
  669. else:
  670. assert isinstance(batch[0], (tuple, list))
  671. # Assume list or tuple
  672. loaded = []
  673. for x in batch[0]:
  674. assert torch.is_tensor(x)
  675. mine = x.clone().detach().to(self.device)
  676. if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[
  677. 'use_reentrant']:
  678. mine.requires_grad = mine.is_floating_point()
  679. loaded.append(mine)
  680. loaded = tuple(loaded)
  681. self.pipe_buffers['inputs'][buffer_id] = loaded
  682. if self.is_last_stage():
  683. loaded = batch[1]
  684. if torch.is_tensor(batch[1]):
  685. loaded = batch[1].to(self.device)
  686. # XXX: torch 1.6.0 DataLoader will auto convert tuple to list
  687. elif isinstance(batch[1], (tuple, list)):
  688. loaded = []
  689. for x in batch[1]:
  690. assert torch.is_tensor(x)
  691. x = x.to(self.device).detach()
  692. loaded.append(x)
  693. loaded = tuple(loaded)
  694. self.pipe_buffers['labels'][buffer_id] = loaded
  695. if self.wall_clock_breakdown():
  696. self.timers(BATCH_INPUT_TIMER).stop()
  697. def _send_tensor_meta(self, buffer, recv_stage):
  698. """ Communicate metadata about upcoming p2p transfers.
  699. Metadata is communicated in this order:
  700. * type (0: tensor, 1: list)
  701. * num_tensors if type=list
  702. foreach tensor in buffer:
  703. * ndims
  704. * shape
  705. """
  706. send_bytes = 0
  707. if isinstance(buffer, torch.Tensor):
  708. type_tensor = torch.LongTensor(data=[0]).to(self.device)
  709. p2p.send(type_tensor, recv_stage)
  710. send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
  711. send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
  712. p2p.send(send_ndims, recv_stage)
  713. p2p.send(send_shape, recv_stage)
  714. send_bytes += _tensor_bytes(buffer)
  715. elif isinstance(buffer, list):
  716. assert (False)
  717. type_tensor = torch.LongTensor(data=[1]).to(self.device)
  718. p2p.send(type_tensor, recv_stage)
  719. count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
  720. p2p.send(count_tensor, recv_stage)
  721. for tensor in buffer:
  722. assert isinstance(tensor, torch.Tensor)
  723. send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
  724. send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
  725. p2p.send(send_ndims, recv_stage)
  726. p2p.send(send_shape, recv_stage)
  727. send_bytes += _tensor_bytes(tensor)
  728. elif isinstance(buffer, tuple):
  729. type_tensor = torch.LongTensor(data=[2]).to(self.device)
  730. p2p.send(type_tensor, recv_stage)
  731. count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
  732. p2p.send(count_tensor, recv_stage)
  733. for idx, tensor in enumerate(buffer):
  734. assert isinstance(tensor, torch.Tensor)
  735. send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
  736. send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
  737. send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device)
  738. p2p.send(send_dtype, recv_stage)
  739. p2p.send(send_ndims, recv_stage)
  740. p2p.send(send_shape, recv_stage)
  741. # Useful for performance debugging.
  742. '''
  743. new_bytes = _tensor_bytes(tensor)
  744. send_bytes += _tensor_bytes(tensor)
  745. # Useful for performance debugging.
  746. if self.grid.data_parallel_id == 0:
  747. print(
  748. f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
  749. )
  750. '''
  751. else:
  752. raise NotImplementedError(f'Could not send meta type {type(buffer)}')
  753. # Useful for performance debugging.
  754. '''
  755. if self.grid.data_parallel_id == 0:
  756. print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB')
  757. '''
  758. def _recv_tensor_meta(self, send_stage):
  759. """Receive metadata about upcoming p2p transfers and return allocated buffers.
  760. Metadata is communicated in this order:
  761. * type (0: tensor, 1: list)
  762. * num_tensors if type=list
  763. foreach tensor in buffer:
  764. * ndims
  765. * shape
  766. Returns:
  767. Allocated buffer for receiving from send_stage.
  768. """
  769. type_tensor = torch.LongTensor(data=[0]).to(self.device)
  770. p2p.recv(type_tensor, send_stage)
  771. recv_type = type_tensor.item()
  772. # A single tensor will be sent.
  773. if recv_type == 0:
  774. recv_ndims = torch.LongTensor(data=[0]).to(self.device)
  775. p2p.recv(recv_ndims, send_stage)
  776. recv_ndims = recv_ndims.item()
  777. recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
  778. p2p.recv(recv_shape, send_stage)
  779. recv_shape = recv_shape.tolist()
  780. return self._allocate_buffer(recv_shape, num_buffers=1)[0]
  781. # List or tuple of tensors
  782. elif recv_type == 1 or recv_type == 2:
  783. count_tensor = torch.LongTensor(data=[0]).to(self.device)
  784. p2p.recv(count_tensor, send_stage)
  785. num_tensors = count_tensor.item()
  786. recv_shapes_and_dtypes = []
  787. for idx in range(num_tensors):
  788. recv_dtype = torch.LongTensor(data=[0]).to(self.device)
  789. p2p.recv(recv_dtype, send_stage)
  790. recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
  791. recv_ndims = torch.LongTensor(data=[0]).to(self.device)
  792. p2p.recv(recv_ndims, send_stage)
  793. recv_ndims = recv_ndims.item()
  794. recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
  795. p2p.recv(recv_shape, send_stage)
  796. recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))
  797. buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
  798. # Convert to tuples if requested.
  799. if recv_type == 2:
  800. buffers = tuple(buffers)
  801. return buffers
  802. else:
  803. raise NotImplementedError(f'Could not receive type {type(recv_type)}')
  804. def _exec_send_activations(self, buffer_id):
  805. if self.wall_clock_breakdown():
  806. self.timers(PIPE_SEND_OUTPUT_TIMER).start()
  807. outputs = self.pipe_buffers['outputs'][buffer_id]
  808. # NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
  809. # We could do char, but with half() we can eventually flatten with other fp16
  810. # messages (TODO)
  811. if self.has_attention_mask or self.has_bool_tensors:
  812. outputs = list(outputs)
  813. outputs[-1] = outputs[-1].half()
  814. outputs = tuple(outputs)
  815. if self.first_output_send:
  816. self.first_output_send = False
  817. self._send_tensor_meta(outputs, self.next_stage)
  818. if isinstance(outputs, torch.Tensor):
  819. p2p.send(outputs, self.next_stage)
  820. elif isinstance(outputs, tuple):
  821. for idx, buffer in enumerate(outputs):
  822. p2p.send(buffer, self.next_stage)
  823. else:
  824. raise NotImplementedError('Could not send output of type '
  825. f'{type(outputs)}')
  826. # Restore the boolean tensor
  827. if self.has_attention_mask or self.has_bool_tensors:
  828. outputs = list(outputs)
  829. outputs[-1] = outputs[-1].bool()
  830. outputs = tuple(outputs)
  831. if self.wall_clock_breakdown():
  832. self.timers(PIPE_SEND_OUTPUT_TIMER).stop()
  833. def _exec_send_grads(self, buffer_id):
  834. if self.wall_clock_breakdown():
  835. self.timers(PIPE_SEND_GRAD_TIMER).start()
  836. inputs = self.pipe_buffers['inputs'][buffer_id]
  837. # Partition the gradient
  838. if self.is_grad_partitioned:
  839. if isinstance(inputs, tuple):
  840. first_input = inputs[0]
  841. assert all([torch.is_tensor(elt) for elt in inputs[1:]])
  842. inputs_grad_tail = [elt.grad for elt in inputs[1:]]
  843. elif torch.is_tensor(inputs):
  844. first_input = inputs
  845. inputs_grad_tail = []
  846. else:
  847. raise ValueError("expecting a tensor or a tuple of tensors")
  848. assert torch.is_tensor(first_input)
  849. part = PartitionedTensor(tensor=first_input.grad, group=self.grid.get_slice_parallel_group())
  850. inputs = (part.to_meta(), part.data(), *inputs_grad_tail)
  851. # XXX Terrible hack
  852. # Drop the attention mask from the input buffer here. It does not have
  853. # a grad that needs to be communicated. We free the buffer immediately
  854. # after, so no need to restore it. The receiver also has a hack that skips
  855. # the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
  856. if self.has_attention_mask or self.has_bool_tensors:
  857. inputs = list(inputs)
  858. inputs.pop()
  859. inputs = tuple(inputs)
  860. if isinstance(inputs, torch.Tensor):
  861. assert inputs.grad is not None
  862. p2p.send(inputs.grad, self.prev_stage)
  863. else:
  864. # XXX terrible hacky branch
  865. if self.is_grad_partitioned:
  866. # First two sends are partitioned gradient
  867. p2p.send(inputs[0], self.prev_stage)
  868. p2p.send(inputs[1], self.prev_stage)
  869. else:
  870. for idx, buffer in enumerate(inputs):
  871. # Skip tensors that will not produce a grad
  872. if not buffer.is_floating_point():
  873. assert buffer.grad is None
  874. continue
  875. assert buffer.grad is not None
  876. p2p.send(buffer.grad, self.prev_stage)
  877. # We can free up the input buffer now
  878. self.pipe_buffers['inputs'][buffer_id] = None
  879. if self.wall_clock_breakdown():
  880. self.timers(PIPE_SEND_GRAD_TIMER).stop()
  881. def _exec_recv_activations(self, buffer_id):
  882. if self.wall_clock_breakdown():
  883. self.timers(PIPE_RECV_INPUT_TIMER).start()
  884. recvd = None
  885. # Allocate the buffer if necessary
  886. if self.pipe_recv_buf is None:
  887. self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)
  888. if isinstance(self.pipe_recv_buf, torch.Tensor):
  889. p2p.recv(self.pipe_recv_buf, self.prev_stage)
  890. recvd = self.pipe_recv_buf.clone().detach()
  891. recvd.requires_grad = recvd.is_floating_point()
  892. else:
  893. assert isinstance(self.pipe_recv_buf, tuple)
  894. recvd = [None] * len(self.pipe_recv_buf)
  895. for idx, buffer in enumerate(self.pipe_recv_buf):
  896. assert torch.is_tensor(buffer)
  897. # XXX hardcode meta type
  898. if self.is_pipe_partitioned and idx == 0 and buffer.dtype != torch.long:
  899. if self.meta_buffer is None:
  900. self.meta_buffer = torch.zeros(buffer.size(), dtype=torch.long, device=self.device)
  901. buffer = self.meta_buffer
  902. p2p.recv(buffer, self.prev_stage)
  903. recvd[idx] = buffer.clone().detach()
  904. # NCCL does not like to send torch.BoolTensor types, so un-cast the
  905. # attention mask
  906. if self.has_attention_mask or self.has_bool_tensors:
  907. recvd[-1] = recvd[-1].bool()
  908. recvd = tuple(recvd)
  909. for buffer in recvd:
  910. buffer.requires_grad = buffer.is_floating_point()
  911. self.pipe_buffers['inputs'][buffer_id] = recvd
  912. if self.wall_clock_breakdown():
  913. self.timers(PIPE_RECV_INPUT_TIMER).stop()
  914. def _exec_recv_grads(self, buffer_id):
  915. if self.wall_clock_breakdown():
  916. self.timers(PIPE_RECV_GRAD_TIMER).start()
  917. outputs = self.pipe_buffers['outputs'][buffer_id]
  918. # XXX these shapes are hardcoded for Megatron
  919. # Restore partitioned output if it was partitioned and we are sending full gradients
  920. if self.is_pipe_partitioned and not self.is_grad_partitioned:
  921. if self.pipe_partition_grad_meta_cache is None:
  922. self.pipe_partition_grad_meta_cache = outputs[0].to('cpu')
  923. part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_grad_meta_cache,
  924. local_part=outputs[1],
  925. group=self.grid.get_slice_parallel_group())
  926. outputs[0].data = part_output.full()
  927. outputs = (outputs[0], *outputs[2:])
  928. # save for backward
  929. self.pipe_buffers['outputs'][buffer_id] = outputs
  930. # Allocate gradient if necessary
  931. if self.grad_layer is None:
  932. if isinstance(outputs, torch.Tensor):
  933. s = list(outputs.size())
  934. self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0]
  935. else:
  936. # XXX This is a HACK
  937. # When we exchange activations/gradients, the two pipe stages
  938. # need to issue the send/recv with the same buffer sizes or
  939. # else there is a deadlock. The is_floating_point() filter is
  940. # used to avoid sending gradients for tensors that do not
  941. # produce gradients. When TP>1, we partition the first
  942. # activations/gradients across TP ranks to save communication
  943. # volume and memory. That partitioned tensor is represented as
  944. # two tensors: a 1/TPth chunk of the original data and also a
  945. # small LongTensor storing the metadata used to reconstruct on
  946. # the other side. When combined, the floating point filter also
  947. # filtered out the metadata tensor. This quick (hacky) fix just
  948. # branches on is_grad_partitioned so we don't filter out the
  949. # metadata tensor.
  950. if self.is_grad_partitioned:
  951. sizes_and_dtypes = [(list(t.size()), t.dtype)
  952. for t in outputs[:2]] + [(list(t.size()), t.dtype)
  953. for t in outputs[2:] if t.is_floating_point()]
  954. else:
  955. sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()]
  956. self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0]
  957. if isinstance(self.grad_layer, torch.Tensor):
  958. p2p.recv(self.grad_layer, self.next_stage)
  959. else:
  960. assert isinstance(outputs, tuple)
  961. for idx, buffer in enumerate(self.grad_layer):
  962. # XXX GPT-2 hack
  963. if self.is_grad_partitioned and idx == 0 and buffer.dtype != torch.long:
  964. buffer.data = torch.zeros(buffer.size(), dtype=torch.long, device=self.device)
  965. p2p.recv(buffer, self.next_stage)
  966. if self.wall_clock_breakdown():
  967. self.timers(PIPE_RECV_GRAD_TIMER).stop()
  968. def _exec_optimizer_step(self, lr_kwargs=None):
  969. if self.wall_clock_breakdown():
  970. self.timers(STEP_MICRO_TIMER).start()
  971. self.timers(STEP_GLOBAL_TIMER).start()
  972. self.mem_status('BEFORE STEP', reset_max=True)
  973. self._force_grad_boundary = True
  974. self._take_model_step(lr_kwargs)
  975. self._force_grad_boundary = False
  976. self.mem_status('AFTER STEP')
  977. if self.global_rank == 0 and self.monitor.enabled:
  978. self.summary_events = [(f'Train/Samples/lr', self.get_lr()[0], self.global_samples)]
  979. if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
  980. self.summary_events.append(
  981. (f'Train/Samples/loss_scale', self.optimizer.cur_scale, self.global_samples))
  982. self.monitor.write_events(self.summary_events)
  983. if self.wall_clock_breakdown():
  984. self.timers(STEP_MICRO_TIMER).stop()
  985. self.timers(STEP_GLOBAL_TIMER).stop()
  986. if self.global_steps % self.steps_per_print() == 0:
  987. self.timers.log([
  988. BATCH_INPUT_TIMER,
  989. FORWARD_MICRO_TIMER,
  990. BACKWARD_MICRO_TIMER,
  991. BACKWARD_INNER_MICRO_TIMER,
  992. BACKWARD_REDUCE_MICRO_TIMER,
  993. STEP_MICRO_TIMER,
  994. ])
  995. if self.global_steps % self.steps_per_print() == 0:
  996. self.timers.log([
  997. FORWARD_GLOBAL_TIMER,
  998. BACKWARD_GLOBAL_TIMER,
  999. BACKWARD_INNER_GLOBAL_TIMER,
  1000. BACKWARD_REDUCE_GLOBAL_TIMER,
  1001. STEP_GLOBAL_TIMER,
  1002. ])
  1003. def _allocate_zeros(self, shape, **kwargs):
  1004. """ Allocate a tensor of zeros on the engine's device.
  1005. Arguments:
  1006. shape: the shape of the tensor to allocate
  1007. kwargs: passed to torch.zeros()
  1008. Returns:
  1009. A tensor from torch.zeros() allocated on self.device.
  1010. """
  1011. if "dtype" not in kwargs:
  1012. if self.fp16_enabled():
  1013. kwargs["dtype"] = torch.half
  1014. if self.bfloat16_enabled():
  1015. kwargs["dtype"] = torch.bfloat16
  1016. return torch.zeros(shape, device=self.device, **kwargs)
  1017. def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
  1018. buffers = []
  1019. if num_buffers == -1:
  1020. num_buffers = self.num_pipe_buffers
  1021. for count in range(num_buffers):
  1022. buffers.append(self._allocate_zeros(shape, **kwargs))
  1023. return buffers
  1024. def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1):
  1025. buffers = []
  1026. if num_buffers == -1:
  1027. num_buffers = self.num_pipe_buffers
  1028. for count in range(num_buffers):
  1029. buffer = []
  1030. for shape, dtype in shapes_and_dtypes:
  1031. buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad))
  1032. buffers.append(buffer)
  1033. return buffers
  1034. def forward(self, *args, **kwargs):
  1035. """Disabled for pipeline parallel training. See ``train_batch()``. """
  1036. raise PipelineError("Only train_batch() is accessible in pipeline mode.")
  1037. def backward(self, *args, **kwargs):
  1038. """Disabled for pipeline parallel training. See ``train_batch()``. """
  1039. raise PipelineError("Only train_batch() is accessible in pipeline mode.")
  1040. def step(self, *args, **kwargs):
  1041. """Disabled for pipeline parallel training. See ``train_batch()``. """
  1042. raise PipelineError("Only train_batch() is accessible in pipeline mode.")
  1043. def mem_status(self, msg, print_rank=-1, reset_max=False):
  1044. return
  1045. global mem_alloced, mem_cached
  1046. if not self.global_steps == 0 or not self.global_steps == 9:
  1047. #return
  1048. pass
  1049. if self.mpu.get_data_parallel_rank() != 0:
  1050. return
  1051. if self.global_rank != 0:
  1052. return
  1053. rank = self.global_rank
  1054. if print_rank != -1 and rank != print_rank:
  1055. return
  1056. get_accelerator().synchronize()
  1057. if reset_max:
  1058. get_accelerator().reset_max_memory_cached()
  1059. get_accelerator().reset_max_memory_allocated()
  1060. new_alloced = get_accelerator().memory_allocated()
  1061. new_cached = get_accelerator().memory_cached()
  1062. delta_alloced = new_alloced - mem_alloced
  1063. delta_cached = new_cached - mem_cached
  1064. mem_cached = new_cached
  1065. mem_alloced = new_alloced
  1066. max_alloced = get_accelerator().max_memory_allocated()
  1067. max_cached = get_accelerator().max_memory_cached()
  1068. # convert to GB for printing
  1069. new_alloced /= 1024**3
  1070. new_cached /= 1024**3
  1071. delta_alloced /= 1024**3
  1072. delta_cached /= 1024**3
  1073. max_alloced /= 1024**3
  1074. max_cached /= 1024**3
  1075. print(
  1076. f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS', msg,
  1077. f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
  1078. f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')
  1079. def module_state_dict(self, exclude_frozen_parameters=False):
  1080. """Override hack to save a pipe model and return the directory path of the save.
  1081. This method should only be called by DeepSpeed's ``save_checkpoint()``. The
  1082. recommended way of saving a ``PipelineModule`` outside of ``save_checkpoint()``
  1083. is ``save_state_dict()``.
  1084. Returns:
  1085. None
  1086. """
  1087. assert isinstance(self.module, PipelineModule)
  1088. assert self._curr_ckpt_path is not None, \
  1089. "PipelineEngine expects module_state_dict() to be called from save_checkpoint()"
  1090. self.module.save_state_dict(self._curr_ckpt_path,
  1091. checkpoint_engine=self.checkpoint_engine,
  1092. exclude_frozen_params=exclude_frozen_parameters)
  1093. return None
  1094. def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False):
  1095. """Override hack to instead use a directory path.
  1096. This is important because pipeline models checkpoint by layer instead of rank.
  1097. If ``state_dict`` is not ``None`` or a ``str``, we revert to ``super()`` expecting a ``dict``.
  1098. Args:
  1099. state_dict (str, None): unused
  1100. strict (bool, optional): Strict state loading. Defaults to True.
  1101. """
  1102. assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism"
  1103. state_dict = checkpoint['module']
  1104. if (state_dict is not None) and (not isinstance(state_dict, str)):
  1105. super().load_module_state_dict(state_dict, strict)
  1106. return
  1107. self.module.load_state_dir(load_dir=self._curr_ckpt_path,
  1108. strict=strict,
  1109. checkpoint_engine=self.checkpoint_engine)
  1110. # A map of PipeInstruction types to methods. Each method will be executed with the
  1111. # kwargs provided to the PipeInstruction from the scheduler.
  1112. _INSTRUCTION_MAP = {
  1113. schedule.OptimizerStep: _exec_optimizer_step,
  1114. schedule.ReduceGrads: _exec_reduce_grads,
  1115. schedule.ReduceTiedGrads: _exec_reduce_tied_grads,
  1116. schedule.LoadMicroBatch: _exec_load_micro_batch,
  1117. schedule.ForwardPass: _exec_forward_pass,
  1118. schedule.BackwardPass: _exec_backward_pass,
  1119. schedule.SendActivation: _exec_send_activations,
  1120. schedule.RecvActivation: _exec_recv_activations,
  1121. schedule.SendGrad: _exec_send_grads,
  1122. schedule.RecvGrad: _exec_recv_grads,
  1123. }
  1124. def _exec_schedule(self, pipe_schedule):
  1125. # Reserve and reset buffers.
  1126. self._reserve_pipe_buffers(pipe_schedule.num_pipe_buffers())
  1127. self.fwd_outputs = []
  1128. # For each step in the schedule
  1129. for step_cmds in pipe_schedule:
  1130. # For each instruction in the step
  1131. for cmd in step_cmds:
  1132. if type(cmd) not in self._INSTRUCTION_MAP:
  1133. raise RuntimeError(f'{self.__class__.__name__} does not understand instruction {repr(cmd)}')
  1134. # Equivalent to: self._exec_forward_pass(buffer_id=0)
  1135. self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
  1136. self._exec_instr(**cmd.kwargs)