12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376 |
- # Copyright 2019 The Microsoft DeepSpeed Team
- from types import MethodType
- import torch
- from deepspeed import comm as dist
- from deepspeed.utils import logger
- from deepspeed.utils.timer import ThroughputTimer
- from deepspeed.accelerator import get_accelerator
- from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
- from ..utils import PartitionedTensor
- from ..dataloader import RepeatingLoader
- from .module import PipelineModule, PipelineError
- from . import p2p
- from . import schedule
- TARGET_ID = -2
- LOG_STAGE = -2
- DATA_PARALLEL_ID = -2
- def is_even(number):
- return number % 2 == 0
- mem_alloced = 0
- mem_cached = 0
- def _tensor_bytes(tensor):
- return tensor.numel() * tensor.element_size()
- class PipelineEngine(DeepSpeedEngine):
- """ A training engine hybrid pipeline, data, and model parallel training.
- This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`
- is provided.
- """
- ID_TO_DTYPE = [
- torch.float32,
- torch.float64,
- torch.complex64,
- torch.complex128,
- torch.float16,
- torch.bfloat16,
- torch.uint8,
- torch.int8,
- torch.int16,
- torch.int32,
- torch.int64,
- torch.bool
- ]
- DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}
- def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
- super().__init__(*super_args, **super_kwargs)
- assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
- assert self.zero_optimization_stage() < 2, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
- # We schedule the all-reduces, so disable it in super().backward()
- self.enable_backward_allreduce = False
- self.has_bool_tensors = has_bool_tensors
- self.eval_return_logits = False
- self.outputs = None
- # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
- self.pipeline_enable_backward_allreduce = True
- if self.elasticity_enabled():
- if not self.is_elastic_model_parallel_supported():
- assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
- " with pipeline parallelism."
- # pipeline step for logging
- self.log_batch_step_id = -1
- self.micro_batch_size = self.train_micro_batch_size_per_gpu()
- self.micro_batches = self.gradient_accumulation_steps()
- # Set Grid and Communication Groups
- self.grid = self.module._grid
- if self.grid.get_global_rank() == 0:
- logger.info(f'CONFIG: micro_batches={self.micro_batches} '
- f'micro_batch_size={self.micro_batch_size}')
- self.global_rank = self.grid.get_global_rank()
- assert self.dp_world_size == self.grid.data_parallel_size
- assert self.train_batch_size() == \
- self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size
- # Set Stage Inf
- self.num_stages = self.grid.pipe_parallel_size
- self.stage_id = self.grid.get_stage_id()
- self.prev_stage = self.stage_id - 1
- self.next_stage = self.stage_id + 1
- self.data_iterator = None
- self.batch_fn = None
- self._force_grad_boundary = False
- self.batch_timer = ThroughputTimer(batch_size=self.train_batch_size(),
- logging_fn=self.tput_log,
- monitor_memory=False,
- steps_per_output=self.steps_per_print())
- # PipelineEngine needs to handle data loading specially due to only the first
- # and last stages loading inputs/labels. We construct a sampler that uses
- if self.training_data:
- self._build_data_iter(self.training_data)
- self.is_pipe_parallel = self.grid.pipe_parallel_size > 1
- self.is_data_parallel = self.grid.data_parallel_size > 1
- self.is_model_parallel = self.grid.model_parallel_size > 1
- # Partition input/output buffers
- # XXX temporarily disable while I revert some partition hacks.
- self.is_pipe_partitioned = self.is_model_parallel
- self.is_grad_partitioned = self.is_model_parallel
- model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())
- num_params = sum([p.numel() for p in model_parameters])
- unique_params = num_params
- # Subtract tied parameters if we don't own them
- if self.module.tied_comms:
- tied_params = 0
- for key, d in self.module.tied_comms.items():
- if self.global_rank != min(d['ranks']):
- tied_params += sum(p.numel() for p in d['module'].parameters())
- unique_params -= tied_params
- params_tensor = torch.LongTensor(data=[num_params,
- unique_params]).to(self.device)
- dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group())
- params_tensor = params_tensor.tolist()
- total_params = params_tensor[0]
- unique_params = params_tensor[1]
- if self.grid.data_parallel_id == 0:
- logger.info(f'RANK={self.global_rank} '
- f'STAGE={self.stage_id} '
- f'LAYERS={self.module._local_stop - self.module._local_start} '
- f'[{self.module._local_start}, {self.module._local_stop}) '
- f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) '
- f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) '
- f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)')
- #initialize peer-2-peer communication and allreduce groups
- if self.is_pipe_parallel:
- p2p.init_process_groups(self.grid)
- # Pipeline buffers
- self.num_pipe_buffers = 0
- self.pipe_buffers = {
- 'inputs' : [], # batch input and received activations
- 'labels' : [], # labels from batch input
- 'outputs' : [], # activations
- 'output_tensors' : [], # tensor object to preserve backward graph
- }
- self.pipe_recv_buf = None
- self.grad_layer = None
- self.meta_buffer = None
- self.first_output_send = True
- self.first_gradient_send = True
- #stores the loss for the current micro batch being processed
- self.loss = torch.tensor(0.0).to(self.device)
- #stores the loss for the entire batch
- self.total_loss = None
- self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
- self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
- if self._config.pipeline['activation_checkpoint_interval'] > 0:
- self.module.activation_checkpoint_interval = self._config.pipeline[
- 'activation_checkpoint_interval']
- self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline
- if self.is_last_stage():
- self.loss_model = self.module.loss_fn
- self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe'
- # Initialize pipeline communicators. Just send a 0.
- if is_even(self.stage_id):
- if not self.is_last_stage():
- p2p.send(self.loss, self.next_stage)
- if not self.is_first_stage():
- p2p.recv(self.loss, self.prev_stage)
- else:
- if not self.is_first_stage():
- p2p.recv(self.loss, self.prev_stage)
- if not self.is_last_stage():
- p2p.send(self.loss, self.next_stage)
- # XXX look into timer reporting timing
- # Initialize some timers because of early weirdness.
- if self.wall_clock_breakdown():
- self.timers('forward_microstep').start()
- self.timers('forward_microstep').stop()
- self.timers('backward_microstep').start()
- self.timers('backward_microstep').stop()
- self.timers('backward_inner_microstep').start()
- self.timers('backward_inner_microstep').stop()
- self.timers('backward_allreduce_microstep').start()
- self.timers('backward_allreduce_microstep').stop()
- self.timers('backward_allreduce').start()
- self.timers('backward_allreduce').stop()
- self.timers('step_microstep').start()
- self.timers('step_microstep').stop()
- def set_has_attention_mask(self, value):
- assert isinstance(value, bool)
- self.has_attention_mask = value
- def _build_data_iter(self, dataset):
- sampler = torch.utils.data.distributed.DistributedSampler(
- dataset,
- num_replicas=self.dp_world_size,
- rank=self.mpu.get_data_parallel_rank(),
- shuffle=False)
- # Build a loader and make it repeating.
- pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler)
- pipe_dataloader = RepeatingLoader(pipe_dataloader)
- self.set_dataloader(pipe_dataloader)
- def _exec_reduce_tied_grads(self):
- # We need to run this first to write to self.averaged_gradients;
- # since this class turns `enable_backward_allreduce` off,
- # `self.overlapping_partition_gradients_reduce_epilogue()` defined in the DeepSpeedEngine
- # never actually runs. I suspect this is because of efficiency problems; get_flat_partition in
- # stage2.py might do something expensive; someone will have to look into that later. But
- # in the meantime, this fixes ZeRO2 + Pipelining enough to run a demo. Further profiling
- # needed to decide if it actually breaks everything.
- # (see https://github.com/EleutherAI/gpt-neox/issues/62#issuecomment-761471944)
- if self.zero_optimization_partition_gradients():
- self.optimizer.overlapping_partition_gradients_reduce_epilogue()
- weight_group_list = self.module.get_tied_weights_and_groups()
- for weight, group in weight_group_list:
- grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad
- dist.all_reduce(grad, group=group)
- def _exec_reduce_grads(self):
- self._force_grad_boundary = True
- if self.pipeline_enable_backward_allreduce:
- if self.bfloat16_enabled():
- if self.zero_optimization_stage() == 0:
- self._bf16_reduce_grads()
- else:
- assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported"
- raise NotImplementedError()
- else:
- self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
- self._force_grad_boundary = False
- def _bf16_reduce_grads(self):
- # Make our own list of gradients from the optimizer's FP32 grads
- grads = []
- self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(),
- elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
- def _reserve_pipe_buffers(self, num_buffers):
- """Ensure that each pipeline buffer has at least ``num_buffers`` slots.
- This method only reserves slots and does not allocate tensors.
- Args:
- num_buffers (int): The number of buffers to reserve.
- """
- if self.num_pipe_buffers >= num_buffers:
- return
- num_added = num_buffers - self.num_pipe_buffers
- for key in self.pipe_buffers:
- self.pipe_buffers[key].extend([None] * num_added)
- self.num_pipe_buffers = num_buffers
- def reset_activation_shape(self):
- """Reset the buffers when the shape of activation and gradient change.
- For example, for curriculum learning that changes the seqlen of each
- sample, we need to call this whenever the seqlen is going to change.
- """
- self.first_output_send = True
- self.pipe_recv_buf = None
- self.grad_layer = None
- self.meta_buffer = None
- def train_batch(self, data_iter=None):
- """Progress the pipeline to train the next batch of data. The engine will ingest
- ``self.train_batch_size()`` total samples collectively across all workers.
- An iterator that over training data should be provided as an argument
- unless ``deepspeed.initialize()`` was provided a training set. In that event,
- the training data will automatically be read.
- .. warning::
- A total of ``self.gradient_accumulation_steps()`` entries will be pulled
- from ``data_iter`` by each pipeline. There must be sufficient
- data left in ``data_iter`` or else a ``StopIteration`` will halt training.
- DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
- that wraps data loaders to automatically restart upon a ``StopIteration``.
- Args:
- data_iter (Iterator, optional): Iterator of training data.
- Returns:
- The arithmetic mean of the losses computed this batch.
- """
- if not torch._C.is_grad_enabled():
- raise RuntimeError(
- f'train_batch() requires gradients enabled. Use eval_batch() instead.')
- # Curriculum learning could change activation shape
- if self.curriculum_enabled_legacy():
- new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
- self.global_steps + 1)
- if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
- self.reset_activation_shape()
- self.curriculum_scheduler_legacy.first_step = False
- elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
- self.global_steps):
- self.reset_activation_shape()
- if data_iter:
- self.set_dataiterator(data_iter)
- self.module.train()
- self.total_loss = None
- self._compute_loss = True
- # Do the work
- self.timers('train_batch').start()
- sched = schedule.TrainSchedule(micro_batches=self.micro_batches,
- stages=self.num_stages,
- stage_id=self.stage_id)
- self._exec_schedule(sched)
- self.agg_train_loss = self._aggregate_total_loss()
- self.timers('train_batch').stop()
- if self.global_steps % self.steps_per_print() == 0:
- if self.global_rank == 0:
- elapsed = self.timers('train_batch').elapsed(reset=True) / 1000.0
- iter_time = elapsed / self.steps_per_print()
- tput = self.train_batch_size() / iter_time
- print(f'steps: {self.global_steps} '
- f'loss: {self.agg_train_loss:0.4f} '
- f'iter time (s): {iter_time:0.3f} '
- f'samples/sec: {tput:0.3f}')
- # Monitoring
- if self.global_rank == 0 and self.monitor.enabled:
- self.summary_events = [(f'Train/Samples/train_loss',
- self.agg_train_loss.mean().item(),
- self.global_samples)]
- self.monitor.write_events(self.summary_events)
- if self.wall_clock_breakdown(
- ) and self.global_steps % self.steps_per_print() == 0:
- self.timers.log([
- 'pipe_send_output',
- 'pipe_send_grad',
- 'pipe_recv_input',
- 'pipe_recv_grad'
- ])
- # TODO: should return precisely what loss returned and allow others to be queried?
- return self.agg_train_loss
- def eval_batch(self,
- data_iter,
- return_logits=False,
- compute_loss=True,
- reduce_output='avg'):
- """Evaluate the pipeline on a batch of data from ``data_iter``. The
- engine will evaluate ``self.train_batch_size()`` total samples
- collectively across all workers.
- This method is equivalent to:
- .. code-block:: python
- module.eval()
- with torch.no_grad():
- output = module(batch)
- .. warning::
- A total of ``self.gradient_accumulation_steps()`` entries will be pulled
- from ``data_iter`` by each pipeline. There must be sufficient
- data left in ``data_iter`` or else a ``StopIteration`` will halt training.
- DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
- that wraps data loaders to automatically restart upon a ``StopIteration``.
- Args:
- data_iter (Iterator): Iterator of data to evaluate.
- Returns:
- The arithmetic mean of the losses computed this batch.
- """
- self.eval_return_logits = return_logits
- self.module.eval()
- # Curriculum learning could change activation shape
- if self.curriculum_enabled_legacy():
- new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
- self.global_steps + 1)
- if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
- self.reset_activation_shape()
- self.curriculum_scheduler_legacy.first_step = False
- elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
- self.global_steps):
- self.reset_activation_shape()
- eval_output = None
- self._compute_loss = compute_loss
- # Use the provided data iterator
- train_iterator = self.data_iterator
- self.set_dataiterator(data_iter)
- # Do the work
- sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
- stages=self.num_stages,
- stage_id=self.stage_id)
- # prevent dead-lock with multiple evals sequence
- dist.barrier()
- with torch.no_grad():
- self._exec_schedule(sched)
- if self.is_last_stage():
- eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output)
- if compute_loss:
- eval_output = self._bcast_pipe_scalar(eval_output)
- if self.global_rank == 0 and self.monitor.enabled:
- self.summary_events = [(f'Train/Samples/eval_loss',
- eval_output.mean().item(),
- self.global_samples)]
- self.monitor.write_events(self.summary_events)
- # Restore the training iterator
- self.set_dataiterator(train_iterator)
- # Reset any buffers that may have been populated during the forward passes.
- #ds_checkpointing.reset()
- self.eval_return_logits = False
- if return_logits:
- outputs = self.outputs
- self.outputs = None
- return eval_output, outputs
- return eval_output
- def set_train_batch_size(self, train_batch_size):
- """Adjust the global batch size by increasing or decreasing the number of
- micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
- (i.e., ``train_micro_batch_size_per_gpu``) is not changed.
- Args:
- train_batch_size (int): The new global batch size for training.
- Raises:
- ValueError: if ``train_batch_size`` is not divisible by the
- configured micro-batch size and data parallelism.
- """
- super().set_train_batch_size(train_batch_size)
- self.micro_batches = self.gradient_accumulation_steps()
- def is_first_stage(self):
- """True if this process is in the first stage in the pipeline."""
- return self.stage_id == 0
- def is_last_stage(self):
- """True if this process is in the last stage in the pipeline."""
- return self.stage_id == self.num_stages - 1
- def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True):
- if reduce is None:
- return outputs
- if reduce.lower() == 'avg':
- # first sum over all microbatches
- if torch.is_tensor(outputs[0]):
- reduced = sum(outputs)
- else:
- assert isinstance(outputs, (list, tuple))
- reduced = [torch.zeros_like(o) for o in outputs[0]]
- for idx, out in outputs:
- reduced[idx] += out
- # Average over the microbatches
- reduced = self._scale_loss_by_gas(reduced)
- # Average over DP groups
- if reduce_dp and self.is_data_parallel:
- if torch.is_tensor(reduced):
- dist.all_reduce(reduced, group=self.mpu.get_data_parallel_group())
- reduced /= self.dp_world_size
- else:
- for idx in range(len(reduced)):
- dist.all_reduce(reduced[idx],
- group=self.mpu.get_data_parallel_group())
- reduced[idx] /= self.dp_world_size
- return reduced
- else:
- raise NotImplementedError(f'reduction type {reduce} not supported.')
- def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32):
- # Default to last stage (e.g., for broadcasting loss)
- if src_rank is None:
- src_rank = self.grid.stage_to_global(self.num_stages - 1)
- assert src_rank in self.grid.pp_group
- if self.global_rank == src_rank:
- result = data.clone().detach()
- else:
- result = torch.Tensor([0.]).type(dtype).to(self.device)
- dist.broadcast(tensor=result,
- src=src_rank,
- group=self.mpu.get_pipe_parallel_group())
- return result
- def _aggregate_total_loss(self):
- # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
- if self.is_last_stage():
- loss = self._scale_loss_by_gas(self.total_loss)
- self.dp_group_loss = loss.clone().detach()
- ## Average loss across all data-parallel groups
- agg_loss = self.dp_group_loss.clone().detach()
- #print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)
- if self.is_data_parallel:
- dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
- agg_loss /= self.dp_world_size
- assert self.global_rank in self.grid.pp_group
- losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
- dist.broadcast(tensor=losses,
- src=self.global_rank,
- group=self.mpu.get_pipe_parallel_group())
- else:
- # Get loss from last stage
- src_rank = self.grid.stage_to_global(self.num_stages - 1)
- assert src_rank in self.grid.pp_group
- losses = torch.Tensor([0., 0.]).to(self.device)
- dist.broadcast(tensor=losses,
- src=src_rank,
- group=self.grid.get_pipe_parallel_group())
- self.dp_group_loss = losses[0].clone().detach()
- agg_loss = losses[1].clone().detach()
- return agg_loss
- def set_dataloader(self, loader):
- """"""
- if self.is_first_stage() or self.is_last_stage():
- self.training_dataloader = loader
- self.data_iterator = iter(self.training_dataloader)
- def set_dataiterator(self, iterator):
- """ Store an iterator to sample for training data. """
- if self.is_first_stage() or self.is_last_stage():
- self.training_dataloader = None
- self.data_iterator = iterator
- def set_batch_fn(self, fn):
- """Execute a post-processing function on input data.
- Args:
- fn (function): The function to run.
- """
- self.batch_fn = fn
- def is_gradient_accumulation_boundary(self):
- """True if the engine is executing a gradient reduction or optimizer step instruction.
- This is overridden from :class:`DeepSpeedEngine` to force reductions
- and steps when the pipeline engine is instructed to do so.
- Returns:
- bool: whether reductions and optimizer steps should occur.
- """
- return self._force_grad_boundary
- def log_for_device(self, *msg):
- if LOG_STAGE == self.stage_id or LOG_STAGE == -1:
- if DATA_PARALLEL_ID == self.grid.data_parallel_id or DATA_PARALLEL_ID == -1:
- print(
- f'RANK={dist.get_rank()} '
- f'PIPE-ID={self.stage_id} '
- f'DATA-ID={self.grid.data_parallel_id} '
- f'MBATCH-ID={self.microbatch_id} '
- f'STEP-ID={self.log_batch_step_id} '
- '::',
- *msg,
- flush=True)
- def tput_log(self, *msg):
- if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0:
- print(*msg)
- def _next_batch(self):
- # If using 3D parallelism, only some first-stage ranks may do IO
- batch = None
- if self.data_iterator is not None:
- batch = next(self.data_iterator)
- # Any post-processing, like broadcasting across a slice-parallel group.
- if self.batch_fn:
- batch = self.batch_fn(batch)
- return batch
- def _exec_forward_pass(self, buffer_id):
- self.tput_timer.start()
- self.mem_status('BEFORE FWD', reset_max=True)
- if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
- inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
- else:
- inputs = self.pipe_buffers['inputs'][buffer_id].clone()
- # collect the partitioned input from the previous stage
- if self.is_pipe_partitioned and not self.is_first_stage():
- part_input = PartitionedTensor.from_meta(
- meta=inputs[0],
- local_part=inputs[1],
- group=self.grid.get_slice_parallel_group())
- inputs = (part_input.full(), *inputs[2:])
- inputs[0].requires_grad = True
- # skip mask
- #inputs[1].requires_grad = True
- part_input = None
- inputs = inputs[0] if len(inputs) == 1 else inputs
- self.pipe_buffers['inputs'][buffer_id] = inputs
- # Zero out the gradients each time we use the tensor because only the data in
- # tensor changes across batches
- self._zero_grads(inputs)
- outputs = super().forward(inputs)
- # Partition the outputs if we are not the last stage
- if self.is_pipe_partitioned and not self.is_last_stage():
- if isinstance(outputs, tuple):
- first_output = outputs[0]
- # TODO: Improve pipe partitioning to pass multiple tensors that require grads
- assert all([
- torch.is_tensor(elt) and elt.requires_grad is False
- for elt in outputs[1:]
- ])
- outputs_tail = outputs[1:]
- elif torch.is_tensor(outputs):
- first_output = outputs
- outputs_tail = []
- else:
- raise ValueError("expecting a tensor or a tuple of tensors")
- part = PartitionedTensor(tensor=first_output,
- group=self.grid.get_slice_parallel_group())
- # Clear the large output data, but save the computation graph
- first_output.data = torch.zeros(1)
- self.pipe_buffers['output_tensors'][buffer_id] = first_output
- # Inject the partitioned tensor into the output before sending
- outputs = (part.to_meta(), part.data(), *outputs_tail)
- part = None
- self.pipe_buffers['outputs'][buffer_id] = outputs
- # Optionally compute loss on the last device
- if self.is_last_stage():
- if self._compute_loss and self.module.loss_fn is not None:
- labels = self.pipe_buffers['labels'][buffer_id]
- self.loss = self.module.loss_fn(outputs, labels)
- else:
- # Some models just return loss from forward()
- self.loss = outputs
- if self.eval_return_logits:
- self.outputs = outputs
- if isinstance(self.loss, torch.Tensor):
- self.fwd_outputs.append(self.loss.detach())
- if self.total_loss is None:
- self.total_loss = torch.zeros_like(self.loss)
- self.total_loss += self.loss.detach()
- else:
- self.fwd_outputs.append([l.detach() for l in self.loss])
- if self.total_loss is None:
- self.total_loss = [torch.zeros_like(l) for l in self.loss]
- for idx, l in enumerate(self.loss):
- self.total_loss[idx] += l.detach()
- def _exec_backward_pass(self, buffer_id):
- assert self.optimizer is not None, "must provide optimizer during " \
- "init in order to use backward"
- self.mem_status('BEFORE BWD', reset_max=True)
- # The last stage just runs backward on the loss using DeepSpeed's typical
- # mechanisms.
- if self.is_last_stage():
- super().backward(self.loss)
- self.mem_status('AFTER BWD')
- return
- outputs = self.pipe_buffers['outputs'][buffer_id]
- if self.wall_clock_breakdown():
- self.timers('backward_microstep').start()
- self.timers('backward').start()
- self.timers('backward_inner_microstep').start()
- self.timers('backward_inner').start()
- # Reconstruct if we previously partitioned the output. We must be
- # careful to also restore the computational graph of the tensors we partitioned.
- if self.is_pipe_partitioned:
- if self.is_grad_partitioned:
- part_output = PartitionedTensor.from_meta(
- meta=outputs[0],
- local_part=outputs[1],
- group=self.grid.get_slice_parallel_group())
- self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
- outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:])
- else:
- # Already restored from partition
- self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0]
- outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:])
- grad_tensors = self.grad_layer
- if self.is_grad_partitioned:
- #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
- part_grad = PartitionedTensor.from_meta(
- meta=self.grad_layer[0],
- local_part=self.grad_layer[1],
- group=self.grid.get_slice_parallel_group())
- grad_tensors = (part_grad.full(), *grad_tensors[2:])
- part_grad = None
- #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
- if self.bfloat16_enabled() and not self.is_last_stage():
- # manually call because we don't call optimizer.backward()
- self.optimizer.clear_lp_grads()
- # This handles either a single tensor or tuple of tensors.
- if isinstance(outputs, tuple):
- out_tensors = [t for t in outputs if t.is_floating_point()]
- assert len(out_tensors) == len(grad_tensors)
- torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors)
- else:
- torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
- if self.bfloat16_enabled() and not self.is_last_stage():
- # manually call because we don't call optimizer.backward()
- self.optimizer.update_hp_grads(clear_lp_grads=False)
- # Free up the memory from the output of forward()
- self.pipe_buffers['output_tensors'][buffer_id] = None
- self.pipe_buffers['outputs'][buffer_id] = None
- grad_tensors = None
- if self.wall_clock_breakdown():
- self.timers('backward_inner').stop()
- self.timers('backward_inner_microstep').stop()
- self.timers('backward').stop()
- self.timers('backward_microstep').stop()
- self.mem_status('AFTER BWD')
- def _exec_load_micro_batch(self, buffer_id):
- if self.wall_clock_breakdown():
- self.timers('batch_input').start()
- batch = self._next_batch()
- if self.is_first_stage():
- loaded = None
- if torch.is_tensor(batch[0]):
- loaded = batch[0].clone().to(self.device).detach()
- loaded.requires_grad = loaded.is_floating_point()
- else:
- assert isinstance(batch[0], tuple)
- # Assume list or tuple
- loaded = []
- for x in batch[0]:
- assert torch.is_tensor(x)
- mine = x.clone().detach().to(self.device)
- mine.requires_grad = mine.is_floating_point()
- loaded.append(mine)
- loaded = tuple(loaded)
- self.pipe_buffers['inputs'][buffer_id] = loaded
- if self.is_last_stage():
- loaded = batch[1]
- if torch.is_tensor(batch[1]):
- loaded = batch[1].to(self.device)
- elif isinstance(batch[1], tuple):
- loaded = []
- for x in batch[1]:
- assert torch.is_tensor(x)
- x = x.to(self.device).detach()
- loaded.append(x)
- loaded = tuple(loaded)
- self.pipe_buffers['labels'][buffer_id] = loaded
- if self.wall_clock_breakdown():
- self.timers('batch_input').stop()
- def _send_tensor_meta(self, buffer, recv_stage):
- """ Communicate metadata about upcoming p2p transfers.
- Metadata is communicated in this order:
- * type (0: tensor, 1: list)
- * num_tensors if type=list
- foreach tensor in buffer:
- * ndims
- * shape
- """
- send_bytes = 0
- if isinstance(buffer, torch.Tensor):
- type_tensor = torch.LongTensor(data=[0]).to(self.device)
- p2p.send(type_tensor, recv_stage)
- send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
- send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
- p2p.send(send_ndims, recv_stage)
- p2p.send(send_shape, recv_stage)
- send_bytes += _tensor_bytes(buffer)
- elif isinstance(buffer, list):
- assert (False)
- type_tensor = torch.LongTensor(data=[1]).to(self.device)
- p2p.send(type_tensor, recv_stage)
- count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
- p2p.send(count_tensor, recv_stage)
- for tensor in buffer:
- assert isinstance(tensor, torch.Tensor)
- send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
- send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
- p2p.send(send_ndims, recv_stage)
- p2p.send(send_shape, recv_stage)
- send_bytes += _tensor_bytes(tensor)
- elif isinstance(buffer, tuple):
- type_tensor = torch.LongTensor(data=[2]).to(self.device)
- p2p.send(type_tensor, recv_stage)
- count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
- p2p.send(count_tensor, recv_stage)
- for idx, tensor in enumerate(buffer):
- assert isinstance(tensor, torch.Tensor)
- send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
- send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
- send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(
- self.device)
- p2p.send(send_dtype, recv_stage)
- p2p.send(send_ndims, recv_stage)
- p2p.send(send_shape, recv_stage)
- # Useful for performance debugging.
- '''
- new_bytes = _tensor_bytes(tensor)
- send_bytes += _tensor_bytes(tensor)
- # Useful for performance debugging.
- if self.grid.data_parallel_id == 0:
- print(
- f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
- )
- '''
- else:
- raise NotImplementedError(f'Could not send meta type {type(buffer)}')
- # Useful for performance debugging.
- '''
- if self.grid.data_parallel_id == 0:
- print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB')
- '''
- def _recv_tensor_meta(self, send_stage):
- """Receive metadata about upcoming p2p transfers and return allocated buffers.
- Metadata is communicated in this order:
- * type (0: tensor, 1: list)
- * num_tensors if type=list
- foreach tensor in buffer:
- * ndims
- * shape
- Returns:
- Allocated buffer for receiving from send_stage.
- """
- type_tensor = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(type_tensor, send_stage)
- recv_type = type_tensor.item()
- # A single tensor will be sent.
- if recv_type == 0:
- recv_ndims = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(recv_ndims, send_stage)
- recv_ndims = recv_ndims.item()
- recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
- p2p.recv(recv_shape, send_stage)
- recv_shape = recv_shape.tolist()
- return self._allocate_buffer(recv_shape, num_buffers=1)[0]
- # List or tuple of tensors
- elif recv_type == 1 or recv_type == 2:
- count_tensor = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(count_tensor, send_stage)
- num_tensors = count_tensor.item()
- recv_shapes_and_dtypes = []
- for idx in range(num_tensors):
- recv_dtype = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(recv_dtype, send_stage)
- recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
- recv_ndims = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(recv_ndims, send_stage)
- recv_ndims = recv_ndims.item()
- recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
- p2p.recv(recv_shape, send_stage)
- recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))
- buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
- # Convert to tuples if requested.
- if recv_type == 2:
- buffers = tuple(buffers)
- return buffers
- else:
- raise NotImplementedError(f'Could not receive type {type(recv_type)}')
- def _exec_send_activations(self, buffer_id):
- if self.wall_clock_breakdown():
- self.timers('pipe_send_output').start()
- outputs = self.pipe_buffers['outputs'][buffer_id]
- # NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
- # We could do char, but with half() we can eventually flatten with other fp16
- # messages (TODO)
- if self.has_attention_mask or self.has_bool_tensors:
- outputs = list(outputs)
- outputs[-1] = outputs[-1].half()
- outputs = tuple(outputs)
- if self.first_output_send:
- self.first_output_send = False
- self._send_tensor_meta(outputs, self.next_stage)
- if isinstance(outputs, torch.Tensor):
- p2p.send(outputs, self.next_stage)
- elif isinstance(outputs, tuple):
- for idx, buffer in enumerate(outputs):
- p2p.send(buffer, self.next_stage)
- else:
- raise NotImplementedError('Could not send output of type '
- f'{type(outputs)}')
- # Restore the boolean tensor
- if self.has_attention_mask or self.has_bool_tensors:
- outputs = list(outputs)
- outputs[-1] = outputs[-1].bool()
- outputs = tuple(outputs)
- if self.wall_clock_breakdown():
- self.timers('pipe_send_output').stop()
- def _exec_send_grads(self, buffer_id):
- if self.wall_clock_breakdown():
- self.timers('pipe_send_grad').start()
- inputs = self.pipe_buffers['inputs'][buffer_id]
- # Partition the gradient
- if self.is_grad_partitioned:
- if isinstance(inputs, tuple):
- first_input = inputs[0]
- assert all([torch.is_tensor(elt) for elt in inputs[1:]])
- inputs_grad_tail = [
- elt.grad for elt in inputs[1:] if elt.grad is not None
- ]
- elif torch.is_tensor(inputs):
- first_input = inputs
- inputs_grad_tail = []
- else:
- raise ValueError("expecting a tensor or a tuple of tensors")
- assert torch.is_tensor(first_input)
- part = PartitionedTensor(tensor=first_input.grad,
- group=self.grid.get_slice_parallel_group())
- inputs = (part.to_meta(), part.data(), *inputs_grad_tail)
- # XXX Terrible hack
- # Drop the attention mask from the input buffer here. It does not have
- # a grad that needs to be communicated. We free the buffer immediately
- # after, so no need to restore it. The receiver also has a hack that skips
- # the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
- if self.has_attention_mask or self.has_bool_tensors:
- inputs = list(inputs)
- inputs.pop()
- inputs = tuple(inputs)
- if isinstance(inputs, torch.Tensor):
- assert inputs.grad is not None
- p2p.send(inputs.grad, self.prev_stage)
- else:
- # XXX terrible hacky branch
- if self.is_grad_partitioned:
- # First two sends are partitioned gradient
- p2p.send(inputs[0], self.prev_stage)
- p2p.send(inputs[1], self.prev_stage)
- else:
- for idx, buffer in enumerate(inputs):
- # Skip tensors that will not produce a grad
- if not buffer.is_floating_point():
- assert buffer.grad is None
- continue
- assert buffer.grad is not None
- p2p.send(buffer.grad, self.prev_stage)
- # We can free up the input buffer now
- self.pipe_buffers['inputs'][buffer_id] = None
- if self.wall_clock_breakdown():
- self.timers('pipe_send_grad').stop()
- def _exec_recv_activations(self, buffer_id):
- if self.wall_clock_breakdown():
- self.timers('pipe_recv_input').start()
- recvd = None
- # Allocate the buffer if necessary
- if self.pipe_recv_buf is None:
- self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)
- if isinstance(self.pipe_recv_buf, torch.Tensor):
- p2p.recv(self.pipe_recv_buf, self.prev_stage)
- recvd = self.pipe_recv_buf.clone().detach()
- recvd.requires_grad = recvd.is_floating_point()
- else:
- assert isinstance(self.pipe_recv_buf, tuple)
- recvd = [None] * len(self.pipe_recv_buf)
- for idx, buffer in enumerate(self.pipe_recv_buf):
- assert torch.is_tensor(buffer)
- # XXX hardcode meta type
- if self.is_pipe_partitioned and idx == 0 and buffer.dtype != torch.long:
- if self.meta_buffer is None:
- self.meta_buffer = torch.zeros(buffer.size(),
- dtype=torch.long,
- device=self.device)
- buffer = self.meta_buffer
- p2p.recv(buffer, self.prev_stage)
- recvd[idx] = buffer.clone().detach()
- # NCCL does not like to send torch.BoolTensor types, so un-cast the
- # attention mask
- if self.has_attention_mask or self.has_bool_tensors:
- recvd[-1] = recvd[-1].bool()
- recvd = tuple(recvd)
- for buffer in recvd:
- buffer.requires_grad = buffer.is_floating_point()
- self.pipe_buffers['inputs'][buffer_id] = recvd
- if self.wall_clock_breakdown():
- self.timers('pipe_recv_input').stop()
- def _exec_recv_grads(self, buffer_id):
- if self.wall_clock_breakdown():
- self.timers('pipe_recv_grad').start()
- outputs = self.pipe_buffers['outputs'][buffer_id]
- # XXX these shapes are hardcoded for Megatron
- # Restore partitioned output if it was partitioned and we are sending full gradients
- if self.is_pipe_partitioned and not self.is_grad_partitioned:
- part_output = PartitionedTensor.from_meta(
- meta=outputs[0],
- local_part=outputs[1],
- group=self.grid.get_slice_parallel_group())
- outputs[0].data = part_output.full()
- outputs = (outputs[0], *outputs[2:])
- # save for backward
- self.pipe_buffers['outputs'][buffer_id] = outputs
- # Allocate gradient if necessary
- if self.grad_layer is None:
- if isinstance(outputs, torch.Tensor):
- s = list(outputs.size())
- self.grad_layer = self._allocate_buffer(s,
- dtype=outputs.dtype,
- num_buffers=1)[0]
- else:
- # XXX This is a HACK
- # When we exchange activations/gradients, the two pipe stages
- # need to issue the send/recv with the same buffer sizes or
- # else there is a deadlock. The is_floating_point() filter is
- # used to avoid sending gradients for tensors that do not
- # produce gradients. When TP>1, we partition the first
- # activations/gradients across TP ranks to save communication
- # volume and memory. That partitioned tensor is represented as
- # two tensors: a 1/TPth chunk of the original data and also a
- # small LongTensor storing the metadata used to reconstruct on
- # the other side. When combined, the floating point filter also
- # filtered out the metadata tensor. This quick (hacky) fix just
- # branches on is_grad_partitioned so we don't filter out the
- # metadata tensor.
- if self.is_grad_partitioned:
- sizes_and_dtypes = [
- (list(t.size()),
- t.dtype) for t in outputs[:2]
- ] + [(list(t.size()),
- t.dtype) for t in outputs[2:] if t.is_floating_point()]
- else:
- sizes_and_dtypes = [(list(t.size()),
- t.dtype) for t in outputs
- if t.is_floating_point()]
- self.grad_layer = self._allocate_buffers(sizes_and_dtypes,
- num_buffers=1)[0]
- if isinstance(self.grad_layer, torch.Tensor):
- p2p.recv(self.grad_layer, self.next_stage)
- else:
- assert isinstance(outputs, tuple)
- for idx, buffer in enumerate(self.grad_layer):
- # XXX GPT-2 hack
- if self.is_grad_partitioned and idx == 0 and buffer.dtype != torch.long:
- buffer.data = torch.zeros(buffer.size(),
- dtype=torch.long,
- device=self.device)
- p2p.recv(buffer, self.next_stage)
- if self.wall_clock_breakdown():
- self.timers('pipe_recv_grad').stop()
- def _exec_optimizer_step(self, lr_kwargs=None):
- if self.wall_clock_breakdown():
- self.timers('step_microstep').start()
- self.timers('step').start()
- self.mem_status('BEFORE STEP', reset_max=True)
- self._force_grad_boundary = True
- self._take_model_step(lr_kwargs)
- self._force_grad_boundary = False
- self.mem_status('AFTER STEP')
- if self.global_rank == 0 and self.monitor.enabled:
- self.summary_events = [(f'Train/Samples/lr',
- self.get_lr()[0],
- self.global_samples)]
- if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
- self.summary_events.append((f'Train/Samples/loss_scale',
- self.optimizer.cur_scale,
- self.global_samples))
- self.monitor.write_events(self.summary_events)
- if self.wall_clock_breakdown():
- self.timers('step_microstep').stop()
- self.timers('step').stop()
- if self.global_steps % self.steps_per_print() == 0:
- self.timers.log([
- 'batch_input',
- 'forward_microstep',
- 'backward_microstep',
- 'backward_inner_microstep',
- 'backward_allreduce_microstep',
- 'backward_tied_allreduce_microstep',
- 'step_microstep'
- ])
- if self.global_steps % self.steps_per_print() == 0:
- self.timers.log([
- 'forward',
- 'backward',
- 'backward_inner',
- 'backward_allreduce',
- 'step'
- ])
- def _zero_grads(self, inputs):
- if isinstance(inputs, torch.Tensor):
- if inputs.grad is not None:
- inputs.grad.data.zero_()
- else:
- for t in inputs:
- if t.grad is not None:
- t.grad.data.zero_()
- def _allocate_zeros(self, shape, **kwargs):
- """ Allocate a tensor of zeros on the engine's device.
- Arguments:
- shape: the shape of the tensor to allocate
- kwargs: passed to torch.zeros()
- Returns:
- A tensor from torch.zeros() allocated on self.device.
- """
- if "dtype" not in kwargs:
- if self.fp16_enabled():
- kwargs["dtype"] = torch.half
- if self.bfloat16_enabled():
- kwargs["dtype"] = torch.bfloat16
- return torch.zeros(shape, device=self.device, **kwargs)
- def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
- buffers = []
- if num_buffers == -1:
- num_buffers = self.num_pipe_buffers
- for count in range(num_buffers):
- buffers.append(self._allocate_zeros(shape, **kwargs))
- return buffers
- def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1):
- buffers = []
- if num_buffers == -1:
- num_buffers = self.num_pipe_buffers
- for count in range(num_buffers):
- buffer = []
- for shape, dtype in shapes_and_dtypes:
- buffer.append(
- self._allocate_zeros(shape,
- dtype=dtype,
- requires_grad=requires_grad))
- buffers.append(buffer)
- return buffers
- def forward(self, *args, **kwargs):
- """Disabled for pipeline parallel training. See ``train_batch()``. """
- raise PipelineError("Only train_batch() is accessible in pipeline mode.")
- def backward(self, *args, **kwargs):
- """Disabled for pipeline parallel training. See ``train_batch()``. """
- raise PipelineError("Only train_batch() is accessible in pipeline mode.")
- def step(self, *args, **kwargs):
- """Disabled for pipeline parallel training. See ``train_batch()``. """
- raise PipelineError("Only train_batch() is accessible in pipeline mode.")
- def mem_status(self, msg, print_rank=-1, reset_max=False):
- return
- global mem_alloced, mem_cached
- if not self.global_steps == 0 or not self.global_steps == 9:
- #return
- pass
- if self.mpu.get_data_parallel_rank() != 0:
- return
- if self.global_rank != 0:
- return
- rank = self.global_rank
- if print_rank != -1 and rank != print_rank:
- return
- get_accelerator().synchronize()
- if reset_max:
- get_accelerator().reset_max_memory_cached()
- get_accelerator().reset_max_memory_allocated()
- new_alloced = get_accelerator().memory_allocated()
- new_cached = get_accelerator().memory_cached()
- delta_alloced = new_alloced - mem_alloced
- delta_cached = new_cached - mem_cached
- mem_cached = new_cached
- mem_alloced = new_alloced
- max_alloced = get_accelerator().max_memory_allocated()
- max_cached = get_accelerator().max_memory_cached()
- # convert to GB for printing
- new_alloced /= 1024**3
- new_cached /= 1024**3
- delta_alloced /= 1024**3
- delta_cached /= 1024**3
- max_alloced /= 1024**3
- max_cached /= 1024**3
- print(
- f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS',
- msg,
- f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
- f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)'
- )
- def module_state_dict(self):
- """Override hack to save a pipe model and return the directory path of the save.
- This method should only be called by DeepSpeed's ``save_checkpoint()``. The
- recommended way of saving a ``PipelineModule`` outside of ``save_checkpoint()``
- is ``save_state_dict()``.
- Returns:
- None
- """
- assert isinstance(self.module, PipelineModule)
- assert self._curr_ckpt_path is not None, \
- "PipelineEngine expects module_state_dict() to be called from save_checkpoint()"
- self.module.save_state_dict(self._curr_ckpt_path,
- checkpoint_engine=self.checkpoint_engine)
- return None
- def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None):
- """Override hack to instead use a directory path.
- This is important because pipeline models checkpoint by layer instead of rank.
- If ``state_dict`` is not ``None`` or a ``str``, we revert to ``super()`` expecting a ``dict``.
- Args:
- state_dict (str, None): unused
- strict (bool, optional): Strict state loading. Defaults to True.
- """
- assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism"
- if (state_dict is not None) and (not isinstance(state_dict, str)):
- super().load_module_state_dict(state_dict, strict)
- return
- self.module.load_state_dir(load_dir=self._curr_ckpt_path,
- strict=strict,
- checkpoint_engine=self.checkpoint_engine)
- # A map of PipeInstruction types to methods. Each method will be executed with the
- # kwargs provided to the PipeInstruction from the scheduler.
- _INSTRUCTION_MAP = {
- schedule.OptimizerStep: _exec_optimizer_step,
- schedule.ReduceGrads: _exec_reduce_grads,
- schedule.ReduceTiedGrads: _exec_reduce_tied_grads,
- schedule.LoadMicroBatch: _exec_load_micro_batch,
- schedule.ForwardPass: _exec_forward_pass,
- schedule.BackwardPass: _exec_backward_pass,
- schedule.SendActivation: _exec_send_activations,
- schedule.RecvActivation: _exec_recv_activations,
- schedule.SendGrad: _exec_send_grads,
- schedule.RecvGrad: _exec_recv_grads,
- }
- def _exec_schedule(self, pipe_schedule):
- # Reserve and reset buffers.
- self._reserve_pipe_buffers(pipe_schedule.num_pipe_buffers())
- self.fwd_outputs = []
- # For each step in the schedule
- for step_cmds in pipe_schedule:
- # For each instruction in the step
- for cmd in step_cmds:
- if type(cmd) not in self._INSTRUCTION_MAP:
- raise RuntimeError(
- f'{self.__class__.__name__} does not understand instruction {repr(cmd)}'
- )
- # Equivalent to: self._exec_forward_pass(buffer_id=0)
- self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
- self._exec_instr(**cmd.kwargs)
|