stage3.py 145 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411
  1. """
  2. "Copyright 2020 The Microsoft DeepSpeed Team.
  3. Licensed under the MIT license.
  4. """
  5. import sys
  6. import os
  7. from collections import defaultdict, OrderedDict
  8. import itertools
  9. import torch
  10. from torch.distributed.distributed_c10d import _get_global_rank
  11. import torch.distributed as dist
  12. import math
  13. from torch._six import inf
  14. from torch.autograd import Variable
  15. from deepspeed.utils.logging import logger
  16. from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
  17. from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim
  18. from deepspeed.runtime.zero.partition_parameters import *
  19. from deepspeed.runtime.zero.partition_parameters import _init_external_params
  20. from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS
  21. from deepspeed.ops.adam import DeepSpeedCPUAdam
  22. from deepspeed.ops.op_builder import UtilsBuilder
  23. from deepspeed.runtime.zero.offload_constants import *
  24. from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
  25. from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper
  26. from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper
  27. from deepspeed.runtime.constants import OPTIMIZER_STATE_DICT
  28. # Toggle this to true to enable correctness test
  29. # with gradient partitioning and without
  30. pg_correctness_test = False
  31. FWD_MODULE_STACK = list()
  32. from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id, debug_param2name_id_numel, debug_param2name_id_shape_device, debug_module2name_class, printflock, log_rank_file
  33. def print_rank_0(message, debug=False, force=False):
  34. rank = torch.distributed.get_rank()
  35. if rank == 0 and (debug or force):
  36. print(message)
  37. # other variations
  38. # - print for all ranks w/o interleaving
  39. # printflock(f"[{rank}] {message}")
  40. # - print to log file per rank
  41. # log_rank_file(rank, message)
  42. def input(msg):
  43. return
  44. def split_half_float_double(tensors):
  45. dtypes = [
  46. "torch.cuda.HalfTensor",
  47. "torch.cuda.FloatTensor",
  48. "torch.cuda.DoubleTensor"
  49. ]
  50. buckets = []
  51. for i, dtype in enumerate(dtypes):
  52. bucket = [t for t in tensors if t.type() == dtype]
  53. if bucket:
  54. buckets.append(bucket)
  55. return buckets
  56. def isclose(a, b, rtol=1e-09, atol=0.0):
  57. return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
  58. def lcm(x, y):
  59. from fractions import gcd # or can import gcd from `math` in Python 3
  60. return x * y // gcd(x, y)
  61. def move_to_cpu(tensor_list):
  62. for tensor in tensor_list:
  63. tensor.data = tensor.data.cpu()
  64. def get_all_parameters(sub_module, recurse=False):
  65. return itertools.chain(sub_module.named_parameters(recurse=recurse),
  66. sub_module.ds_external_parameters())
  67. #apply torch.autograd.Function that calls a backward_function to tensors in output
  68. def _apply_to_tensors_only(module, functional, backward_function, outputs):
  69. if type(outputs) is tuple:
  70. touched_outputs = []
  71. for output in outputs:
  72. touched_output = _apply_to_tensors_only(module,
  73. functional,
  74. backward_function,
  75. output)
  76. touched_outputs.append(touched_output)
  77. return tuple(touched_outputs)
  78. elif type(outputs) is torch.Tensor:
  79. return functional.apply(module, backward_function, outputs)
  80. else:
  81. return outputs
  82. #for each tensor in outputs run the forward_function and register backward_function as hook
  83. def _apply_forward_and_backward_to_tensors_only(module,
  84. forward_function,
  85. backward_function,
  86. outputs):
  87. if type(outputs) is tuple:
  88. touched_outputs = []
  89. for output in outputs:
  90. touched_output = _apply_forward_and_backward_to_tensors_only(
  91. module,
  92. forward_function,
  93. backward_function,
  94. output)
  95. touched_outputs.append(touched_output)
  96. return tuple(touched_outputs)
  97. elif type(outputs) is torch.Tensor:
  98. forward_function(outputs)
  99. if outputs.requires_grad:
  100. outputs.register_hook(backward_function)
  101. return outputs
  102. else:
  103. return outputs
  104. class ZeROOrderedDict(OrderedDict):
  105. def __init__(self, parent_module, *args, **kwargs):
  106. """A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
  107. Args:
  108. parent_module (``collections.OrderedDict``): the collection to replace
  109. """
  110. super().__init__(*args, **kwargs)
  111. self._parent_module = parent_module
  112. self._in_forward = False
  113. def __getitem__(self, key):
  114. param = super().__getitem__(key)
  115. # Params can be registered as None (e.g., bias)
  116. if param is None:
  117. return param
  118. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  119. if self._parent_module._parameters._in_forward:
  120. print_rank_0(f'Registering external parameter from getter {key}',
  121. force=False)
  122. register_external_parameter(FWD_MODULE_STACK[-1], param)
  123. param.all_gather()
  124. return param
  125. def _inject_parameters(module, cls):
  126. for module in module.modules():
  127. if cls == ZeROOrderedDict:
  128. new_param = cls(parent_module=module)
  129. else:
  130. new_param = cls()
  131. for key, param in module._parameters.items():
  132. new_param[key] = param
  133. module._parameters = new_param
  134. # TODO Needs to be implemented
  135. class PrefetchCoordinator(object):
  136. def __init__(self):
  137. # step_id keeps track of the number of sub-modules invoked so far
  138. # the step_id is tracking forward and backward sequence of sub-modules
  139. self.step_id = 0
  140. # stores the sequence of sub modules in forward+backward pass
  141. self.sub_module_trace = []
  142. # maps sub_module id to submodule objects
  143. self.id_to_sub_module_map = {}
  144. # stores the total number of parameters in each sub_module
  145. self.id_to_sub_module_size_map = {}
  146. self.trace_completed = False
  147. self.most_recent_sub_module_step = {}
  148. # reuse distances
  149. self.reuse_numel_for_step_id = {}
  150. def record_trace(self, sub_module):
  151. if not self.trace_completed:
  152. self.sub_module_trace.append(sub_module.id)
  153. self.id_to_sub_module_map[sub_module.id] = sub_module
  154. def print_trace(self):
  155. print_rank_0(
  156. f"The module trace is : {[self.id_to_sub_module_map[module_id].id for module_id in self.sub_module_trace]}"
  157. )
  158. def increment_step(self, sub_module):
  159. self.most_recent_sub_module_step[sub_module.id] = self.step_id
  160. self.step_id += 1
  161. def reset_step(self):
  162. self.step_id = 0
  163. # returns the next numel parameters that will be used next but are not available or inflight
  164. def get_params_to_prefetch(self, sub_module, numel=2000000):
  165. # numel_in_sub_module = 0
  166. # for name, param in sub_module.named_parameters(recurse=False):
  167. # numel_in_sub_module += param.ds_numel
  168. # #if numel_in_sub_module < (numel // 2):
  169. # return []
  170. # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing
  171. if sub_module.id != self.sub_module_trace[self.step_id]:
  172. print_rank_0(
  173. f"Tracing failed. Prefetching is disabled at sub-module: {debug_module2name_id(sub_module)}"
  174. )
  175. return []
  176. params_to_prefetch = []
  177. total_numel_to_prefetch = 0
  178. for i in range(self.step_id, len(self.sub_module_trace)):
  179. module_id = self.sub_module_trace[i]
  180. for _, param in get_all_parameters(self.id_to_sub_module_map[module_id]):
  181. if param.ds_status is ZeroParamStatus.NOT_AVAILABLE and (
  182. param.ds_id not in [p.ds_id for p in params_to_prefetch]):
  183. params_to_prefetch.append(param)
  184. total_numel_to_prefetch += param.ds_numel
  185. #print_rank_0(f"Total numel to prefetch: {total_numel_to_prefetch}. Param: {param.ds_shape} and numel {param.ds_numel}, numel limit {numel}")
  186. if total_numel_to_prefetch >= numel: # and total_numel_to_prefetch > (numel_in_sub_module // 2):
  187. return params_to_prefetch
  188. return params_to_prefetch
  189. # checks if this sub_module will be used again and if so then returns the number of elements
  190. # in the parameters used between this sub_module and the reuse of this sub_module
  191. def get_reuse_distance_in_numel(self, sub_module, sub_module_step_id=None):
  192. #assert is_forward is not None, "is_forward must be set to True for Forward Propagation and False for backward Propagation"
  193. is_there_reuse = False
  194. reuse_distance_in_numel = 1000000000000
  195. # set the appropriate trace
  196. trace = self.sub_module_trace
  197. total_steps = len(trace)
  198. if sub_module_step_id is None:
  199. sub_module_step_id = self.most_recent_sub_module_step[sub_module.id]
  200. # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing
  201. if sub_module.id != trace[sub_module_step_id]:
  202. print_rank_0(
  203. f"Tracing failed. Cannot tell if the sub_module: {sub_module.id} is reused"
  204. )
  205. return reuse_distance_in_numel
  206. # return cached value
  207. if sub_module_step_id in self.reuse_numel_for_step_id:
  208. return self.reuse_numel_for_step_id[sub_module_step_id]
  209. start_step = self.step_id
  210. print_rank_0(f"Step id is {self.step_id} ")
  211. for step_id in range(start_step, total_steps):
  212. print_rank_0(f"Trace id {trace[step_id]} and sub_module id {sub_module.id}")
  213. if sub_module.id == trace[step_id]:
  214. end_step = step_id
  215. is_there_reuse = True
  216. reuse_distance_in_numel = self._distance_in_numel(
  217. start_step,
  218. end_step,
  219. trace)
  220. break
  221. self.reuse_numel_for_step_id[sub_module_step_id] = reuse_distance_in_numel
  222. return reuse_distance_in_numel
  223. def _distance_in_numel(self, start_step, end_step, trace):
  224. distance_in_numel = 0
  225. for step_id in range(start_step, end_step):
  226. module_id = trace[step_id]
  227. for _, param in self.id_to_sub_module_map[module_id].named_parameters(recurse=False):
  228. distance_in_numel += param.ds_numel
  229. for _, param in self.id_to_sub_module_map[module_id].ds_external_parameters():
  230. distance_in_numel += param.ds_numel
  231. return distance_in_numel
  232. class PartitionedParameterCoordinator(object):
  233. def __init__(self,
  234. comm_stream=None,
  235. max_reuse_distance_in_numel=500000000,
  236. max_available_parameters_in_numel=700000000):
  237. self.in_flight_handles = []
  238. self.params_in_flight = []
  239. self.comm_stream = comm_stream if comm_stream is not None else torch.cuda.current_stream(
  240. )
  241. self.prefetch_coordinator = PrefetchCoordinator()
  242. self.hierarchy = 0
  243. self.total_available_parameter_numel = 0
  244. self.max_available_parameters_in_numel = max_available_parameters_in_numel
  245. # max distance between two use of the module beyond which module is released
  246. self.max_reuse_distance_in_numel = max_reuse_distance_in_numel
  247. def _increment_available_parameter_numel(self, increment):
  248. self.total_available_parameter_numel += increment
  249. def _decrement_available_parameter_numel(self, decrement):
  250. self.total_available_parameter_numel -= decrement
  251. '''-----------------------Tracing and Prefetching ---------------'''
  252. def record_trace(self, sub_module):
  253. self.prefetch_coordinator.record_trace(sub_module)
  254. def finish_tracing(self, print_trace=False):
  255. self.prefetch_coordinator.trace_completed = True
  256. if print_trace:
  257. self.prefetch_coordinator.print_trace()
  258. #swap in parameter partitions from nvme for those parameters that will be used
  259. # after the ones that are already being prefetched into full parameters
  260. def _prefetch_nvme_param_partitions(self, sub_module, params_in_flight):
  261. numel_in_flight = sum([param.ds_tensor.ds_numel for param in params_in_flight])
  262. upcoming_param_list = self.prefetch_coordinator.get_params_to_prefetch(
  263. sub_module,
  264. numel=2 * numel_in_flight)
  265. swap_in_params = []
  266. for param in upcoming_param_list:
  267. if len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers():
  268. break
  269. if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
  270. swap_in_params.append(param)
  271. if len(swap_in_params) > 0:
  272. swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True)
  273. # Pre fetches the parameters for sub_modules that comes after
  274. # the current sub_module. This call is asynchronous
  275. def prefetch_next_sub_modules(self, sub_module, numel=5000000, nvme=False):
  276. params_to_prefetch = []
  277. if not self.prefetch_coordinator.trace_completed:
  278. return params_to_prefetch
  279. # prefetch if there is no current prefetching in flight
  280. if not self.in_flight_handles and self.total_available_parameter_numel < self.max_available_parameters_in_numel:
  281. params_to_prefetch = self.prefetch_coordinator.get_params_to_prefetch(
  282. sub_module,
  283. numel=numel)
  284. self._all_gather(params_to_prefetch, async_op=True)
  285. for param in params_to_prefetch:
  286. param.ds_status = ZeroParamStatus.INFLIGHT
  287. # keeping track of number of elements consumed by available parameters
  288. self._increment_available_parameter_numel(param.ds_numel)
  289. if nvme:
  290. self._prefetch_nvme_param_partitions(sub_module, params_to_prefetch)
  291. self._print_prefetch_elements_info(sub_module, params_to_prefetch)
  292. print_rank_0(
  293. f"{'--' * self.hierarchy}--PreFetching parameters {[param.ds_id for param in params_to_prefetch]} and available {self.total_available_parameter_numel}, max limit {self.max_available_parameters_in_numel}",
  294. force=False)
  295. def _print_prefetch_elements_info(self, sub_module, params_to_prefetch):
  296. sub_module_numel = 0.0
  297. for name, param in sub_module.named_parameters(recurse=False):
  298. sub_module_numel += param.ds_numel
  299. numel_being_prefetched = 0
  300. for param in params_to_prefetch:
  301. numel_being_prefetched = param.ds_numel
  302. print_rank_0(
  303. f"{'--' * self.hierarchy}--PreFetching {numel_being_prefetched} numels and number of numel in the next sub module is {sub_module_numel}",
  304. force=False)
  305. def increment_step(self, sub_module):
  306. self.prefetch_coordinator.increment_step(sub_module)
  307. def reset_step(self):
  308. self.prefetch_coordinator.reset_step()
  309. '''----------------------------------------------------------------------'''
  310. # Fetches the parameters in the sub_module
  311. # This call is blocking
  312. def fetch_sub_module(self, sub_module):
  313. partitioned_params = []
  314. params_in_flight = False
  315. print_rank_0(
  316. f"{'--' * self.hierarchy}Fetching params in module {debug_module2name_class(sub_module)}"
  317. )
  318. params_to_fetch = [
  319. param for _,
  320. param in sub_module.named_parameters(recurse=False)
  321. ]
  322. # print([n for n,p in sub_module.named_parameters(recurse=False)])
  323. if hasattr(sub_module, 'ds_external_parameters'):
  324. print_rank_0(
  325. f"{'--' * self.hierarchy}--Fetching external parameters {sub_module.ds_external_parameters()}"
  326. )
  327. params_to_fetch += [
  328. param for _,
  329. param in sub_module.ds_external_parameters()
  330. ]
  331. # for _, param in sub_module.named_parameters(recurse=False):
  332. for param in params_to_fetch:
  333. param.ds_active_sub_modules += 1
  334. print_rank_0(
  335. f"{'--' * self.hierarchy}--Fetching parameters {debug_param2name_id_shape(param)} with active sub modules {param.ds_active_sub_modules}"
  336. )
  337. if param.ds_status == ZeroParamStatus.AVAILABLE:
  338. print_rank_0(
  339. f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is already available"
  340. )
  341. if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
  342. print_rank_0(
  343. f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is being fetched"
  344. )
  345. partitioned_params.append(param)
  346. # keeping track of number of elements consumed by available parameters
  347. self._increment_available_parameter_numel(param.ds_numel)
  348. print_rank_0(f"Incrementing with parameter id {param.ds_id}")
  349. if param.ds_status == ZeroParamStatus.INFLIGHT:
  350. params_in_flight = True
  351. print_rank_0(
  352. f"{'--' * self.hierarchy}--Parameters {debug_param2name_id(param)} is already in flight (prefetched)"
  353. )
  354. self.hierarchy += 1
  355. # parameters are partitioned and need to be allgathered
  356. self._all_gather(partitioned_params, async_op=False)
  357. # parameters are inflight and communication needs to be completed
  358. if partitioned_params or params_in_flight:
  359. self._synchronize_communication()
  360. for _, param in sub_module.named_parameters(recurse=False):
  361. param.ds_status = ZeroParamStatus.AVAILABLE
  362. print_rank_0(
  363. f"Param {debug_param2name_id_shape_device(param)} norm={param.norm()}",
  364. force=False)
  365. #print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}")
  366. def release_sub_module(self, sub_module):
  367. self.hierarchy -= 1
  368. print_rank_0(
  369. f"{'--' * self.hierarchy}Releasing params in module {debug_module2name_class(sub_module)}"
  370. )
  371. params_to_release = [
  372. param for _,
  373. param in sub_module.named_parameters(recurse=False)
  374. ]
  375. if hasattr(sub_module, 'ds_external_parameters'):
  376. #print_rank_0(f"Releasing external parameters {sub_module.ds_external_parameters()}")
  377. params_to_release += [
  378. param for _,
  379. param in sub_module.ds_external_parameters()
  380. ]
  381. # for _, param in sub_module.named_parameters(recurse=False):
  382. for param in params_to_release:
  383. param.ds_active_sub_modules -= 1
  384. if not param.ds_active_sub_modules and not self._keep_for_later(
  385. sub_module) and not param.ds_persist:
  386. print_rank_0(
  387. f"{'--' * self.hierarchy}--Releasing parameter {debug_param2name_id_numel(param)} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}",
  388. force=False)
  389. # Keeping track of number of elements that are consumed by available parameters
  390. self._decrement_available_parameter_numel(param.ds_numel)
  391. see_memory_usage(
  392. f"Before releasing param {debug_param2name_id_numel(param)}",
  393. force=False)
  394. param.partition(hierarchy=self.hierarchy)
  395. see_memory_usage(
  396. f"After releasing param {debug_param2name_id_numel(param)}",
  397. force=False)
  398. param.ds_status = ZeroParamStatus.NOT_AVAILABLE
  399. else:
  400. print_rank_0(
  401. f"{'--' * self.hierarchy}--Did not release param {debug_param2name_id_numel(param)} with active sub modules {param.ds_active_sub_modules}, keep for later={self._keep_for_later(sub_module)} and persistence={param.ds_persist}",
  402. force=False)
  403. def release_and_reset_parameter(self, param):
  404. param.ds_active_sub_modules = 0
  405. if param.ds_status == ZeroParamStatus.AVAILABLE:
  406. print_rank_0(
  407. f"Releasing unpartitioned param {debug_param2name_id_numel(param)} active sub-modules {param.ds_active_sub_modules} and persistence {param.ds_persist}"
  408. )
  409. self._decrement_available_parameter_numel(param.ds_numel)
  410. param.partition()
  411. def _keep_for_later(self, sub_module):
  412. if not self.prefetch_coordinator.trace_completed:
  413. return False
  414. if self.max_reuse_distance_in_numel == 0:
  415. return False
  416. reuse_distance_in_numel = self.prefetch_coordinator.get_reuse_distance_in_numel(
  417. sub_module)
  418. #print_rank_0(f"Reuse distance and numel for sub_module id {sub_module.id} is {reuse_distance_in_numel}")
  419. return reuse_distance_in_numel < self.max_reuse_distance_in_numel
  420. def _all_gather(self, partitioned_params, async_op=False):
  421. with torch.cuda.stream(self.comm_stream):
  422. handles = partitioned_params[0].all_gather(
  423. param_list=partitioned_params,
  424. async_op=async_op,
  425. hierarchy=self.hierarchy) if partitioned_params else None
  426. if handles is not None:
  427. self.in_flight_handles.extend(handles)
  428. self.params_in_flight.extend(partitioned_params)
  429. def _synchronize_communication(self, synchronize_streams=True):
  430. assert len(self.params_in_flight) == len(self.in_flight_handles)
  431. for handle, param in zip(self.in_flight_handles, self.params_in_flight):
  432. if handle is not None:
  433. with torch.cuda.stream(self.comm_stream):
  434. handle.wait()
  435. param.ds_status = ZeroParamStatus.AVAILABLE
  436. self.comm_stream.synchronize()
  437. torch.cuda.synchronize() if synchronize_streams else None
  438. self.in_flight_handles = []
  439. self.params_in_flight = []
  440. class PreBackwardFunction(torch.autograd.Function):
  441. @staticmethod
  442. def forward(ctx, module, pre_backward_function, outputs):
  443. ctx.module = module
  444. ctx.pre_backward_function = pre_backward_function
  445. if not hasattr(module, "applied_pre_backward_ref_cnt"):
  446. module.applied_pre_backward_ref_cnt = 0
  447. module.applied_pre_backward_ref_cnt += 1
  448. #print(f"After Forward: {ctx.module.__class__.__name__}")
  449. outputs = outputs.detach()
  450. return outputs
  451. @staticmethod
  452. def backward(ctx, *args):
  453. #print(f"Before Backward: {ctx.module.__class__.__name__}")
  454. ctx.pre_backward_function(ctx.module)
  455. return (None, None) + args
  456. class PostBackwardFunction(torch.autograd.Function):
  457. @staticmethod
  458. def forward(ctx, module, pre_backward_function, output):
  459. ctx.module = module
  460. if output.requires_grad:
  461. #TODO SOME TIMES post backward does not seem to be triggered debug in detail
  462. #Should only cause increase in memory not correctness issue
  463. #if output.grad_fn.__class__.__name__ == 'ViewBackward':
  464. # ctx.view=True
  465. # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
  466. #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
  467. #if module.ds_grads_remaining == 0:
  468. # print(f"Before Forward: {ctx.module.__class__.__name__}")
  469. module.ds_grads_remaining += 1
  470. ctx.pre_backward_function = pre_backward_function
  471. output = output.detach()
  472. return output
  473. @staticmethod
  474. def backward(ctx, *args):
  475. ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
  476. if ctx.module.ds_grads_remaining == 0:
  477. ctx.pre_backward_function(ctx.module)
  478. #print(f"After Backward: {ctx.module.__class__.__name__}")
  479. return (None, None) + args
  480. INITIAL_MICRO_STEP_ID = -1
  481. class DeepSpeedZeroOptimizer_Stage3(object):
  482. """
  483. DeepSpeedZeroOptimizer designed to reduce the memory footprint
  484. required for training large deep learning models.
  485. For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
  486. https://arxiv.org/abs/1910.02054
  487. For usage examples, refer to TODO: DeepSpeed Tutorial
  488. """
  489. def __init__(self,
  490. module,
  491. init_optimizer,
  492. timers,
  493. ds_config,
  494. static_loss_scale=1.0,
  495. dynamic_loss_scale=False,
  496. dynamic_loss_args=None,
  497. verbose=True,
  498. contiguous_gradients=True,
  499. reduce_bucket_size=500000000,
  500. prefetch_bucket_size=50000000,
  501. max_reuse_distance=1000000000,
  502. max_live_parameters=1000000000,
  503. param_persistence_threshold=100000,
  504. dp_process_group=None,
  505. reduce_scatter=True,
  506. overlap_comm=False,
  507. offload_optimizer_config=None,
  508. offload_param_config=None,
  509. sub_group_size=1000000000000,
  510. mpu=None,
  511. clip_grad=0.0,
  512. communication_data_type=torch.float16,
  513. postscale_gradients=True,
  514. gradient_predivide_factor=1.0,
  515. gradient_accumulation_steps=1,
  516. elastic_checkpoint=False,
  517. aio_config=None):
  518. see_memory_usage("Stage 3 initialize beginning", force=False)
  519. if dist.get_rank() == 0:
  520. logger.info(f"Reduce bucket size {reduce_bucket_size}")
  521. logger.info(f"Allgather bucket size {prefetch_bucket_size}")
  522. # The fused optimizer does all the work. We need this layer for two reason:
  523. # 1. maintain same user API from apex.fp16_utils
  524. # 2. keep common stuff here in case we need to add ne552w fused optimizer later
  525. # differences from apex.fp16_utils:
  526. # - assume all model params in fp16
  527. # - assume all params requires grad
  528. # - flat by groups, not keeping state. TODO: remove state explicitly?
  529. # - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
  530. if not torch.cuda.is_available:
  531. raise SystemError("Cannot use fp16 without CUDA.")
  532. self.optimizer = init_optimizer
  533. # Load pre-built or JIT compile (un)flatten ops
  534. util_ops = UtilsBuilder().load()
  535. self.flatten = util_ops.flatten
  536. self.unflatten = util_ops.unflatten
  537. self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
  538. self._global_grad_norm = 0.
  539. self.optimizer_swapper = None
  540. self.swap_optimizer = False
  541. self.offload_optimizer = False
  542. self.offload_optimizer_pin_memory = False
  543. self.offload_optimizer_fast_init = False
  544. self.offload_param = False
  545. self.offload_param_pin_memory = False
  546. self.params_in_nvme_and_cpu = False
  547. self.max_params_in_cpu = 0
  548. self._configure_offloading(offload_optimizer_config, offload_param_config)
  549. self._convert_to_zero_parameters(ds_config, module, mpu)
  550. for m in module.modules():
  551. _init_external_params(m)
  552. self.module = module
  553. self.elastic_checkpoint = elastic_checkpoint
  554. self.overlap_comm = overlap_comm
  555. # Replace ._parameters with a new class to enable auto-registration of
  556. # external parameters
  557. _inject_parameters(module, ZeROOrderedDict)
  558. if self.overlap_comm:
  559. self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda()
  560. self.deepspeed_adam_offload = (self.offload_optimizer
  561. and type(init_optimizer) == DeepSpeedCPUAdam)
  562. self.device = torch.cuda.current_device(
  563. ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE
  564. ############################################################################
  565. see_memory_usage("Before Partitioned Parameter Coordinator", force=False)
  566. fetch_stream = torch.cuda.Stream() if self.overlap_comm else None
  567. self.param_coordinator = PartitionedParameterCoordinator(
  568. comm_stream=fetch_stream,
  569. max_reuse_distance_in_numel=int(max_reuse_distance),
  570. max_available_parameters_in_numel=int(max_live_parameters))
  571. see_memory_usage("After Partitioned Parameter Coordinator", force=False)
  572. #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream())
  573. #-------------Stage 3 Setup-------------------#
  574. # parameters smaller than the threshold will be collectively gathered at the
  575. # end of the optimizer step and will be kept till the end of the backward pass
  576. # TODO maybe worth just replicating these parameters and doing all reduce for them
  577. self.persistence_threshold = int(param_persistence_threshold)
  578. self.persistent_parameters = self.persistent_parameters()
  579. self.setup_zero_stage3_hooks()
  580. #resetting ds_tensor just in case parameters have been changed after initialization
  581. #example .half() or .to()
  582. #self.reset_ds_tensor()
  583. #---------------------------------------------#
  584. self.timers = timers
  585. self.reduce_scatter = reduce_scatter
  586. self.dp_process_group = dp_process_group
  587. self.partition_count = dist.get_world_size(group=self.dp_process_group)
  588. if mpu is None:
  589. self.model_parallel_group = None
  590. self.model_parallel_rank = 0
  591. else:
  592. self.model_parallel_group = mpu.get_model_parallel_group()
  593. self.model_parallel_rank = mpu.get_model_parallel_rank()
  594. self.overflow = False
  595. self.clip_grad = clip_grad
  596. self.communication_data_type = communication_data_type
  597. self.gradient_predivide_factor = gradient_predivide_factor
  598. self.postscale_gradients = postscale_gradients
  599. self.gradient_accumulation_steps = gradient_accumulation_steps
  600. self.micro_step_id = INITIAL_MICRO_STEP_ID
  601. if self.reduce_scatter:
  602. assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-3 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
  603. assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
  604. assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
  605. # Holds the mode parameter
  606. # The param.data may not hold any meaningful data
  607. # when param's status is NOT_AVAILABLE or IN_FLGHT
  608. self.fp16_groups = []
  609. # Hold partitioned parameters
  610. self.fp16_partitioned_groups = []
  611. # Holds a fused and flattened copy of the parameters
  612. self.fp16_partitioned_groups_flat = []
  613. self.fp16_partitioned_groups_flat_numel = []
  614. #defragmented pinned memory
  615. self.param_groups_fp16_flat_cpu_memory = []
  616. #a single 32-bit partition of the parallel partitioned parameters
  617. #that this process will update
  618. self.fp32_partitioned_groups_flat = []
  619. self.next_swappable_fp32_partitioned_groups = []
  620. # number of elements per partition in each group
  621. self.partition_size = []
  622. self.all_reduce_print = False
  623. self.prefetch_elements = int(prefetch_bucket_size)
  624. # padding on each partition for alignment purposes
  625. self.groups_padding = []
  626. self.sub_group_size = sub_group_size
  627. self.sub_group_to_group_id = {}
  628. see_memory_usage("Before creating fp16 partitions", force=False)
  629. self._create_fp16_partitions_with_defragmentation()
  630. num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
  631. see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}",
  632. force=False)
  633. # Optimizer tensor swapping
  634. if self.swap_optimizer:
  635. self._configure_tensor_swapping(offload_optimizer_config, aio_config)
  636. see_memory_usage("Before creating fp32 partitions", force=False)
  637. if not isinstance(self.optimizer, DummyOptim):
  638. self._create_fp32_partitions()
  639. see_memory_usage("After creating fp32 partitions", force=False)
  640. dist.barrier()
  641. # To support pipelined optimizer swapping
  642. if not isinstance(init_optimizer, DummyOptim):
  643. self._create_next_swappable_fp32_groups()
  644. see_memory_usage("Before initializing optimizer states", force=False)
  645. if not isinstance(init_optimizer, DummyOptim):
  646. self.initialize_optimizer_states()
  647. see_memory_usage("After initializing optimizer states", force=False)
  648. dist.barrier()
  649. if dist.get_rank() == 0:
  650. logger.info(f"optimizer state initialized")
  651. self.reduce_bucket_size = int(reduce_bucket_size)
  652. self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
  653. self.reduction_stream = torch.cuda.Stream(
  654. ) if self.overlap_comm else torch.cuda.current_stream()
  655. self.callback_queued = False
  656. self.copy_grad_stream = torch.cuda.Stream()
  657. self.param_dict = {}
  658. # map between param_id and bool to specify if a param is in this partition
  659. self.is_param_in_current_partition = {}
  660. self.contiguous_gradients = contiguous_gradients
  661. self.extra_large_param_to_reduce = None
  662. self.grads_in_ipg_bucket = []
  663. self.params_in_ipg_bucket = []
  664. self.elements_in_ipg_bucket = 0
  665. self.params_already_reduced = []
  666. self.is_gradient_accumulation_boundary = True
  667. self._release_ipg_buffers()
  668. self.previous_reduced_grads = None
  669. # simplified param id
  670. self.param_id = {}
  671. count = 0
  672. for i, params_group in enumerate(self.fp16_groups):
  673. for param in params_group:
  674. unique_id = id(param)
  675. self.param_id[unique_id] = count
  676. self.param_dict[count] = param
  677. self.params_already_reduced.append(False)
  678. count = count + 1
  679. #Largest partitioned param
  680. largest_partitioned_param_numel = max([
  681. max([
  682. max(tensor.numel(),
  683. tensor.ds_numel) for tensor in fp16_partitioned_group
  684. ]) for fp16_partitioned_group in self.fp16_partitioned_groups
  685. ])
  686. print_rank_0(
  687. f'Largest partitioned param numel = {largest_partitioned_param_numel}',
  688. force=False)
  689. see_memory_usage(f"Before Set Grad positions", force=False)
  690. self.grad_position = {}
  691. self.set_grad_positions()
  692. see_memory_usage(f"Before CPU Offload initialization", force=False)
  693. self.grads_in_partition = None
  694. if self.offload_optimizer:
  695. self.accumulated_grads_in_cpu = {}
  696. self.norm_for_param_grads = {}
  697. self.local_overflow = False
  698. self.temp_grad_buffer_for_gpu_offload = torch.zeros(
  699. largest_partitioned_param_numel,
  700. device=torch.cuda.current_device(),
  701. dtype=self.dtype)
  702. self.temp_grad_gpu_buffer = torch.zeros(largest_partitioned_param_numel,
  703. device=torch.cuda.current_device(),
  704. dtype=self.dtype)
  705. see_memory_usage(f"After CPU Offload initialization", force=False)
  706. # stores if a partition has been reduced in this step
  707. self.is_partition_reduced = {}
  708. # stores if a grad in a partition has been computed or not
  709. self.is_grad_computed = {}
  710. # will store the averaged gradients required by this paritition
  711. self.averaged_gradients = {}
  712. #creates backward hooks for gradient partitioning
  713. self.create_reduce_and_remove_grad_hooks()
  714. #exit(0)
  715. # we may have a way of fusing dynamic scale. Do not support for now
  716. if self.dtype == torch.float or not dynamic_loss_scale:
  717. loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale
  718. self.dynamic_loss_scale = False
  719. self.loss_scaler = LossScaler(scale=loss_scale_value)
  720. cur_iter = 0
  721. else:
  722. if dynamic_loss_args is None:
  723. self.loss_scaler = DynamicLossScaler()
  724. else:
  725. self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
  726. self.dynamic_loss_scale = True
  727. self.debug_fp16_grads = [{} for _ in self.fp16_groups]
  728. if dist.get_rank(group=self.dp_process_group) == 0:
  729. see_memory_usage(f"After initializing ZeRO optimizer", force=False)
  730. def _configure_offloading(self, offload_optimizer_config, offload_param_config):
  731. ###################### offload optimizer setup ##################################
  732. if offload_optimizer_config is not None:
  733. self.offload_optimizer = True
  734. self.offload_optimizer_pin_memory = offload_optimizer_config[
  735. OFFLOAD_OPTIMIZER_PIN_MEMORY]
  736. self.swap_optimizer = offload_optimizer_config[
  737. OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE
  738. self.offload_optimizer_fast_init = offload_optimizer_config[
  739. OFFLOAD_OPTIMIZER_FAST_INIT]
  740. ###################### offload param setup ##################################
  741. if offload_param_config is not None:
  742. if not isinstance(self.optimizer, DummyOptim):
  743. assert self.offload_optimizer, "parameter offload is only available with optimizer state offload"
  744. self.offload_param = True
  745. self.offload_param_pin_memory = offload_param_config[
  746. OFFLOAD_PARAM_PIN_MEMORY]
  747. self.params_in_nvme_and_cpu = offload_param_config[
  748. OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE
  749. self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU]
  750. print_rank_0(
  751. f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}",
  752. force=False)
  753. def _convert_to_zero_parameters(self, ds_config, module, mpu):
  754. non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
  755. if non_zero_params:
  756. zero_params = [p for p in module.parameters() if is_zero_param(p)]
  757. if zero_params:
  758. zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
  759. else:
  760. group = None
  761. if mpu:
  762. group = mpu.get_data_parallel_group()
  763. if self.params_in_nvme_and_cpu:
  764. remote_device = OFFLOAD_NVME_DEVICE
  765. elif self.offload_param:
  766. remote_device = OFFLOAD_CPU_DEVICE
  767. else:
  768. remote_device = None
  769. Init(module=module,
  770. data_parallel_group=group,
  771. dtype=self.dtype,
  772. config_dict_or_path=ds_config,
  773. remote_device=remote_device,
  774. pin_memory=self.offload_param_pin_memory,
  775. mpu=mpu)
  776. def _configure_tensor_swapping(self, offload_optimizer_config, aio_config):
  777. nvme_swap_folder = os.path.join(
  778. offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH],
  779. 'zero_stage_3')
  780. os.makedirs(nvme_swap_folder, exist_ok=True)
  781. if torch.distributed.get_rank() == 0:
  782. logger.info(f'Tensor Swapping: Adding optimizer tensors')
  783. swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[
  784. OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper
  785. self.optimizer_swapper = swapper_type(
  786. swap_config=offload_optimizer_config,
  787. aio_config=aio_config,
  788. base_folder=nvme_swap_folder,
  789. optimizer=self.optimizer,
  790. largest_numel=max(self.fp16_partitioned_groups_flat_numel),
  791. device=self.device,
  792. dtype=torch.float32,
  793. timers=self.timers)
  794. def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False):
  795. '''If flat buffer is None then the parameters in the param_list are
  796. not copied to the flat buffer. This is because they excede the number of max_params_in_cpu
  797. Some of these parameters may aready be in CPU in unflattened buffers
  798. or they maybe in GPU, or they maybe in NVME. If they are in NVME, then
  799. they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are
  800. needed during training.'''
  801. if flat_buffer is None:
  802. # this dst buffer is on NVMe, so skip this
  803. return
  804. start = 0
  805. for param in param_list:
  806. src = param.ds_tensor
  807. dest = flat_buffer.narrow(0, start, src.ds_numel)
  808. start = start + src.ds_numel
  809. '''if the parameter was initialized in nvme then bring it to the destination buffer directly'''
  810. if src.status == PartitionedParamStatus.NOT_AVAILABLE:
  811. print_rank_0(
  812. f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU"
  813. )
  814. param.nvme_swapper.swap_into_buffer(param, dest)
  815. src.data = dest.data
  816. src.status = PartitionedParamStatus.AVAILABLE
  817. else:
  818. assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here"
  819. if not avoid_copy:
  820. dest.data.copy_(src.data)
  821. src.data = dest.data
  822. # Final location must be gpu/cpu in this case
  823. param.ds_tensor.final_location = 'not-nvme'
  824. def _create_param_groups_fp16_flat_cpu_memory(self):
  825. aggregate_params_count = 0
  826. for j, param_group in enumerate(self.optimizer.param_groups):
  827. params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']])
  828. flat_buffer_size = params_in_group
  829. if self.params_in_nvme_and_cpu and \
  830. aggregate_params_count + params_in_group > self.max_params_in_cpu:
  831. flat_buffer_size = max(0,
  832. self.max_params_in_cpu - aggregate_params_count)
  833. aggregate_params_count += params_in_group
  834. if flat_buffer_size > 0:
  835. print_rank_0(f"group {j} flat buffer size {flat_buffer_size}",
  836. force=False)
  837. self.param_groups_fp16_flat_cpu_memory.append(
  838. torch.empty(int(flat_buffer_size),
  839. dtype=self.dtype,
  840. pin_memory=True))
  841. else:
  842. print_rank_0(
  843. f"No flat buffer size. Param group size was {params_in_group}",
  844. force=False)
  845. self.param_groups_fp16_flat_cpu_memory.append(
  846. torch.empty(1,
  847. dtype=self.dtype))
  848. def _create_fp16_partitions_with_defragmentation(self):
  849. dist.barrier()
  850. partition_id = dist.get_rank(group=self.dp_process_group)
  851. create_fp16_flat_reuse_buffer = False
  852. largest_partition_numel = []
  853. max_partition_numel = 0
  854. #create a flat CPU memory allocation for each param group
  855. if self.offload_param:
  856. self._create_param_groups_fp16_flat_cpu_memory()
  857. # loop to deal with groups
  858. for j, param_group in enumerate(self.optimizer.param_groups):
  859. sub_groups = self._create_fp16_sub_groups(param_group['params'])
  860. print_rank_0(f'fp16 group {j} has {len(sub_groups)} subgroups', force=False)
  861. flat_offset = 0
  862. for sub_group in sub_groups:
  863. i = len(self.fp16_groups)
  864. # push this group to list before modify
  865. self.fp16_groups.append(sub_group)
  866. self.sub_group_to_group_id[i] = j
  867. # comment out for zero_to_fp32 debug
  868. # if torch.distributed.get_rank() == 0:
  869. # for param in self.fp16_groups[i]:
  870. # print(f"{debug_param2name_id_shape(param)} {param.ds_shape}")
  871. #These are the list of the partitioned parameters
  872. self.fp16_partitioned_groups.append(
  873. [param.ds_tensor for param in self.fp16_groups[i]])
  874. total_elements = sum(
  875. [t.ds_numel for t in self.fp16_partitioned_groups[i]])
  876. self.fp16_partitioned_groups_flat_numel.append(total_elements)
  877. if total_elements > max_partition_numel:
  878. largest_partition_numel = [
  879. t.ds_numel for t in self.fp16_partitioned_groups[i]
  880. ]
  881. max_partition_numel = total_elements
  882. print_rank_0(
  883. f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}"
  884. )
  885. # Record padding required to align group to world size (only applies to last rank)
  886. if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
  887. padding = [p.padding_size() for p in self.fp16_groups[i]]
  888. else:
  889. padding = [0] * len(self.fp16_groups[i])
  890. self.groups_padding.append(padding)
  891. #not sure why apex was cloning the weights before flattening
  892. #removing cloning here
  893. see_memory_usage(f"Before Flattening param subgroup {i}", force=False)
  894. #all partitioned parameters remain in GPU during training
  895. if not self.offload_param:
  896. see_memory_usage(f"Before moving param subgroup group {i} to CPU",
  897. force=False)
  898. #move all the parameters to cpu to free up GPU space for creating flat buffer
  899. move_to_cpu(self.fp16_partitioned_groups[i])
  900. see_memory_usage(f"After moving param subgroup {i} to CPU",
  901. force=False)
  902. #create flat buffer in CPU and move to GPU
  903. self.fp16_partitioned_groups_flat.append(
  904. self.flatten_dense_tensors_aligned(
  905. self.fp16_partitioned_groups[i],
  906. 1).cuda(torch.cuda.current_device()))
  907. see_memory_usage(
  908. f"After flattening and moving param subgroup {i} to GPU",
  909. force=False)
  910. #all partitioned parameters are in CPU during training
  911. else:
  912. print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}")
  913. #Flat buffer may not be available for parameters that reside in NVME
  914. if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[
  915. j].numel():
  916. fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[
  917. j].narrow(0,
  918. flat_offset,
  919. total_elements)
  920. print_rank_0(
  921. f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}",
  922. force=False)
  923. #these parameters reside in NVME and
  924. elif self.params_in_nvme_and_cpu:
  925. fp16_partitioned_group_flat = None
  926. print_rank_0(
  927. f"No flat buffer for sub group {i} of {total_elements} elements",
  928. force=False)
  929. else:
  930. assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs"
  931. self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat)
  932. flat_offset += total_elements
  933. # move param to flat buffer for both param offload on/off
  934. self._move_to_flat_buffer(self.fp16_groups[i],
  935. self.fp16_partitioned_groups_flat[i],
  936. avoid_copy=not self.offload_param)
  937. see_memory_usage(f"After Flattening param group {i}", force=False)
  938. #create a pinned memory to be used for swapping out params to NVME after optimizer step
  939. if self.fp16_partitioned_groups_flat[-1] is None:
  940. create_fp16_flat_reuse_buffer = True
  941. see_memory_usage(f"After Flattening param subgroup {i}", force=False)
  942. if create_fp16_flat_reuse_buffer:
  943. assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty'
  944. self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space(
  945. largest_partition_numel)
  946. def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id):
  947. offset = 0
  948. elements_in_sub_group = sum(
  949. [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]])
  950. assert (flat_buffer.numel() == elements_in_sub_group)
  951. for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]):
  952. dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel)
  953. if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
  954. print_rank_0(
  955. f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}"
  956. )
  957. param.nvme_swapper.swap_in([param], async_op=False)
  958. dest.data.copy_(partitioned_param.data)
  959. param.nvme_swapper.remove_partition_and_release_buffers([param])
  960. print_rank_0(f"Swapping in {param.ds_id} done")
  961. else:
  962. dest.data.copy_(partitioned_param.data)
  963. offset += partitioned_param.ds_numel
  964. def _create_next_swappable_fp32_groups(self):
  965. reverse_order_indices = [
  966. i for i in range(len(self.fp32_partitioned_groups_flat))
  967. ]
  968. reverse_order_indices.reverse()
  969. next_group = None
  970. for i in reverse_order_indices:
  971. self.next_swappable_fp32_partitioned_groups.append(next_group)
  972. if self._swappable_optimizer_subgroup(i):
  973. next_group = self.fp32_partitioned_groups_flat[i]
  974. self.next_swappable_fp32_partitioned_groups.reverse()
  975. def _get_sub_group_partitions(self, sub_group_id):
  976. sub_group_partitions = []
  977. for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]):
  978. if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
  979. swap_path = param.nvme_swapper.get_path(param, True)
  980. sub_group_partitions.append((partitioned_param,
  981. param.ds_tensor.ds_numel,
  982. swap_path))
  983. else:
  984. sub_group_partitions.append((partitioned_param,
  985. partitioned_param.ds_numel,
  986. None))
  987. return sub_group_partitions
  988. def _create_fp32_partitions(self):
  989. cpu_memory_usage = 0
  990. cpu_memory_sub_groups = 0
  991. nvme_memory_usage = 0
  992. num_swappable_partitions = 0
  993. num_swap_from_nvme_partitions = 0
  994. num_swap_from_cpu_partitions = 0
  995. swap_from_nvme_memory_usage = 0
  996. swap_from_cpu_memory_usage = 0
  997. GIGA_BYTES = (1024**3)
  998. swappable_fp32_tensors = []
  999. swappable_fp16_src_tensors = []
  1000. nvme_fp16_partitions_info = []
  1001. nvme_fp16_num_elems = []
  1002. nvme_fp32_dest_tensors = []
  1003. fp32_element_size = torch.tensor([], dtype=torch.float32).element_size()
  1004. for i, tensor in enumerate(self.fp16_partitioned_groups_flat):
  1005. num_elements = self.fp16_partitioned_groups_flat_numel[i]
  1006. # a partition of the fp32 master weights that will be updated by this process
  1007. if self._swappable_optimizer_subgroup(i):
  1008. self.fp32_partitioned_groups_flat.append(torch.Tensor())
  1009. nvme_memory_usage += (fp32_element_size * num_elements)
  1010. num_swappable_partitions += 1
  1011. if self.params_in_nvme_and_cpu and tensor is None:
  1012. num_swap_from_nvme_partitions += 1
  1013. swap_from_nvme_memory_usage += (fp32_element_size * num_elements)
  1014. if self.offload_optimizer_fast_init:
  1015. sub_group_partitions = self._get_sub_group_partitions(i)
  1016. nvme_fp16_partitions_info.append(sub_group_partitions)
  1017. nvme_fp16_num_elems.append(num_elements)
  1018. nvme_fp32_dest_tensors.append(
  1019. self.fp32_partitioned_groups_flat[i])
  1020. else:
  1021. unpinned_fp32_buffer = torch.empty(num_elements,
  1022. device=self.device,
  1023. dtype=torch.float)
  1024. self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i)
  1025. self.optimizer_swapper.initialize_parameters(
  1026. parameters=[self.fp32_partitioned_groups_flat[i]],
  1027. src_tensors=[unpinned_fp32_buffer])
  1028. else:
  1029. num_swap_from_cpu_partitions += 1
  1030. swap_from_cpu_memory_usage += (fp32_element_size * num_elements)
  1031. swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i])
  1032. swappable_fp16_src_tensors.append(
  1033. self.fp16_partitioned_groups_flat[i])
  1034. else:
  1035. cpu_memory_usage += (fp32_element_size * num_elements)
  1036. cpu_memory_sub_groups += 1
  1037. if self.params_in_nvme_and_cpu and tensor is None:
  1038. unpinned_fp32_buffer = torch.empty(num_elements,
  1039. device=self.device,
  1040. dtype=torch.float)
  1041. self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i)
  1042. self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer)
  1043. else:
  1044. self.fp32_partitioned_groups_flat.append(
  1045. self.fp16_partitioned_groups_flat[i].to(
  1046. self.device).clone().float().detach())
  1047. self.fp32_partitioned_groups_flat[
  1048. i].requires_grad = True # keep this in case internal optimizer uses it
  1049. if len(swappable_fp32_tensors) > 0:
  1050. self.optimizer_swapper.initialize_parameters(
  1051. parameters=swappable_fp32_tensors,
  1052. src_tensors=swappable_fp16_src_tensors)
  1053. if len(nvme_fp32_dest_tensors) > 0:
  1054. fp16_pinned_buffers = self.fp16_groups[0][
  1055. 0].nvme_swapper.reserve_available_buffers()
  1056. assert len(fp16_pinned_buffers) > 0
  1057. self.optimizer_swapper.initialize_from_swapped_fp16_params(
  1058. fp16_partitions_info=nvme_fp16_partitions_info,
  1059. fp16_num_elems=nvme_fp16_num_elems,
  1060. fp16_pinned_buffers=fp16_pinned_buffers,
  1061. fp32_parameters=nvme_fp32_dest_tensors)
  1062. self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers()
  1063. nvme_gigabytes = nvme_memory_usage / GIGA_BYTES
  1064. print_rank_0(
  1065. f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB',
  1066. force=False)
  1067. if self.params_in_nvme_and_cpu:
  1068. print_rank_0(
  1069. f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB',
  1070. force=False)
  1071. print_rank_0(
  1072. f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB',
  1073. force=False)
  1074. cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES
  1075. print_rank_0(
  1076. f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB',
  1077. force=False)
  1078. # Clear for on-the-fly population before the optimizer step
  1079. for param_group in self.optimizer.param_groups:
  1080. param_group['params'] = []
  1081. def _create_fp16_sub_groups(self, params_group):
  1082. params_group_numel = sum([param.partitioned_size() for param in params_group])
  1083. sub_group_size = self.sub_group_size
  1084. if sub_group_size is None or sub_group_size >= params_group_numel:
  1085. return [params_group]
  1086. sub_groups = []
  1087. sub_group = []
  1088. local_sub_group_size = 0
  1089. for param in params_group:
  1090. sub_group.append(param)
  1091. local_sub_group_size += param.partitioned_size()
  1092. if local_sub_group_size >= sub_group_size or id(param) == id(
  1093. params_group[-1]):
  1094. sub_groups.append(sub_group)
  1095. sub_group = []
  1096. local_sub_group_size = 0
  1097. return sub_groups
  1098. # def reset_ds_tensor(self):
  1099. # for name, param in self.module.named_parameters(recurse=True):
  1100. # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible"
  1101. # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now"
  1102. # param.ds_tensor.data = param.data
  1103. def setup_zero_stage3_hooks(self):
  1104. self.hierarchy = 0
  1105. self._register_hooks_recursively(self.module)
  1106. #reset step at the beginning of forward
  1107. def _pre_forward_hook(module, *args):
  1108. self.param_coordinator.reset_step()
  1109. #reset step if in inference mode
  1110. def _end_of_forward_hook(module, *args):
  1111. if not torch._C.is_grad_enabled():
  1112. self.param_coordinator.reset_step()
  1113. #likely one of them should be enough but just to be safe
  1114. self.module.register_forward_hook(_end_of_forward_hook)
  1115. self.module.register_forward_pre_hook(_pre_forward_hook)
  1116. # Add top module to stack trace
  1117. global FWD_MODULE_STACK
  1118. FWD_MODULE_STACK.append(self.module)
  1119. def persistent_parameters(self):
  1120. persistent_params = []
  1121. total_persistent_parameters = 0
  1122. params_count = 0
  1123. for _, param in self.module.named_parameters(recurse=True):
  1124. if param.ds_numel < self.persistence_threshold:
  1125. params_count += 1
  1126. param.ds_persist = True
  1127. persistent_params.append(param)
  1128. total_persistent_parameters += param.ds_numel
  1129. print_rank_0(
  1130. f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
  1131. force=False)
  1132. return persistent_params
  1133. def _register_hooks_recursively(self, module, count=[0]):
  1134. my_count = count[0]
  1135. module.id = my_count
  1136. #print(f"{module.__class__} : {module.id}")
  1137. for child in module.children():
  1138. count[0] = count[0] + 1
  1139. self._register_hooks_recursively(child, count=count)
  1140. def _pre_forward_module_hook(module, *args):
  1141. self.pre_sub_module_forward_function(module)
  1142. def _post_forward_module_hook(module, input, output):
  1143. global FWD_MODULE_STACK
  1144. FWD_MODULE_STACK.pop()
  1145. if output is None:
  1146. output = []
  1147. elif not isinstance(output, (list, tuple)):
  1148. if torch.is_tensor(output):
  1149. output = [output]
  1150. else:
  1151. #print(f'got UNKNOWN type {type(output)}')
  1152. outputs = []
  1153. output = output if isinstance(output, dict) else vars(output)
  1154. for name, val in output.items():
  1155. if not name.startswith('__') and torch.is_tensor(val):
  1156. outputs.append(val)
  1157. output = outputs
  1158. #print(f'convert output to {output}')
  1159. for item in filter(lambda item: is_zero_param(item), output):
  1160. if not any(id(item) in m._external_params for m in FWD_MODULE_STACK):
  1161. item.ds_active_sub_modules += 1
  1162. module_to_register = FWD_MODULE_STACK[-1]
  1163. print_rank_0(
  1164. f'Registering dangling parameter for module {module_to_register.__class__.__name__}.',
  1165. force=False)
  1166. register_external_parameter(module_to_register, item)
  1167. # It's possible that the parameter was already external to the completed module. If so, remove it the
  1168. # registration as it will be covered by the outer module instead.
  1169. if id(item) in module._external_params:
  1170. print_rank_0(
  1171. f' Unregistering nested dangling parameter from module {module.__class__.__name__}',
  1172. force=False)
  1173. unregister_external_parameter(module, item)
  1174. item.all_gather()
  1175. self.post_sub_module_forward_function(module)
  1176. def _pre_backward_module_hook(module, inputs, output):
  1177. def _run_before_backward_function(sub_module):
  1178. # some models (e.g. Albert) may run multiple forwards on the same layer in a loop
  1179. # before doing backwards, so each backward will need a pre-fetch - using reference
  1180. # counting to support this scenario
  1181. #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
  1182. if sub_module.applied_pre_backward_ref_cnt > 0:
  1183. self.pre_sub_module_backward_function(sub_module)
  1184. sub_module.applied_pre_backward_ref_cnt -= 1
  1185. #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
  1186. return _apply_to_tensors_only(module,
  1187. PreBackwardFunction,
  1188. _run_before_backward_function,
  1189. output)
  1190. #This is an alternate to doing _post_backward_module_hook
  1191. #it uses tensor.register_hook instead of using torch.autograd.Function
  1192. def _alternate_post_backward_module_hook(module, inputs):
  1193. module.ds_grads_remaining = 0
  1194. #print(f"Before Forward {module.__class__.__name__}")
  1195. def _run_after_backward_hook(*unused):
  1196. module.ds_grads_remaining = module.ds_grads_remaining - 1
  1197. if module.ds_grads_remaining == 0:
  1198. #print(f"After backward {module.__class__.__name__}")
  1199. self.post_sub_module_backward_function(module)
  1200. def _run_before_forward_function(input):
  1201. if input.requires_grad:
  1202. module.ds_grads_remaining += 1
  1203. return _apply_forward_and_backward_to_tensors_only(
  1204. module,
  1205. _run_before_forward_function,
  1206. _run_after_backward_hook,
  1207. inputs)
  1208. def _post_backward_module_hook(module, inputs):
  1209. module.ds_grads_remaining = 0
  1210. def _run_after_backward_function(sub_module):
  1211. if sub_module.ds_grads_remaining == 0:
  1212. self.post_sub_module_backward_function(sub_module)
  1213. return _apply_to_tensors_only(module,
  1214. PostBackwardFunction,
  1215. _run_after_backward_function,
  1216. inputs)
  1217. # Pre forward hook
  1218. module.register_forward_pre_hook(_pre_forward_module_hook)
  1219. # Post forward hook
  1220. module.register_forward_hook(_post_forward_module_hook)
  1221. # Pre backward hook
  1222. module.register_forward_hook(_pre_backward_module_hook)
  1223. # post backward hook
  1224. module.register_forward_pre_hook(_post_backward_module_hook)
  1225. def pre_sub_module_forward_function(self, sub_module):
  1226. see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}",
  1227. force=False)
  1228. global FWD_MODULE_STACK
  1229. FWD_MODULE_STACK.append(sub_module)
  1230. self.param_coordinator.record_trace(sub_module)
  1231. self.param_coordinator.fetch_sub_module(sub_module)
  1232. see_memory_usage(
  1233. f"Before sub module function {sub_module.__class__.__name__} after fetch",
  1234. force=False)
  1235. self.param_coordinator.prefetch_next_sub_modules(
  1236. sub_module,
  1237. numel=self.prefetch_elements,
  1238. nvme=self.params_in_nvme_and_cpu)
  1239. see_memory_usage(
  1240. f"Before sub module function {sub_module.__class__.__name__} after prefetch",
  1241. force=False)
  1242. self.param_coordinator.increment_step(sub_module)
  1243. def post_sub_module_forward_function(self, sub_module):
  1244. see_memory_usage(
  1245. f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
  1246. force=False)
  1247. self.param_coordinator.release_sub_module(sub_module)
  1248. see_memory_usage(
  1249. f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
  1250. force=False)
  1251. def pre_sub_module_backward_function(self, sub_module):
  1252. self.param_coordinator.record_trace(sub_module)
  1253. self.param_coordinator.fetch_sub_module(sub_module)
  1254. self.param_coordinator.prefetch_next_sub_modules(sub_module,
  1255. numel=self.prefetch_elements)
  1256. self.param_coordinator.increment_step(sub_module)
  1257. def post_sub_module_backward_function(self, sub_module):
  1258. see_memory_usage(
  1259. f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
  1260. force=False)
  1261. self.param_coordinator.release_sub_module(sub_module)
  1262. see_memory_usage(
  1263. f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
  1264. force=False)
  1265. def _release_ipg_buffers(self):
  1266. if self.contiguous_gradients:
  1267. self.ipg_buffer = None
  1268. if not self.offload_optimizer and self.is_gradient_accumulation_boundary:
  1269. self.grads_in_partition = None
  1270. self.grads_in_partition_offset = 0
  1271. def _optimizer_step(self, sub_group_id):
  1272. param_group_id = self.sub_group_to_group_id[sub_group_id]
  1273. fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
  1274. fp16_param = self.fp16_partitioned_groups_flat[sub_group_id]
  1275. self.optimizer.param_groups[param_group_id]['params'] = [fp32_param]
  1276. self.optimizer.step()
  1277. self.optimizer.param_groups[param_group_id]['params'] = []
  1278. def _swappable_optimizer_subgroup(self, sub_group_id):
  1279. if not self.swap_optimizer:
  1280. return False
  1281. return self.optimizer_swapper.swappable_tensor(
  1282. None,
  1283. numel=self.fp16_partitioned_groups_flat_numel[sub_group_id])
  1284. def _partitioned_params_swap_out(self, i):
  1285. offset = 0
  1286. fp32_param = self.fp32_partitioned_groups_flat[i]
  1287. assert fp32_param is not None, \
  1288. f'fp32 parameters of sub_group {i} is None'
  1289. swap_fp16_params = []
  1290. swap_fp32_params = []
  1291. for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]):
  1292. src = fp32_param.narrow(0, offset, partitioned_param.ds_numel)
  1293. if partitioned_param.status == PartitionedParamStatus.AVAILABLE:
  1294. partitioned_param.data.copy_(src.data)
  1295. else:
  1296. swap_fp32_params.append(src)
  1297. swap_fp16_params.append(param)
  1298. offset += partitioned_param.ds_numel
  1299. if len(swap_fp16_params):
  1300. swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(
  1301. dst_fp16_params=swap_fp16_params,
  1302. src_fp32_params=swap_fp32_params)
  1303. def initialize_optimizer_states(self):
  1304. num_subgroups = len(self.fp16_groups)
  1305. largest_numel = max(
  1306. [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups])
  1307. gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype
  1308. gradient_buffer = torch.zeros(int(largest_numel),
  1309. dtype=gradient_dtype,
  1310. device=self.device)
  1311. timers = self.timers
  1312. timer_names = set()
  1313. if self.swap_optimizer:
  1314. self.optimizer_swapper.init_timers()
  1315. INIT_OPTIMIZER_TIMER = 'init_optimizer_state'
  1316. timer_names.add(INIT_OPTIMIZER_TIMER)
  1317. self.start_timers([INIT_OPTIMIZER_TIMER])
  1318. for i, group in enumerate(self.fp16_groups):
  1319. swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i)
  1320. swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None
  1321. num_elements = int(self.fp16_partitioned_groups_flat_numel[i])
  1322. see_memory_usage(
  1323. f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
  1324. force=False)
  1325. if swappable_optimizer_subgroup:
  1326. self._optimizer_states_and_gradient_swap_in(i, timer_names)
  1327. if self.offload_optimizer and not swappable_optimizer_subgroup:
  1328. subgroup_gradient_buffer = torch.zeros(num_elements,
  1329. dtype=gradient_dtype,
  1330. device=self.device)
  1331. if self.offload_optimizer_pin_memory:
  1332. subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory()
  1333. self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer
  1334. else:
  1335. self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(
  1336. 0,
  1337. 0,
  1338. num_elements)
  1339. self._optimizer_step(i)
  1340. if swappable_param_subgroup:
  1341. self._partitioned_params_swap_out(i)
  1342. if swappable_optimizer_subgroup:
  1343. self._optimizer_states_and_gradient_swap_out(i, timer_names)
  1344. see_memory_usage(
  1345. f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
  1346. force=False)
  1347. self.stop_timers([INIT_OPTIMIZER_TIMER])
  1348. self.log_timers(timer_names)
  1349. if self.swap_optimizer:
  1350. self.optimizer_swapper.log_timers()
  1351. if not self.offload_optimizer:
  1352. for group in self.fp32_partitioned_groups_flat:
  1353. group.grad = None
  1354. # Reset steps
  1355. return
  1356. #########################################################################
  1357. #########################ZeRO Partition Gradients########################
  1358. #########################################################################
  1359. def get_first_param_index(self, group_id, param_group, partition_id):
  1360. for index, param in enumerate(param_group):
  1361. param_id = self.get_param_id(param)
  1362. if partition_id in self.param_to_partition_ids[group_id][param_id]:
  1363. return index
  1364. return None
  1365. def initialize_gradient_partitioning_data_structures(self):
  1366. total_partitions = dist.get_world_size(group=self.dp_process_group)
  1367. for i, param_group in enumerate(self.fp16_groups):
  1368. self.param_to_partition_ids[i] = {}
  1369. self.is_partition_reduced[i] = {}
  1370. self.total_grads_in_partition[i] = {}
  1371. self.remaining_grads_in_partition[i] = {}
  1372. self.is_grad_computed[i] = {}
  1373. self.grad_partition_insertion_offset[i] = {}
  1374. self.grad_start_offset[i] = {}
  1375. self.first_param_index_in_partition[i] = {}
  1376. for partition_id in range(total_partitions):
  1377. self.is_grad_computed[i][partition_id] = {}
  1378. self.grad_partition_insertion_offset[i][partition_id] = {}
  1379. self.grad_start_offset[i][partition_id] = {}
  1380. self.initialize_gradient_partition(i, param_group, partition_id)
  1381. self.is_partition_reduced[i][partition_id] = False
  1382. self.first_param_index_in_partition[i][
  1383. partition_id] = self.get_first_param_index(
  1384. i,
  1385. param_group,
  1386. partition_id)
  1387. def independent_gradient_partition_epilogue(self):
  1388. self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
  1389. self.reduce_ipg_grads()
  1390. self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
  1391. if self.overlap_comm:
  1392. self.reduction_stream.synchronize()
  1393. with torch.cuda.stream(self.reduction_stream):
  1394. self.partition_previous_reduced_grads()
  1395. # if dist.get_rank() == 0:
  1396. # logger.info("Params already reduced %s", self.params_already_reduced)
  1397. for i in range(len(self.params_already_reduced)):
  1398. self.params_already_reduced[i] = False
  1399. #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad
  1400. #TODO: use a similar code path for both cpu_offload and non-cpu offload
  1401. if not self.offload_optimizer:
  1402. for i, sub_group in enumerate(self.fp16_groups):
  1403. self.averaged_gradients[i] = [
  1404. torch.zeros_like(param.ds_tensor) if param.grad is None else
  1405. param.grad.data.narrow(0,
  1406. 0,
  1407. param.ds_tensor.numel())
  1408. for param in sub_group
  1409. ]
  1410. # self.averaged_gradients[i] = self.get_flat_partition(
  1411. # self.fp16_groups[i],
  1412. # 0,
  1413. # self.fp32_partitioned_groups_flat[i].numel(),
  1414. # return_tensor_list=True)
  1415. self._release_ipg_buffers()
  1416. see_memory_usage(f"End ipg_epilogue", force=False)
  1417. # resets all partition to no reduced
  1418. # sets remaining grads to the total number of grads in each partition
  1419. # set is grad computed to false for all grads in partition
  1420. def reset_partition_gradient_structures(self):
  1421. total_partitions = dist.get_world_size(group=self.dp_process_group)
  1422. for i, _ in enumerate(self.fp16_groups):
  1423. for partition_id in range(total_partitions):
  1424. self.is_partition_reduced[i][partition_id] = False
  1425. self.remaining_grads_in_partition[i][
  1426. partition_id] = self.total_grads_in_partition[i][partition_id]
  1427. for param_id in self.is_grad_computed[i][partition_id]:
  1428. self.is_grad_computed[i][partition_id][param_id] = False
  1429. def initialize_gradient_partition(self, i, param_group, partition_id):
  1430. def set_key_value_list(dictionary, key, value):
  1431. if key in dictionary:
  1432. dictionary[key].append(value)
  1433. else:
  1434. dictionary[key] = [value]
  1435. def increment_value(dictionary, key):
  1436. if key in dictionary:
  1437. dictionary[key] += 1
  1438. else:
  1439. dictionary[key] = 1
  1440. partition_size = self.partition_size[i]
  1441. start_index = partition_size * partition_id
  1442. end_index = partition_size * (partition_id + 1)
  1443. current_index = 0
  1444. first_offset = 0
  1445. for param in param_group:
  1446. param_size = param.numel()
  1447. param_id = self.get_param_id(param)
  1448. if (current_index >= start_index and current_index < end_index):
  1449. set_key_value_list(self.param_to_partition_ids[i],
  1450. param_id,
  1451. partition_id)
  1452. increment_value(self.total_grads_in_partition[i], partition_id)
  1453. self.is_grad_computed[i][partition_id][param_id] = False
  1454. self.grad_partition_insertion_offset[i][partition_id][
  1455. param_id] = current_index - start_index
  1456. self.grad_start_offset[i][partition_id][param_id] = 0
  1457. elif start_index > current_index and start_index < (current_index +
  1458. param_size):
  1459. assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
  1460. first_offset = start_index - current_index
  1461. set_key_value_list(self.param_to_partition_ids[i],
  1462. param_id,
  1463. partition_id)
  1464. increment_value(self.total_grads_in_partition[i], partition_id)
  1465. self.is_grad_computed[i][partition_id][param_id] = False
  1466. self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
  1467. self.grad_start_offset[i][partition_id][param_id] = first_offset
  1468. current_index = current_index + param_size
  1469. def overlapping_partition_gradients_reduce_epilogue(self):
  1470. self.independent_gradient_partition_epilogue()
  1471. self.zero_grad()
  1472. def create_reduce_and_remove_grad_hooks(self):
  1473. print_rank_0(f'[Begin] Create gradient reduction hooks')
  1474. self.grad_accs = []
  1475. for i, param_group in enumerate(self.fp16_groups):
  1476. for param in param_group:
  1477. if param.requires_grad:
  1478. #print_rank_0(f" Before all gather {param.device}, {param.shape}")
  1479. # The hook must be created in un-partitioned parameter
  1480. param.all_gather()
  1481. #print(f"After all gather {param.device}, {param.shape}")
  1482. def wrapper(param, i):
  1483. param_tmp = param.expand_as(param)
  1484. grad_acc = param_tmp.grad_fn.next_functions[0][0]
  1485. def reduce_partition_and_remove_grads(*notneeded):
  1486. self.reduce_ready_partitions_and_remove_grads(param, i)
  1487. grad_acc.register_hook(reduce_partition_and_remove_grads)
  1488. self.grad_accs.append(grad_acc)
  1489. #print(f"param grad fn {param.expand_as(param).grad_fn}")
  1490. wrapper(param, i)
  1491. # Partition the parameter after creating the hook
  1492. param.partition()
  1493. print_rank_0(f'[End] Create gradient reduction hooks')
  1494. def get_param_id(self, param):
  1495. unique_id = id(param)
  1496. return self.param_id[unique_id]
  1497. def report_ipg_memory_usage(self, tag, param_elems):
  1498. elem_count = self.elements_in_ipg_bucket + param_elems
  1499. percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
  1500. see_memory_usage(
  1501. f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}",
  1502. force=False)
  1503. ###############Idependent Partition Gradient ########################
  1504. def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
  1505. #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True)
  1506. # Because the ipg bucket is initialized with a random place holder tensor, we must
  1507. # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket >
  1508. # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a
  1509. # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be
  1510. # empty, while reduction_list will have that garbage data.
  1511. if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size:
  1512. self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
  1513. param.ds_numel)
  1514. self.reduce_ipg_grads()
  1515. if self.contiguous_gradients and self.overlap_comm:
  1516. # Swap ipg_index between 0 and 1
  1517. self.ipg_index = 1 - self.ipg_index
  1518. self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
  1519. param.ds_numel)
  1520. param_id = self.get_param_id(param)
  1521. assert self.params_already_reduced[param_id] == False, \
  1522. f"The parameter {param_id} has already been reduced. \
  1523. Gradient computed twice for this partition. \
  1524. Multiple gradient reduction is currently not supported"
  1525. # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
  1526. if param.ds_numel > self.reduce_bucket_size:
  1527. self.extra_large_param_to_reduce = param
  1528. elif self.contiguous_gradients:
  1529. #print_rank_0("before new grad tensor move")
  1530. new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
  1531. 0,
  1532. self.elements_in_ipg_bucket,
  1533. param.ds_numel)
  1534. #print_rank_0("after new grad tensor move")
  1535. new_grad_tensor.copy_(param.grad.view(-1))
  1536. param.grad.data = new_grad_tensor.data.view_as(param.grad)
  1537. self.elements_in_ipg_bucket += param.ds_numel
  1538. self.grads_in_ipg_bucket.append(param.grad)
  1539. self.params_in_ipg_bucket.append((i, param, param_id))
  1540. self.report_ipg_memory_usage("End ipg_remove_grads", 0)
  1541. def gradient_reduction_w_predivide(self, tensor):
  1542. dp_world_size = dist.get_world_size(group=self.dp_process_group)
  1543. tensor_to_allreduce = tensor
  1544. if self.communication_data_type != tensor.dtype:
  1545. tensor_to_allreduce = tensor.to(self.communication_data_type)
  1546. if self.postscale_gradients:
  1547. if self.gradient_predivide_factor != 1.0:
  1548. tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
  1549. dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
  1550. if self.gradient_predivide_factor != dp_world_size:
  1551. tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size)
  1552. else:
  1553. tensor_to_allreduce.div_(dp_world_size)
  1554. dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
  1555. if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
  1556. tensor.copy_(tensor_to_allreduce)
  1557. return tensor
  1558. def average_tensor(self, tensors, params_to_reduce):
  1559. with torch.cuda.stream(self.reduction_stream):
  1560. if not self.reduce_scatter:
  1561. for tensor in tensors:
  1562. self.gradient_reduction_w_predivide(tensor)
  1563. return
  1564. for tensor in tensors:
  1565. tensor.div_(dist.get_world_size(group=self.dp_process_group))
  1566. # reduction resulting with each rank only holding the gradient partition it owns
  1567. # This could either be a reduce scatter or a reduce op depending on how
  1568. # parameters are partitionied. The method is implemented by the
  1569. # DeepSpeed param extensions to the pytorch parameter, so its up to
  1570. # the extension to define what happens here
  1571. params_to_reduce[0].reduce_gradients_at_owner(
  1572. param_list=params_to_reduce,
  1573. hierarchy=self.param_coordinator.hierarchy)
  1574. def set_grad_positions(self):
  1575. for i, group in enumerate(self.fp16_groups):
  1576. current_offset = 0
  1577. for param in group:
  1578. param_id = self.get_param_id(param)
  1579. num_elements = param.ds_tensor.ds_numel
  1580. self.grad_position[param_id] = [
  1581. int(i),
  1582. int(current_offset),
  1583. int(num_elements)
  1584. ]
  1585. #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}")
  1586. current_offset += num_elements
  1587. def async_accumulate_grad_in_cpu_via_gpu(self, param, acc_grad_cpu_partition):
  1588. # copy to a preexisiting buffer to avoid memory allocation penalty
  1589. dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
  1590. 0,
  1591. 0,
  1592. param.ds_tensor.ds_numel)
  1593. if self.micro_step_id > 0:
  1594. dest_buffer.copy_(acc_grad_cpu_partition.view(-1), non_blocking=True)
  1595. param.grad.data.view(-1).add_(dest_buffer)
  1596. # at the boundary we will send 32bit directly
  1597. if not self.is_gradient_accumulation_boundary:
  1598. acc_grad_cpu_partition.data.copy_(param.grad.data.view(-1),
  1599. non_blocking=True)
  1600. def _constant_buffered_norm2(self, input, buffer_size=250000000):
  1601. norm = None
  1602. for part in input.view(-1).split(buffer_size):
  1603. if norm is None:
  1604. norm = part.data.double().norm(2)**2.0
  1605. else:
  1606. norm += part.data.double().norm(2)**2.0
  1607. return norm**0.5
  1608. def set_norm_for_param_grad_in_gpu(self, param):
  1609. param_id = self.get_param_id(param)
  1610. #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2)
  1611. #Using a more memory efficient version
  1612. self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad)
  1613. def update_overflow_tracker_for_param_grad(self, param):
  1614. #Credit to our user David Minn
  1615. if param.grad is not None:
  1616. if self.overlap_comm:
  1617. self.gpu_sum = self.gpu_sum + param.grad.data.float().sum()
  1618. elif self._has_inf_or_nan(param.grad.data):
  1619. self.local_overflow = True
  1620. def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor):
  1621. with torch.cuda.stream(self.copy_grad_stream):
  1622. param_id = self.get_param_id(param)
  1623. src_tensor = param.grad.view(-1).float()
  1624. #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}")
  1625. fp32_grad_tensor.copy_(src_tensor, non_blocking=True)
  1626. param.grad = None
  1627. def complete_grad_norm_calculation_for_cpu_offload(self, params):
  1628. total_norm = 0.0
  1629. norm_type = 2.0
  1630. for p in params:
  1631. if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
  1632. param_id = self.get_param_id(p)
  1633. if param_id in self.norm_for_param_grads.keys():
  1634. param_norm = self.norm_for_param_grads[param_id]
  1635. total_norm += param_norm.item()**2
  1636. # Sum across all model parallel GPUs.
  1637. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  1638. torch.distributed.all_reduce(total_norm_cuda,
  1639. op=torch.distributed.ReduceOp.SUM,
  1640. group=self.dp_process_group)
  1641. self._model_parallel_all_reduce(tensor=total_norm_cuda,
  1642. op=torch.distributed.ReduceOp.SUM)
  1643. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  1644. if total_norm == float(
  1645. 'inf') or total_norm == -float('inf') or total_norm != total_norm:
  1646. total_norm = -1
  1647. return total_norm
  1648. def partition_previous_reduced_grads(self):
  1649. if not self.previous_reduced_grads:
  1650. return
  1651. if self.offload_optimizer:
  1652. allocate_grads_in_partition = self.grads_in_partition is None\
  1653. and self.gradient_accumulation_steps > 1
  1654. else:
  1655. allocate_grads_in_partition = self.grads_in_partition is None
  1656. if allocate_grads_in_partition:
  1657. self.grads_in_partition = []
  1658. for i, group in enumerate(self.fp16_groups):
  1659. total_size = 0
  1660. for param_in_partition in group:
  1661. total_size += param_in_partition.ds_tensor.ds_numel
  1662. see_memory_usage(
  1663. f"group {i} before creating {total_size} reduced gradients into partition",
  1664. force=False)
  1665. if self.offload_param_pin_memory:
  1666. self.grads_in_partition.append(
  1667. torch.zeros(int(total_size),
  1668. dtype=self.dtype,
  1669. device=self.device).pin_memory())
  1670. else:
  1671. self.grads_in_partition.append(
  1672. torch.zeros(int(total_size),
  1673. dtype=self.dtype,
  1674. device=self.device))
  1675. see_memory_usage(
  1676. f"group {i} after creating {total_size} reduced gradients into partition",
  1677. force=False)
  1678. if self.offload_optimizer:
  1679. offload_fp32_gradients = {}
  1680. offload_fp32_offsets = {}
  1681. with torch.cuda.stream(self.copy_grad_stream):
  1682. self.reduction_stream.synchronize()
  1683. for param in self.previous_reduced_grads:
  1684. [i,
  1685. dest_offset,
  1686. num_elements] = self.grad_position[self.get_param_id(param)]
  1687. if self.offload_optimizer:
  1688. param.partition_gradients(
  1689. partition_buffers=self.temp_grad_gpu_buffer)
  1690. #with torch.cuda.stream(self.copy_grad_stream):
  1691. # self.reduction_stream.synchronize()
  1692. if self.gradient_accumulation_steps > 1:
  1693. # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer
  1694. fp16_grad_tensor = self.grads_in_partition[i].narrow(
  1695. 0,
  1696. dest_offset,
  1697. num_elements)
  1698. self.async_accumulate_grad_in_cpu_via_gpu(
  1699. param,
  1700. fp16_grad_tensor)
  1701. if self.is_gradient_accumulation_boundary:
  1702. self.set_norm_for_param_grad_in_gpu(param)
  1703. self.update_overflow_tracker_for_param_grad(param)
  1704. if self._swappable_optimizer_subgroup(i):
  1705. if not i in offload_fp32_gradients.keys():
  1706. offload_fp32_gradients[i] = []
  1707. offload_fp32_offsets[i] = []
  1708. offload_fp32_gradients[i].append(param.grad.view(-1).float())
  1709. param.grad = None
  1710. offload_fp32_offsets[i].append(dest_offset)
  1711. else:
  1712. fp32_grad_tensor = self.fp32_partitioned_groups_flat[
  1713. i].grad.narrow(0,
  1714. dest_offset,
  1715. num_elements)
  1716. self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(
  1717. param,
  1718. fp32_grad_tensor)
  1719. else:
  1720. # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer
  1721. fp16_grad_tensor = self.grads_in_partition[i].narrow(
  1722. 0,
  1723. dest_offset,
  1724. num_elements)
  1725. param.partition_gradients(
  1726. partition_buffers=fp16_grad_tensor,
  1727. accumulate=True if self.micro_step_id > 0 else False)
  1728. if self.offload_optimizer and self.swap_optimizer:
  1729. for i in offload_fp32_gradients.keys():
  1730. self.optimizer_swapper.swap_out_gradients(
  1731. parameter=self.fp32_partitioned_groups_flat[i],
  1732. gradient_offsets=offload_fp32_offsets[i],
  1733. gradient_tensors=offload_fp32_gradients[i])
  1734. self.previous_reduced_grads = []
  1735. def reduce_ipg_grads(self, extra_param=None):
  1736. if self.overlap_comm:
  1737. self.reduction_stream.synchronize()
  1738. with torch.cuda.stream(self.reduction_stream):
  1739. self.partition_previous_reduced_grads()
  1740. params_to_reduce = [param for i, param, param_id in self.params_in_ipg_bucket]
  1741. #print(f"Params in ipg bucket {self.params_in_ipg_bucket}")
  1742. #print(f"Reducing {[(debug_param2name_id_shape(param), param.grad) for param in params_to_reduce]}")
  1743. #exit(0)
  1744. if self.contiguous_gradients:
  1745. reduction_list = [self.ipg_buffer[self.ipg_index]]
  1746. if self.extra_large_param_to_reduce is not None:
  1747. reduction_list.append(self.extra_large_param_to_reduce.grad)
  1748. self.extra_large_param_to_reduce = None
  1749. self.average_tensor(reduction_list, params_to_reduce)
  1750. else:
  1751. self.buffered_reduce_fallback(
  1752. None,
  1753. self.grads_in_ipg_bucket,
  1754. elements_per_buffer=self.elements_in_ipg_bucket)
  1755. for _, param, param_id in self.params_in_ipg_bucket:
  1756. self.params_already_reduced[param_id] = True
  1757. self.previous_reduced_grads = params_to_reduce
  1758. self.grads_in_ipg_bucket = []
  1759. self.params_in_ipg_bucket = []
  1760. self.elements_in_ipg_bucket = 0
  1761. #####################################################################
  1762. def reduce_ready_partitions_and_remove_grads(self, param, i):
  1763. #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True)
  1764. self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  1765. def zero_reduced_gradients(self, partition_id, i):
  1766. def are_all_related_partitions_reduced(params_id):
  1767. for partition_id in self.param_to_partition_ids[i][params_id]:
  1768. if not self.is_partition_reduced[i][partition_id]:
  1769. return False
  1770. return True
  1771. for params_id in self.is_grad_computed[i][partition_id]:
  1772. if are_all_related_partitions_reduced(params_id):
  1773. self.param_dict[params_id].grad = None
  1774. def flatten_and_print(self, message, tensors, start=0, n=5):
  1775. flatten_tensor = self.flatten(tensors)
  1776. def print_func():
  1777. logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
  1778. self.sequential_execution(print_func, message)
  1779. def get_grads_to_reduce(self, i, partition_id):
  1780. def get_reducible_portion(key):
  1781. grad = self.param_dict[key].grad
  1782. total_elements = grad.numel()
  1783. start = self.grad_start_offset[i][partition_id][key]
  1784. num_elements = min(
  1785. total_elements - start,
  1786. self.partition_size[i] -
  1787. self.grad_partition_insertion_offset[i][partition_id][key])
  1788. if not pg_correctness_test:
  1789. if num_elements == total_elements:
  1790. return grad
  1791. else:
  1792. return grad.contiguous().view(-1).narrow(0,
  1793. int(start),
  1794. int(num_elements))
  1795. else:
  1796. if num_elements == total_elements:
  1797. return grad.clone()
  1798. else:
  1799. return grad.clone().contiguous().view(-1).narrow(
  1800. 0,
  1801. int(start),
  1802. int(num_elements))
  1803. grads_to_reduce = []
  1804. for key in self.is_grad_computed[i][partition_id]:
  1805. grad = get_reducible_portion(key)
  1806. grads_to_reduce.append(grad)
  1807. return grads_to_reduce
  1808. def sequential_execution(self, function, message, group=None):
  1809. if group is None:
  1810. group = self.dp_process_group
  1811. if dist.get_rank(group=group) == 0:
  1812. logger.info(message)
  1813. for id in range(dist.get_world_size(group=group)):
  1814. if id == dist.get_rank(group=group):
  1815. function()
  1816. dist.barrier(group=group)
  1817. def set_none_gradients_to_zero(self, i, partition_id):
  1818. for param_id in self.is_grad_computed[i][partition_id]:
  1819. param = self.param_dict[param_id]
  1820. if param.grad is None:
  1821. param.grad = torch.zero_like(param)
  1822. ######################Reduction Related Methods##############################
  1823. def allreduce_bucket(self,
  1824. bucket,
  1825. communication_data_type=torch.float16,
  1826. rank=None,
  1827. log=None):
  1828. rank = None
  1829. tensor = self.flatten(bucket)
  1830. tensor_to_allreduce = tensor
  1831. if pg_correctness_test:
  1832. communication_data_type = torch.float32
  1833. if communication_data_type != tensor.dtype:
  1834. tensor_to_allreduce = tensor.to(communication_data_type)
  1835. tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
  1836. if rank is None:
  1837. # "All Reducing"
  1838. dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
  1839. else:
  1840. global_rank = _get_global_rank(self.dp_process_group, rank)
  1841. dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
  1842. if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
  1843. if rank is None or rank == dist.get_rank(group=self.dp_process_group):
  1844. tensor.copy_(tensor_to_allreduce)
  1845. return tensor
  1846. # if rank is specified do a reduction instead of an allreduce
  1847. def allreduce_and_copy(self, small_bucket, rank=None, log=None):
  1848. with torch.cuda.stream(self.reduction_stream):
  1849. allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
  1850. if rank is None or rank == dist.get_rank(group=self.dp_process_group):
  1851. for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
  1852. buf.copy_(synced)
  1853. def allreduce_no_retain(self,
  1854. bucket,
  1855. numel_per_bucket=500000000,
  1856. rank=None,
  1857. log=None):
  1858. small_bucket = []
  1859. numel = 0
  1860. for tensor in bucket:
  1861. small_bucket.append(tensor)
  1862. numel = numel + tensor.numel()
  1863. if numel > numel_per_bucket:
  1864. self.allreduce_and_copy(small_bucket, rank=rank, log=None)
  1865. small_bucket = []
  1866. if len(small_bucket) > 0:
  1867. self.allreduce_and_copy(small_bucket, rank=rank, log=log)
  1868. # allows using reduction of gradients instead of using all_reduce
  1869. def buffered_reduce_fallback(self,
  1870. rank,
  1871. grads,
  1872. elements_per_buffer=500000000,
  1873. log=None):
  1874. split_buckets = split_half_float_double(grads)
  1875. for i, bucket in enumerate(split_buckets):
  1876. self.allreduce_no_retain(bucket,
  1877. numel_per_bucket=elements_per_buffer,
  1878. rank=rank,
  1879. log=log)
  1880. #############################################################################
  1881. #############################################################################
  1882. #############################################################################
  1883. # views the tensor as multiple partitions and returns
  1884. # those partitions
  1885. def get_data_parallel_partitions(self, tensor):
  1886. partitions = []
  1887. dp = dist.get_world_size(group=self.dp_process_group)
  1888. dp_id = dist.get_rank(group=self.dp_process_group)
  1889. total_num_elements = tensor.numel()
  1890. base_size = total_num_elements // dp
  1891. remaining = total_num_elements % dp
  1892. start = 0
  1893. for id in range(dp):
  1894. partition_size = base_size
  1895. if id < remaining:
  1896. partition_size = partition_size + 1
  1897. partitions.append(tensor.narrow(0, start, partition_size))
  1898. start = start + partition_size
  1899. return partitions
  1900. def get_partition_info(self, tensor_list, partition_size, partition_id):
  1901. params_in_partition = []
  1902. params_not_in_partition = []
  1903. start_index = partition_size * partition_id
  1904. end_index = partition_size * (partition_id + 1)
  1905. current_index = 0
  1906. first_offset = 0
  1907. for tensor in tensor_list:
  1908. tensor_size = tensor.numel()
  1909. if (current_index >= start_index and current_index < end_index):
  1910. params_in_partition.append(tensor)
  1911. elif start_index > current_index and start_index < (current_index +
  1912. tensor_size):
  1913. params_in_partition.append(tensor)
  1914. assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
  1915. first_offset = start_index - current_index
  1916. else:
  1917. params_not_in_partition.append(tensor)
  1918. current_index = current_index + tensor_size
  1919. return params_in_partition, params_not_in_partition, first_offset
  1920. def zero_grad(self, set_grads_to_None=True):
  1921. """
  1922. Zero FP16 parameter grads.
  1923. """
  1924. # FP32 grad should never exist.
  1925. # For speed, set model fp16 grad to None by default
  1926. for group in self.fp16_groups:
  1927. for p in group:
  1928. if set_grads_to_None:
  1929. p.grad = None
  1930. else:
  1931. if p.grad is not None:
  1932. p.grad.detach_()
  1933. p.grad.zero_()
  1934. def _model_parallel_all_reduce(self, tensor, op):
  1935. """ Perform all reduce within model parallel group, if any.
  1936. """
  1937. if self.model_parallel_group is None:
  1938. pass
  1939. else:
  1940. torch.distributed.all_reduce(tensor=tensor,
  1941. op=op,
  1942. group=self.model_parallel_group)
  1943. def get_grad_norm_direct(self, gradients, params, norm_type=2):
  1944. """Clips gradient norm of an iterable of parameters.
  1945. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  1946. added functionality to handle model parallel parameters. Note that
  1947. the gradients are modified in place.
  1948. Arguments:
  1949. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  1950. single Tensor that will have gradients normalized
  1951. max_norm (float or int): max norm of the gradients
  1952. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  1953. infinity norm.
  1954. Returns:
  1955. Total norm of the parameters (viewed as a single vector).
  1956. """
  1957. norm_type = float(norm_type)
  1958. if norm_type == inf:
  1959. total_norm = max(g.data.abs().max() for g in gradients)
  1960. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  1961. torch.distributed.all_reduce(total_norm_cuda,
  1962. op=torch.distributed.ReduceOp.MAX,
  1963. group=self.dp_process_group)
  1964. # Take max across all GPUs.
  1965. self._model_parallel_all_reduce(tensor=total_norm_cuda,
  1966. op=torch.distributed.ReduceOp.MAX)
  1967. total_norm = total_norm_cuda[0].item()
  1968. else:
  1969. total_norm = 0.0
  1970. # if dist.get_rank() == 0:
  1971. # logger.info(f"Total Norm beginning {total_norm}")
  1972. for g, p in zip(gradients, params):
  1973. if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
  1974. param_norm = g.data.double().norm(2)
  1975. total_norm += param_norm.item()**2
  1976. # Sum across all model parallel GPUs.
  1977. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  1978. torch.distributed.all_reduce(total_norm_cuda,
  1979. op=torch.distributed.ReduceOp.SUM,
  1980. group=self.dp_process_group)
  1981. self._model_parallel_all_reduce(tensor=total_norm_cuda,
  1982. op=torch.distributed.ReduceOp.SUM)
  1983. total_norm = total_norm_cuda[0].item()**(1. / norm_type)
  1984. if total_norm == float(
  1985. 'inf') or total_norm == -float('inf') or total_norm != total_norm:
  1986. total_norm = -1
  1987. return total_norm
  1988. # creates a flat fused tensor from the tensor list starting at the first_offset
  1989. # in the first tensor of the list. If there are not enough elements in the tensor
  1990. # list then the flat tensor will be padded with zeros
  1991. def get_flat_partition(self,
  1992. tensor_list,
  1993. first_offset,
  1994. partition_size,
  1995. return_tensor_list=False):
  1996. flat_tensor_list = []
  1997. current_size = 0
  1998. for i, tensor in enumerate(tensor_list):
  1999. if tensor.grad is None:
  2000. tensor.grad = torch.zeros_like(tensor)
  2001. tensor = tensor.grad
  2002. num_elements = tensor.numel()
  2003. tensor_offset = 0
  2004. # we need to offset to get to the right element
  2005. if i == 0 and first_offset > 0:
  2006. tensor_offset = first_offset
  2007. num_elements = num_elements - tensor_offset
  2008. # we dont need all elements of the tensor
  2009. if num_elements > (partition_size - current_size):
  2010. num_elements = partition_size - current_size
  2011. # we need a narrow view of the tensor based on the tensor offset and number of elements that
  2012. # we need from this tensor
  2013. if tensor_offset > 0 or num_elements < tensor.numel():
  2014. flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
  2015. 0,
  2016. int(tensor_offset),
  2017. int(num_elements)))
  2018. else:
  2019. flat_tensor_list.append(tensor)
  2020. current_size = current_size + num_elements
  2021. # this means its the last partition and does not align with the dp boundary. We need to pad before flattening
  2022. if current_size < partition_size:
  2023. flat_tensor_list.append(
  2024. torch.zeros(int(partition_size - current_size),
  2025. dtype=tensor_list[0].dtype,
  2026. device=tensor_list[0].device))
  2027. if return_tensor_list:
  2028. return flat_tensor_list
  2029. return self.flatten(flat_tensor_list)
  2030. def free_grad_in_param_list(self, param_list):
  2031. for p in param_list:
  2032. p.grad = None
  2033. def reset_cpu_buffers(self):
  2034. self.norm_for_param_grads = {}
  2035. self.local_overflow = False
  2036. def log_timers(self, timer_names):
  2037. if self.timers is None:
  2038. return
  2039. self.timers.log(names=list(timer_names))
  2040. def start_timers(self, timer_names):
  2041. if self.timers is None:
  2042. return
  2043. for name in timer_names:
  2044. self.timers(name).start()
  2045. def stop_timers(self, timer_names):
  2046. if self.timers is None:
  2047. return
  2048. for name in timer_names:
  2049. self.timers(name).stop()
  2050. def _pre_step(self):
  2051. self.micro_step_id = INITIAL_MICRO_STEP_ID
  2052. print_rank_0(f"Inside Step function")
  2053. see_memory_usage(f"In step before checking overflow", force=False)
  2054. print_rank_0("Finished Tracing at Beginning of Step")
  2055. self.param_coordinator.hierarchy = 0
  2056. self.param_coordinator.finish_tracing(print_trace=True)
  2057. self.param_coordinator.reset_step()
  2058. print_rank_0("Finished Tracing at Beginning of Step")
  2059. def _get_norm_groups(self):
  2060. norm_groups = []
  2061. for i, group in enumerate(self.fp16_groups):
  2062. if self.offload_optimizer:
  2063. norm_groups.append(
  2064. self.complete_grad_norm_calculation_for_cpu_offload(
  2065. self.fp16_groups[i]))
  2066. else:
  2067. norm_groups.append(
  2068. self.get_grad_norm_direct(self.averaged_gradients[i],
  2069. self.fp16_groups[i]))
  2070. return norm_groups
  2071. def _prepare_fp32_grad_for_sub_group(self, sub_group_id):
  2072. partition_id = dist.get_rank(group=self.dp_process_group)
  2073. single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to(
  2074. self.fp32_partitioned_groups_flat[sub_group_id].dtype)
  2075. assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \
  2076. "averaged gradients have different number of elements that partition size {} {} {} {}".format(
  2077. single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id)
  2078. self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition
  2079. # release all the gradient since we have already created a necessary copy in dp_grad_partition
  2080. self.zero_grad()
  2081. self.averaged_gradients[sub_group_id] = None
  2082. def _prepare_sub_group(self, sub_group_id, timer_names=set()):
  2083. see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}',
  2084. force=False)
  2085. if self._swappable_optimizer_subgroup(sub_group_id):
  2086. self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names)
  2087. elif not self.offload_optimizer:
  2088. self._prepare_fp32_grad_for_sub_group(sub_group_id)
  2089. see_memory_usage(f'After prepare optimizer sub group {sub_group_id}',
  2090. force=False)
  2091. def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()):
  2092. param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id]
  2093. fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id])
  2094. assert self._swappable_optimizer_subgroup(sub_group_id), \
  2095. f'Parameter {fp32_param_id} of numel={param_length} is not swappable'
  2096. OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state'
  2097. see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}',
  2098. force=False)
  2099. self.start_timers([OPTIMIZER_SWAP_IN_STATE])
  2100. self.optimizer_swapper.swap_in_optimizer_state(
  2101. parameter=self.fp32_partitioned_groups_flat[sub_group_id],
  2102. async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id])
  2103. self.stop_timers([OPTIMIZER_SWAP_IN_STATE])
  2104. timer_names.add(OPTIMIZER_SWAP_IN_STATE)
  2105. see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}',
  2106. force=False)
  2107. def _release_sub_group(self, sub_group_id, timer_names=set()):
  2108. see_memory_usage(f'Before release optimizer sub group {sub_group_id}',
  2109. force=False)
  2110. # get rid of the fp32 gradients. Not needed anymore
  2111. if not self.offload_optimizer:
  2112. self.fp32_partitioned_groups_flat[sub_group_id].grad = None
  2113. if self._swappable_optimizer_subgroup(sub_group_id):
  2114. self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names)
  2115. see_memory_usage(f'After release optimizer sub group {sub_group_id}',
  2116. force=False)
  2117. # create a flat tensor aligned at the alignment boundary
  2118. def flatten_dense_tensors_aligned(self, tensor_list, alignment):
  2119. num_elements = 0
  2120. for tens in tensor_list:
  2121. num_elements = num_elements + tens.numel()
  2122. remaining = num_elements % alignment
  2123. if remaining:
  2124. elements_to_add = alignment - remaining
  2125. pad_tensor = torch.zeros(elements_to_add,
  2126. device=tensor_list[0].device,
  2127. dtype=tensor_list[0].dtype)
  2128. padded_tensor_list = tensor_list + [pad_tensor]
  2129. num_elements = num_elements + elements_to_add
  2130. else:
  2131. padded_tensor_list = tensor_list
  2132. return self.flatten(padded_tensor_list)
  2133. def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()):
  2134. param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id]
  2135. fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id])
  2136. assert self._swappable_optimizer_subgroup(sub_group_id), \
  2137. f'Parameter {fp32_param_id} of numel={param_length} is not swappable'
  2138. OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state'
  2139. see_memory_usage(
  2140. f'post-step Before swapping out optimizer tensors {sub_group_id}',
  2141. force=False)
  2142. self.start_timers([OPTIMIZER_SWAP_OUT_STATE])
  2143. self.optimizer_swapper.swap_out_optimizer_state(
  2144. parameter=self.fp32_partitioned_groups_flat[sub_group_id],
  2145. async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id]
  2146. is not None)
  2147. self.stop_timers([OPTIMIZER_SWAP_OUT_STATE])
  2148. see_memory_usage(
  2149. f'post-step After swapping out optimizer tensors {sub_group_id}',
  2150. force=False)
  2151. timer_names.add(OPTIMIZER_SWAP_OUT_STATE)
  2152. # get rid of the fp32 gradients. Not needed anymore
  2153. self.fp32_partitioned_groups_flat[sub_group_id].grad = None
  2154. def _unflatten_partitioned_parameters(self, sub_group_id):
  2155. updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id],
  2156. self.fp16_partitioned_groups[sub_group_id])
  2157. for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
  2158. partitioned_param.data = q.data
  2159. def _overflow_clean_up(self, prev_scale):
  2160. see_memory_usage('After overflow before clearing gradients', force=False)
  2161. self.zero_grad()
  2162. if self.offload_optimizer:
  2163. self.reset_cpu_buffers()
  2164. else:
  2165. self.averaged_gradients = {}
  2166. see_memory_usage('After overflow after clearing gradients', force=False)
  2167. if torch.distributed.get_rank() == 0:
  2168. logger.info(
  2169. "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
  2170. "reducing to {}".format(dist.get_rank(),
  2171. prev_scale,
  2172. self.loss_scale))
  2173. def _overflow_check_and_loss_scale_update(self):
  2174. # First compute norm for all group so we know if there is overflow
  2175. self.check_overflow()
  2176. #loss scaling related computation
  2177. prev_scale = self.loss_scale
  2178. self._update_scale(self.overflow)
  2179. if self.overflow:
  2180. self._overflow_clean_up(prev_scale)
  2181. return self.overflow
  2182. def _post_step(self, timer_names=set()):
  2183. if self.offload_optimizer:
  2184. self.reset_cpu_buffers()
  2185. #Gathering persisting parameters
  2186. if len(self.persistent_parameters) > 0:
  2187. self.persistent_parameters[0].all_gather(self.persistent_parameters)
  2188. if self.swap_optimizer:
  2189. self.optimizer_swapper.log_timers()
  2190. self.log_timers(timer_names)
  2191. see_memory_usage('After zero_optimizer step', force=False)
  2192. print_rank_0(f"------------------Finishing Step-----------------------")
  2193. def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id):
  2194. if self.fp16_partitioned_groups_flat[sub_group_id] is not None:
  2195. self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
  2196. self.fp32_partitioned_groups_flat[sub_group_id].data)
  2197. #unflatten fp16 parameter subgroup
  2198. self._unflatten_partitioned_parameters(sub_group_id)
  2199. else:
  2200. self._partitioned_params_swap_out(sub_group_id)
  2201. def step(self, closure=None):
  2202. """
  2203. Not supporting closure.
  2204. """
  2205. self._pre_step()
  2206. #checks for overflow, adjust the loss scale accordingly
  2207. if self._overflow_check_and_loss_scale_update():
  2208. if self.swap_optimizer:
  2209. self.optimizer_swapper.log_timers()
  2210. return
  2211. norm_groups = self._get_norm_groups()
  2212. self._global_grad_norm = get_global_norm(norm_list=norm_groups)
  2213. timer_names = set()
  2214. timer_names.add('optimizer_step')
  2215. self.start_timers(['optimizer_step'])
  2216. #update parameters one sub group at a time
  2217. for sub_group_id, group in enumerate(self.fp16_groups):
  2218. #prepare optimizer states, gradients and fp32 parameters for update
  2219. self._prepare_sub_group(sub_group_id, timer_names)
  2220. #scale the fp32 gradients
  2221. self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm)
  2222. #apply the optimizer step on the sub group and copy fp32 parameters to fp16
  2223. self._optimizer_step(sub_group_id)
  2224. #put fp16 parameters in appropriate location
  2225. self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
  2226. #release memory or swap out optimizer states of fp32 parameters
  2227. self._release_sub_group(sub_group_id, timer_names)
  2228. self.stop_timers(['optimizer_step'])
  2229. self._post_step(timer_names)
  2230. return
  2231. def dump_pre_step_gradients(self, debug_fp32_grads):
  2232. # Dump gradient norms for debugging
  2233. for i, _ in enumerate(self.fp16_groups):
  2234. print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC')
  2235. for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]):
  2236. param_id = self.get_param_id(fp16_param)
  2237. fp16_grad_norm = self.debug_fp16_grads[i][param_id]
  2238. fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad]
  2239. norm_list = [fp16_grad_norm, fp32_grad_norm]
  2240. print(f'Pre-Step Norms {i} {param_id} = {norm_list}')
  2241. def dump_post_step_gradients(self):
  2242. # Dump gradient norms for debugging
  2243. for i, group in enumerate(self.fp16_groups):
  2244. print(
  2245. f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT')
  2246. unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i])
  2247. unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i],
  2248. self.fp16_groups[i])
  2249. for j, p in enumerate(self.fp16_groups[i]):
  2250. param_id = self.get_param_id(p)
  2251. param_norm = float(p.data.float().norm(2))
  2252. ds_norm = float(p.ds_tensor.data.float().norm(2))
  2253. unflat_norm = [
  2254. float(t.data.float().norm(2))
  2255. for t in [unflat_fp16[j],
  2256. unflat_fp32[j]]
  2257. ]
  2258. norm_list = [param_norm, ds_norm] + unflat_norm
  2259. print(f'Post-Step Norms {i} {param_id} = {norm_list}')
  2260. def unscale_and_clip_grads(self, sub_group_id, total_norm):
  2261. grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad]
  2262. # compute combined scale factor for this group
  2263. combined_scale = self.loss_scale
  2264. if self.clip_grad > 0.:
  2265. # norm is in fact norm*scale
  2266. clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
  2267. if clip > 1:
  2268. combined_scale = clip * self.loss_scale
  2269. for grad in grad_groups_flat:
  2270. if isinstance(grad, list):
  2271. sub_partitions = grad
  2272. for g in sub_partitions:
  2273. g.data.mul_(1. / combined_scale)
  2274. else:
  2275. grad.data.mul_(1. / combined_scale)
  2276. def _check_overflow(self, partition_gradients=True):
  2277. self.overflow = self.has_overflow(partition_gradients)
  2278. # `params` is a list / generator of torch.Variable
  2279. def has_overflow_serial(self, params, is_grad_list=False):
  2280. for p in params:
  2281. if p.grad is not None and self._has_inf_or_nan(p.grad.data):
  2282. return True
  2283. return False
  2284. def has_overflow_partitioned_grads_serial(self):
  2285. for i in range(len(self.fp16_groups)):
  2286. for j, grad in enumerate(self.averaged_gradients[i]):
  2287. if grad is not None and self._has_inf_or_nan(grad.data, j):
  2288. return True
  2289. return False
  2290. def has_overflow(self, partition_gradients=True):
  2291. if partition_gradients:
  2292. if self.overlap_comm:
  2293. self.local_overflow = self._has_inf_or_nan(self.gpu_sum)
  2294. self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda()
  2295. overflow = self.local_overflow if self.offload_optimizer else self.has_overflow_partitioned_grads_serial(
  2296. )
  2297. #overflow = self.has_overflow_partitioned_grads_serial()
  2298. overflow_gpu = torch.cuda.ByteTensor([overflow])
  2299. torch.distributed.all_reduce(overflow_gpu,
  2300. op=torch.distributed.ReduceOp.MAX,
  2301. group=self.dp_process_group)
  2302. else:
  2303. params = []
  2304. for group in self.fp16_groups:
  2305. for param in group:
  2306. params.append(param)
  2307. overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
  2308. overflow_gpu = torch.cuda.ByteTensor([overflow])
  2309. # Since each model parallel GPU carries only part of the model,
  2310. # make sure overflow flag is synced across all the model parallel GPUs
  2311. self._model_parallel_all_reduce(tensor=overflow_gpu,
  2312. op=torch.distributed.ReduceOp.MAX)
  2313. overflow = overflow_gpu[0].item()
  2314. return bool(overflow)
  2315. # `x` is a torch.Tensor
  2316. @staticmethod
  2317. def _has_inf_or_nan(x, j=None):
  2318. try:
  2319. # if x is half, the .float() incurs an additional deep copy, but it's necessary if
  2320. # Pytorch's .sum() creates a one-element tensor of the same type as x
  2321. # (which is true for some recent version of pytorch).
  2322. cpu_sum = float(x.float().sum())
  2323. # More efficient version that can be used if .sum() returns a Python scalar
  2324. # cpu_sum = float(x.sum())
  2325. except RuntimeError as instance:
  2326. # We want to check if inst is actually an overflow exception.
  2327. # RuntimeError could come from a different error.
  2328. # If so, we still want the exception to propagate.
  2329. if "value cannot be converted" not in instance.args[0]:
  2330. raise
  2331. return True
  2332. else:
  2333. if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
  2334. return True
  2335. return False
  2336. def backward(self, loss, retain_graph=False):
  2337. """
  2338. :attr:`backward` performs the following steps:
  2339. 1. fp32_loss = loss.float()
  2340. 2. scaled_loss = fp32_loss*loss_scale
  2341. 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
  2342. """
  2343. self.micro_step_id += 1
  2344. print_rank_0(
  2345. f"Total fully available parameters {self.param_coordinator.total_available_parameter_numel}"
  2346. )
  2347. if self.swap_optimizer:
  2348. self.optimizer_swapper.pre_backward()
  2349. see_memory_usage(f"Before backward", force=False)
  2350. if self.contiguous_gradients:
  2351. self.ipg_buffer = []
  2352. buf_0 = torch.empty(self.reduce_bucket_size,
  2353. dtype=self.dtype,
  2354. device=torch.cuda.current_device())
  2355. self.ipg_buffer.append(buf_0)
  2356. # Use double buffers to avoid data access conflict when overlap_comm is enabled.
  2357. if self.overlap_comm:
  2358. buf_1 = torch.empty(self.reduce_bucket_size,
  2359. dtype=self.dtype,
  2360. device=torch.cuda.current_device())
  2361. self.ipg_buffer.append(buf_1)
  2362. self.ipg_index = 0
  2363. self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  2364. '''Partitioning Parameters that were not partitioned
  2365. Usually if parameters of modules whose input parameters do not require
  2366. grad computation do not trigger post call and will therefore will remain unpartitioned '''
  2367. self._partition_all_parameters()
  2368. if self.swap_optimizer:
  2369. self.optimizer_swapper.post_backward()
  2370. def _partition_all_parameters(self):
  2371. for name, param in self.module.named_parameters(recurse=True):
  2372. self.param_coordinator.release_and_reset_parameter(param)
  2373. def check_overflow(self, partition_gradients=True):
  2374. self._check_overflow(partition_gradients)
  2375. def _update_scale(self, has_overflow=False):
  2376. self.loss_scaler.update_scale(has_overflow)
  2377. # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
  2378. def _get_state(self):
  2379. return self.optimizer.state
  2380. def _set_state(self, value):
  2381. self.optimizer.state = value
  2382. state = property(_get_state, _set_state)
  2383. # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
  2384. # (for example, to adjust the learning rate)
  2385. def _get_param_groups(self):
  2386. return self.optimizer.param_groups
  2387. def _set_param_groups(self, value):
  2388. self.optimizer.param_groups = value
  2389. param_groups = property(_get_param_groups, _set_param_groups)
  2390. # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
  2391. def _get_loss_scale(self):
  2392. return self.loss_scaler.loss_scale
  2393. def _set_loss_scale(self, value):
  2394. self.loss_scaler.cur_scale = value
  2395. loss_scale = property(_get_loss_scale, _set_loss_scale)
  2396. cur_scale = property(_get_loss_scale, _set_loss_scale)
  2397. def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings):
  2398. # Remove paddings from flattened tensor
  2399. individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors)
  2400. lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)]
  2401. lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)]
  2402. #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}')
  2403. return lean_tensors
  2404. #TODO REVISIT this for stage 3
  2405. def get_lean_optimizer_state(self):
  2406. # Return optimizer states after removing paddings.
  2407. # This method assumes that each param group contains a single flattened tensor.
  2408. optimizer_groups_state = []
  2409. for i, group in enumerate(self.optimizer.param_groups):
  2410. p = group['params'][0]
  2411. lean_state = {}
  2412. for key, value in self.optimizer.state[p].items():
  2413. if torch.is_tensor(value):
  2414. padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]]
  2415. lean_state[key] = self._get_lean_tensors(
  2416. value,
  2417. self.fp16_partitioned_groups[i],
  2418. self.groups_padding[i])
  2419. lean_flat_len = sum([t.numel() for t in lean_state[key]])
  2420. else:
  2421. lean_state[key] = value
  2422. optimizer_groups_state.append(lean_state)
  2423. return optimizer_groups_state
  2424. def get_groups_without_padding(self, groups_with_padding):
  2425. # Return group tensor after removing paddings added for alignment to DP world size.
  2426. groups_without_padding = []
  2427. for i, group in enumerate(groups_with_padding):
  2428. lean_group = self._get_lean_tensors(group,
  2429. self.fp16_partitioned_groups[i],
  2430. self.groups_padding[i])
  2431. groups_without_padding.append(lean_group)
  2432. return groups_without_padding
  2433. def _set_fp32_optimizer_param_groups(self):
  2434. for sub_group_id, _ in enumerate(self.fp16_groups):
  2435. param_group_id = self.sub_group_to_group_id[sub_group_id]
  2436. self.optimizer.param_groups[param_group_id]['params'].append(
  2437. self.fp32_partitioned_groups_flat[sub_group_id])
  2438. def _clear_fp32_optimizer_param_groups(self):
  2439. for param_group in self.optimizer.param_groups:
  2440. param_group['params'] = []
  2441. def _rigid_state_dict(self):
  2442. state_dict = {}
  2443. state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS
  2444. state_dict['loss_scaler'] = self.loss_scaler
  2445. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  2446. state_dict['overflow'] = self.overflow
  2447. state_dict['partition_count'] = self.partition_count
  2448. self._set_fp32_optimizer_param_groups()
  2449. state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
  2450. state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat
  2451. self._clear_fp32_optimizer_param_groups()
  2452. return state_dict
  2453. def state_dict(self):
  2454. """
  2455. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
  2456. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
  2457. of the contained Pytorch optimizer.
  2458. Example::
  2459. checkpoint = {}
  2460. checkpoint['model'] = model.state_dict()
  2461. checkpoint['optimizer'] = optimizer.state_dict()
  2462. torch.save(checkpoint, "saved.pth")
  2463. """
  2464. if self.elastic_checkpoint:
  2465. raise NotImplementedError(
  2466. "ZeRO-3 does not yet support elastic checkpointing, please disable for now."
  2467. )
  2468. if self.swap_optimizer or self.params_in_nvme_and_cpu:
  2469. raise NotImplementedError(
  2470. "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now."
  2471. )
  2472. return self._rigid_state_dict()
  2473. # Restore base optimizer fp32 weights from checkpoint by:
  2474. # 1) Merging fp32 weights from checkpoints of all partitions
  2475. # 2) Extracting fp32 weights for current partition from merged weights
  2476. # 3) Using extracted weights to update base optimizer weights directly.
  2477. def _restore_from_fp32_weights(self, all_state_dict):
  2478. flat_local_partition = []
  2479. for i in range(len(self.fp32_partitioned_groups_flat)):
  2480. merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict]
  2481. flat_local_partition.append(self._get_flattened_partition(merged_partitions))
  2482. for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition):
  2483. current.data.copy_(saved.data)
  2484. # Restore base optimizer fp32 weights from ZeRO fp16 weights
  2485. def _restore_from_fp16_weights(self):
  2486. for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat):
  2487. fp32_partition.data.copy_(fp16_partitions.data)
  2488. # Refresh the fp32 master params from the fp16 copies.
  2489. def refresh_fp32_params(self):
  2490. self._restore_from_fp16_weights()
  2491. # Extract flattened partition for current rank from all partitions
  2492. def _get_flattened_partition(self, all_partition_states):
  2493. partition_id = dist.get_rank(group=self.dp_process_group)
  2494. alignment = dist.get_world_size(group=self.dp_process_group)
  2495. param_partitions = [[] for _ in range(len(all_partition_states[0]))]
  2496. for i, partition in enumerate(all_partition_states):
  2497. for j, param in enumerate(partition):
  2498. param_partitions[j].append(param)
  2499. local_state_partitions = []
  2500. for param_index, param_slices in enumerate(param_partitions):
  2501. flattened_merged_tensor = self.flatten_dense_tensors_aligned(
  2502. param_slices,
  2503. alignment)
  2504. new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor)
  2505. local_state_partitions.append(new_partitions[partition_id])
  2506. if torch.is_tensor(local_state_partitions[0]):
  2507. return self.flatten_dense_tensors_aligned(local_state_partitions, alignment)
  2508. # Assume non-tensor states are not partitioned and equal across ranks, so return first one
  2509. return local_state_partitions[0]
  2510. # Restore base optimizer state from checkpoint by
  2511. # 1) Merging optimizer state from checkpoints of all partitions
  2512. # 2) Extracting optimizer state for current partition from the merged state
  2513. # 3) Using the extracted value to directly update the base optimizer.
  2514. def _restore_base_optimizer_state(self, all_state_dict):
  2515. base_optimizer_group_states = []
  2516. for i in range(len(self.optimizer.param_groups)):
  2517. partition_states = {}
  2518. all_partition_group_states = [
  2519. sd['base_optimizer_state'][i] for sd in all_state_dict
  2520. ]
  2521. for key in all_partition_group_states[0].keys():
  2522. all_partition_states = [
  2523. all_states[key] for all_states in all_partition_group_states
  2524. ]
  2525. partition_states[key] = self._get_flattened_partition(
  2526. all_partition_states)
  2527. base_optimizer_group_states.append(partition_states)
  2528. for i, group in enumerate(self.optimizer.param_groups):
  2529. p = group['params'][0]
  2530. for key, saved in base_optimizer_group_states[i].items():
  2531. if torch.is_tensor(self.optimizer.state[p][key]):
  2532. self.optimizer.state[p][key].data.copy_(saved.data)
  2533. else:
  2534. self.optimizer.state[p][key] = saved
  2535. def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
  2536. # I think it should actually be ok to reload the optimizer before the model.
  2537. self.loss_scaler = state_dict['loss_scaler']
  2538. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  2539. self.overflow = state_dict['overflow']
  2540. if load_optimizer_states:
  2541. self._set_fp32_optimizer_param_groups()
  2542. self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
  2543. self._clear_fp32_optimizer_param_groups()
  2544. # restore fp32 partitions
  2545. for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']):
  2546. curr_param.data.copy_(saved_param.data)
  2547. # restore fp16 partitions from fp32
  2548. for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
  2549. fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
  2550. fp16_param = self.fp16_partitioned_groups_flat[sub_group_id]
  2551. fp16_param.data.copy_(fp32_param.data)
  2552. # update fp16 unflattened params
  2553. for sub_group_id in range(len(self.fp16_partitioned_groups_flat)):
  2554. updated_params = self.unflatten(
  2555. self.fp16_partitioned_groups_flat[sub_group_id],
  2556. self.fp16_partitioned_groups[sub_group_id])
  2557. for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
  2558. partitioned_param.data = q.data
  2559. # TODO: Support different/changing load/save DP degree.
  2560. def load_state_dict(self,
  2561. state_dict_list,
  2562. load_optimizer_states=True,
  2563. load_from_fp32_weights=False):
  2564. r"""Loading a ZeRO checkpoint
  2565. Arguments:
  2566. state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
  2567. Note that the number of saved partitions may differ from number of loading partitions to support
  2568. changing GPU count, specifically DP world size, between saving and loading checkpoints.
  2569. load_optimizer_states: Boolean indicating whether or not to load base optimizer states
  2570. load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
  2571. copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
  2572. """
  2573. """
  2574. Loads a state_dict created by an earlier call to state_dict().
  2575. If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
  2576. whose parameters in turn came from ``model``, it is expected that the user
  2577. will call ``model.load_state_dict()`` before
  2578. ``fp16_optimizer_instance.load_state_dict()`` is called.
  2579. Example::
  2580. model = torch.nn.Linear(D_in, D_out).cuda().half()
  2581. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  2582. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  2583. ...
  2584. checkpoint = torch.load("saved.pth")
  2585. model.load_state_dict(checkpoint['model'])
  2586. optimizer.load_state_dict(checkpoint['optimizer'])
  2587. """
  2588. if self.elastic_checkpoint:
  2589. raise NotImplementedError(
  2590. "ZeRO-3 does not yet support elastic checkpointing, please disable for now."
  2591. )
  2592. if self.swap_optimizer or self.params_in_nvme_and_cpu:
  2593. raise NotImplementedError(
  2594. "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now."
  2595. )
  2596. self._rigid_load_state_dict(
  2597. state_dict_list[dist.get_rank(group=self.dp_process_group)],
  2598. load_optimizer_states=load_optimizer_states)
  2599. if len(self.persistent_parameters) > 0:
  2600. self.persistent_parameters[0].partition(self.persistent_parameters)
  2601. self.persistent_parameters[0].all_gather(self.persistent_parameters)
  2602. def save_checkpoint_prologue(self):
  2603. self._partition_all_parameters()
  2604. def save_checkpoint_epilogue(self):
  2605. if len(self.persistent_parameters) > 0:
  2606. self.persistent_parameters[0].all_gather(self.persistent_parameters)
  2607. def _handle_overflow(cpu_sum, x, i):
  2608. import math
  2609. rank = torch.distributed.get_rank()
  2610. if rank == 0:
  2611. t_i = -1
  2612. for v_i, v in enumerate(x.data.contiguous().view(-1)):
  2613. if not math.isfinite(float(v)):
  2614. t_i = v_i
  2615. break
  2616. logger.info(
  2617. f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
  2618. )
  2619. def estimate_zero3_model_states_mem_needs(total_params,
  2620. largest_layer_params,
  2621. num_gpus_per_node=1,
  2622. num_nodes=1,
  2623. cpu_offload=True,
  2624. cpu_offload_params=True,
  2625. zero_init=True,
  2626. additional_buffer_factor=1.5):
  2627. total_gpus = num_nodes * num_gpus_per_node
  2628. gpus_factor = 1 / num_nodes
  2629. largest_layer_memory = (4 * largest_layer_params)
  2630. if cpu_offload:
  2631. if cpu_offload_params:
  2632. gpu_mem = largest_layer_memory
  2633. if zero_init:
  2634. cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor
  2635. else:
  2636. cpu_mem = total_params * max(4 * num_gpus_per_node,
  2637. 18 * gpus_factor) * additional_buffer_factor
  2638. else:
  2639. gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus)
  2640. if zero_init:
  2641. cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor
  2642. else:
  2643. cpu_mem = total_params * max(4 * num_gpus_per_node,
  2644. 16 * gpus_factor) * additional_buffer_factor
  2645. else:
  2646. gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus)
  2647. if zero_init:
  2648. cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor
  2649. else:
  2650. cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
  2651. return int(cpu_mem), int(gpu_mem), largest_layer_memory
  2652. def model_to_params(model):
  2653. # shared params calculated only once
  2654. total_params = sum(
  2655. dict((p.data_ptr(),
  2656. p.numel()) for p in model.parameters()).values())
  2657. largest_layer_params = 0
  2658. for m in model.modules():
  2659. # assuming no shared params within a single layer
  2660. layer_params = sum(p.numel() for p in m.parameters(recurse=False))
  2661. largest_layer_params = max(largest_layer_params, layer_params)
  2662. return total_params, largest_layer_params
  2663. import math
  2664. def estimate_zero3_model_states_mem_needs_all_live(model,
  2665. num_gpus_per_node=1,
  2666. num_nodes=1,
  2667. additional_buffer_factor=1.5):
  2668. """
  2669. Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients
  2670. for a given ``model`` and hardware setup.
  2671. If you have an actual model object, use this function and everything will be derived
  2672. automatically.
  2673. If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass
  2674. the ``total_params`` and ``largest_layer_params`` explicitly.
  2675. Args:
  2676. - ``model``: ``nn.Module`` object
  2677. - ``num_gpus_per_node``: how many gpus per node (defaults to 1)
  2678. - ``num_nodes``: how many nodes (defaults to 1),
  2679. - ``additional_buffer_factor``: estimation factor (defaults to 1.5):
  2680. """
  2681. total_params, largest_layer_params = model_to_params(model)
  2682. estimate_zero3_model_states_mem_needs_all_cold(
  2683. total_params=total_params,
  2684. largest_layer_params=largest_layer_params,
  2685. num_gpus_per_node=num_gpus_per_node,
  2686. num_nodes=num_nodes,
  2687. additional_buffer_factor=additional_buffer_factor)
  2688. def estimate_zero3_model_states_mem_needs_all_cold(total_params,
  2689. largest_layer_params,
  2690. num_gpus_per_node=1,
  2691. num_nodes=1,
  2692. additional_buffer_factor=1.5):
  2693. """
  2694. Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients
  2695. for a given ``model`` and hardware setup.
  2696. If it's a hypothetical model, use this function where you have to pass
  2697. the ``total_params`` and ``largest_layer_params`` explicitly.
  2698. If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything
  2699. will be derived automatically.
  2700. Args:
  2701. - ``total_params``: total model params
  2702. - ``largest_layer_params``: largest layer's params
  2703. - ``num_gpus_per_node``: how many gpus per node (defaults to 1)
  2704. - ``num_nodes``: how many nodes (defaults to 1),
  2705. - ``additional_buffer_factor``: estimation factor (defaults to 1.5):
  2706. """
  2707. def format_options(cpu_offload, cpu_offload_params, zero_init):
  2708. enabled = []
  2709. padded_cpu_str = f'{OFFLOAD_CPU_DEVICE:4}'
  2710. param_device = padded_cpu_str if cpu_offload_params else "none"
  2711. enabled.append(f"{OFFLOAD_PARAM}={param_device}")
  2712. optimizer_device = padded_cpu_str if cpu_offload else "none"
  2713. enabled.append(f"{OFFLOAD_OPTIMIZER}={optimizer_device}")
  2714. enabled.append(f"zero_init={1 if zero_init else 0}")
  2715. return ", ".join(enabled)
  2716. nodes_str = "nodes" if num_nodes > 1 else "node"
  2717. gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
  2718. print(
  2719. "Estimated memory needed for params, optim states and gradients for a:\n"
  2720. f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
  2721. f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params."
  2722. )
  2723. print(" per CPU | per GPU | Options")
  2724. for cpu_offload in [True, False]:
  2725. for cpu_offload_params in [True, False]:
  2726. if not cpu_offload and cpu_offload_params:
  2727. continue
  2728. for zero_init in [True, False]:
  2729. cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs(
  2730. total_params=total_params,
  2731. largest_layer_params=largest_layer_params,
  2732. num_gpus_per_node=num_gpus_per_node,
  2733. num_nodes=num_nodes,
  2734. cpu_offload=cpu_offload,
  2735. cpu_offload_params=cpu_offload_params,
  2736. zero_init=zero_init,
  2737. additional_buffer_factor=additional_buffer_factor
  2738. )
  2739. options_str = format_options(cpu_offload=cpu_offload,
  2740. cpu_offload_params=cpu_offload_params,
  2741. zero_init=zero_init)
  2742. print(
  2743. f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}")