123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411 |
- """
- "Copyright 2020 The Microsoft DeepSpeed Team.
- Licensed under the MIT license.
- """
- import sys
- import os
- from collections import defaultdict, OrderedDict
- import itertools
- import torch
- from torch.distributed.distributed_c10d import _get_global_rank
- import torch.distributed as dist
- import math
- from torch._six import inf
- from torch.autograd import Variable
- from deepspeed.utils.logging import logger
- from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
- from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim
- from deepspeed.runtime.zero.partition_parameters import *
- from deepspeed.runtime.zero.partition_parameters import _init_external_params
- from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS
- from deepspeed.ops.adam import DeepSpeedCPUAdam
- from deepspeed.ops.op_builder import UtilsBuilder
- from deepspeed.runtime.zero.offload_constants import *
- from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
- from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper
- from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper
- from deepspeed.runtime.constants import OPTIMIZER_STATE_DICT
- # Toggle this to true to enable correctness test
- # with gradient partitioning and without
- pg_correctness_test = False
- FWD_MODULE_STACK = list()
- from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id, debug_param2name_id_numel, debug_param2name_id_shape_device, debug_module2name_class, printflock, log_rank_file
- def print_rank_0(message, debug=False, force=False):
- rank = torch.distributed.get_rank()
- if rank == 0 and (debug or force):
- print(message)
- # other variations
- # - print for all ranks w/o interleaving
- # printflock(f"[{rank}] {message}")
- # - print to log file per rank
- # log_rank_file(rank, message)
- def input(msg):
- return
- def split_half_float_double(tensors):
- dtypes = [
- "torch.cuda.HalfTensor",
- "torch.cuda.FloatTensor",
- "torch.cuda.DoubleTensor"
- ]
- buckets = []
- for i, dtype in enumerate(dtypes):
- bucket = [t for t in tensors if t.type() == dtype]
- if bucket:
- buckets.append(bucket)
- return buckets
- def isclose(a, b, rtol=1e-09, atol=0.0):
- return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
- def lcm(x, y):
- from fractions import gcd # or can import gcd from `math` in Python 3
- return x * y // gcd(x, y)
- def move_to_cpu(tensor_list):
- for tensor in tensor_list:
- tensor.data = tensor.data.cpu()
- def get_all_parameters(sub_module, recurse=False):
- return itertools.chain(sub_module.named_parameters(recurse=recurse),
- sub_module.ds_external_parameters())
- #apply torch.autograd.Function that calls a backward_function to tensors in output
- def _apply_to_tensors_only(module, functional, backward_function, outputs):
- if type(outputs) is tuple:
- touched_outputs = []
- for output in outputs:
- touched_output = _apply_to_tensors_only(module,
- functional,
- backward_function,
- output)
- touched_outputs.append(touched_output)
- return tuple(touched_outputs)
- elif type(outputs) is torch.Tensor:
- return functional.apply(module, backward_function, outputs)
- else:
- return outputs
- #for each tensor in outputs run the forward_function and register backward_function as hook
- def _apply_forward_and_backward_to_tensors_only(module,
- forward_function,
- backward_function,
- outputs):
- if type(outputs) is tuple:
- touched_outputs = []
- for output in outputs:
- touched_output = _apply_forward_and_backward_to_tensors_only(
- module,
- forward_function,
- backward_function,
- output)
- touched_outputs.append(touched_output)
- return tuple(touched_outputs)
- elif type(outputs) is torch.Tensor:
- forward_function(outputs)
- if outputs.requires_grad:
- outputs.register_hook(backward_function)
- return outputs
- else:
- return outputs
- class ZeROOrderedDict(OrderedDict):
- def __init__(self, parent_module, *args, **kwargs):
- """A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
- Args:
- parent_module (``collections.OrderedDict``): the collection to replace
- """
- super().__init__(*args, **kwargs)
- self._parent_module = parent_module
- self._in_forward = False
- def __getitem__(self, key):
- param = super().__getitem__(key)
- # Params can be registered as None (e.g., bias)
- if param is None:
- return param
- if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
- if self._parent_module._parameters._in_forward:
- print_rank_0(f'Registering external parameter from getter {key}',
- force=False)
- register_external_parameter(FWD_MODULE_STACK[-1], param)
- param.all_gather()
- return param
- def _inject_parameters(module, cls):
- for module in module.modules():
- if cls == ZeROOrderedDict:
- new_param = cls(parent_module=module)
- else:
- new_param = cls()
- for key, param in module._parameters.items():
- new_param[key] = param
- module._parameters = new_param
- # TODO Needs to be implemented
- class PrefetchCoordinator(object):
- def __init__(self):
- # step_id keeps track of the number of sub-modules invoked so far
- # the step_id is tracking forward and backward sequence of sub-modules
- self.step_id = 0
- # stores the sequence of sub modules in forward+backward pass
- self.sub_module_trace = []
- # maps sub_module id to submodule objects
- self.id_to_sub_module_map = {}
- # stores the total number of parameters in each sub_module
- self.id_to_sub_module_size_map = {}
- self.trace_completed = False
- self.most_recent_sub_module_step = {}
- # reuse distances
- self.reuse_numel_for_step_id = {}
- def record_trace(self, sub_module):
- if not self.trace_completed:
- self.sub_module_trace.append(sub_module.id)
- self.id_to_sub_module_map[sub_module.id] = sub_module
- def print_trace(self):
- print_rank_0(
- f"The module trace is : {[self.id_to_sub_module_map[module_id].id for module_id in self.sub_module_trace]}"
- )
- def increment_step(self, sub_module):
- self.most_recent_sub_module_step[sub_module.id] = self.step_id
- self.step_id += 1
- def reset_step(self):
- self.step_id = 0
- # returns the next numel parameters that will be used next but are not available or inflight
- def get_params_to_prefetch(self, sub_module, numel=2000000):
- # numel_in_sub_module = 0
- # for name, param in sub_module.named_parameters(recurse=False):
- # numel_in_sub_module += param.ds_numel
- # #if numel_in_sub_module < (numel // 2):
- # return []
- # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing
- if sub_module.id != self.sub_module_trace[self.step_id]:
- print_rank_0(
- f"Tracing failed. Prefetching is disabled at sub-module: {debug_module2name_id(sub_module)}"
- )
- return []
- params_to_prefetch = []
- total_numel_to_prefetch = 0
- for i in range(self.step_id, len(self.sub_module_trace)):
- module_id = self.sub_module_trace[i]
- for _, param in get_all_parameters(self.id_to_sub_module_map[module_id]):
- if param.ds_status is ZeroParamStatus.NOT_AVAILABLE and (
- param.ds_id not in [p.ds_id for p in params_to_prefetch]):
- params_to_prefetch.append(param)
- total_numel_to_prefetch += param.ds_numel
- #print_rank_0(f"Total numel to prefetch: {total_numel_to_prefetch}. Param: {param.ds_shape} and numel {param.ds_numel}, numel limit {numel}")
- if total_numel_to_prefetch >= numel: # and total_numel_to_prefetch > (numel_in_sub_module // 2):
- return params_to_prefetch
- return params_to_prefetch
- # checks if this sub_module will be used again and if so then returns the number of elements
- # in the parameters used between this sub_module and the reuse of this sub_module
- def get_reuse_distance_in_numel(self, sub_module, sub_module_step_id=None):
- #assert is_forward is not None, "is_forward must be set to True for Forward Propagation and False for backward Propagation"
- is_there_reuse = False
- reuse_distance_in_numel = 1000000000000
- # set the appropriate trace
- trace = self.sub_module_trace
- total_steps = len(trace)
- if sub_module_step_id is None:
- sub_module_step_id = self.most_recent_sub_module_step[sub_module.id]
- # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing
- if sub_module.id != trace[sub_module_step_id]:
- print_rank_0(
- f"Tracing failed. Cannot tell if the sub_module: {sub_module.id} is reused"
- )
- return reuse_distance_in_numel
- # return cached value
- if sub_module_step_id in self.reuse_numel_for_step_id:
- return self.reuse_numel_for_step_id[sub_module_step_id]
- start_step = self.step_id
- print_rank_0(f"Step id is {self.step_id} ")
- for step_id in range(start_step, total_steps):
- print_rank_0(f"Trace id {trace[step_id]} and sub_module id {sub_module.id}")
- if sub_module.id == trace[step_id]:
- end_step = step_id
- is_there_reuse = True
- reuse_distance_in_numel = self._distance_in_numel(
- start_step,
- end_step,
- trace)
- break
- self.reuse_numel_for_step_id[sub_module_step_id] = reuse_distance_in_numel
- return reuse_distance_in_numel
- def _distance_in_numel(self, start_step, end_step, trace):
- distance_in_numel = 0
- for step_id in range(start_step, end_step):
- module_id = trace[step_id]
- for _, param in self.id_to_sub_module_map[module_id].named_parameters(recurse=False):
- distance_in_numel += param.ds_numel
- for _, param in self.id_to_sub_module_map[module_id].ds_external_parameters():
- distance_in_numel += param.ds_numel
- return distance_in_numel
- class PartitionedParameterCoordinator(object):
- def __init__(self,
- comm_stream=None,
- max_reuse_distance_in_numel=500000000,
- max_available_parameters_in_numel=700000000):
- self.in_flight_handles = []
- self.params_in_flight = []
- self.comm_stream = comm_stream if comm_stream is not None else torch.cuda.current_stream(
- )
- self.prefetch_coordinator = PrefetchCoordinator()
- self.hierarchy = 0
- self.total_available_parameter_numel = 0
- self.max_available_parameters_in_numel = max_available_parameters_in_numel
- # max distance between two use of the module beyond which module is released
- self.max_reuse_distance_in_numel = max_reuse_distance_in_numel
- def _increment_available_parameter_numel(self, increment):
- self.total_available_parameter_numel += increment
- def _decrement_available_parameter_numel(self, decrement):
- self.total_available_parameter_numel -= decrement
- '''-----------------------Tracing and Prefetching ---------------'''
- def record_trace(self, sub_module):
- self.prefetch_coordinator.record_trace(sub_module)
- def finish_tracing(self, print_trace=False):
- self.prefetch_coordinator.trace_completed = True
- if print_trace:
- self.prefetch_coordinator.print_trace()
- #swap in parameter partitions from nvme for those parameters that will be used
- # after the ones that are already being prefetched into full parameters
- def _prefetch_nvme_param_partitions(self, sub_module, params_in_flight):
- numel_in_flight = sum([param.ds_tensor.ds_numel for param in params_in_flight])
- upcoming_param_list = self.prefetch_coordinator.get_params_to_prefetch(
- sub_module,
- numel=2 * numel_in_flight)
- swap_in_params = []
- for param in upcoming_param_list:
- if len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers():
- break
- if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
- swap_in_params.append(param)
- if len(swap_in_params) > 0:
- swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True)
- # Pre fetches the parameters for sub_modules that comes after
- # the current sub_module. This call is asynchronous
- def prefetch_next_sub_modules(self, sub_module, numel=5000000, nvme=False):
- params_to_prefetch = []
- if not self.prefetch_coordinator.trace_completed:
- return params_to_prefetch
- # prefetch if there is no current prefetching in flight
- if not self.in_flight_handles and self.total_available_parameter_numel < self.max_available_parameters_in_numel:
- params_to_prefetch = self.prefetch_coordinator.get_params_to_prefetch(
- sub_module,
- numel=numel)
- self._all_gather(params_to_prefetch, async_op=True)
- for param in params_to_prefetch:
- param.ds_status = ZeroParamStatus.INFLIGHT
- # keeping track of number of elements consumed by available parameters
- self._increment_available_parameter_numel(param.ds_numel)
- if nvme:
- self._prefetch_nvme_param_partitions(sub_module, params_to_prefetch)
- self._print_prefetch_elements_info(sub_module, params_to_prefetch)
- print_rank_0(
- f"{'--' * self.hierarchy}--PreFetching parameters {[param.ds_id for param in params_to_prefetch]} and available {self.total_available_parameter_numel}, max limit {self.max_available_parameters_in_numel}",
- force=False)
- def _print_prefetch_elements_info(self, sub_module, params_to_prefetch):
- sub_module_numel = 0.0
- for name, param in sub_module.named_parameters(recurse=False):
- sub_module_numel += param.ds_numel
- numel_being_prefetched = 0
- for param in params_to_prefetch:
- numel_being_prefetched = param.ds_numel
- print_rank_0(
- f"{'--' * self.hierarchy}--PreFetching {numel_being_prefetched} numels and number of numel in the next sub module is {sub_module_numel}",
- force=False)
- def increment_step(self, sub_module):
- self.prefetch_coordinator.increment_step(sub_module)
- def reset_step(self):
- self.prefetch_coordinator.reset_step()
- '''----------------------------------------------------------------------'''
- # Fetches the parameters in the sub_module
- # This call is blocking
- def fetch_sub_module(self, sub_module):
- partitioned_params = []
- params_in_flight = False
- print_rank_0(
- f"{'--' * self.hierarchy}Fetching params in module {debug_module2name_class(sub_module)}"
- )
- params_to_fetch = [
- param for _,
- param in sub_module.named_parameters(recurse=False)
- ]
- # print([n for n,p in sub_module.named_parameters(recurse=False)])
- if hasattr(sub_module, 'ds_external_parameters'):
- print_rank_0(
- f"{'--' * self.hierarchy}--Fetching external parameters {sub_module.ds_external_parameters()}"
- )
- params_to_fetch += [
- param for _,
- param in sub_module.ds_external_parameters()
- ]
- # for _, param in sub_module.named_parameters(recurse=False):
- for param in params_to_fetch:
- param.ds_active_sub_modules += 1
- print_rank_0(
- f"{'--' * self.hierarchy}--Fetching parameters {debug_param2name_id_shape(param)} with active sub modules {param.ds_active_sub_modules}"
- )
- if param.ds_status == ZeroParamStatus.AVAILABLE:
- print_rank_0(
- f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is already available"
- )
- if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
- print_rank_0(
- f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is being fetched"
- )
- partitioned_params.append(param)
- # keeping track of number of elements consumed by available parameters
- self._increment_available_parameter_numel(param.ds_numel)
- print_rank_0(f"Incrementing with parameter id {param.ds_id}")
- if param.ds_status == ZeroParamStatus.INFLIGHT:
- params_in_flight = True
- print_rank_0(
- f"{'--' * self.hierarchy}--Parameters {debug_param2name_id(param)} is already in flight (prefetched)"
- )
- self.hierarchy += 1
- # parameters are partitioned and need to be allgathered
- self._all_gather(partitioned_params, async_op=False)
- # parameters are inflight and communication needs to be completed
- if partitioned_params or params_in_flight:
- self._synchronize_communication()
- for _, param in sub_module.named_parameters(recurse=False):
- param.ds_status = ZeroParamStatus.AVAILABLE
- print_rank_0(
- f"Param {debug_param2name_id_shape_device(param)} norm={param.norm()}",
- force=False)
- #print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}")
- def release_sub_module(self, sub_module):
- self.hierarchy -= 1
- print_rank_0(
- f"{'--' * self.hierarchy}Releasing params in module {debug_module2name_class(sub_module)}"
- )
- params_to_release = [
- param for _,
- param in sub_module.named_parameters(recurse=False)
- ]
- if hasattr(sub_module, 'ds_external_parameters'):
- #print_rank_0(f"Releasing external parameters {sub_module.ds_external_parameters()}")
- params_to_release += [
- param for _,
- param in sub_module.ds_external_parameters()
- ]
- # for _, param in sub_module.named_parameters(recurse=False):
- for param in params_to_release:
- param.ds_active_sub_modules -= 1
- if not param.ds_active_sub_modules and not self._keep_for_later(
- sub_module) and not param.ds_persist:
- print_rank_0(
- f"{'--' * self.hierarchy}--Releasing parameter {debug_param2name_id_numel(param)} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}",
- force=False)
- # Keeping track of number of elements that are consumed by available parameters
- self._decrement_available_parameter_numel(param.ds_numel)
- see_memory_usage(
- f"Before releasing param {debug_param2name_id_numel(param)}",
- force=False)
- param.partition(hierarchy=self.hierarchy)
- see_memory_usage(
- f"After releasing param {debug_param2name_id_numel(param)}",
- force=False)
- param.ds_status = ZeroParamStatus.NOT_AVAILABLE
- else:
- print_rank_0(
- f"{'--' * self.hierarchy}--Did not release param {debug_param2name_id_numel(param)} with active sub modules {param.ds_active_sub_modules}, keep for later={self._keep_for_later(sub_module)} and persistence={param.ds_persist}",
- force=False)
- def release_and_reset_parameter(self, param):
- param.ds_active_sub_modules = 0
- if param.ds_status == ZeroParamStatus.AVAILABLE:
- print_rank_0(
- f"Releasing unpartitioned param {debug_param2name_id_numel(param)} active sub-modules {param.ds_active_sub_modules} and persistence {param.ds_persist}"
- )
- self._decrement_available_parameter_numel(param.ds_numel)
- param.partition()
- def _keep_for_later(self, sub_module):
- if not self.prefetch_coordinator.trace_completed:
- return False
- if self.max_reuse_distance_in_numel == 0:
- return False
- reuse_distance_in_numel = self.prefetch_coordinator.get_reuse_distance_in_numel(
- sub_module)
- #print_rank_0(f"Reuse distance and numel for sub_module id {sub_module.id} is {reuse_distance_in_numel}")
- return reuse_distance_in_numel < self.max_reuse_distance_in_numel
- def _all_gather(self, partitioned_params, async_op=False):
- with torch.cuda.stream(self.comm_stream):
- handles = partitioned_params[0].all_gather(
- param_list=partitioned_params,
- async_op=async_op,
- hierarchy=self.hierarchy) if partitioned_params else None
- if handles is not None:
- self.in_flight_handles.extend(handles)
- self.params_in_flight.extend(partitioned_params)
- def _synchronize_communication(self, synchronize_streams=True):
- assert len(self.params_in_flight) == len(self.in_flight_handles)
- for handle, param in zip(self.in_flight_handles, self.params_in_flight):
- if handle is not None:
- with torch.cuda.stream(self.comm_stream):
- handle.wait()
- param.ds_status = ZeroParamStatus.AVAILABLE
- self.comm_stream.synchronize()
- torch.cuda.synchronize() if synchronize_streams else None
- self.in_flight_handles = []
- self.params_in_flight = []
- class PreBackwardFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, module, pre_backward_function, outputs):
- ctx.module = module
- ctx.pre_backward_function = pre_backward_function
- if not hasattr(module, "applied_pre_backward_ref_cnt"):
- module.applied_pre_backward_ref_cnt = 0
- module.applied_pre_backward_ref_cnt += 1
- #print(f"After Forward: {ctx.module.__class__.__name__}")
- outputs = outputs.detach()
- return outputs
- @staticmethod
- def backward(ctx, *args):
- #print(f"Before Backward: {ctx.module.__class__.__name__}")
- ctx.pre_backward_function(ctx.module)
- return (None, None) + args
- class PostBackwardFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, module, pre_backward_function, output):
- ctx.module = module
- if output.requires_grad:
- #TODO SOME TIMES post backward does not seem to be triggered debug in detail
- #Should only cause increase in memory not correctness issue
- #if output.grad_fn.__class__.__name__ == 'ViewBackward':
- # ctx.view=True
- # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
- #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
- #if module.ds_grads_remaining == 0:
- # print(f"Before Forward: {ctx.module.__class__.__name__}")
- module.ds_grads_remaining += 1
- ctx.pre_backward_function = pre_backward_function
- output = output.detach()
- return output
- @staticmethod
- def backward(ctx, *args):
- ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
- if ctx.module.ds_grads_remaining == 0:
- ctx.pre_backward_function(ctx.module)
- #print(f"After Backward: {ctx.module.__class__.__name__}")
- return (None, None) + args
- INITIAL_MICRO_STEP_ID = -1
- class DeepSpeedZeroOptimizer_Stage3(object):
- """
- DeepSpeedZeroOptimizer designed to reduce the memory footprint
- required for training large deep learning models.
- For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
- https://arxiv.org/abs/1910.02054
- For usage examples, refer to TODO: DeepSpeed Tutorial
- """
- def __init__(self,
- module,
- init_optimizer,
- timers,
- ds_config,
- static_loss_scale=1.0,
- dynamic_loss_scale=False,
- dynamic_loss_args=None,
- verbose=True,
- contiguous_gradients=True,
- reduce_bucket_size=500000000,
- prefetch_bucket_size=50000000,
- max_reuse_distance=1000000000,
- max_live_parameters=1000000000,
- param_persistence_threshold=100000,
- dp_process_group=None,
- reduce_scatter=True,
- overlap_comm=False,
- offload_optimizer_config=None,
- offload_param_config=None,
- sub_group_size=1000000000000,
- mpu=None,
- clip_grad=0.0,
- communication_data_type=torch.float16,
- postscale_gradients=True,
- gradient_predivide_factor=1.0,
- gradient_accumulation_steps=1,
- elastic_checkpoint=False,
- aio_config=None):
- see_memory_usage("Stage 3 initialize beginning", force=False)
- if dist.get_rank() == 0:
- logger.info(f"Reduce bucket size {reduce_bucket_size}")
- logger.info(f"Allgather bucket size {prefetch_bucket_size}")
- # The fused optimizer does all the work. We need this layer for two reason:
- # 1. maintain same user API from apex.fp16_utils
- # 2. keep common stuff here in case we need to add ne552w fused optimizer later
- # differences from apex.fp16_utils:
- # - assume all model params in fp16
- # - assume all params requires grad
- # - flat by groups, not keeping state. TODO: remove state explicitly?
- # - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
- if not torch.cuda.is_available:
- raise SystemError("Cannot use fp16 without CUDA.")
- self.optimizer = init_optimizer
- # Load pre-built or JIT compile (un)flatten ops
- util_ops = UtilsBuilder().load()
- self.flatten = util_ops.flatten
- self.unflatten = util_ops.unflatten
- self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
- self._global_grad_norm = 0.
- self.optimizer_swapper = None
- self.swap_optimizer = False
- self.offload_optimizer = False
- self.offload_optimizer_pin_memory = False
- self.offload_optimizer_fast_init = False
- self.offload_param = False
- self.offload_param_pin_memory = False
- self.params_in_nvme_and_cpu = False
- self.max_params_in_cpu = 0
- self._configure_offloading(offload_optimizer_config, offload_param_config)
- self._convert_to_zero_parameters(ds_config, module, mpu)
- for m in module.modules():
- _init_external_params(m)
- self.module = module
- self.elastic_checkpoint = elastic_checkpoint
- self.overlap_comm = overlap_comm
- # Replace ._parameters with a new class to enable auto-registration of
- # external parameters
- _inject_parameters(module, ZeROOrderedDict)
- if self.overlap_comm:
- self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda()
- self.deepspeed_adam_offload = (self.offload_optimizer
- and type(init_optimizer) == DeepSpeedCPUAdam)
- self.device = torch.cuda.current_device(
- ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE
- ############################################################################
- see_memory_usage("Before Partitioned Parameter Coordinator", force=False)
- fetch_stream = torch.cuda.Stream() if self.overlap_comm else None
- self.param_coordinator = PartitionedParameterCoordinator(
- comm_stream=fetch_stream,
- max_reuse_distance_in_numel=int(max_reuse_distance),
- max_available_parameters_in_numel=int(max_live_parameters))
- see_memory_usage("After Partitioned Parameter Coordinator", force=False)
- #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream())
- #-------------Stage 3 Setup-------------------#
- # parameters smaller than the threshold will be collectively gathered at the
- # end of the optimizer step and will be kept till the end of the backward pass
- # TODO maybe worth just replicating these parameters and doing all reduce for them
- self.persistence_threshold = int(param_persistence_threshold)
- self.persistent_parameters = self.persistent_parameters()
- self.setup_zero_stage3_hooks()
- #resetting ds_tensor just in case parameters have been changed after initialization
- #example .half() or .to()
- #self.reset_ds_tensor()
- #---------------------------------------------#
- self.timers = timers
- self.reduce_scatter = reduce_scatter
- self.dp_process_group = dp_process_group
- self.partition_count = dist.get_world_size(group=self.dp_process_group)
- if mpu is None:
- self.model_parallel_group = None
- self.model_parallel_rank = 0
- else:
- self.model_parallel_group = mpu.get_model_parallel_group()
- self.model_parallel_rank = mpu.get_model_parallel_rank()
- self.overflow = False
- self.clip_grad = clip_grad
- self.communication_data_type = communication_data_type
- self.gradient_predivide_factor = gradient_predivide_factor
- self.postscale_gradients = postscale_gradients
- self.gradient_accumulation_steps = gradient_accumulation_steps
- self.micro_step_id = INITIAL_MICRO_STEP_ID
- if self.reduce_scatter:
- assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-3 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
- assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
- assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
- # Holds the mode parameter
- # The param.data may not hold any meaningful data
- # when param's status is NOT_AVAILABLE or IN_FLGHT
- self.fp16_groups = []
- # Hold partitioned parameters
- self.fp16_partitioned_groups = []
- # Holds a fused and flattened copy of the parameters
- self.fp16_partitioned_groups_flat = []
- self.fp16_partitioned_groups_flat_numel = []
- #defragmented pinned memory
- self.param_groups_fp16_flat_cpu_memory = []
- #a single 32-bit partition of the parallel partitioned parameters
- #that this process will update
- self.fp32_partitioned_groups_flat = []
- self.next_swappable_fp32_partitioned_groups = []
- # number of elements per partition in each group
- self.partition_size = []
- self.all_reduce_print = False
- self.prefetch_elements = int(prefetch_bucket_size)
- # padding on each partition for alignment purposes
- self.groups_padding = []
- self.sub_group_size = sub_group_size
- self.sub_group_to_group_id = {}
- see_memory_usage("Before creating fp16 partitions", force=False)
- self._create_fp16_partitions_with_defragmentation()
- num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
- see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}",
- force=False)
- # Optimizer tensor swapping
- if self.swap_optimizer:
- self._configure_tensor_swapping(offload_optimizer_config, aio_config)
- see_memory_usage("Before creating fp32 partitions", force=False)
- if not isinstance(self.optimizer, DummyOptim):
- self._create_fp32_partitions()
- see_memory_usage("After creating fp32 partitions", force=False)
- dist.barrier()
- # To support pipelined optimizer swapping
- if not isinstance(init_optimizer, DummyOptim):
- self._create_next_swappable_fp32_groups()
- see_memory_usage("Before initializing optimizer states", force=False)
- if not isinstance(init_optimizer, DummyOptim):
- self.initialize_optimizer_states()
- see_memory_usage("After initializing optimizer states", force=False)
- dist.barrier()
- if dist.get_rank() == 0:
- logger.info(f"optimizer state initialized")
- self.reduce_bucket_size = int(reduce_bucket_size)
- self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
- self.reduction_stream = torch.cuda.Stream(
- ) if self.overlap_comm else torch.cuda.current_stream()
- self.callback_queued = False
- self.copy_grad_stream = torch.cuda.Stream()
- self.param_dict = {}
- # map between param_id and bool to specify if a param is in this partition
- self.is_param_in_current_partition = {}
- self.contiguous_gradients = contiguous_gradients
- self.extra_large_param_to_reduce = None
- self.grads_in_ipg_bucket = []
- self.params_in_ipg_bucket = []
- self.elements_in_ipg_bucket = 0
- self.params_already_reduced = []
- self.is_gradient_accumulation_boundary = True
- self._release_ipg_buffers()
- self.previous_reduced_grads = None
- # simplified param id
- self.param_id = {}
- count = 0
- for i, params_group in enumerate(self.fp16_groups):
- for param in params_group:
- unique_id = id(param)
- self.param_id[unique_id] = count
- self.param_dict[count] = param
- self.params_already_reduced.append(False)
- count = count + 1
- #Largest partitioned param
- largest_partitioned_param_numel = max([
- max([
- max(tensor.numel(),
- tensor.ds_numel) for tensor in fp16_partitioned_group
- ]) for fp16_partitioned_group in self.fp16_partitioned_groups
- ])
- print_rank_0(
- f'Largest partitioned param numel = {largest_partitioned_param_numel}',
- force=False)
- see_memory_usage(f"Before Set Grad positions", force=False)
- self.grad_position = {}
- self.set_grad_positions()
- see_memory_usage(f"Before CPU Offload initialization", force=False)
- self.grads_in_partition = None
- if self.offload_optimizer:
- self.accumulated_grads_in_cpu = {}
- self.norm_for_param_grads = {}
- self.local_overflow = False
- self.temp_grad_buffer_for_gpu_offload = torch.zeros(
- largest_partitioned_param_numel,
- device=torch.cuda.current_device(),
- dtype=self.dtype)
- self.temp_grad_gpu_buffer = torch.zeros(largest_partitioned_param_numel,
- device=torch.cuda.current_device(),
- dtype=self.dtype)
- see_memory_usage(f"After CPU Offload initialization", force=False)
- # stores if a partition has been reduced in this step
- self.is_partition_reduced = {}
- # stores if a grad in a partition has been computed or not
- self.is_grad_computed = {}
- # will store the averaged gradients required by this paritition
- self.averaged_gradients = {}
- #creates backward hooks for gradient partitioning
- self.create_reduce_and_remove_grad_hooks()
- #exit(0)
- # we may have a way of fusing dynamic scale. Do not support for now
- if self.dtype == torch.float or not dynamic_loss_scale:
- loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale
- self.dynamic_loss_scale = False
- self.loss_scaler = LossScaler(scale=loss_scale_value)
- cur_iter = 0
- else:
- if dynamic_loss_args is None:
- self.loss_scaler = DynamicLossScaler()
- else:
- self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
- self.dynamic_loss_scale = True
- self.debug_fp16_grads = [{} for _ in self.fp16_groups]
- if dist.get_rank(group=self.dp_process_group) == 0:
- see_memory_usage(f"After initializing ZeRO optimizer", force=False)
- def _configure_offloading(self, offload_optimizer_config, offload_param_config):
- ###################### offload optimizer setup ##################################
- if offload_optimizer_config is not None:
- self.offload_optimizer = True
- self.offload_optimizer_pin_memory = offload_optimizer_config[
- OFFLOAD_OPTIMIZER_PIN_MEMORY]
- self.swap_optimizer = offload_optimizer_config[
- OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE
- self.offload_optimizer_fast_init = offload_optimizer_config[
- OFFLOAD_OPTIMIZER_FAST_INIT]
- ###################### offload param setup ##################################
- if offload_param_config is not None:
- if not isinstance(self.optimizer, DummyOptim):
- assert self.offload_optimizer, "parameter offload is only available with optimizer state offload"
- self.offload_param = True
- self.offload_param_pin_memory = offload_param_config[
- OFFLOAD_PARAM_PIN_MEMORY]
- self.params_in_nvme_and_cpu = offload_param_config[
- OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE
- self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU]
- print_rank_0(
- f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}",
- force=False)
- def _convert_to_zero_parameters(self, ds_config, module, mpu):
- non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
- if non_zero_params:
- zero_params = [p for p in module.parameters() if is_zero_param(p)]
- if zero_params:
- zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
- else:
- group = None
- if mpu:
- group = mpu.get_data_parallel_group()
- if self.params_in_nvme_and_cpu:
- remote_device = OFFLOAD_NVME_DEVICE
- elif self.offload_param:
- remote_device = OFFLOAD_CPU_DEVICE
- else:
- remote_device = None
- Init(module=module,
- data_parallel_group=group,
- dtype=self.dtype,
- config_dict_or_path=ds_config,
- remote_device=remote_device,
- pin_memory=self.offload_param_pin_memory,
- mpu=mpu)
- def _configure_tensor_swapping(self, offload_optimizer_config, aio_config):
- nvme_swap_folder = os.path.join(
- offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH],
- 'zero_stage_3')
- os.makedirs(nvme_swap_folder, exist_ok=True)
- if torch.distributed.get_rank() == 0:
- logger.info(f'Tensor Swapping: Adding optimizer tensors')
- swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[
- OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper
- self.optimizer_swapper = swapper_type(
- swap_config=offload_optimizer_config,
- aio_config=aio_config,
- base_folder=nvme_swap_folder,
- optimizer=self.optimizer,
- largest_numel=max(self.fp16_partitioned_groups_flat_numel),
- device=self.device,
- dtype=torch.float32,
- timers=self.timers)
- def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False):
- '''If flat buffer is None then the parameters in the param_list are
- not copied to the flat buffer. This is because they excede the number of max_params_in_cpu
- Some of these parameters may aready be in CPU in unflattened buffers
- or they maybe in GPU, or they maybe in NVME. If they are in NVME, then
- they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are
- needed during training.'''
- if flat_buffer is None:
- # this dst buffer is on NVMe, so skip this
- return
- start = 0
- for param in param_list:
- src = param.ds_tensor
- dest = flat_buffer.narrow(0, start, src.ds_numel)
- start = start + src.ds_numel
- '''if the parameter was initialized in nvme then bring it to the destination buffer directly'''
- if src.status == PartitionedParamStatus.NOT_AVAILABLE:
- print_rank_0(
- f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU"
- )
- param.nvme_swapper.swap_into_buffer(param, dest)
- src.data = dest.data
- src.status = PartitionedParamStatus.AVAILABLE
- else:
- assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here"
- if not avoid_copy:
- dest.data.copy_(src.data)
- src.data = dest.data
- # Final location must be gpu/cpu in this case
- param.ds_tensor.final_location = 'not-nvme'
- def _create_param_groups_fp16_flat_cpu_memory(self):
- aggregate_params_count = 0
- for j, param_group in enumerate(self.optimizer.param_groups):
- params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']])
- flat_buffer_size = params_in_group
- if self.params_in_nvme_and_cpu and \
- aggregate_params_count + params_in_group > self.max_params_in_cpu:
- flat_buffer_size = max(0,
- self.max_params_in_cpu - aggregate_params_count)
- aggregate_params_count += params_in_group
- if flat_buffer_size > 0:
- print_rank_0(f"group {j} flat buffer size {flat_buffer_size}",
- force=False)
- self.param_groups_fp16_flat_cpu_memory.append(
- torch.empty(int(flat_buffer_size),
- dtype=self.dtype,
- pin_memory=True))
- else:
- print_rank_0(
- f"No flat buffer size. Param group size was {params_in_group}",
- force=False)
- self.param_groups_fp16_flat_cpu_memory.append(
- torch.empty(1,
- dtype=self.dtype))
- def _create_fp16_partitions_with_defragmentation(self):
- dist.barrier()
- partition_id = dist.get_rank(group=self.dp_process_group)
- create_fp16_flat_reuse_buffer = False
- largest_partition_numel = []
- max_partition_numel = 0
- #create a flat CPU memory allocation for each param group
- if self.offload_param:
- self._create_param_groups_fp16_flat_cpu_memory()
- # loop to deal with groups
- for j, param_group in enumerate(self.optimizer.param_groups):
- sub_groups = self._create_fp16_sub_groups(param_group['params'])
- print_rank_0(f'fp16 group {j} has {len(sub_groups)} subgroups', force=False)
- flat_offset = 0
- for sub_group in sub_groups:
- i = len(self.fp16_groups)
- # push this group to list before modify
- self.fp16_groups.append(sub_group)
- self.sub_group_to_group_id[i] = j
- # comment out for zero_to_fp32 debug
- # if torch.distributed.get_rank() == 0:
- # for param in self.fp16_groups[i]:
- # print(f"{debug_param2name_id_shape(param)} {param.ds_shape}")
- #These are the list of the partitioned parameters
- self.fp16_partitioned_groups.append(
- [param.ds_tensor for param in self.fp16_groups[i]])
- total_elements = sum(
- [t.ds_numel for t in self.fp16_partitioned_groups[i]])
- self.fp16_partitioned_groups_flat_numel.append(total_elements)
- if total_elements > max_partition_numel:
- largest_partition_numel = [
- t.ds_numel for t in self.fp16_partitioned_groups[i]
- ]
- max_partition_numel = total_elements
- print_rank_0(
- f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}"
- )
- # Record padding required to align group to world size (only applies to last rank)
- if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
- padding = [p.padding_size() for p in self.fp16_groups[i]]
- else:
- padding = [0] * len(self.fp16_groups[i])
- self.groups_padding.append(padding)
- #not sure why apex was cloning the weights before flattening
- #removing cloning here
- see_memory_usage(f"Before Flattening param subgroup {i}", force=False)
- #all partitioned parameters remain in GPU during training
- if not self.offload_param:
- see_memory_usage(f"Before moving param subgroup group {i} to CPU",
- force=False)
- #move all the parameters to cpu to free up GPU space for creating flat buffer
- move_to_cpu(self.fp16_partitioned_groups[i])
- see_memory_usage(f"After moving param subgroup {i} to CPU",
- force=False)
- #create flat buffer in CPU and move to GPU
- self.fp16_partitioned_groups_flat.append(
- self.flatten_dense_tensors_aligned(
- self.fp16_partitioned_groups[i],
- 1).cuda(torch.cuda.current_device()))
- see_memory_usage(
- f"After flattening and moving param subgroup {i} to GPU",
- force=False)
- #all partitioned parameters are in CPU during training
- else:
- print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}")
- #Flat buffer may not be available for parameters that reside in NVME
- if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[
- j].numel():
- fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[
- j].narrow(0,
- flat_offset,
- total_elements)
- print_rank_0(
- f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}",
- force=False)
- #these parameters reside in NVME and
- elif self.params_in_nvme_and_cpu:
- fp16_partitioned_group_flat = None
- print_rank_0(
- f"No flat buffer for sub group {i} of {total_elements} elements",
- force=False)
- else:
- assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs"
- self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat)
- flat_offset += total_elements
- # move param to flat buffer for both param offload on/off
- self._move_to_flat_buffer(self.fp16_groups[i],
- self.fp16_partitioned_groups_flat[i],
- avoid_copy=not self.offload_param)
- see_memory_usage(f"After Flattening param group {i}", force=False)
- #create a pinned memory to be used for swapping out params to NVME after optimizer step
- if self.fp16_partitioned_groups_flat[-1] is None:
- create_fp16_flat_reuse_buffer = True
- see_memory_usage(f"After Flattening param subgroup {i}", force=False)
- if create_fp16_flat_reuse_buffer:
- assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty'
- self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space(
- largest_partition_numel)
- def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id):
- offset = 0
- elements_in_sub_group = sum(
- [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]])
- assert (flat_buffer.numel() == elements_in_sub_group)
- for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]):
- dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel)
- if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
- print_rank_0(
- f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}"
- )
- param.nvme_swapper.swap_in([param], async_op=False)
- dest.data.copy_(partitioned_param.data)
- param.nvme_swapper.remove_partition_and_release_buffers([param])
- print_rank_0(f"Swapping in {param.ds_id} done")
- else:
- dest.data.copy_(partitioned_param.data)
- offset += partitioned_param.ds_numel
- def _create_next_swappable_fp32_groups(self):
- reverse_order_indices = [
- i for i in range(len(self.fp32_partitioned_groups_flat))
- ]
- reverse_order_indices.reverse()
- next_group = None
- for i in reverse_order_indices:
- self.next_swappable_fp32_partitioned_groups.append(next_group)
- if self._swappable_optimizer_subgroup(i):
- next_group = self.fp32_partitioned_groups_flat[i]
- self.next_swappable_fp32_partitioned_groups.reverse()
- def _get_sub_group_partitions(self, sub_group_id):
- sub_group_partitions = []
- for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]):
- if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
- swap_path = param.nvme_swapper.get_path(param, True)
- sub_group_partitions.append((partitioned_param,
- param.ds_tensor.ds_numel,
- swap_path))
- else:
- sub_group_partitions.append((partitioned_param,
- partitioned_param.ds_numel,
- None))
- return sub_group_partitions
- def _create_fp32_partitions(self):
- cpu_memory_usage = 0
- cpu_memory_sub_groups = 0
- nvme_memory_usage = 0
- num_swappable_partitions = 0
- num_swap_from_nvme_partitions = 0
- num_swap_from_cpu_partitions = 0
- swap_from_nvme_memory_usage = 0
- swap_from_cpu_memory_usage = 0
- GIGA_BYTES = (1024**3)
- swappable_fp32_tensors = []
- swappable_fp16_src_tensors = []
- nvme_fp16_partitions_info = []
- nvme_fp16_num_elems = []
- nvme_fp32_dest_tensors = []
- fp32_element_size = torch.tensor([], dtype=torch.float32).element_size()
- for i, tensor in enumerate(self.fp16_partitioned_groups_flat):
- num_elements = self.fp16_partitioned_groups_flat_numel[i]
- # a partition of the fp32 master weights that will be updated by this process
- if self._swappable_optimizer_subgroup(i):
- self.fp32_partitioned_groups_flat.append(torch.Tensor())
- nvme_memory_usage += (fp32_element_size * num_elements)
- num_swappable_partitions += 1
- if self.params_in_nvme_and_cpu and tensor is None:
- num_swap_from_nvme_partitions += 1
- swap_from_nvme_memory_usage += (fp32_element_size * num_elements)
- if self.offload_optimizer_fast_init:
- sub_group_partitions = self._get_sub_group_partitions(i)
- nvme_fp16_partitions_info.append(sub_group_partitions)
- nvme_fp16_num_elems.append(num_elements)
- nvme_fp32_dest_tensors.append(
- self.fp32_partitioned_groups_flat[i])
- else:
- unpinned_fp32_buffer = torch.empty(num_elements,
- device=self.device,
- dtype=torch.float)
- self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i)
- self.optimizer_swapper.initialize_parameters(
- parameters=[self.fp32_partitioned_groups_flat[i]],
- src_tensors=[unpinned_fp32_buffer])
- else:
- num_swap_from_cpu_partitions += 1
- swap_from_cpu_memory_usage += (fp32_element_size * num_elements)
- swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i])
- swappable_fp16_src_tensors.append(
- self.fp16_partitioned_groups_flat[i])
- else:
- cpu_memory_usage += (fp32_element_size * num_elements)
- cpu_memory_sub_groups += 1
- if self.params_in_nvme_and_cpu and tensor is None:
- unpinned_fp32_buffer = torch.empty(num_elements,
- device=self.device,
- dtype=torch.float)
- self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i)
- self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer)
- else:
- self.fp32_partitioned_groups_flat.append(
- self.fp16_partitioned_groups_flat[i].to(
- self.device).clone().float().detach())
- self.fp32_partitioned_groups_flat[
- i].requires_grad = True # keep this in case internal optimizer uses it
- if len(swappable_fp32_tensors) > 0:
- self.optimizer_swapper.initialize_parameters(
- parameters=swappable_fp32_tensors,
- src_tensors=swappable_fp16_src_tensors)
- if len(nvme_fp32_dest_tensors) > 0:
- fp16_pinned_buffers = self.fp16_groups[0][
- 0].nvme_swapper.reserve_available_buffers()
- assert len(fp16_pinned_buffers) > 0
- self.optimizer_swapper.initialize_from_swapped_fp16_params(
- fp16_partitions_info=nvme_fp16_partitions_info,
- fp16_num_elems=nvme_fp16_num_elems,
- fp16_pinned_buffers=fp16_pinned_buffers,
- fp32_parameters=nvme_fp32_dest_tensors)
- self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers()
- nvme_gigabytes = nvme_memory_usage / GIGA_BYTES
- print_rank_0(
- f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB',
- force=False)
- if self.params_in_nvme_and_cpu:
- print_rank_0(
- f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB',
- force=False)
- print_rank_0(
- f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB',
- force=False)
- cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES
- print_rank_0(
- f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB',
- force=False)
- # Clear for on-the-fly population before the optimizer step
- for param_group in self.optimizer.param_groups:
- param_group['params'] = []
- def _create_fp16_sub_groups(self, params_group):
- params_group_numel = sum([param.partitioned_size() for param in params_group])
- sub_group_size = self.sub_group_size
- if sub_group_size is None or sub_group_size >= params_group_numel:
- return [params_group]
- sub_groups = []
- sub_group = []
- local_sub_group_size = 0
- for param in params_group:
- sub_group.append(param)
- local_sub_group_size += param.partitioned_size()
- if local_sub_group_size >= sub_group_size or id(param) == id(
- params_group[-1]):
- sub_groups.append(sub_group)
- sub_group = []
- local_sub_group_size = 0
- return sub_groups
- # def reset_ds_tensor(self):
- # for name, param in self.module.named_parameters(recurse=True):
- # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible"
- # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now"
- # param.ds_tensor.data = param.data
- def setup_zero_stage3_hooks(self):
- self.hierarchy = 0
- self._register_hooks_recursively(self.module)
- #reset step at the beginning of forward
- def _pre_forward_hook(module, *args):
- self.param_coordinator.reset_step()
- #reset step if in inference mode
- def _end_of_forward_hook(module, *args):
- if not torch._C.is_grad_enabled():
- self.param_coordinator.reset_step()
- #likely one of them should be enough but just to be safe
- self.module.register_forward_hook(_end_of_forward_hook)
- self.module.register_forward_pre_hook(_pre_forward_hook)
- # Add top module to stack trace
- global FWD_MODULE_STACK
- FWD_MODULE_STACK.append(self.module)
- def persistent_parameters(self):
- persistent_params = []
- total_persistent_parameters = 0
- params_count = 0
- for _, param in self.module.named_parameters(recurse=True):
- if param.ds_numel < self.persistence_threshold:
- params_count += 1
- param.ds_persist = True
- persistent_params.append(param)
- total_persistent_parameters += param.ds_numel
- print_rank_0(
- f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
- force=False)
- return persistent_params
- def _register_hooks_recursively(self, module, count=[0]):
- my_count = count[0]
- module.id = my_count
- #print(f"{module.__class__} : {module.id}")
- for child in module.children():
- count[0] = count[0] + 1
- self._register_hooks_recursively(child, count=count)
- def _pre_forward_module_hook(module, *args):
- self.pre_sub_module_forward_function(module)
- def _post_forward_module_hook(module, input, output):
- global FWD_MODULE_STACK
- FWD_MODULE_STACK.pop()
- if output is None:
- output = []
- elif not isinstance(output, (list, tuple)):
- if torch.is_tensor(output):
- output = [output]
- else:
- #print(f'got UNKNOWN type {type(output)}')
- outputs = []
- output = output if isinstance(output, dict) else vars(output)
- for name, val in output.items():
- if not name.startswith('__') and torch.is_tensor(val):
- outputs.append(val)
- output = outputs
- #print(f'convert output to {output}')
- for item in filter(lambda item: is_zero_param(item), output):
- if not any(id(item) in m._external_params for m in FWD_MODULE_STACK):
- item.ds_active_sub_modules += 1
- module_to_register = FWD_MODULE_STACK[-1]
- print_rank_0(
- f'Registering dangling parameter for module {module_to_register.__class__.__name__}.',
- force=False)
- register_external_parameter(module_to_register, item)
- # It's possible that the parameter was already external to the completed module. If so, remove it the
- # registration as it will be covered by the outer module instead.
- if id(item) in module._external_params:
- print_rank_0(
- f' Unregistering nested dangling parameter from module {module.__class__.__name__}',
- force=False)
- unregister_external_parameter(module, item)
- item.all_gather()
- self.post_sub_module_forward_function(module)
- def _pre_backward_module_hook(module, inputs, output):
- def _run_before_backward_function(sub_module):
- # some models (e.g. Albert) may run multiple forwards on the same layer in a loop
- # before doing backwards, so each backward will need a pre-fetch - using reference
- # counting to support this scenario
- #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
- if sub_module.applied_pre_backward_ref_cnt > 0:
- self.pre_sub_module_backward_function(sub_module)
- sub_module.applied_pre_backward_ref_cnt -= 1
- #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
- return _apply_to_tensors_only(module,
- PreBackwardFunction,
- _run_before_backward_function,
- output)
- #This is an alternate to doing _post_backward_module_hook
- #it uses tensor.register_hook instead of using torch.autograd.Function
- def _alternate_post_backward_module_hook(module, inputs):
- module.ds_grads_remaining = 0
- #print(f"Before Forward {module.__class__.__name__}")
- def _run_after_backward_hook(*unused):
- module.ds_grads_remaining = module.ds_grads_remaining - 1
- if module.ds_grads_remaining == 0:
- #print(f"After backward {module.__class__.__name__}")
- self.post_sub_module_backward_function(module)
- def _run_before_forward_function(input):
- if input.requires_grad:
- module.ds_grads_remaining += 1
- return _apply_forward_and_backward_to_tensors_only(
- module,
- _run_before_forward_function,
- _run_after_backward_hook,
- inputs)
- def _post_backward_module_hook(module, inputs):
- module.ds_grads_remaining = 0
- def _run_after_backward_function(sub_module):
- if sub_module.ds_grads_remaining == 0:
- self.post_sub_module_backward_function(sub_module)
- return _apply_to_tensors_only(module,
- PostBackwardFunction,
- _run_after_backward_function,
- inputs)
- # Pre forward hook
- module.register_forward_pre_hook(_pre_forward_module_hook)
- # Post forward hook
- module.register_forward_hook(_post_forward_module_hook)
- # Pre backward hook
- module.register_forward_hook(_pre_backward_module_hook)
- # post backward hook
- module.register_forward_pre_hook(_post_backward_module_hook)
- def pre_sub_module_forward_function(self, sub_module):
- see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}",
- force=False)
- global FWD_MODULE_STACK
- FWD_MODULE_STACK.append(sub_module)
- self.param_coordinator.record_trace(sub_module)
- self.param_coordinator.fetch_sub_module(sub_module)
- see_memory_usage(
- f"Before sub module function {sub_module.__class__.__name__} after fetch",
- force=False)
- self.param_coordinator.prefetch_next_sub_modules(
- sub_module,
- numel=self.prefetch_elements,
- nvme=self.params_in_nvme_and_cpu)
- see_memory_usage(
- f"Before sub module function {sub_module.__class__.__name__} after prefetch",
- force=False)
- self.param_coordinator.increment_step(sub_module)
- def post_sub_module_forward_function(self, sub_module):
- see_memory_usage(
- f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
- force=False)
- self.param_coordinator.release_sub_module(sub_module)
- see_memory_usage(
- f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
- force=False)
- def pre_sub_module_backward_function(self, sub_module):
- self.param_coordinator.record_trace(sub_module)
- self.param_coordinator.fetch_sub_module(sub_module)
- self.param_coordinator.prefetch_next_sub_modules(sub_module,
- numel=self.prefetch_elements)
- self.param_coordinator.increment_step(sub_module)
- def post_sub_module_backward_function(self, sub_module):
- see_memory_usage(
- f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
- force=False)
- self.param_coordinator.release_sub_module(sub_module)
- see_memory_usage(
- f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
- force=False)
- def _release_ipg_buffers(self):
- if self.contiguous_gradients:
- self.ipg_buffer = None
- if not self.offload_optimizer and self.is_gradient_accumulation_boundary:
- self.grads_in_partition = None
- self.grads_in_partition_offset = 0
- def _optimizer_step(self, sub_group_id):
- param_group_id = self.sub_group_to_group_id[sub_group_id]
- fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
- fp16_param = self.fp16_partitioned_groups_flat[sub_group_id]
- self.optimizer.param_groups[param_group_id]['params'] = [fp32_param]
- self.optimizer.step()
- self.optimizer.param_groups[param_group_id]['params'] = []
- def _swappable_optimizer_subgroup(self, sub_group_id):
- if not self.swap_optimizer:
- return False
- return self.optimizer_swapper.swappable_tensor(
- None,
- numel=self.fp16_partitioned_groups_flat_numel[sub_group_id])
- def _partitioned_params_swap_out(self, i):
- offset = 0
- fp32_param = self.fp32_partitioned_groups_flat[i]
- assert fp32_param is not None, \
- f'fp32 parameters of sub_group {i} is None'
- swap_fp16_params = []
- swap_fp32_params = []
- for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]):
- src = fp32_param.narrow(0, offset, partitioned_param.ds_numel)
- if partitioned_param.status == PartitionedParamStatus.AVAILABLE:
- partitioned_param.data.copy_(src.data)
- else:
- swap_fp32_params.append(src)
- swap_fp16_params.append(param)
- offset += partitioned_param.ds_numel
- if len(swap_fp16_params):
- swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(
- dst_fp16_params=swap_fp16_params,
- src_fp32_params=swap_fp32_params)
- def initialize_optimizer_states(self):
- num_subgroups = len(self.fp16_groups)
- largest_numel = max(
- [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups])
- gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype
- gradient_buffer = torch.zeros(int(largest_numel),
- dtype=gradient_dtype,
- device=self.device)
- timers = self.timers
- timer_names = set()
- if self.swap_optimizer:
- self.optimizer_swapper.init_timers()
- INIT_OPTIMIZER_TIMER = 'init_optimizer_state'
- timer_names.add(INIT_OPTIMIZER_TIMER)
- self.start_timers([INIT_OPTIMIZER_TIMER])
- for i, group in enumerate(self.fp16_groups):
- swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i)
- swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None
- num_elements = int(self.fp16_partitioned_groups_flat_numel[i])
- see_memory_usage(
- f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
- force=False)
- if swappable_optimizer_subgroup:
- self._optimizer_states_and_gradient_swap_in(i, timer_names)
- if self.offload_optimizer and not swappable_optimizer_subgroup:
- subgroup_gradient_buffer = torch.zeros(num_elements,
- dtype=gradient_dtype,
- device=self.device)
- if self.offload_optimizer_pin_memory:
- subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory()
- self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer
- else:
- self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(
- 0,
- 0,
- num_elements)
- self._optimizer_step(i)
- if swappable_param_subgroup:
- self._partitioned_params_swap_out(i)
- if swappable_optimizer_subgroup:
- self._optimizer_states_and_gradient_swap_out(i, timer_names)
- see_memory_usage(
- f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
- force=False)
- self.stop_timers([INIT_OPTIMIZER_TIMER])
- self.log_timers(timer_names)
- if self.swap_optimizer:
- self.optimizer_swapper.log_timers()
- if not self.offload_optimizer:
- for group in self.fp32_partitioned_groups_flat:
- group.grad = None
- # Reset steps
- return
- #########################################################################
- #########################ZeRO Partition Gradients########################
- #########################################################################
- def get_first_param_index(self, group_id, param_group, partition_id):
- for index, param in enumerate(param_group):
- param_id = self.get_param_id(param)
- if partition_id in self.param_to_partition_ids[group_id][param_id]:
- return index
- return None
- def initialize_gradient_partitioning_data_structures(self):
- total_partitions = dist.get_world_size(group=self.dp_process_group)
- for i, param_group in enumerate(self.fp16_groups):
- self.param_to_partition_ids[i] = {}
- self.is_partition_reduced[i] = {}
- self.total_grads_in_partition[i] = {}
- self.remaining_grads_in_partition[i] = {}
- self.is_grad_computed[i] = {}
- self.grad_partition_insertion_offset[i] = {}
- self.grad_start_offset[i] = {}
- self.first_param_index_in_partition[i] = {}
- for partition_id in range(total_partitions):
- self.is_grad_computed[i][partition_id] = {}
- self.grad_partition_insertion_offset[i][partition_id] = {}
- self.grad_start_offset[i][partition_id] = {}
- self.initialize_gradient_partition(i, param_group, partition_id)
- self.is_partition_reduced[i][partition_id] = False
- self.first_param_index_in_partition[i][
- partition_id] = self.get_first_param_index(
- i,
- param_group,
- partition_id)
- def independent_gradient_partition_epilogue(self):
- self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
- self.reduce_ipg_grads()
- self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
- if self.overlap_comm:
- self.reduction_stream.synchronize()
- with torch.cuda.stream(self.reduction_stream):
- self.partition_previous_reduced_grads()
- # if dist.get_rank() == 0:
- # logger.info("Params already reduced %s", self.params_already_reduced)
- for i in range(len(self.params_already_reduced)):
- self.params_already_reduced[i] = False
- #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad
- #TODO: use a similar code path for both cpu_offload and non-cpu offload
- if not self.offload_optimizer:
- for i, sub_group in enumerate(self.fp16_groups):
- self.averaged_gradients[i] = [
- torch.zeros_like(param.ds_tensor) if param.grad is None else
- param.grad.data.narrow(0,
- 0,
- param.ds_tensor.numel())
- for param in sub_group
- ]
- # self.averaged_gradients[i] = self.get_flat_partition(
- # self.fp16_groups[i],
- # 0,
- # self.fp32_partitioned_groups_flat[i].numel(),
- # return_tensor_list=True)
- self._release_ipg_buffers()
- see_memory_usage(f"End ipg_epilogue", force=False)
- # resets all partition to no reduced
- # sets remaining grads to the total number of grads in each partition
- # set is grad computed to false for all grads in partition
- def reset_partition_gradient_structures(self):
- total_partitions = dist.get_world_size(group=self.dp_process_group)
- for i, _ in enumerate(self.fp16_groups):
- for partition_id in range(total_partitions):
- self.is_partition_reduced[i][partition_id] = False
- self.remaining_grads_in_partition[i][
- partition_id] = self.total_grads_in_partition[i][partition_id]
- for param_id in self.is_grad_computed[i][partition_id]:
- self.is_grad_computed[i][partition_id][param_id] = False
- def initialize_gradient_partition(self, i, param_group, partition_id):
- def set_key_value_list(dictionary, key, value):
- if key in dictionary:
- dictionary[key].append(value)
- else:
- dictionary[key] = [value]
- def increment_value(dictionary, key):
- if key in dictionary:
- dictionary[key] += 1
- else:
- dictionary[key] = 1
- partition_size = self.partition_size[i]
- start_index = partition_size * partition_id
- end_index = partition_size * (partition_id + 1)
- current_index = 0
- first_offset = 0
- for param in param_group:
- param_size = param.numel()
- param_id = self.get_param_id(param)
- if (current_index >= start_index and current_index < end_index):
- set_key_value_list(self.param_to_partition_ids[i],
- param_id,
- partition_id)
- increment_value(self.total_grads_in_partition[i], partition_id)
- self.is_grad_computed[i][partition_id][param_id] = False
- self.grad_partition_insertion_offset[i][partition_id][
- param_id] = current_index - start_index
- self.grad_start_offset[i][partition_id][param_id] = 0
- elif start_index > current_index and start_index < (current_index +
- param_size):
- assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
- first_offset = start_index - current_index
- set_key_value_list(self.param_to_partition_ids[i],
- param_id,
- partition_id)
- increment_value(self.total_grads_in_partition[i], partition_id)
- self.is_grad_computed[i][partition_id][param_id] = False
- self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
- self.grad_start_offset[i][partition_id][param_id] = first_offset
- current_index = current_index + param_size
- def overlapping_partition_gradients_reduce_epilogue(self):
- self.independent_gradient_partition_epilogue()
- self.zero_grad()
- def create_reduce_and_remove_grad_hooks(self):
- print_rank_0(f'[Begin] Create gradient reduction hooks')
- self.grad_accs = []
- for i, param_group in enumerate(self.fp16_groups):
- for param in param_group:
- if param.requires_grad:
- #print_rank_0(f" Before all gather {param.device}, {param.shape}")
- # The hook must be created in un-partitioned parameter
- param.all_gather()
- #print(f"After all gather {param.device}, {param.shape}")
- def wrapper(param, i):
- param_tmp = param.expand_as(param)
- grad_acc = param_tmp.grad_fn.next_functions[0][0]
- def reduce_partition_and_remove_grads(*notneeded):
- self.reduce_ready_partitions_and_remove_grads(param, i)
- grad_acc.register_hook(reduce_partition_and_remove_grads)
- self.grad_accs.append(grad_acc)
- #print(f"param grad fn {param.expand_as(param).grad_fn}")
- wrapper(param, i)
- # Partition the parameter after creating the hook
- param.partition()
- print_rank_0(f'[End] Create gradient reduction hooks')
- def get_param_id(self, param):
- unique_id = id(param)
- return self.param_id[unique_id]
- def report_ipg_memory_usage(self, tag, param_elems):
- elem_count = self.elements_in_ipg_bucket + param_elems
- percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
- see_memory_usage(
- f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}",
- force=False)
- ###############Idependent Partition Gradient ########################
- def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
- #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True)
- # Because the ipg bucket is initialized with a random place holder tensor, we must
- # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket >
- # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a
- # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be
- # empty, while reduction_list will have that garbage data.
- if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size:
- self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
- param.ds_numel)
- self.reduce_ipg_grads()
- if self.contiguous_gradients and self.overlap_comm:
- # Swap ipg_index between 0 and 1
- self.ipg_index = 1 - self.ipg_index
- self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
- param.ds_numel)
- param_id = self.get_param_id(param)
- assert self.params_already_reduced[param_id] == False, \
- f"The parameter {param_id} has already been reduced. \
- Gradient computed twice for this partition. \
- Multiple gradient reduction is currently not supported"
- # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
- if param.ds_numel > self.reduce_bucket_size:
- self.extra_large_param_to_reduce = param
- elif self.contiguous_gradients:
- #print_rank_0("before new grad tensor move")
- new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
- 0,
- self.elements_in_ipg_bucket,
- param.ds_numel)
- #print_rank_0("after new grad tensor move")
- new_grad_tensor.copy_(param.grad.view(-1))
- param.grad.data = new_grad_tensor.data.view_as(param.grad)
- self.elements_in_ipg_bucket += param.ds_numel
- self.grads_in_ipg_bucket.append(param.grad)
- self.params_in_ipg_bucket.append((i, param, param_id))
- self.report_ipg_memory_usage("End ipg_remove_grads", 0)
- def gradient_reduction_w_predivide(self, tensor):
- dp_world_size = dist.get_world_size(group=self.dp_process_group)
- tensor_to_allreduce = tensor
- if self.communication_data_type != tensor.dtype:
- tensor_to_allreduce = tensor.to(self.communication_data_type)
- if self.postscale_gradients:
- if self.gradient_predivide_factor != 1.0:
- tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
- dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
- if self.gradient_predivide_factor != dp_world_size:
- tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size)
- else:
- tensor_to_allreduce.div_(dp_world_size)
- dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
- if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
- tensor.copy_(tensor_to_allreduce)
- return tensor
- def average_tensor(self, tensors, params_to_reduce):
- with torch.cuda.stream(self.reduction_stream):
- if not self.reduce_scatter:
- for tensor in tensors:
- self.gradient_reduction_w_predivide(tensor)
- return
- for tensor in tensors:
- tensor.div_(dist.get_world_size(group=self.dp_process_group))
- # reduction resulting with each rank only holding the gradient partition it owns
- # This could either be a reduce scatter or a reduce op depending on how
- # parameters are partitionied. The method is implemented by the
- # DeepSpeed param extensions to the pytorch parameter, so its up to
- # the extension to define what happens here
- params_to_reduce[0].reduce_gradients_at_owner(
- param_list=params_to_reduce,
- hierarchy=self.param_coordinator.hierarchy)
- def set_grad_positions(self):
- for i, group in enumerate(self.fp16_groups):
- current_offset = 0
- for param in group:
- param_id = self.get_param_id(param)
- num_elements = param.ds_tensor.ds_numel
- self.grad_position[param_id] = [
- int(i),
- int(current_offset),
- int(num_elements)
- ]
- #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}")
- current_offset += num_elements
- def async_accumulate_grad_in_cpu_via_gpu(self, param, acc_grad_cpu_partition):
- # copy to a preexisiting buffer to avoid memory allocation penalty
- dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
- 0,
- 0,
- param.ds_tensor.ds_numel)
- if self.micro_step_id > 0:
- dest_buffer.copy_(acc_grad_cpu_partition.view(-1), non_blocking=True)
- param.grad.data.view(-1).add_(dest_buffer)
- # at the boundary we will send 32bit directly
- if not self.is_gradient_accumulation_boundary:
- acc_grad_cpu_partition.data.copy_(param.grad.data.view(-1),
- non_blocking=True)
- def _constant_buffered_norm2(self, input, buffer_size=250000000):
- norm = None
- for part in input.view(-1).split(buffer_size):
- if norm is None:
- norm = part.data.double().norm(2)**2.0
- else:
- norm += part.data.double().norm(2)**2.0
- return norm**0.5
- def set_norm_for_param_grad_in_gpu(self, param):
- param_id = self.get_param_id(param)
- #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2)
- #Using a more memory efficient version
- self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad)
- def update_overflow_tracker_for_param_grad(self, param):
- #Credit to our user David Minn
- if param.grad is not None:
- if self.overlap_comm:
- self.gpu_sum = self.gpu_sum + param.grad.data.float().sum()
- elif self._has_inf_or_nan(param.grad.data):
- self.local_overflow = True
- def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor):
- with torch.cuda.stream(self.copy_grad_stream):
- param_id = self.get_param_id(param)
- src_tensor = param.grad.view(-1).float()
- #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}")
- fp32_grad_tensor.copy_(src_tensor, non_blocking=True)
- param.grad = None
- def complete_grad_norm_calculation_for_cpu_offload(self, params):
- total_norm = 0.0
- norm_type = 2.0
- for p in params:
- if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
- param_id = self.get_param_id(p)
- if param_id in self.norm_for_param_grads.keys():
- param_norm = self.norm_for_param_grads[param_id]
- total_norm += param_norm.item()**2
- # Sum across all model parallel GPUs.
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=self.dp_process_group)
- self._model_parallel_all_reduce(tensor=total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM)
- total_norm = total_norm_cuda[0].item()**(1. / norm_type)
- if total_norm == float(
- 'inf') or total_norm == -float('inf') or total_norm != total_norm:
- total_norm = -1
- return total_norm
- def partition_previous_reduced_grads(self):
- if not self.previous_reduced_grads:
- return
- if self.offload_optimizer:
- allocate_grads_in_partition = self.grads_in_partition is None\
- and self.gradient_accumulation_steps > 1
- else:
- allocate_grads_in_partition = self.grads_in_partition is None
- if allocate_grads_in_partition:
- self.grads_in_partition = []
- for i, group in enumerate(self.fp16_groups):
- total_size = 0
- for param_in_partition in group:
- total_size += param_in_partition.ds_tensor.ds_numel
- see_memory_usage(
- f"group {i} before creating {total_size} reduced gradients into partition",
- force=False)
- if self.offload_param_pin_memory:
- self.grads_in_partition.append(
- torch.zeros(int(total_size),
- dtype=self.dtype,
- device=self.device).pin_memory())
- else:
- self.grads_in_partition.append(
- torch.zeros(int(total_size),
- dtype=self.dtype,
- device=self.device))
- see_memory_usage(
- f"group {i} after creating {total_size} reduced gradients into partition",
- force=False)
- if self.offload_optimizer:
- offload_fp32_gradients = {}
- offload_fp32_offsets = {}
- with torch.cuda.stream(self.copy_grad_stream):
- self.reduction_stream.synchronize()
- for param in self.previous_reduced_grads:
- [i,
- dest_offset,
- num_elements] = self.grad_position[self.get_param_id(param)]
- if self.offload_optimizer:
- param.partition_gradients(
- partition_buffers=self.temp_grad_gpu_buffer)
- #with torch.cuda.stream(self.copy_grad_stream):
- # self.reduction_stream.synchronize()
- if self.gradient_accumulation_steps > 1:
- # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer
- fp16_grad_tensor = self.grads_in_partition[i].narrow(
- 0,
- dest_offset,
- num_elements)
- self.async_accumulate_grad_in_cpu_via_gpu(
- param,
- fp16_grad_tensor)
- if self.is_gradient_accumulation_boundary:
- self.set_norm_for_param_grad_in_gpu(param)
- self.update_overflow_tracker_for_param_grad(param)
- if self._swappable_optimizer_subgroup(i):
- if not i in offload_fp32_gradients.keys():
- offload_fp32_gradients[i] = []
- offload_fp32_offsets[i] = []
- offload_fp32_gradients[i].append(param.grad.view(-1).float())
- param.grad = None
- offload_fp32_offsets[i].append(dest_offset)
- else:
- fp32_grad_tensor = self.fp32_partitioned_groups_flat[
- i].grad.narrow(0,
- dest_offset,
- num_elements)
- self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(
- param,
- fp32_grad_tensor)
- else:
- # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer
- fp16_grad_tensor = self.grads_in_partition[i].narrow(
- 0,
- dest_offset,
- num_elements)
- param.partition_gradients(
- partition_buffers=fp16_grad_tensor,
- accumulate=True if self.micro_step_id > 0 else False)
- if self.offload_optimizer and self.swap_optimizer:
- for i in offload_fp32_gradients.keys():
- self.optimizer_swapper.swap_out_gradients(
- parameter=self.fp32_partitioned_groups_flat[i],
- gradient_offsets=offload_fp32_offsets[i],
- gradient_tensors=offload_fp32_gradients[i])
- self.previous_reduced_grads = []
- def reduce_ipg_grads(self, extra_param=None):
- if self.overlap_comm:
- self.reduction_stream.synchronize()
- with torch.cuda.stream(self.reduction_stream):
- self.partition_previous_reduced_grads()
- params_to_reduce = [param for i, param, param_id in self.params_in_ipg_bucket]
- #print(f"Params in ipg bucket {self.params_in_ipg_bucket}")
- #print(f"Reducing {[(debug_param2name_id_shape(param), param.grad) for param in params_to_reduce]}")
- #exit(0)
- if self.contiguous_gradients:
- reduction_list = [self.ipg_buffer[self.ipg_index]]
- if self.extra_large_param_to_reduce is not None:
- reduction_list.append(self.extra_large_param_to_reduce.grad)
- self.extra_large_param_to_reduce = None
- self.average_tensor(reduction_list, params_to_reduce)
- else:
- self.buffered_reduce_fallback(
- None,
- self.grads_in_ipg_bucket,
- elements_per_buffer=self.elements_in_ipg_bucket)
- for _, param, param_id in self.params_in_ipg_bucket:
- self.params_already_reduced[param_id] = True
- self.previous_reduced_grads = params_to_reduce
- self.grads_in_ipg_bucket = []
- self.params_in_ipg_bucket = []
- self.elements_in_ipg_bucket = 0
- #####################################################################
- def reduce_ready_partitions_and_remove_grads(self, param, i):
- #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True)
- self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
- def zero_reduced_gradients(self, partition_id, i):
- def are_all_related_partitions_reduced(params_id):
- for partition_id in self.param_to_partition_ids[i][params_id]:
- if not self.is_partition_reduced[i][partition_id]:
- return False
- return True
- for params_id in self.is_grad_computed[i][partition_id]:
- if are_all_related_partitions_reduced(params_id):
- self.param_dict[params_id].grad = None
- def flatten_and_print(self, message, tensors, start=0, n=5):
- flatten_tensor = self.flatten(tensors)
- def print_func():
- logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
- self.sequential_execution(print_func, message)
- def get_grads_to_reduce(self, i, partition_id):
- def get_reducible_portion(key):
- grad = self.param_dict[key].grad
- total_elements = grad.numel()
- start = self.grad_start_offset[i][partition_id][key]
- num_elements = min(
- total_elements - start,
- self.partition_size[i] -
- self.grad_partition_insertion_offset[i][partition_id][key])
- if not pg_correctness_test:
- if num_elements == total_elements:
- return grad
- else:
- return grad.contiguous().view(-1).narrow(0,
- int(start),
- int(num_elements))
- else:
- if num_elements == total_elements:
- return grad.clone()
- else:
- return grad.clone().contiguous().view(-1).narrow(
- 0,
- int(start),
- int(num_elements))
- grads_to_reduce = []
- for key in self.is_grad_computed[i][partition_id]:
- grad = get_reducible_portion(key)
- grads_to_reduce.append(grad)
- return grads_to_reduce
- def sequential_execution(self, function, message, group=None):
- if group is None:
- group = self.dp_process_group
- if dist.get_rank(group=group) == 0:
- logger.info(message)
- for id in range(dist.get_world_size(group=group)):
- if id == dist.get_rank(group=group):
- function()
- dist.barrier(group=group)
- def set_none_gradients_to_zero(self, i, partition_id):
- for param_id in self.is_grad_computed[i][partition_id]:
- param = self.param_dict[param_id]
- if param.grad is None:
- param.grad = torch.zero_like(param)
- ######################Reduction Related Methods##############################
- def allreduce_bucket(self,
- bucket,
- communication_data_type=torch.float16,
- rank=None,
- log=None):
- rank = None
- tensor = self.flatten(bucket)
- tensor_to_allreduce = tensor
- if pg_correctness_test:
- communication_data_type = torch.float32
- if communication_data_type != tensor.dtype:
- tensor_to_allreduce = tensor.to(communication_data_type)
- tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
- if rank is None:
- # "All Reducing"
- dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
- else:
- global_rank = _get_global_rank(self.dp_process_group, rank)
- dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
- if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
- if rank is None or rank == dist.get_rank(group=self.dp_process_group):
- tensor.copy_(tensor_to_allreduce)
- return tensor
- # if rank is specified do a reduction instead of an allreduce
- def allreduce_and_copy(self, small_bucket, rank=None, log=None):
- with torch.cuda.stream(self.reduction_stream):
- allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
- if rank is None or rank == dist.get_rank(group=self.dp_process_group):
- for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
- buf.copy_(synced)
- def allreduce_no_retain(self,
- bucket,
- numel_per_bucket=500000000,
- rank=None,
- log=None):
- small_bucket = []
- numel = 0
- for tensor in bucket:
- small_bucket.append(tensor)
- numel = numel + tensor.numel()
- if numel > numel_per_bucket:
- self.allreduce_and_copy(small_bucket, rank=rank, log=None)
- small_bucket = []
- if len(small_bucket) > 0:
- self.allreduce_and_copy(small_bucket, rank=rank, log=log)
- # allows using reduction of gradients instead of using all_reduce
- def buffered_reduce_fallback(self,
- rank,
- grads,
- elements_per_buffer=500000000,
- log=None):
- split_buckets = split_half_float_double(grads)
- for i, bucket in enumerate(split_buckets):
- self.allreduce_no_retain(bucket,
- numel_per_bucket=elements_per_buffer,
- rank=rank,
- log=log)
- #############################################################################
- #############################################################################
- #############################################################################
- # views the tensor as multiple partitions and returns
- # those partitions
- def get_data_parallel_partitions(self, tensor):
- partitions = []
- dp = dist.get_world_size(group=self.dp_process_group)
- dp_id = dist.get_rank(group=self.dp_process_group)
- total_num_elements = tensor.numel()
- base_size = total_num_elements // dp
- remaining = total_num_elements % dp
- start = 0
- for id in range(dp):
- partition_size = base_size
- if id < remaining:
- partition_size = partition_size + 1
- partitions.append(tensor.narrow(0, start, partition_size))
- start = start + partition_size
- return partitions
- def get_partition_info(self, tensor_list, partition_size, partition_id):
- params_in_partition = []
- params_not_in_partition = []
- start_index = partition_size * partition_id
- end_index = partition_size * (partition_id + 1)
- current_index = 0
- first_offset = 0
- for tensor in tensor_list:
- tensor_size = tensor.numel()
- if (current_index >= start_index and current_index < end_index):
- params_in_partition.append(tensor)
- elif start_index > current_index and start_index < (current_index +
- tensor_size):
- params_in_partition.append(tensor)
- assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
- first_offset = start_index - current_index
- else:
- params_not_in_partition.append(tensor)
- current_index = current_index + tensor_size
- return params_in_partition, params_not_in_partition, first_offset
- def zero_grad(self, set_grads_to_None=True):
- """
- Zero FP16 parameter grads.
- """
- # FP32 grad should never exist.
- # For speed, set model fp16 grad to None by default
- for group in self.fp16_groups:
- for p in group:
- if set_grads_to_None:
- p.grad = None
- else:
- if p.grad is not None:
- p.grad.detach_()
- p.grad.zero_()
- def _model_parallel_all_reduce(self, tensor, op):
- """ Perform all reduce within model parallel group, if any.
- """
- if self.model_parallel_group is None:
- pass
- else:
- torch.distributed.all_reduce(tensor=tensor,
- op=op,
- group=self.model_parallel_group)
- def get_grad_norm_direct(self, gradients, params, norm_type=2):
- """Clips gradient norm of an iterable of parameters.
- This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
- added functionality to handle model parallel parameters. Note that
- the gradients are modified in place.
- Arguments:
- parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
- single Tensor that will have gradients normalized
- max_norm (float or int): max norm of the gradients
- norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
- infinity norm.
- Returns:
- Total norm of the parameters (viewed as a single vector).
- """
- norm_type = float(norm_type)
- if norm_type == inf:
- total_norm = max(g.data.abs().max() for g in gradients)
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX,
- group=self.dp_process_group)
- # Take max across all GPUs.
- self._model_parallel_all_reduce(tensor=total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX)
- total_norm = total_norm_cuda[0].item()
- else:
- total_norm = 0.0
- # if dist.get_rank() == 0:
- # logger.info(f"Total Norm beginning {total_norm}")
- for g, p in zip(gradients, params):
- if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
- param_norm = g.data.double().norm(2)
- total_norm += param_norm.item()**2
- # Sum across all model parallel GPUs.
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=self.dp_process_group)
- self._model_parallel_all_reduce(tensor=total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM)
- total_norm = total_norm_cuda[0].item()**(1. / norm_type)
- if total_norm == float(
- 'inf') or total_norm == -float('inf') or total_norm != total_norm:
- total_norm = -1
- return total_norm
- # creates a flat fused tensor from the tensor list starting at the first_offset
- # in the first tensor of the list. If there are not enough elements in the tensor
- # list then the flat tensor will be padded with zeros
- def get_flat_partition(self,
- tensor_list,
- first_offset,
- partition_size,
- return_tensor_list=False):
- flat_tensor_list = []
- current_size = 0
- for i, tensor in enumerate(tensor_list):
- if tensor.grad is None:
- tensor.grad = torch.zeros_like(tensor)
- tensor = tensor.grad
- num_elements = tensor.numel()
- tensor_offset = 0
- # we need to offset to get to the right element
- if i == 0 and first_offset > 0:
- tensor_offset = first_offset
- num_elements = num_elements - tensor_offset
- # we dont need all elements of the tensor
- if num_elements > (partition_size - current_size):
- num_elements = partition_size - current_size
- # we need a narrow view of the tensor based on the tensor offset and number of elements that
- # we need from this tensor
- if tensor_offset > 0 or num_elements < tensor.numel():
- flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
- 0,
- int(tensor_offset),
- int(num_elements)))
- else:
- flat_tensor_list.append(tensor)
- current_size = current_size + num_elements
- # this means its the last partition and does not align with the dp boundary. We need to pad before flattening
- if current_size < partition_size:
- flat_tensor_list.append(
- torch.zeros(int(partition_size - current_size),
- dtype=tensor_list[0].dtype,
- device=tensor_list[0].device))
- if return_tensor_list:
- return flat_tensor_list
- return self.flatten(flat_tensor_list)
- def free_grad_in_param_list(self, param_list):
- for p in param_list:
- p.grad = None
- def reset_cpu_buffers(self):
- self.norm_for_param_grads = {}
- self.local_overflow = False
- def log_timers(self, timer_names):
- if self.timers is None:
- return
- self.timers.log(names=list(timer_names))
- def start_timers(self, timer_names):
- if self.timers is None:
- return
- for name in timer_names:
- self.timers(name).start()
- def stop_timers(self, timer_names):
- if self.timers is None:
- return
- for name in timer_names:
- self.timers(name).stop()
- def _pre_step(self):
- self.micro_step_id = INITIAL_MICRO_STEP_ID
- print_rank_0(f"Inside Step function")
- see_memory_usage(f"In step before checking overflow", force=False)
- print_rank_0("Finished Tracing at Beginning of Step")
- self.param_coordinator.hierarchy = 0
- self.param_coordinator.finish_tracing(print_trace=True)
- self.param_coordinator.reset_step()
- print_rank_0("Finished Tracing at Beginning of Step")
- def _get_norm_groups(self):
- norm_groups = []
- for i, group in enumerate(self.fp16_groups):
- if self.offload_optimizer:
- norm_groups.append(
- self.complete_grad_norm_calculation_for_cpu_offload(
- self.fp16_groups[i]))
- else:
- norm_groups.append(
- self.get_grad_norm_direct(self.averaged_gradients[i],
- self.fp16_groups[i]))
- return norm_groups
- def _prepare_fp32_grad_for_sub_group(self, sub_group_id):
- partition_id = dist.get_rank(group=self.dp_process_group)
- single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to(
- self.fp32_partitioned_groups_flat[sub_group_id].dtype)
- assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \
- "averaged gradients have different number of elements that partition size {} {} {} {}".format(
- single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id)
- self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition
- # release all the gradient since we have already created a necessary copy in dp_grad_partition
- self.zero_grad()
- self.averaged_gradients[sub_group_id] = None
- def _prepare_sub_group(self, sub_group_id, timer_names=set()):
- see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}',
- force=False)
- if self._swappable_optimizer_subgroup(sub_group_id):
- self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names)
- elif not self.offload_optimizer:
- self._prepare_fp32_grad_for_sub_group(sub_group_id)
- see_memory_usage(f'After prepare optimizer sub group {sub_group_id}',
- force=False)
- def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()):
- param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id]
- fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id])
- assert self._swappable_optimizer_subgroup(sub_group_id), \
- f'Parameter {fp32_param_id} of numel={param_length} is not swappable'
- OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state'
- see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}',
- force=False)
- self.start_timers([OPTIMIZER_SWAP_IN_STATE])
- self.optimizer_swapper.swap_in_optimizer_state(
- parameter=self.fp32_partitioned_groups_flat[sub_group_id],
- async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id])
- self.stop_timers([OPTIMIZER_SWAP_IN_STATE])
- timer_names.add(OPTIMIZER_SWAP_IN_STATE)
- see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}',
- force=False)
- def _release_sub_group(self, sub_group_id, timer_names=set()):
- see_memory_usage(f'Before release optimizer sub group {sub_group_id}',
- force=False)
- # get rid of the fp32 gradients. Not needed anymore
- if not self.offload_optimizer:
- self.fp32_partitioned_groups_flat[sub_group_id].grad = None
- if self._swappable_optimizer_subgroup(sub_group_id):
- self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names)
- see_memory_usage(f'After release optimizer sub group {sub_group_id}',
- force=False)
- # create a flat tensor aligned at the alignment boundary
- def flatten_dense_tensors_aligned(self, tensor_list, alignment):
- num_elements = 0
- for tens in tensor_list:
- num_elements = num_elements + tens.numel()
- remaining = num_elements % alignment
- if remaining:
- elements_to_add = alignment - remaining
- pad_tensor = torch.zeros(elements_to_add,
- device=tensor_list[0].device,
- dtype=tensor_list[0].dtype)
- padded_tensor_list = tensor_list + [pad_tensor]
- num_elements = num_elements + elements_to_add
- else:
- padded_tensor_list = tensor_list
- return self.flatten(padded_tensor_list)
- def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()):
- param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id]
- fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id])
- assert self._swappable_optimizer_subgroup(sub_group_id), \
- f'Parameter {fp32_param_id} of numel={param_length} is not swappable'
- OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state'
- see_memory_usage(
- f'post-step Before swapping out optimizer tensors {sub_group_id}',
- force=False)
- self.start_timers([OPTIMIZER_SWAP_OUT_STATE])
- self.optimizer_swapper.swap_out_optimizer_state(
- parameter=self.fp32_partitioned_groups_flat[sub_group_id],
- async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id]
- is not None)
- self.stop_timers([OPTIMIZER_SWAP_OUT_STATE])
- see_memory_usage(
- f'post-step After swapping out optimizer tensors {sub_group_id}',
- force=False)
- timer_names.add(OPTIMIZER_SWAP_OUT_STATE)
- # get rid of the fp32 gradients. Not needed anymore
- self.fp32_partitioned_groups_flat[sub_group_id].grad = None
- def _unflatten_partitioned_parameters(self, sub_group_id):
- updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id],
- self.fp16_partitioned_groups[sub_group_id])
- for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
- partitioned_param.data = q.data
- def _overflow_clean_up(self, prev_scale):
- see_memory_usage('After overflow before clearing gradients', force=False)
- self.zero_grad()
- if self.offload_optimizer:
- self.reset_cpu_buffers()
- else:
- self.averaged_gradients = {}
- see_memory_usage('After overflow after clearing gradients', force=False)
- if torch.distributed.get_rank() == 0:
- logger.info(
- "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
- "reducing to {}".format(dist.get_rank(),
- prev_scale,
- self.loss_scale))
- def _overflow_check_and_loss_scale_update(self):
- # First compute norm for all group so we know if there is overflow
- self.check_overflow()
- #loss scaling related computation
- prev_scale = self.loss_scale
- self._update_scale(self.overflow)
- if self.overflow:
- self._overflow_clean_up(prev_scale)
- return self.overflow
- def _post_step(self, timer_names=set()):
- if self.offload_optimizer:
- self.reset_cpu_buffers()
- #Gathering persisting parameters
- if len(self.persistent_parameters) > 0:
- self.persistent_parameters[0].all_gather(self.persistent_parameters)
- if self.swap_optimizer:
- self.optimizer_swapper.log_timers()
- self.log_timers(timer_names)
- see_memory_usage('After zero_optimizer step', force=False)
- print_rank_0(f"------------------Finishing Step-----------------------")
- def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id):
- if self.fp16_partitioned_groups_flat[sub_group_id] is not None:
- self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
- self.fp32_partitioned_groups_flat[sub_group_id].data)
- #unflatten fp16 parameter subgroup
- self._unflatten_partitioned_parameters(sub_group_id)
- else:
- self._partitioned_params_swap_out(sub_group_id)
- def step(self, closure=None):
- """
- Not supporting closure.
- """
- self._pre_step()
- #checks for overflow, adjust the loss scale accordingly
- if self._overflow_check_and_loss_scale_update():
- if self.swap_optimizer:
- self.optimizer_swapper.log_timers()
- return
- norm_groups = self._get_norm_groups()
- self._global_grad_norm = get_global_norm(norm_list=norm_groups)
- timer_names = set()
- timer_names.add('optimizer_step')
- self.start_timers(['optimizer_step'])
- #update parameters one sub group at a time
- for sub_group_id, group in enumerate(self.fp16_groups):
- #prepare optimizer states, gradients and fp32 parameters for update
- self._prepare_sub_group(sub_group_id, timer_names)
- #scale the fp32 gradients
- self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm)
- #apply the optimizer step on the sub group and copy fp32 parameters to fp16
- self._optimizer_step(sub_group_id)
- #put fp16 parameters in appropriate location
- self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
- #release memory or swap out optimizer states of fp32 parameters
- self._release_sub_group(sub_group_id, timer_names)
- self.stop_timers(['optimizer_step'])
- self._post_step(timer_names)
- return
- def dump_pre_step_gradients(self, debug_fp32_grads):
- # Dump gradient norms for debugging
- for i, _ in enumerate(self.fp16_groups):
- print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC')
- for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]):
- param_id = self.get_param_id(fp16_param)
- fp16_grad_norm = self.debug_fp16_grads[i][param_id]
- fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad]
- norm_list = [fp16_grad_norm, fp32_grad_norm]
- print(f'Pre-Step Norms {i} {param_id} = {norm_list}')
- def dump_post_step_gradients(self):
- # Dump gradient norms for debugging
- for i, group in enumerate(self.fp16_groups):
- print(
- f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT')
- unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i])
- unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i],
- self.fp16_groups[i])
- for j, p in enumerate(self.fp16_groups[i]):
- param_id = self.get_param_id(p)
- param_norm = float(p.data.float().norm(2))
- ds_norm = float(p.ds_tensor.data.float().norm(2))
- unflat_norm = [
- float(t.data.float().norm(2))
- for t in [unflat_fp16[j],
- unflat_fp32[j]]
- ]
- norm_list = [param_norm, ds_norm] + unflat_norm
- print(f'Post-Step Norms {i} {param_id} = {norm_list}')
- def unscale_and_clip_grads(self, sub_group_id, total_norm):
- grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad]
- # compute combined scale factor for this group
- combined_scale = self.loss_scale
- if self.clip_grad > 0.:
- # norm is in fact norm*scale
- clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
- if clip > 1:
- combined_scale = clip * self.loss_scale
- for grad in grad_groups_flat:
- if isinstance(grad, list):
- sub_partitions = grad
- for g in sub_partitions:
- g.data.mul_(1. / combined_scale)
- else:
- grad.data.mul_(1. / combined_scale)
- def _check_overflow(self, partition_gradients=True):
- self.overflow = self.has_overflow(partition_gradients)
- # `params` is a list / generator of torch.Variable
- def has_overflow_serial(self, params, is_grad_list=False):
- for p in params:
- if p.grad is not None and self._has_inf_or_nan(p.grad.data):
- return True
- return False
- def has_overflow_partitioned_grads_serial(self):
- for i in range(len(self.fp16_groups)):
- for j, grad in enumerate(self.averaged_gradients[i]):
- if grad is not None and self._has_inf_or_nan(grad.data, j):
- return True
- return False
- def has_overflow(self, partition_gradients=True):
- if partition_gradients:
- if self.overlap_comm:
- self.local_overflow = self._has_inf_or_nan(self.gpu_sum)
- self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda()
- overflow = self.local_overflow if self.offload_optimizer else self.has_overflow_partitioned_grads_serial(
- )
- #overflow = self.has_overflow_partitioned_grads_serial()
- overflow_gpu = torch.cuda.ByteTensor([overflow])
- torch.distributed.all_reduce(overflow_gpu,
- op=torch.distributed.ReduceOp.MAX,
- group=self.dp_process_group)
- else:
- params = []
- for group in self.fp16_groups:
- for param in group:
- params.append(param)
- overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
- overflow_gpu = torch.cuda.ByteTensor([overflow])
- # Since each model parallel GPU carries only part of the model,
- # make sure overflow flag is synced across all the model parallel GPUs
- self._model_parallel_all_reduce(tensor=overflow_gpu,
- op=torch.distributed.ReduceOp.MAX)
- overflow = overflow_gpu[0].item()
- return bool(overflow)
- # `x` is a torch.Tensor
- @staticmethod
- def _has_inf_or_nan(x, j=None):
- try:
- # if x is half, the .float() incurs an additional deep copy, but it's necessary if
- # Pytorch's .sum() creates a one-element tensor of the same type as x
- # (which is true for some recent version of pytorch).
- cpu_sum = float(x.float().sum())
- # More efficient version that can be used if .sum() returns a Python scalar
- # cpu_sum = float(x.sum())
- except RuntimeError as instance:
- # We want to check if inst is actually an overflow exception.
- # RuntimeError could come from a different error.
- # If so, we still want the exception to propagate.
- if "value cannot be converted" not in instance.args[0]:
- raise
- return True
- else:
- if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
- return True
- return False
- def backward(self, loss, retain_graph=False):
- """
- :attr:`backward` performs the following steps:
- 1. fp32_loss = loss.float()
- 2. scaled_loss = fp32_loss*loss_scale
- 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
- """
- self.micro_step_id += 1
- print_rank_0(
- f"Total fully available parameters {self.param_coordinator.total_available_parameter_numel}"
- )
- if self.swap_optimizer:
- self.optimizer_swapper.pre_backward()
- see_memory_usage(f"Before backward", force=False)
- if self.contiguous_gradients:
- self.ipg_buffer = []
- buf_0 = torch.empty(self.reduce_bucket_size,
- dtype=self.dtype,
- device=torch.cuda.current_device())
- self.ipg_buffer.append(buf_0)
- # Use double buffers to avoid data access conflict when overlap_comm is enabled.
- if self.overlap_comm:
- buf_1 = torch.empty(self.reduce_bucket_size,
- dtype=self.dtype,
- device=torch.cuda.current_device())
- self.ipg_buffer.append(buf_1)
- self.ipg_index = 0
- self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
- '''Partitioning Parameters that were not partitioned
- Usually if parameters of modules whose input parameters do not require
- grad computation do not trigger post call and will therefore will remain unpartitioned '''
- self._partition_all_parameters()
- if self.swap_optimizer:
- self.optimizer_swapper.post_backward()
- def _partition_all_parameters(self):
- for name, param in self.module.named_parameters(recurse=True):
- self.param_coordinator.release_and_reset_parameter(param)
- def check_overflow(self, partition_gradients=True):
- self._check_overflow(partition_gradients)
- def _update_scale(self, has_overflow=False):
- self.loss_scaler.update_scale(has_overflow)
- # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
- def _get_state(self):
- return self.optimizer.state
- def _set_state(self, value):
- self.optimizer.state = value
- state = property(_get_state, _set_state)
- # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
- # (for example, to adjust the learning rate)
- def _get_param_groups(self):
- return self.optimizer.param_groups
- def _set_param_groups(self, value):
- self.optimizer.param_groups = value
- param_groups = property(_get_param_groups, _set_param_groups)
- # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
- def _get_loss_scale(self):
- return self.loss_scaler.loss_scale
- def _set_loss_scale(self, value):
- self.loss_scaler.cur_scale = value
- loss_scale = property(_get_loss_scale, _set_loss_scale)
- cur_scale = property(_get_loss_scale, _set_loss_scale)
- def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings):
- # Remove paddings from flattened tensor
- individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors)
- lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)]
- lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)]
- #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}')
- return lean_tensors
- #TODO REVISIT this for stage 3
- def get_lean_optimizer_state(self):
- # Return optimizer states after removing paddings.
- # This method assumes that each param group contains a single flattened tensor.
- optimizer_groups_state = []
- for i, group in enumerate(self.optimizer.param_groups):
- p = group['params'][0]
- lean_state = {}
- for key, value in self.optimizer.state[p].items():
- if torch.is_tensor(value):
- padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]]
- lean_state[key] = self._get_lean_tensors(
- value,
- self.fp16_partitioned_groups[i],
- self.groups_padding[i])
- lean_flat_len = sum([t.numel() for t in lean_state[key]])
- else:
- lean_state[key] = value
- optimizer_groups_state.append(lean_state)
- return optimizer_groups_state
- def get_groups_without_padding(self, groups_with_padding):
- # Return group tensor after removing paddings added for alignment to DP world size.
- groups_without_padding = []
- for i, group in enumerate(groups_with_padding):
- lean_group = self._get_lean_tensors(group,
- self.fp16_partitioned_groups[i],
- self.groups_padding[i])
- groups_without_padding.append(lean_group)
- return groups_without_padding
- def _set_fp32_optimizer_param_groups(self):
- for sub_group_id, _ in enumerate(self.fp16_groups):
- param_group_id = self.sub_group_to_group_id[sub_group_id]
- self.optimizer.param_groups[param_group_id]['params'].append(
- self.fp32_partitioned_groups_flat[sub_group_id])
- def _clear_fp32_optimizer_param_groups(self):
- for param_group in self.optimizer.param_groups:
- param_group['params'] = []
- def _rigid_state_dict(self):
- state_dict = {}
- state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS
- state_dict['loss_scaler'] = self.loss_scaler
- state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
- state_dict['overflow'] = self.overflow
- state_dict['partition_count'] = self.partition_count
- self._set_fp32_optimizer_param_groups()
- state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
- state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat
- self._clear_fp32_optimizer_param_groups()
- return state_dict
- def state_dict(self):
- """
- Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
- This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
- of the contained Pytorch optimizer.
- Example::
- checkpoint = {}
- checkpoint['model'] = model.state_dict()
- checkpoint['optimizer'] = optimizer.state_dict()
- torch.save(checkpoint, "saved.pth")
- """
- if self.elastic_checkpoint:
- raise NotImplementedError(
- "ZeRO-3 does not yet support elastic checkpointing, please disable for now."
- )
- if self.swap_optimizer or self.params_in_nvme_and_cpu:
- raise NotImplementedError(
- "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now."
- )
- return self._rigid_state_dict()
- # Restore base optimizer fp32 weights from checkpoint by:
- # 1) Merging fp32 weights from checkpoints of all partitions
- # 2) Extracting fp32 weights for current partition from merged weights
- # 3) Using extracted weights to update base optimizer weights directly.
- def _restore_from_fp32_weights(self, all_state_dict):
- flat_local_partition = []
- for i in range(len(self.fp32_partitioned_groups_flat)):
- merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict]
- flat_local_partition.append(self._get_flattened_partition(merged_partitions))
- for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition):
- current.data.copy_(saved.data)
- # Restore base optimizer fp32 weights from ZeRO fp16 weights
- def _restore_from_fp16_weights(self):
- for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat):
- fp32_partition.data.copy_(fp16_partitions.data)
- # Refresh the fp32 master params from the fp16 copies.
- def refresh_fp32_params(self):
- self._restore_from_fp16_weights()
- # Extract flattened partition for current rank from all partitions
- def _get_flattened_partition(self, all_partition_states):
- partition_id = dist.get_rank(group=self.dp_process_group)
- alignment = dist.get_world_size(group=self.dp_process_group)
- param_partitions = [[] for _ in range(len(all_partition_states[0]))]
- for i, partition in enumerate(all_partition_states):
- for j, param in enumerate(partition):
- param_partitions[j].append(param)
- local_state_partitions = []
- for param_index, param_slices in enumerate(param_partitions):
- flattened_merged_tensor = self.flatten_dense_tensors_aligned(
- param_slices,
- alignment)
- new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor)
- local_state_partitions.append(new_partitions[partition_id])
- if torch.is_tensor(local_state_partitions[0]):
- return self.flatten_dense_tensors_aligned(local_state_partitions, alignment)
- # Assume non-tensor states are not partitioned and equal across ranks, so return first one
- return local_state_partitions[0]
- # Restore base optimizer state from checkpoint by
- # 1) Merging optimizer state from checkpoints of all partitions
- # 2) Extracting optimizer state for current partition from the merged state
- # 3) Using the extracted value to directly update the base optimizer.
- def _restore_base_optimizer_state(self, all_state_dict):
- base_optimizer_group_states = []
- for i in range(len(self.optimizer.param_groups)):
- partition_states = {}
- all_partition_group_states = [
- sd['base_optimizer_state'][i] for sd in all_state_dict
- ]
- for key in all_partition_group_states[0].keys():
- all_partition_states = [
- all_states[key] for all_states in all_partition_group_states
- ]
- partition_states[key] = self._get_flattened_partition(
- all_partition_states)
- base_optimizer_group_states.append(partition_states)
- for i, group in enumerate(self.optimizer.param_groups):
- p = group['params'][0]
- for key, saved in base_optimizer_group_states[i].items():
- if torch.is_tensor(self.optimizer.state[p][key]):
- self.optimizer.state[p][key].data.copy_(saved.data)
- else:
- self.optimizer.state[p][key] = saved
- def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
- # I think it should actually be ok to reload the optimizer before the model.
- self.loss_scaler = state_dict['loss_scaler']
- self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
- self.overflow = state_dict['overflow']
- if load_optimizer_states:
- self._set_fp32_optimizer_param_groups()
- self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
- self._clear_fp32_optimizer_param_groups()
- # restore fp32 partitions
- for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']):
- curr_param.data.copy_(saved_param.data)
- # restore fp16 partitions from fp32
- for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
- fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
- fp16_param = self.fp16_partitioned_groups_flat[sub_group_id]
- fp16_param.data.copy_(fp32_param.data)
- # update fp16 unflattened params
- for sub_group_id in range(len(self.fp16_partitioned_groups_flat)):
- updated_params = self.unflatten(
- self.fp16_partitioned_groups_flat[sub_group_id],
- self.fp16_partitioned_groups[sub_group_id])
- for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
- partitioned_param.data = q.data
- # TODO: Support different/changing load/save DP degree.
- def load_state_dict(self,
- state_dict_list,
- load_optimizer_states=True,
- load_from_fp32_weights=False):
- r"""Loading a ZeRO checkpoint
- Arguments:
- state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
- Note that the number of saved partitions may differ from number of loading partitions to support
- changing GPU count, specifically DP world size, between saving and loading checkpoints.
- load_optimizer_states: Boolean indicating whether or not to load base optimizer states
- load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
- copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
- """
- """
- Loads a state_dict created by an earlier call to state_dict().
- If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
- whose parameters in turn came from ``model``, it is expected that the user
- will call ``model.load_state_dict()`` before
- ``fp16_optimizer_instance.load_state_dict()`` is called.
- Example::
- model = torch.nn.Linear(D_in, D_out).cuda().half()
- optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
- optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
- ...
- checkpoint = torch.load("saved.pth")
- model.load_state_dict(checkpoint['model'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- """
- if self.elastic_checkpoint:
- raise NotImplementedError(
- "ZeRO-3 does not yet support elastic checkpointing, please disable for now."
- )
- if self.swap_optimizer or self.params_in_nvme_and_cpu:
- raise NotImplementedError(
- "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now."
- )
- self._rigid_load_state_dict(
- state_dict_list[dist.get_rank(group=self.dp_process_group)],
- load_optimizer_states=load_optimizer_states)
- if len(self.persistent_parameters) > 0:
- self.persistent_parameters[0].partition(self.persistent_parameters)
- self.persistent_parameters[0].all_gather(self.persistent_parameters)
- def save_checkpoint_prologue(self):
- self._partition_all_parameters()
- def save_checkpoint_epilogue(self):
- if len(self.persistent_parameters) > 0:
- self.persistent_parameters[0].all_gather(self.persistent_parameters)
- def _handle_overflow(cpu_sum, x, i):
- import math
- rank = torch.distributed.get_rank()
- if rank == 0:
- t_i = -1
- for v_i, v in enumerate(x.data.contiguous().view(-1)):
- if not math.isfinite(float(v)):
- t_i = v_i
- break
- logger.info(
- f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
- )
- def estimate_zero3_model_states_mem_needs(total_params,
- largest_layer_params,
- num_gpus_per_node=1,
- num_nodes=1,
- cpu_offload=True,
- cpu_offload_params=True,
- zero_init=True,
- additional_buffer_factor=1.5):
- total_gpus = num_nodes * num_gpus_per_node
- gpus_factor = 1 / num_nodes
- largest_layer_memory = (4 * largest_layer_params)
- if cpu_offload:
- if cpu_offload_params:
- gpu_mem = largest_layer_memory
- if zero_init:
- cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor
- else:
- cpu_mem = total_params * max(4 * num_gpus_per_node,
- 18 * gpus_factor) * additional_buffer_factor
- else:
- gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus)
- if zero_init:
- cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor
- else:
- cpu_mem = total_params * max(4 * num_gpus_per_node,
- 16 * gpus_factor) * additional_buffer_factor
- else:
- gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus)
- if zero_init:
- cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor
- else:
- cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
- return int(cpu_mem), int(gpu_mem), largest_layer_memory
- def model_to_params(model):
- # shared params calculated only once
- total_params = sum(
- dict((p.data_ptr(),
- p.numel()) for p in model.parameters()).values())
- largest_layer_params = 0
- for m in model.modules():
- # assuming no shared params within a single layer
- layer_params = sum(p.numel() for p in m.parameters(recurse=False))
- largest_layer_params = max(largest_layer_params, layer_params)
- return total_params, largest_layer_params
- import math
- def estimate_zero3_model_states_mem_needs_all_live(model,
- num_gpus_per_node=1,
- num_nodes=1,
- additional_buffer_factor=1.5):
- """
- Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients
- for a given ``model`` and hardware setup.
- If you have an actual model object, use this function and everything will be derived
- automatically.
- If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass
- the ``total_params`` and ``largest_layer_params`` explicitly.
- Args:
- - ``model``: ``nn.Module`` object
- - ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- - ``num_nodes``: how many nodes (defaults to 1),
- - ``additional_buffer_factor``: estimation factor (defaults to 1.5):
- """
- total_params, largest_layer_params = model_to_params(model)
- estimate_zero3_model_states_mem_needs_all_cold(
- total_params=total_params,
- largest_layer_params=largest_layer_params,
- num_gpus_per_node=num_gpus_per_node,
- num_nodes=num_nodes,
- additional_buffer_factor=additional_buffer_factor)
- def estimate_zero3_model_states_mem_needs_all_cold(total_params,
- largest_layer_params,
- num_gpus_per_node=1,
- num_nodes=1,
- additional_buffer_factor=1.5):
- """
- Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients
- for a given ``model`` and hardware setup.
- If it's a hypothetical model, use this function where you have to pass
- the ``total_params`` and ``largest_layer_params`` explicitly.
- If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything
- will be derived automatically.
- Args:
- - ``total_params``: total model params
- - ``largest_layer_params``: largest layer's params
- - ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- - ``num_nodes``: how many nodes (defaults to 1),
- - ``additional_buffer_factor``: estimation factor (defaults to 1.5):
- """
- def format_options(cpu_offload, cpu_offload_params, zero_init):
- enabled = []
- padded_cpu_str = f'{OFFLOAD_CPU_DEVICE:4}'
- param_device = padded_cpu_str if cpu_offload_params else "none"
- enabled.append(f"{OFFLOAD_PARAM}={param_device}")
- optimizer_device = padded_cpu_str if cpu_offload else "none"
- enabled.append(f"{OFFLOAD_OPTIMIZER}={optimizer_device}")
- enabled.append(f"zero_init={1 if zero_init else 0}")
- return ", ".join(enabled)
- nodes_str = "nodes" if num_nodes > 1 else "node"
- gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
- print(
- "Estimated memory needed for params, optim states and gradients for a:\n"
- f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
- f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params."
- )
- print(" per CPU | per GPU | Options")
- for cpu_offload in [True, False]:
- for cpu_offload_params in [True, False]:
- if not cpu_offload and cpu_offload_params:
- continue
- for zero_init in [True, False]:
- cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs(
- total_params=total_params,
- largest_layer_params=largest_layer_params,
- num_gpus_per_node=num_gpus_per_node,
- num_nodes=num_nodes,
- cpu_offload=cpu_offload,
- cpu_offload_params=cpu_offload_params,
- zero_init=zero_init,
- additional_buffer_factor=additional_buffer_factor
- )
- options_str = format_options(cpu_offload=cpu_offload,
- cpu_offload_params=cpu_offload_params,
- zero_init=zero_init)
- print(
- f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}")
|