engine.py 159 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import re
  6. import stat
  7. import torch
  8. import hashlib
  9. from collections import defaultdict, OrderedDict, deque
  10. from shutil import copyfile
  11. import gc
  12. from torch.nn.modules import Module
  13. from torch.nn.parameter import Parameter
  14. from torch.optim import Optimizer
  15. from torch.optim.lr_scheduler import _LRScheduler
  16. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  17. from typing import Callable, Dict, Union, Iterable
  18. import deepspeed
  19. from deepspeed import comm as dist
  20. from deepspeed.runtime.utils import see_memory_usage, DummyOptim
  21. from .zero.offload_config import OffloadDeviceEnum
  22. from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
  23. from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
  24. from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
  25. from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
  26. from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION
  27. from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
  28. from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
  29. from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
  30. from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
  31. ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
  32. TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER
  33. from deepspeed.runtime.dataloader import DeepSpeedDataLoader
  34. from deepspeed.runtime.constants import \
  35. ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
  36. PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \
  37. DATA_PARALLEL_GROUP, GLOBAL_RANK
  38. from deepspeed.runtime.zero.config import ZeroStageEnum
  39. from deepspeed.compression import compression_scheduler
  40. from deepspeed.compression.constants import \
  41. WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \
  42. WEIGHT_QUANTIZATION, SHARED_PARAMETERS, \
  43. WEIGHT_QUANTIZE_ENABLED, \
  44. WEIGHT_QUANTIZE_GROUPS, \
  45. WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \
  46. WEIGHT_QUANTIZE_CHANGE_RATIO, \
  47. WEIGHT_QUANTIZE_TYPE, \
  48. WEIGHT_QUANTIZE_ROUNDING, \
  49. WEIGHT_QUANTIZE_VERBOSE, \
  50. WEIGHT_QUANTIZE_KERNEL
  51. from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS
  52. from deepspeed.runtime.sparse_tensor import SparseTensor
  53. from deepspeed.runtime import lr_schedules
  54. from deepspeed.utils import groups
  55. from deepspeed.utils import logger, log_dist, instrument_w_nvtx
  56. from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
  57. FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
  58. STEP_MICRO_TIMER, \
  59. FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \
  60. STEP_GLOBAL_TIMER
  61. from deepspeed.utils.debug import debug_extract_module_and_param_names
  62. from deepspeed.monitor.monitor import MonitorMaster
  63. from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
  64. from deepspeed.runtime.utils import clip_grad_norm_
  65. from deepspeed.runtime.eigenvalue import Eigenvalue
  66. from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \
  67. DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \
  68. CURRICULUM_LEARNING_ENABLED, DATA_SAMPLING_NUM_WORKERS, RANDOM_LTD, \
  69. RANDOM_LTD_ENABLED, RANDOM_LTD_LAYER_ID, RANDOM_LTD_LAYER_NUM, \
  70. RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, \
  71. RANDOM_LTD_GLOBAL_BATCH_SIZE, RANDOM_LTD_MICRO_BATCH_SIZE, DATA_EFFICIENCY
  72. from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
  73. from deepspeed.runtime.data_pipeline.data_routing.scheduler import RandomLTDScheduler
  74. from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict
  75. from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop
  76. from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
  77. from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
  78. from .pipe.module import PipelineModule
  79. from .utils import get_ma_status
  80. from ..ops.adam import FusedAdam
  81. from ..moe.sharded_moe import TopKGate, MOELayer
  82. from ..moe.layer import MoE
  83. from ..moe.utils import is_moe_param
  84. from ..git_version_info import version
  85. from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
  86. from deepspeed.utils.logging import print_json_dist, print_configuration
  87. from deepspeed.accelerator import get_accelerator
  88. from deepspeed.runtime.config import DtypeEnum
  89. MEMORY_OPT_ALLREDUCE_SIZE = 500000000
  90. DeepSpeedOptimizerCallable = \
  91. Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer]
  92. DeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler]
  93. try:
  94. import apex
  95. from apex import amp
  96. APEX_INSTALLED = True
  97. except ImportError:
  98. # Fail silently so we don't spam logs unnecessarily if user isn't using amp
  99. APEX_INSTALLED = False
  100. def split_half_float_double_sparse(tensors):
  101. device_type = get_accelerator().device_name()
  102. supported_types = [
  103. "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
  104. "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type),
  105. SparseTensor.type()
  106. ]
  107. for t in tensors:
  108. assert t.type() in supported_types, f"attempting to reduce an unsupported grad type: {t.type()}"
  109. buckets = []
  110. for i, dtype in enumerate(supported_types):
  111. bucket = [t for t in tensors if t.type() == dtype]
  112. if bucket:
  113. buckets.append((dtype, bucket))
  114. return buckets
  115. class EngineTimers(object):
  116. r"""Wallclock timers for DeepSpeedEngine"""
  117. def __init__(self, enable_micro_timers, enable_global_timers):
  118. self.forward_timers = []
  119. self.backward_timers = []
  120. self.backward_inner_timers = []
  121. self.backward_reduce_timers = []
  122. self.step_timers = []
  123. self.global_timers = []
  124. self.micro_timers = []
  125. if enable_micro_timers:
  126. self.forward_timers += [FORWARD_MICRO_TIMER]
  127. self.backward_timers += [BACKWARD_MICRO_TIMER]
  128. self.backward_inner_timers += [BACKWARD_INNER_MICRO_TIMER]
  129. self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER]
  130. self.step_timers += [STEP_MICRO_TIMER]
  131. self.micro_timers += [
  132. FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER,
  133. STEP_MICRO_TIMER
  134. ]
  135. if enable_global_timers:
  136. self.forward_timers += [FORWARD_GLOBAL_TIMER]
  137. self.backward_timers += [BACKWARD_GLOBAL_TIMER]
  138. self.backward_inner_timers += [BACKWARD_INNER_GLOBAL_TIMER]
  139. self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER]
  140. self.step_timers += [STEP_GLOBAL_TIMER]
  141. self.global_timers += [
  142. FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER,
  143. STEP_GLOBAL_TIMER
  144. ]
  145. class DeepSpeedEngine(Module):
  146. r"""DeepSpeed engine for training."""
  147. def __init__(
  148. self,
  149. args,
  150. model,
  151. optimizer=None,
  152. model_parameters=None,
  153. training_data=None,
  154. lr_scheduler=None,
  155. mpu=None,
  156. dist_init_required=None,
  157. collate_fn=None,
  158. config=None,
  159. config_class=None,
  160. dont_change_device=False,
  161. ):
  162. super(DeepSpeedEngine, self).__init__()
  163. self.dont_change_device = dont_change_device
  164. self.client_optimizer = optimizer
  165. self.client_lr_scheduler = lr_scheduler
  166. self.training_data = training_data
  167. self.collate_fn = collate_fn
  168. self.mpu = mpu
  169. self.all_to_all_group = None
  170. self.data_parallel_group = None
  171. self.global_steps = 0
  172. self.global_samples = 0
  173. self.micro_steps = 0
  174. self.skipped_steps = 0
  175. self.gradient_average = True
  176. self.warn_unscaled_loss = True
  177. self.config = config
  178. self._config = config_class
  179. self.loaded_checkpoint_mp_world_size = None
  180. self.loaded_checkpoint_dp_world_size = None
  181. self.enable_backward_allreduce = True
  182. self.progressive_layer_drop = None
  183. self.eigenvalue = None
  184. self.block_eigenvalue = None
  185. self.gas_boundary_ctr = 0
  186. self.dist_backend = get_accelerator().communication_backend_name()
  187. self.has_moe_layers = False
  188. self.num_experts = []
  189. self.gate_modules = []
  190. self.moe_layers = []
  191. self._step_applied = False
  192. self._global_grad_norm = None
  193. self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend.
  194. self.checkpoint_engine = None
  195. self._is_gradient_accumulation_boundary = None
  196. self.scale_wrt_gas = None
  197. self.losses = 0.0
  198. # for debug purposes - can then debug print: debug_get_module_name(module)
  199. debug_extract_module_and_param_names(model)
  200. # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
  201. self.param_names = {param: name for name, param in model.named_parameters()}
  202. self._do_args_sanity_check(args)
  203. self._configure_with_arguments(args, mpu)
  204. self._do_sanity_check()
  205. see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
  206. if mpu is not None:
  207. if self.elasticity_enabled():
  208. if not self.is_elastic_model_parallel_supported():
  209. assert not self.elasticity_enabled(), ("Elasticity is not currently supported"
  210. " with model parallelism.")
  211. self._set_distributed_vars(args)
  212. dist.configure(self._config)
  213. self.monitor = MonitorMaster(self._config.monitor_config)
  214. see_memory_usage(
  215. f"DeepSpeed Engine: Before configure distributed model",
  216. force=self.memory_breakdown(),
  217. )
  218. self.pipeline_parallelism = isinstance(model, PipelineModule)
  219. # Configure distributed model
  220. self._configure_distributed_model(model)
  221. self._get_model_parameters()
  222. see_memory_usage(f"DeepSpeed Engine: After configure distributed model")
  223. # Configure wall clock timers
  224. self.timers = SynchronizedWallClockTimer()
  225. # Throughput timer
  226. self.tput_timer = ThroughputTimer(
  227. batch_size=self.train_batch_size(),
  228. steps_per_output=self.steps_per_print(),
  229. monitor_memory=False,
  230. )
  231. log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0])
  232. if self.flops_profiler_enabled():
  233. self.flops_profiler = FlopsProfiler(self.module, self, self.flops_profiler_recompute_fwd_factor())
  234. if training_data:
  235. self.training_dataloader = self.deepspeed_io(training_data)
  236. else:
  237. self.training_dataloader = None
  238. # Configure optimizer and scheduler
  239. self.optimizer = None
  240. self.basic_optimizer = None
  241. self.lr_scheduler = None
  242. has_optimizer = False
  243. if optimizer or self.optimizer_name():
  244. has_optimizer = True
  245. # If no parameters given by init default to module parameters
  246. if model_parameters is None:
  247. model_parameters = self.module.parameters()
  248. # Convert model parameters from generator to list
  249. if not isinstance(model_parameters, list):
  250. model_parameters = list(model_parameters)
  251. if has_optimizer:
  252. self._configure_optimizer(optimizer, model_parameters)
  253. self._configure_lr_scheduler(lr_scheduler)
  254. self._report_progress(0)
  255. elif self.zero_optimization():
  256. # no optim selected but zero is enabled
  257. self.optimizer = self._configure_zero_optimizer(optimizer=None)
  258. elif self.bfloat16_enabled():
  259. self.optimizer = self._configure_bf16_optimizer(optimizer=None)
  260. # Hook optimizer for snip_momentum pruning
  261. if hasattr(model, 'pruners'):
  262. from ..compression.helper import rewrite_optimizer_step
  263. self.optimizer.pruners = model.pruners
  264. rewrite_optimizer_step(self.optimizer)
  265. # Bookkeeping for sparse support
  266. self.sparse_tensor_module_names = set()
  267. # if self.sparse_gradients_enabled():
  268. for name, module in self.module.named_modules():
  269. if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled():
  270. self.sparse_tensor_module_names.add(name + ".weight")
  271. logger.info("Will convert {} to sparse tensor during training".format(name))
  272. self.save_non_zero_checkpoint = False
  273. self.save_zero_checkpoint = False
  274. if not isinstance(self.optimizer, DeepSpeedZeRoOffload):
  275. self._configure_checkpointing(dist_init_required)
  276. if self.eigenvalue_enabled():
  277. self.eigenvalue = self._configure_eigenvalue()
  278. if self.pld_enabled():
  279. self.progressive_layer_drop = self._configure_progressive_layer_drop()
  280. if self.curriculum_enabled_legacy():
  281. self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy()
  282. if self.random_ltd_enabled():
  283. random_ltd_config = self.random_ltd_config()
  284. random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size()
  285. random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu()
  286. self.random_ltd_scheduler = self._configure_random_ltd_scheduler(random_ltd_config)
  287. # Engine timers
  288. self.engine_timers = EngineTimers(enable_micro_timers=self.wall_clock_breakdown(),
  289. enable_global_timers=self.wall_clock_breakdown()
  290. or self.flops_profiler_enabled())
  291. if self.global_rank == 0:
  292. self._config.print("DeepSpeedEngine configuration")
  293. if self.dump_state():
  294. print_configuration(self, "DeepSpeedEngine")
  295. # Use torch (un)flatten ops
  296. self.flatten = _flatten_dense_tensors
  297. self.unflatten = _unflatten_dense_tensors
  298. def destroy(self):
  299. if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
  300. self.optimizer.destroy()
  301. def _get_model_parameters(self):
  302. if self.autotuning_profile_model_info():
  303. self.autotuning_model_info = {}
  304. num_params = 0
  305. trainable_num_params = 0
  306. for p in self.module.parameters():
  307. # since user code might call deepspeed.zero.Init() before deepspeed.initialize(), need to check the attribute to check if the parameter is partitioned in zero 3 already or not
  308. n = 0
  309. if hasattr(p, "ds_tensor"): # if the parameter is partitioned in zero 3
  310. n += p.ds_numel
  311. else: # if the parameter is not partitioned in zero 3 yet
  312. n += p.numel()
  313. num_params += n
  314. if p.requires_grad:
  315. trainable_num_params += n
  316. if self.global_rank == 0:
  317. self.autotuning_model_info["num_params"] = num_params * self.mp_world_size
  318. self.autotuning_model_info["trainable_num_params"] = trainable_num_params * self.mp_world_size
  319. logger.info(f"model parameter = {num_params}")
  320. def get_batch_info(self):
  321. """Get all training batch related settings.
  322. Returns:
  323. train_batch_size (int): The effective training batch size. This is the amount of data
  324. samples that leads to one step of model update.
  325. train_micro_batch_size_per_gpu (int): Batch size to be processed by one GPU in one
  326. step (without gradient accumulation).
  327. gradient_accumulation_steps (int): Number of training steps to accumulate gradients
  328. before averaging and applying them.
  329. """
  330. return (
  331. self.train_batch_size,
  332. self.train_micro_batch_size_per_gpu,
  333. self.gradient_accumulation_steps,
  334. )
  335. def set_train_batch_size(self, train_batch_size):
  336. """Adjust the global batch size by increasing or decreasing the number of
  337. micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
  338. (i.e., ``train_micro_batch_size_per_gpu``) is not changed.
  339. Args:
  340. train_batch_size (int): The new global batch size for training.
  341. Raises:
  342. ValueError: if ``train_batch_size`` is not divisible by the
  343. configured micro-batch size and data parallelism.
  344. """
  345. if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0:
  346. #print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}')
  347. raise ValueError(f'Train batch size must be divisible by micro-batch data parallelism')
  348. new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size)
  349. # overwrite config
  350. self._config.train_batch_size = train_batch_size
  351. self._config.gradient_accumulation_steps = new_gas
  352. def set_train_micro_batch_size(self, micro_batch_size):
  353. """Adjust the micro batch size(i.e., the micro batch size in every data parallel group),
  354. while keep the gradient accumulation steps the same.
  355. Args:
  356. micro_batch_size (int): The new micro batch size for training.
  357. """
  358. # overwrite config
  359. new_global_batch_size = micro_batch_size * self._config.gradient_accumulation_steps * self.dp_world_size
  360. self._config.train_batch_size = new_global_batch_size
  361. self._config.train_micro_batch_size_per_gpu = micro_batch_size
  362. def set_data_post_process_func(self, post_process_func):
  363. if self.training_dataloader is not None:
  364. self.training_dataloader.post_process_func = post_process_func
  365. def set_custom_curriculum_learning_schedule(self, schedule_func_dict):
  366. if self.training_dataloader is not None and self.curriculum_learning_enabled():
  367. self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(schedule_func_dict)
  368. def get_global_grad_norm(self) -> float:
  369. """Return the 2-norm of all gradients. If there is model parallelism,
  370. the norm will be global.
  371. The computed norm will be cached and reused until the next step() pass.
  372. .. note::
  373. In the presence of model parallelism, this is a collective call
  374. and acts as a barrier among ``mpu.get_model_parallel_group()``.
  375. Returns:
  376. float: norm
  377. """
  378. return self._global_grad_norm
  379. def __getattr__(self, name):
  380. """
  381. Pass through attributes defined in the model if they are not overridden by ds-engine.
  382. """
  383. _module = {}
  384. if "module" in self.__dict__:
  385. _module = self.__dict__['module']
  386. if name in dir(self):
  387. return getattr(self, name)
  388. elif name in dir(_module):
  389. return getattr(_module, name)
  390. else:
  391. raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
  392. def checkpoint_tag_validation_enabled(self):
  393. return self._config.checkpoint_tag_validation_enabled
  394. def checkpoint_tag_validation_fail(self):
  395. return self._config.checkpoint_tag_validation_fail
  396. def elasticity_enabled(self):
  397. return self._config.elasticity_enabled
  398. def is_elastic_model_parallel_supported(self):
  399. if self.elasticity_enabled():
  400. # Add code for finding number of GPUs per node automatically
  401. if self._config.num_gpus_per_node % self._config.elastic_model_parallel_size == 0:
  402. return True
  403. else:
  404. return False
  405. def pld_enabled(self):
  406. return self._config.pld_enabled
  407. def pld_params(self):
  408. return self._config.pld_params
  409. def pld_theta(self):
  410. return self.pld_params()[PLD_THETA]
  411. def pld_gamma(self):
  412. return self.pld_params()[PLD_GAMMA]
  413. def eigenvalue_enabled(self):
  414. return self._config.eigenvalue_enabled
  415. def eigenvalue_verbose(self):
  416. return self._config.eigenvalue_verbose
  417. def eigenvalue_max_iter(self):
  418. return self._config.eigenvalue_max_iter
  419. def eigenvalue_tol(self):
  420. return self._config.eigenvalue_tol
  421. def eigenvalue_stability(self):
  422. return self._config.eigenvalue_stability
  423. def eigenvalue_gas_boundary_resolution(self):
  424. return self._config.eigenvalue_gas_boundary_resolution
  425. def eigenvalue_layer_name(self):
  426. return self._config.eigenvalue_layer_name
  427. def eigenvalue_layer_num(self):
  428. return self._config.eigenvalue_layer_num
  429. def curriculum_enabled_legacy(self):
  430. return self._config.curriculum_enabled_legacy
  431. def curriculum_params_legacy(self):
  432. return self._config.curriculum_params_legacy
  433. def data_efficiency_enabled(self):
  434. return self._config.data_efficiency_enabled
  435. def data_efficiency_config(self):
  436. return self._config.data_efficiency_config
  437. def data_sampling_enabled(self):
  438. return self._config.data_efficiency_config[DATA_SAMPLING][DATA_SAMPLING_ENABLED]
  439. def data_sampling_config(self):
  440. return self._config.data_efficiency_config[DATA_SAMPLING]
  441. def curriculum_learning_enabled(self):
  442. return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]
  443. def curriculum_learning_config(self):
  444. return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING]
  445. def random_ltd_enabled(self):
  446. return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][RANDOM_LTD_ENABLED]
  447. def random_ltd_config(self):
  448. return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD]
  449. def random_ltd_initialize(self):
  450. assert self.random_ltd_enabled()
  451. random_ltd_config = self.random_ltd_config()
  452. random_ltd_queue = deque([x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])])
  453. count = 0
  454. for name, layer in self.module.named_modules():
  455. if isinstance(layer, RandomLayerTokenDrop):
  456. if len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name: ###[1,2,3]
  457. layer.init_config(random_ltd_config, self.random_ltd_scheduler, count)
  458. random_ltd_queue.popleft()
  459. count += 1
  460. if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count:
  461. raise ValueError(f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \
  462. equivalent to the len of random_ltd_layer_id {count}')
  463. if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:
  464. assert self.client_lr_scheduler is None
  465. raise ValueError(f'not yet support')
  466. #self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)
  467. def wall_clock_breakdown(self):
  468. return self._config.wall_clock_breakdown
  469. def flops_profiler_enabled(self):
  470. return self._config.flops_profiler_config.enabled or self.autotuning_enabled()
  471. def flops_profiler_recompute_fwd_factor(self):
  472. return self._config.flops_profiler_config.recompute_fwd_factor
  473. def flops_profiler_profile_step(self):
  474. step = self._config.flops_profiler_config.profile_step
  475. if self._config.autotuning_config.enabled:
  476. step = self.autotuning_start_profile_step()
  477. return step
  478. def flops_profiler_module_depth(self):
  479. return self._config.flops_profiler_config.module_depth
  480. def flops_profiler_top_modules(self):
  481. return self._config.flops_profiler_config.top_modules
  482. def flops_profiler_detailed(self):
  483. if self._config.autotuning_config.enabled:
  484. return False
  485. return self._config.flops_profiler_config.detailed
  486. def flops_profiler_output_file(self):
  487. return self._config.flops_profiler_config.output_file
  488. def memory_breakdown(self):
  489. return self._config.memory_breakdown
  490. def autotuning_enabled(self):
  491. return self._config.autotuning_config.enabled
  492. def autotuning_start_profile_step(self):
  493. return self._config.autotuning_config.start_profile_step
  494. def autotuning_end_profile_step(self):
  495. return self._config.autotuning_config.end_profile_step
  496. def autotuning_metric_path(self):
  497. path = self._config.autotuning_config.metric_path
  498. if not path:
  499. path = os.path.join(os.getcwd(), "autotuning_metric.json")
  500. return path
  501. def autotuning_model_info_path(self):
  502. path = self._config.autotuning_config.model_info_path
  503. if not path:
  504. path = os.path.join(os.getcwd(), "autotuning_model_info.json")
  505. return path
  506. def autotuning_metric(self):
  507. return self._config.autotuning_config.metric
  508. def autotuning_profile_model_info(self):
  509. return self.autotuning_enabled(
  510. ) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get(
  511. "profile", False)
  512. def sparse_gradients_enabled(self):
  513. return self._config.sparse_gradients_enabled
  514. def train_batch_size(self):
  515. return self._config.train_batch_size
  516. def train_micro_batch_size_per_gpu(self):
  517. return self._config.train_micro_batch_size_per_gpu
  518. def optimizer_name(self):
  519. return (self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name)
  520. def optimizer_params(self):
  521. return self._config.optimizer_params
  522. def optimizer_legacy_fusion(self):
  523. return self._config.optimizer_legacy_fusion
  524. def scheduler_name(self):
  525. return self._config.scheduler_name
  526. def scheduler_params(self):
  527. return self._config.scheduler_params
  528. def quantize_training(self):
  529. return (
  530. self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
  531. [WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
  532. self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED],
  533. self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_GROUPS],
  534. self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
  535. [WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
  536. self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_CHANGE_RATIO],
  537. self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_TYPE],
  538. self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ROUNDING],
  539. self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_VERBOSE],
  540. self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_KERNEL],
  541. )
  542. def zero_optimization(self):
  543. return self._config.zero_enabled
  544. def zero_allow_untested_optimizer(self):
  545. return self._config.zero_allow_untested_optimizer
  546. def zero_force_ds_cpu_optimizer(self):
  547. return self._config.zero_force_ds_cpu_optimizer
  548. def zero_reduce_scatter(self):
  549. return self._config.zero_config.reduce_scatter
  550. def zero_overlap_comm(self):
  551. return self._config.zero_config.overlap_comm
  552. def zero_offload_optimizer(self):
  553. return self._config.zero_config.offload_optimizer
  554. def zero_offload_param(self):
  555. return self._config.zero_config.offload_param
  556. def zero_use_cpu_optimizer(self):
  557. if self._config.zero_config.offload_optimizer is not None:
  558. return self._config.zero_config.offload_optimizer.device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]
  559. return False
  560. def zero_cpu_offload(self):
  561. if self._config.zero_config.offload_optimizer is not None:
  562. return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu
  563. return False
  564. def zero_sub_group_size(self):
  565. return self._config.zero_config.sub_group_size
  566. def zero_optimization_stage(self):
  567. return self._config.zero_optimization_stage
  568. def mics_shard_size(self):
  569. return self._config.mics_shard_size
  570. def zero_reduce_bucket_size(self):
  571. return self._config.zero_config.reduce_bucket_size
  572. def zero_allgather_bucket_size(self):
  573. return self._config.zero_config.allgather_bucket_size
  574. def zero_optimization_partition_gradients(self):
  575. return self.zero_optimization_stage() >= ZeroStageEnum.gradients
  576. def zero_optimization_partition_weights(self):
  577. return self.zero_optimization_stage() >= ZeroStageEnum.weights
  578. def is_first_weights_partition_group(self):
  579. ret = True if self.mics_shard_size() < 0 \
  580. and self.zero_optimization_partition_weights() else False
  581. if self.mics_shard_size() > 0 and self.global_rank < self.mics_shard_size():
  582. ret = True
  583. return ret
  584. def zero_contiguous_gradients(self):
  585. return self._config.zero_config.contiguous_gradients
  586. def zero_load_from_fp32_weights(self):
  587. return self._config.zero_config.load_from_fp32_weights
  588. def zero_elastic_checkpoint(self):
  589. return self._config.zero_config.elastic_checkpoint
  590. def zero_max_live_parameters(self):
  591. return self._config.zero_config.max_live_parameters
  592. def zero_max_reuse_distance(self):
  593. return self._config.zero_config.max_reuse_distance
  594. def zero_prefetch_bucket_size(self):
  595. return self._config.zero_config.prefetch_bucket_size
  596. def zero_param_persistence_threshold(self):
  597. return self._config.zero_config.param_persistence_threshold
  598. def zero_model_persistence_threshold(self):
  599. return self._config.zero_config.model_persistence_threshold
  600. def zero_gather_16bit_weights_on_model_save(self):
  601. return self._config.zero_config.gather_16bit_weights_on_model_save
  602. def zero_grad_hooks(self):
  603. return self._config.zero_config.grad_hooks
  604. def zero_legacy_stage1(self):
  605. return self._config.zero_config.legacy_stage1
  606. def zero_ignore_unused_parameters(self):
  607. return self._config.zero_config.ignore_unused_parameters
  608. def fp16_enabled(self):
  609. return self._config.fp16_enabled
  610. def bfloat16_enabled(self):
  611. return self._config.bfloat16_enabled
  612. def fp16_master_weights_and_gradients(self):
  613. return self._config.fp16_master_weights_and_gradients
  614. def amp_enabled(self):
  615. return self._config.amp_enabled
  616. def amp_params(self):
  617. return self._config.amp_params
  618. def fp16_auto_cast(self):
  619. return self._config.fp16_auto_cast
  620. def loss_scale(self):
  621. return self._config.loss_scale
  622. def gradient_accumulation_steps(self):
  623. return self._config.gradient_accumulation_steps
  624. def use_node_local_storage(self):
  625. return self._config.use_node_local_storage
  626. def load_universal_checkpoint(self):
  627. return self._config.load_universal_checkpoint
  628. @property
  629. def communication_data_type(self):
  630. res = self._config.communication_data_type
  631. if res is not None:
  632. return res
  633. if self.fp16_enabled():
  634. return torch.float16
  635. if self.bfloat16_enabled():
  636. return torch.bfloat16
  637. return torch.float32
  638. def postscale_gradients(self):
  639. return not self._config.prescale_gradients
  640. def gradient_predivide_factor(self):
  641. return self._config.gradient_predivide_factor
  642. def steps_per_print(self):
  643. return self._config.steps_per_print
  644. def zero_allgather_partitions(self):
  645. return self._config.zero_config.allgather_partitions
  646. def zero_round_robin_gradients(self):
  647. return self._config.zero_config.round_robin_gradients
  648. def zero_hpz_partition_size(self):
  649. return self._config.zero_config.zero_hpz_partition_size
  650. def zero_quantized_weights(self):
  651. return self._config.zero_config.zero_quantized_weights
  652. def zero_quantized_nontrainable_weights(self):
  653. return self._config.zero_config.zero_quantized_nontrainable_weights
  654. def zero_quantized_gradients(self):
  655. return self._config.zero_config.zero_quantized_gradients
  656. def dump_state(self):
  657. return self._config.dump_state
  658. def gradient_clipping(self):
  659. return self._config.gradient_clipping
  660. def dynamic_loss_scale(self):
  661. return self._config.loss_scale == 0
  662. def initial_dynamic_scale(self):
  663. return self._config.initial_dynamic_scale
  664. def dynamic_loss_scale_args(self):
  665. return self._config.dynamic_loss_scale_args
  666. def swap_tensor_config(self):
  667. return self._config.swap_tensor_config
  668. def aio_config(self):
  669. return self._config.aio_config
  670. def get_data_types(self):
  671. model_dtype = torch.float32
  672. if self.fp16_enabled():
  673. model_dtype = torch.float16
  674. elif self.bfloat16_enabled():
  675. model_dtype = torch.bfloat16
  676. if self._config.grad_accum_dtype is None:
  677. if model_dtype == torch.bfloat16 and not self.zero_optimization():
  678. grad_accum_dtype = torch.float32
  679. else:
  680. grad_accum_dtype = model_dtype
  681. else:
  682. grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value
  683. return (model_dtype, grad_accum_dtype)
  684. def _optimizer_has_ckpt_event_prologue(self):
  685. return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_prologue')
  686. def _optimizer_has_ckpt_event_epilogue(self):
  687. return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue')
  688. def _configure_lr_scheduler(self, client_lr_scheduler):
  689. # First check for scheduler in json configuration
  690. lr_scheduler = self._scheduler_from_config(self.optimizer)
  691. if lr_scheduler:
  692. log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0])
  693. self.lr_scheduler = lr_scheduler
  694. else:
  695. if isinstance(client_lr_scheduler, Callable):
  696. log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0])
  697. self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)
  698. else:
  699. log_dist('DeepSpeed using client LR scheduler', ranks=[0])
  700. self.lr_scheduler = client_lr_scheduler
  701. log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])
  702. def _configure_checkpointing(self, dist_init_required):
  703. self.checkpoint_engine = TorchCheckpointEngine()
  704. if self._config is not None and self._config.nebula_config.enabled:
  705. try:
  706. from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \
  707. NebulaCheckpointEngine
  708. self.checkpoint_engine = NebulaCheckpointEngine(config_params=self._config.nebula_config)
  709. except ImportError as err:
  710. logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
  711. self.checkpoint_engine = TorchCheckpointEngine()
  712. dp_rank = groups._get_sequence_data_parallel_rank()
  713. rank = self.local_rank if self.use_node_local_storage() else dp_rank
  714. # only the first data parallel process needs to store the model checkpoint
  715. # if you want to use node local storage this must be done by rank 0 on each
  716. # node
  717. self.save_non_zero_checkpoint = (rank == 0) or (self.zero_optimization_partition_weights()
  718. and self.is_first_weights_partition_group())
  719. if self.zero_optimization() or self.bfloat16_enabled():
  720. param_rank = dist.get_rank(group=self.optimizer.dp_process_group)
  721. # Only the first parameter parallel process needs to store the
  722. # optimizer state checkpoints for zero
  723. self.save_zero_checkpoint = param_rank == dp_rank
  724. def _scheduler_from_config(self, optimizer):
  725. scheduler_name = self.scheduler_name()
  726. if scheduler_name is not None:
  727. if hasattr(lr_schedules, scheduler_name):
  728. scheduler = getattr(lr_schedules, scheduler_name)
  729. else:
  730. assert hasattr(torch.optim.lr_scheduler,
  731. scheduler_name), f"DeepSpeed does not recognize LR scheduler {scheduler_name}"
  732. scheduler = getattr(torch.optim.lr_scheduler, scheduler_name)
  733. scheduler_params = self.scheduler_params()
  734. instantiated_scheduler = scheduler(optimizer, **scheduler_params)
  735. return instantiated_scheduler
  736. else:
  737. return None
  738. def _set_distributed_vars(self, args):
  739. device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank
  740. if device_rank >= 0:
  741. get_accelerator().set_device(device_rank)
  742. self.device = torch.device(get_accelerator().device_name(), device_rank)
  743. self.world_size = dist.get_world_size()
  744. self.global_rank = dist.get_rank()
  745. else:
  746. self.world_size = 1
  747. self.global_rank = 0
  748. self.device = torch.device(get_accelerator().device_name())
  749. # Configure based on command line arguments
  750. def _configure_with_arguments(self, args, mpu):
  751. # After the distributed backend is initialized we are guaranteed the LOCAL_RANK
  752. # environment variable is set. We must align args.local_rank to this value for
  753. # backwards compatibility with scripts relying on [args|self].local_rank containing
  754. # the correct local rank info. _do_args_sanity_check will ensure this is the case.
  755. if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
  756. ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
  757. local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank)
  758. assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \
  759. "not sure how to proceed as we're seeing conflicting local rank info."
  760. os.environ['LOCAL_RANK'] = local_rank
  761. self.local_rank = int(os.environ['LOCAL_RANK'])
  762. if hasattr(args, 'local_rank'):
  763. args.local_rank = self.local_rank
  764. # Validate command line arguments
  765. def _do_args_sanity_check(self, args):
  766. assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \
  767. "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \
  768. "different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed."
  769. if hasattr(args, 'local_rank') and args.local_rank is not None:
  770. assert isinstance(args.local_rank,
  771. int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}"
  772. if args.local_rank >= 0:
  773. env_local_rank = int(os.environ.get("LOCAL_RANK"))
  774. assert (
  775. env_local_rank == args.local_rank
  776. ), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}."
  777. def _is_supported_optimizer(self, optimizer_name):
  778. return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None)
  779. def _supported_optims(self):
  780. FairseqOptimizer = None
  781. try:
  782. from fairseq.optim.fairseq_optimizer import FairseqOptimizer
  783. except ImportError:
  784. pass
  785. expected_optim_types = [Optimizer]
  786. if FairseqOptimizer:
  787. # fairseq optims are not torch.optim objects
  788. expected_optim_types.append(FairseqOptimizer)
  789. return expected_optim_types
  790. # Validate configuration based on command line arguments
  791. def _do_sanity_check(self):
  792. expected_optim_types = self._supported_optims()
  793. expected_optim_types += [type(None), Callable]
  794. assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \
  795. f'Client Optimizer is of unexpected type {type(self.client_optimizer)}'
  796. if not self.client_optimizer:
  797. if self.optimizer_name() is not None:
  798. assert self._is_supported_optimizer(
  799. self.optimizer_name()), "{} is not a supported DeepSpeed Optimizer".format(self.optimizer_name())
  800. if (self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER):
  801. assert (self.dynamic_loss_scale()), "DeepSpeed {} optimizer requires dynamic loss scaling".format(
  802. self.optimizer_name())
  803. # Detect invalid combinations of client optimizer and client scheduler
  804. if isinstance(self.client_lr_scheduler, _LRScheduler):
  805. assert isinstance(self.client_optimizer, Optimizer), \
  806. f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated'
  807. def _broadcast_model(self):
  808. def is_replicated(p):
  809. if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE:
  810. return False
  811. return True
  812. for p in self.module.parameters():
  813. # Broadcast the model for different parameters
  814. if is_moe_param(p):
  815. if torch.is_tensor(p) and is_replicated(p):
  816. dist.broadcast(p,
  817. groups._get_expert_broadcast_src_rank(p.group_name),
  818. group=self.expert_data_parallel_group[p.group_name])
  819. else:
  820. if torch.is_tensor(p) and is_replicated(p):
  821. dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group)
  822. @staticmethod
  823. def __check_params(model: Module, dtype: torch.dtype) -> None:
  824. return
  825. if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0:
  826. raise ValueError(f"{dtype} is enabled but the following parameters have dtype that is "
  827. f"not {dtype}: "
  828. f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}")
  829. def _set_client_model(self, model):
  830. # register client model in _modules so that nn.module methods work correctly
  831. modules = self.__dict__.get('_modules')
  832. modules['module'] = model
  833. # register module attribute in engine but avoid getattr
  834. self.__dict__['module'] = model
  835. def _configure_distributed_model(self, model):
  836. self._set_client_model(model)
  837. is_zero_init_model = self.zero_optimization_partition_weights() and any(
  838. [hasattr(param, "ds_id") for param in self.module.parameters()])
  839. if self.fp16_enabled():
  840. if is_zero_init_model:
  841. self.__check_params(self.module, torch.half)
  842. self.module.half()
  843. elif self.bfloat16_enabled():
  844. if is_zero_init_model:
  845. self.__check_params(self.module, torch.bfloat16)
  846. self.module.bfloat16()
  847. else:
  848. self.__check_params(self.module, torch.float)
  849. # zero.Init() handles device placement of model
  850. if not (self.dont_change_device or is_zero_init_model):
  851. self.module.to(self.device)
  852. # MoE related initialization
  853. for _, module in self.module.named_modules():
  854. if isinstance(module, MoE):
  855. self.has_moe_layers = True
  856. self.num_experts.append(module.num_experts)
  857. if self.has_moe_layers:
  858. for _, module in self.module.named_modules():
  859. if isinstance(module, TopKGate):
  860. self.gate_modules.append(module)
  861. if self.wall_clock_breakdown():
  862. module.wall_clock_breakdown = True
  863. if isinstance(module, MOELayer):
  864. self.moe_layers.append(module)
  865. if self.wall_clock_breakdown():
  866. module.wall_clock_breakdown = True
  867. # Pass the mpu from here to groups. For subsequent use, just query groups
  868. if self.mpu is not None:
  869. groups.mpu = self.mpu
  870. # Set deepspeed parallelism spec. for the model including expert parallelism
  871. for _, module in self.module.named_modules():
  872. if hasattr(module, 'set_deepspeed_parallelism'):
  873. module.set_deepspeed_parallelism()
  874. # Query the groups module to get information about various parallel groups
  875. self.local_all_to_all_group = None
  876. if self.zero_quantized_gradients():
  877. log_dist("Using quantized gradients", ranks=[0])
  878. self.local_all_to_all_group = groups._get_local_all_to_all_group()
  879. self.data_parallel_group = groups._get_data_parallel_group()
  880. self.dp_world_size = groups._get_data_parallel_world_size()
  881. self.seq_data_parallel_group = groups._get_sequence_data_parallel_group()
  882. self.seq_dp_world_size = groups._get_sequence_data_parallel_world_size()
  883. self.mp_world_size = groups._get_model_parallel_world_size()
  884. self.expert_parallel_group = groups._get_expert_parallel_group_dict()
  885. self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()
  886. if not (self.amp_enabled() or is_zero_init_model):
  887. self._broadcast_model()
  888. # check if parameters are duplicated in optimizer param_groups
  889. def _check_for_duplicates(self, optimizer):
  890. for name, param in self.module.named_parameters():
  891. param_id = id(param)
  892. def ids_list(group):
  893. return [id(param) for param in group]
  894. occurrence = sum([
  895. ids_list(group['params']).count(param_id) if param_id in ids_list(group['params']) else 0
  896. for group in optimizer.param_groups
  897. ])
  898. assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behavior."
  899. def _do_optimizer_sanity_check(self, basic_optimizer):
  900. model_dtype, grad_accum_dtype = self.get_data_types()
  901. zero_enabled = self.zero_optimization()
  902. amp_enabled = self.amp_enabled()
  903. # config based assertions
  904. assert (
  905. not (amp_enabled and zero_enabled)
  906. ), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
  907. if zero_enabled:
  908. if not is_zero_supported_optimizer(basic_optimizer):
  909. assert (
  910. self.zero_allow_untested_optimizer()
  911. ), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
  912. if self.global_rank == 0:
  913. logger.warning("**** You are using ZeRO with an untested optimizer, proceed with caution *****")
  914. if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage(
  915. ) == 1 and not self.zero_cpu_offload():
  916. return BFLOAT16
  917. return ZERO_OPTIMIZATION
  918. elif amp_enabled:
  919. if model_dtype != grad_accum_dtype:
  920. raise NotImplementedError(
  921. "Model data type and gradient accumulation data type must be equal to use Amp")
  922. if model_dtype == torch.bfloat16 or model_dtype == torch.float16:
  923. raise NotImplementedError("Cannot enable both amp with (legacy) fp16 or bfloat16 mode")
  924. try:
  925. logger.info("Initializing Apex amp from: {}".format(amp.__path__))
  926. except NameError:
  927. # If apex/amp is available it will be imported above
  928. raise RuntimeError("Unable to import apex/amp, please make sure it is installed")
  929. return AMP
  930. # data type checks
  931. elif model_dtype == grad_accum_dtype:
  932. if model_dtype == torch.bfloat16:
  933. raise NotImplementedError(
  934. "Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation"
  935. )
  936. if model_dtype == torch.float16:
  937. return FP16
  938. # else optimizer_wrapper = None
  939. elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32:
  940. return BFLOAT16
  941. else:
  942. raise NotImplementedError("unsupported mix of model dtype and gradient accumulation type")
  943. return None
  944. # Configure optimizer
  945. def _configure_optimizer(self, client_optimizer, model_parameters):
  946. if client_optimizer is None:
  947. basic_optimizer = self._configure_basic_optimizer(model_parameters)
  948. log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0])
  949. else:
  950. if isinstance(client_optimizer, tuple(self._supported_optims())):
  951. basic_optimizer = client_optimizer
  952. log_dist('Using client Optimizer as basic optimizer', ranks=[0])
  953. else:
  954. basic_optimizer = client_optimizer(model_parameters)
  955. log_dist('Using client callable to create basic optimizer', ranks=[0])
  956. if self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam):
  957. if self.zero_force_ds_cpu_optimizer():
  958. msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.'
  959. raise ZeRORuntimeException(msg)
  960. basic_optimizer.param_groups[:] = [pg for pg in basic_optimizer.param_groups if len(pg["params"]) != 0]
  961. log_dist("Removing param_group that has no 'params' in the basic Optimizer", ranks=[0])
  962. self._check_for_duplicates(basic_optimizer)
  963. self.basic_optimizer = basic_optimizer
  964. log_dist("DeepSpeed Basic Optimizer = {}".format(basic_optimizer.__class__.__name__), ranks=[0])
  965. optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer)
  966. if optimizer_wrapper == ZERO_OPTIMIZATION:
  967. self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  968. elif optimizer_wrapper == AMP:
  969. amp_params = self.amp_params()
  970. log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0])
  971. model, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
  972. self._set_client_model(model)
  973. self._broadcast_model()
  974. # TODO: maybe need to broadcast experts differently?
  975. elif optimizer_wrapper == FP16:
  976. self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
  977. elif optimizer_wrapper == BFLOAT16:
  978. self.optimizer = self._configure_bf16_optimizer(basic_optimizer)
  979. else:
  980. self.optimizer = basic_optimizer
  981. log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0])
  982. self.compression_scheduler = self._configure_compression_scheduler()
  983. self.quantizer = self._configure_quantization()
  984. def _configure_basic_optimizer(self, model_parameters):
  985. optimizer_parameters = self.optimizer_params()
  986. if optimizer_parameters is None:
  987. optimizer_parameters = {}
  988. # print(optimizer_parameters.keys())
  989. if "max_grad_norm" in optimizer_parameters.keys():
  990. raise ValueError(
  991. "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
  992. )
  993. if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
  994. torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
  995. adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)
  996. # Optimizer name of Adam forces AdamW logic unless adam_w_mode is explicitly set
  997. effective_adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode
  998. if torch_adam:
  999. if not effective_adam_w_mode:
  1000. optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
  1001. else:
  1002. optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)
  1003. else:
  1004. if self.zero_use_cpu_optimizer():
  1005. from deepspeed.ops.adam import DeepSpeedCPUAdam
  1006. optimizer = DeepSpeedCPUAdam(model_parameters,
  1007. **optimizer_parameters,
  1008. adamw_mode=effective_adam_w_mode)
  1009. else:
  1010. from deepspeed.ops.adam import FusedAdam
  1011. optimizer = FusedAdam(
  1012. model_parameters,
  1013. **optimizer_parameters,
  1014. adam_w_mode=effective_adam_w_mode,
  1015. )
  1016. elif self.optimizer_name() == ADAGRAD_OPTIMIZER:
  1017. if self.zero_use_cpu_optimizer():
  1018. from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
  1019. optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters)
  1020. else:
  1021. optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters)
  1022. elif self.optimizer_name() == LAMB_OPTIMIZER:
  1023. from deepspeed.ops.lamb import FusedLamb
  1024. optimizer = FusedLamb(model_parameters, **optimizer_parameters)
  1025. elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
  1026. assert not self.zero_optimization(), "1bit-Adam is not compatible with ZeRO"
  1027. from deepspeed.runtime.fp16.onebit.adam import OnebitAdam
  1028. optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
  1029. if not self.fp16_enabled():
  1030. logger.warning(f"Currently the convergence of 1-bit Adam is only verified under FP16")
  1031. elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER:
  1032. assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO"
  1033. from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
  1034. optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters)
  1035. if not self.fp16_enabled():
  1036. logger.warning(f'Currently the convergence of 0/1 Adam is only verified under FP16')
  1037. elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER:
  1038. assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO"
  1039. from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb
  1040. optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters)
  1041. if not self.fp16_enabled():
  1042. logger.warning(f"Currently the convergence of 1-bit Lamb is only verified under FP16")
  1043. elif self.optimizer_name() == MUADAM_OPTIMIZER:
  1044. try:
  1045. from mup import MuAdam
  1046. except ImportError:
  1047. logger.error(f"Install mup to use MuAdam optimizer")
  1048. optimizer = MuAdam(model_parameters, **optimizer_parameters)
  1049. elif self.optimizer_name() == MUADAMW_OPTIMIZER:
  1050. try:
  1051. from mup import MuAdamW
  1052. except ImportError:
  1053. logger.error(f"Install mup to use MuAdamW optimizer")
  1054. optimizer = MuAdamW(model_parameters, **optimizer_parameters)
  1055. elif self.optimizer_name() == MUSGD_OPTIMIZER:
  1056. try:
  1057. from mup import MuSGD
  1058. except ImportError:
  1059. logger.error(f"Install mup to use MuSGD optimizer")
  1060. optimizer = MuSGD(model_parameters, **optimizer_parameters)
  1061. else:
  1062. torch_optimizer = getattr(torch.optim, self.optimizer_name())
  1063. optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
  1064. return optimizer
  1065. def _configure_compression_scheduler(self):
  1066. return compression_scheduler(self.module, self._config.compression_config)
  1067. def _configure_random_ltd_scheduler(self, configs):
  1068. return RandomLTDScheduler(configs)
  1069. def _configure_quantization(self):
  1070. (
  1071. quantize_weight_in_forward,
  1072. quantize_enabled,
  1073. q_groups,
  1074. q_mixed_fp16,
  1075. q_change_ratio,
  1076. q_type,
  1077. q_rounding,
  1078. q_verbose,
  1079. use_quantizer_kernel,
  1080. ) = self.quantize_training()
  1081. if quantize_enabled and not quantize_weight_in_forward:
  1082. assert self.fp16_enabled(
  1083. ), "MoQ (quantize in optimization step) weight quantization is only supported for FP16"
  1084. quantizer = None
  1085. if quantize_enabled and not quantize_weight_in_forward:
  1086. from deepspeed.runtime.quantize import Quantizer
  1087. quantizer = Quantizer(
  1088. q_groups,
  1089. q_mixed_fp16,
  1090. q_change_ratio,
  1091. q_type,
  1092. q_rounding,
  1093. q_verbose,
  1094. self.eigenvalue_enabled(),
  1095. use_quantizer_kernel,
  1096. self.eigenvalue_layer_num() if self.eigenvalue_enabled() else 0,
  1097. )
  1098. return quantizer
  1099. def _configure_fp16_optimizer(self, optimizer):
  1100. initial_dynamic_scale = self.initial_dynamic_scale()
  1101. dynamic_loss_args = self.dynamic_loss_scale_args()
  1102. clip_grad = self.gradient_clipping()
  1103. if APEX_INSTALLED:
  1104. fused_opts = (apex.optimizers.FusedAdam, FusedAdam)
  1105. else:
  1106. fused_opts = FusedAdam
  1107. if isinstance(optimizer, fused_opts) \
  1108. or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]:
  1109. if self.dynamic_loss_scale():
  1110. log_dist(f'Creating fp16 optimizer with dynamic loss scale', ranks=[0])
  1111. timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
  1112. optimizer = FP16_Optimizer(
  1113. optimizer,
  1114. deepspeed=self,
  1115. dynamic_loss_scale=True,
  1116. initial_dynamic_scale=initial_dynamic_scale,
  1117. dynamic_loss_args=dynamic_loss_args,
  1118. mpu=self.mpu,
  1119. clip_grad=clip_grad,
  1120. fused_adam_legacy=self.optimizer_legacy_fusion(),
  1121. timers=timers,
  1122. has_moe_layers=self.has_moe_layers,
  1123. )
  1124. else:
  1125. log_dist(f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}', ranks=[0])
  1126. optimizer = FP16_Optimizer(
  1127. optimizer,
  1128. deepspeed=self,
  1129. static_loss_scale=self.loss_scale(),
  1130. mpu=self.mpu,
  1131. clip_grad=clip_grad,
  1132. fused_adam_legacy=self.optimizer_legacy_fusion(),
  1133. has_moe_layers=self.has_moe_layers,
  1134. )
  1135. else:
  1136. log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0])
  1137. optimizer = FP16_UnfusedOptimizer(
  1138. optimizer,
  1139. deepspeed=self,
  1140. static_loss_scale=self.loss_scale(),
  1141. dynamic_loss_scale=self.dynamic_loss_scale(),
  1142. dynamic_loss_args=dynamic_loss_args,
  1143. mpu=self.mpu,
  1144. clip_grad=clip_grad,
  1145. fused_lamb_legacy=self.optimizer_name() == LAMB_OPTIMIZER,
  1146. )
  1147. return optimizer
  1148. def _configure_bf16_optimizer(self, optimizer):
  1149. clip_grad = self.gradient_clipping()
  1150. if optimizer is None:
  1151. optimizer = DummyOptim(list(self.module.parameters()))
  1152. log_dist('Creating BF16 optimizer', ranks=[0])
  1153. timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
  1154. optimizer = BF16_Optimizer(optimizer,
  1155. self.param_names,
  1156. mpu=self.mpu,
  1157. clip_grad=clip_grad,
  1158. allgather_bucket_size=self.zero_allgather_bucket_size(),
  1159. dp_process_group=self.seq_data_parallel_group,
  1160. timers=timers)
  1161. return optimizer
  1162. def _configure_zero_optimizer(self, optimizer):
  1163. zero_stage = self.zero_optimization_stage()
  1164. mics_shard_size = self.mics_shard_size()
  1165. model_dtype, gradient_accumulation_dtype = self.get_data_types()
  1166. timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
  1167. if optimizer is None:
  1168. optimizer = DummyOptim(list(self.module.parameters()))
  1169. if self.zero_legacy_stage1():
  1170. raise Exception(
  1171. "The deprecated version of ZeRO Stage 1 is not supported in deepspeed >= 0.5.9. Please downgrade to a version less than 0.5.9 if you need to use this deprecated version of ZeRO."
  1172. )
  1173. if zero_stage <= ZeroStageEnum.gradients:
  1174. overlap_comm = self.zero_overlap_comm()
  1175. contiguous_gradients = self.zero_contiguous_gradients()
  1176. round_robin_gradients = self.zero_round_robin_gradients()
  1177. assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage)
  1178. log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
  1179. # Overlap and contiguous grads are meaningless in stage 1 and are ignored
  1180. if zero_stage == ZeroStageEnum.optimizer_states:
  1181. overlap_comm = False
  1182. round_robin_gradients = False
  1183. # Non-MoE requires contiguous grads to be disabled w. stage 1
  1184. if not self.has_moe_layers:
  1185. contiguous_gradients = False
  1186. if isinstance(self.module, PipelineModule):
  1187. if overlap_comm:
  1188. logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.")
  1189. overlap_comm = False
  1190. optimizer = DeepSpeedZeroOptimizer(
  1191. optimizer,
  1192. self.param_names,
  1193. timers=timers,
  1194. static_loss_scale=self.loss_scale(),
  1195. dynamic_loss_scale=self.dynamic_loss_scale(),
  1196. dynamic_loss_args=self.dynamic_loss_scale_args(),
  1197. clip_grad=self.gradient_clipping(),
  1198. contiguous_gradients=contiguous_gradients,
  1199. reduce_bucket_size=self.zero_reduce_bucket_size(),
  1200. allgather_bucket_size=self.zero_allgather_bucket_size(),
  1201. dp_process_group=self.seq_data_parallel_group,
  1202. expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None,
  1203. expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None,
  1204. reduce_scatter=self.zero_reduce_scatter(),
  1205. overlap_comm=overlap_comm,
  1206. offload_optimizer_config=self.zero_offload_optimizer(),
  1207. mpu=self.mpu,
  1208. postscale_gradients=self.postscale_gradients(),
  1209. gradient_predivide_factor=self.gradient_predivide_factor(),
  1210. gradient_accumulation_steps=self.gradient_accumulation_steps(),
  1211. ignore_unused_parameters=self.zero_ignore_unused_parameters(),
  1212. partition_grads=zero_stage == ZeroStageEnum.gradients,
  1213. round_robin_gradients=round_robin_gradients,
  1214. has_moe_layers=self.has_moe_layers,
  1215. fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(),
  1216. gradient_accumulation_dtype=gradient_accumulation_dtype,
  1217. communication_data_type=self.communication_data_type,
  1218. elastic_checkpoint=self.zero_elastic_checkpoint())
  1219. elif zero_stage == ZeroStageEnum.weights:
  1220. assert not self.has_moe_layers, "MoE not supported with Stage 3"
  1221. if isinstance(optimizer, DummyOptim):
  1222. log_dist("Creating ZeRO Offload", ranks=[0])
  1223. zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
  1224. if self.zero_hpz_partition_size() > 1 and zero_param_parallel_group is None:
  1225. self._set_zero_group_parallelism()
  1226. zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
  1227. optimizer = DeepSpeedZeRoOffload(
  1228. self.module,
  1229. timers=timers,
  1230. ds_config=self.config,
  1231. overlap_comm=self.zero_overlap_comm(),
  1232. prefetch_bucket_size=self.zero_prefetch_bucket_size(),
  1233. max_reuse_distance=self.zero_max_reuse_distance(),
  1234. max_live_parameters=self.zero_max_live_parameters(),
  1235. param_persistence_threshold=self.zero_param_persistence_threshold(),
  1236. model_persistence_threshold=self.zero_model_persistence_threshold(),
  1237. offload_param_config=self.zero_offload_param(),
  1238. mpu=self.mpu,
  1239. zero_param_parallel_group=zero_param_parallel_group,
  1240. zero_quantized_weights=self.zero_quantized_weights(),
  1241. zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
  1242. )
  1243. else:
  1244. log_dist(
  1245. f'Creating fp16 ZeRO stage {zero_stage} optimizer,'
  1246. f' MiCS is enabled {mics_shard_size>0},'
  1247. f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}',
  1248. ranks=[0])
  1249. if mics_shard_size > 0:
  1250. return self._return_mics_optimizer(optimizer, timers)
  1251. log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
  1252. from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
  1253. optimizer = DeepSpeedZeroOptimizer_Stage3(
  1254. self.module,
  1255. optimizer,
  1256. timers=timers,
  1257. ds_config=self.config,
  1258. static_loss_scale=self.loss_scale(),
  1259. dynamic_loss_scale=self.dynamic_loss_scale(),
  1260. dynamic_loss_args=self.dynamic_loss_scale_args(),
  1261. clip_grad=self.gradient_clipping(),
  1262. contiguous_gradients=self.zero_contiguous_gradients(),
  1263. reduce_bucket_size=self.zero_reduce_bucket_size(),
  1264. prefetch_bucket_size=self.zero_prefetch_bucket_size(),
  1265. max_reuse_distance=self.zero_max_reuse_distance(),
  1266. max_live_parameters=self.zero_max_live_parameters(),
  1267. param_persistence_threshold=self.zero_param_persistence_threshold(),
  1268. model_persistence_threshold=self.zero_model_persistence_threshold(),
  1269. dp_process_group=self.seq_data_parallel_group,
  1270. all2all_process_group=self.local_all_to_all_group,
  1271. reduce_scatter=self.zero_reduce_scatter(),
  1272. overlap_comm=self.zero_overlap_comm(),
  1273. offload_optimizer_config=self.zero_offload_optimizer(),
  1274. offload_param_config=self.zero_offload_param(),
  1275. sub_group_size=self.zero_sub_group_size(),
  1276. mpu=self.mpu,
  1277. postscale_gradients=self.postscale_gradients(),
  1278. gradient_predivide_factor=self.gradient_predivide_factor(),
  1279. gradient_accumulation_steps=self.gradient_accumulation_steps(),
  1280. aio_config=self.aio_config(),
  1281. gradient_accumulation_dtype=gradient_accumulation_dtype,
  1282. communication_data_type=self.communication_data_type,
  1283. zero_hpz_partition_size=self.zero_hpz_partition_size(),
  1284. zero_quantized_weights=self.zero_quantized_weights(),
  1285. zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
  1286. )
  1287. else:
  1288. raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
  1289. return optimizer
  1290. def _return_mics_optimizer(self, basic_optimizer, timers):
  1291. from deepspeed.runtime.zero.mics import MiCS_Optimizer
  1292. model_dtype, gradient_accumulation_dtype = self.get_data_types()
  1293. optimizer = MiCS_Optimizer(self.module,
  1294. basic_optimizer,
  1295. timers=timers,
  1296. ds_config=self.config,
  1297. static_loss_scale=self.loss_scale(),
  1298. dynamic_loss_scale=self.dynamic_loss_scale(),
  1299. dynamic_loss_args=self.dynamic_loss_scale_args(),
  1300. clip_grad=self.gradient_clipping(),
  1301. contiguous_gradients=self.zero_contiguous_gradients(),
  1302. reduce_bucket_size=self.zero_reduce_bucket_size(),
  1303. prefetch_bucket_size=self.zero_prefetch_bucket_size(),
  1304. max_reuse_distance=self.zero_max_reuse_distance(),
  1305. max_live_parameters=self.zero_max_live_parameters(),
  1306. param_persistence_threshold=self.zero_param_persistence_threshold(),
  1307. model_persistence_threshold=self.zero_model_persistence_threshold(),
  1308. dp_process_group=self.seq_data_parallel_group,
  1309. reduce_scatter=self.zero_reduce_scatter(),
  1310. overlap_comm=self.zero_overlap_comm(),
  1311. offload_optimizer_config=self.zero_offload_optimizer(),
  1312. offload_param_config=self.zero_offload_param(),
  1313. sub_group_size=self.zero_sub_group_size(),
  1314. mpu=self.mpu,
  1315. postscale_gradients=self.postscale_gradients(),
  1316. gradient_predivide_factor=self.gradient_predivide_factor(),
  1317. gradient_accumulation_steps=self.gradient_accumulation_steps(),
  1318. aio_config=self.aio_config(),
  1319. gradient_accumulation_dtype=gradient_accumulation_dtype,
  1320. communication_data_type=self.communication_data_type)
  1321. return optimizer
  1322. def _configure_eigenvalue(self):
  1323. eigenvalue = Eigenvalue(
  1324. verbose=self.eigenvalue_verbose(),
  1325. max_iter=self.eigenvalue_max_iter(),
  1326. tol=self.eigenvalue_tol(),
  1327. stability=self.eigenvalue_stability(),
  1328. gas_boundary_resolution=self.eigenvalue_gas_boundary_resolution(),
  1329. layer_name=self.eigenvalue_layer_name(),
  1330. layer_num=self.eigenvalue_layer_num(),
  1331. )
  1332. return eigenvalue
  1333. def _configure_progressive_layer_drop(self):
  1334. pld = ProgressiveLayerDrop(theta=self.pld_theta(), gamma=self.pld_gamma())
  1335. return pld
  1336. def _configure_curriculum_scheduler_legacy(self):
  1337. scheduler = CurriculumScheduler(self.curriculum_params_legacy())
  1338. return scheduler
  1339. @staticmethod
  1340. def is_map_style_dataset(obj):
  1341. return hasattr(obj, "__getitem__") and hasattr(obj, "__len__")
  1342. @staticmethod
  1343. def is_iterable_style_dataset(obj):
  1344. return isinstance(obj, torch.utils.data.IterableDataset) # hasattr(obj, "__iter__") should work as well
  1345. def dataloader_drop_last(self):
  1346. return self._config.dataloader_drop_last
  1347. def was_step_applied(self) -> bool:
  1348. """Returns True if the latest ``step()`` produced in parameter updates.
  1349. Note that a ``False`` return is not an error condition. Steps are frequently
  1350. no-ops, such as between gradient accumulation boundaries or when overflows
  1351. occur.
  1352. Returns:
  1353. bool: Whether the latest ``step()`` modified model parameters.
  1354. """
  1355. return self._step_applied
  1356. def deepspeed_io(self,
  1357. dataset,
  1358. batch_size=None,
  1359. route=ROUTE_TRAIN,
  1360. pin_memory=True,
  1361. data_sampler=None,
  1362. collate_fn=None,
  1363. num_local_io_workers=None):
  1364. if not (self.is_map_style_dataset(dataset) or self.is_iterable_style_dataset(dataset)):
  1365. raise ValueError("Training data must be a torch Dataset")
  1366. if batch_size is None:
  1367. batch_size = self.train_micro_batch_size_per_gpu()
  1368. if collate_fn is None:
  1369. collate_fn = self.collate_fn
  1370. # Currently we only use timer in train route
  1371. deepspeed_io_timer = None
  1372. if route == ROUTE_TRAIN:
  1373. deepspeed_io_timer = self.tput_timer
  1374. # If mpu is provided, forward world size and parallel rank to sampler.
  1375. data_parallel_world_size = self.dp_world_size
  1376. data_parallel_rank = self.global_rank
  1377. if self.mpu is not None:
  1378. data_parallel_world_size = self.mpu.get_data_parallel_world_size()
  1379. data_parallel_rank = self.mpu.get_data_parallel_rank()
  1380. if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):
  1381. data_sampler = torch.utils.data.DistributedSampler(
  1382. dataset,
  1383. num_replicas=data_parallel_world_size,
  1384. rank=data_parallel_rank,
  1385. shuffle=False,
  1386. )
  1387. deepspeed_dataloader_config = {}
  1388. if self.curriculum_learning_enabled():
  1389. deepspeed_dataloader_config = {
  1390. CURRICULUM_LEARNING: self.curriculum_learning_enabled(),
  1391. DATA_EFFICIENCY: self.data_efficiency_config(),
  1392. DATA_PARALLEL_GROUP: self.data_parallel_group,
  1393. GRADIENT_ACCUMULATION_STEPS: self.gradient_accumulation_steps(),
  1394. GLOBAL_RANK: self.global_rank,
  1395. DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
  1396. }
  1397. return DeepSpeedDataLoader(dataset=dataset,
  1398. batch_size=batch_size,
  1399. pin_memory=pin_memory,
  1400. collate_fn=collate_fn,
  1401. local_rank=self.local_rank,
  1402. tput_timer=deepspeed_io_timer,
  1403. num_local_io_workers=num_local_io_workers,
  1404. data_sampler=data_sampler,
  1405. data_parallel_world_size=data_parallel_world_size,
  1406. data_parallel_rank=data_parallel_rank,
  1407. dataloader_drop_last=self.dataloader_drop_last(),
  1408. deepspeed_dataloader_config=deepspeed_dataloader_config)
  1409. def train(self, mode=True):
  1410. r""""""
  1411. self.warn_unscaled_loss = True
  1412. self.module.train(mode)
  1413. def eval(self):
  1414. r""""""
  1415. self.warn_unscaled_loss = True
  1416. self.module.train(False)
  1417. def _scale_loss_by_gas(self, prescaled_loss):
  1418. if isinstance(prescaled_loss, torch.Tensor):
  1419. scaled_loss = prescaled_loss / self.gradient_accumulation_steps()
  1420. elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list):
  1421. scaled_loss = []
  1422. for l in prescaled_loss:
  1423. if isinstance(l, torch.Tensor):
  1424. scaled_loss.append(l / self.gradient_accumulation_steps())
  1425. else:
  1426. scaled_loss.append(l)
  1427. else:
  1428. scaled_loss = prescaled_loss
  1429. if self.warn_unscaled_loss:
  1430. logger.warning(f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}")
  1431. self.warn_unscaled_loss = False
  1432. return scaled_loss
  1433. @instrument_w_nvtx
  1434. def forward(self, *inputs, **kwargs):
  1435. r"""Execute forward propagation
  1436. Arguments:
  1437. *inputs: Variable length input list
  1438. **kwargs: variable length keyword arguments
  1439. """
  1440. if self.autotuning_profile_model_info():
  1441. ma = get_ma_status()
  1442. else:
  1443. see_memory_usage("Engine before forward", force=self.memory_breakdown())
  1444. flops_profiler_active = (self.flops_profiler_enabled()
  1445. and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0)
  1446. # used to check quantization happens at step 0!
  1447. if self.global_steps == 0 and hasattr(self, "compression_scheduler"):
  1448. self.compression_scheduler.step(step_zero_check=True)
  1449. if self.quantizer:
  1450. tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage(
  1451. ) == 2 else self.optimizer.fp16_groups
  1452. if self.compression_scheduler.weight_quantization_enabled:
  1453. self.quantizer.quantize(
  1454. tensor_to_quantize,
  1455. (self.optimizer.overflow if self.fp16_enabled() else False),
  1456. self.eigenvalue_enabled(),
  1457. None,
  1458. )
  1459. if flops_profiler_active:
  1460. self.flops_profiler.start_profile(ignore_list=None)
  1461. if self.module.training:
  1462. if self.progressive_layer_drop:
  1463. kwargs.update(self.progressive_layer_drop.get_state())
  1464. if self.__class__.__name__ != "PipelineEngine":
  1465. # TODO: The above if condition is a HACK since for PipelineEngine
  1466. # it's difficult to inject argument in forward pass.
  1467. if self.module.training and self.curriculum_enabled_legacy():
  1468. self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
  1469. if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
  1470. kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
  1471. if self.module.training and self.random_ltd_enabled():
  1472. self.random_ltd_scheduler.update_seq(self.global_steps)
  1473. if self.zero_optimization_partition_weights():
  1474. # Enable automated discovery of external parameters by indicating that
  1475. # we are in a forward pass.
  1476. for module in self.module.modules():
  1477. module._parameters._in_forward = True
  1478. self._start_timers(self.engine_timers.forward_timers)
  1479. if self.training_dataloader is None:
  1480. self.tput_timer.start()
  1481. if self.fp16_auto_cast():
  1482. inputs = self._cast_inputs_half(inputs)
  1483. loss = self.module(*inputs, **kwargs)
  1484. if self.zero_optimization_partition_weights():
  1485. # Disable automated discovery of external parameters
  1486. for module in self.module.modules():
  1487. module._parameters._in_forward = False
  1488. self._stop_timers(self.engine_timers.forward_timers)
  1489. if flops_profiler_active:
  1490. self.flops_profiler.stop_profile()
  1491. if self.autotuning_profile_model_info():
  1492. activation_mem = get_ma_status() - ma
  1493. self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem
  1494. print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path())
  1495. exit()
  1496. else:
  1497. see_memory_usage("Engine after forward", force=self.memory_breakdown())
  1498. return loss
  1499. def _cast_inputs_half(self, inputs):
  1500. if isinstance(inputs, (list, tuple)):
  1501. new_inputs = []
  1502. for v in inputs:
  1503. new_inputs.append(self._cast_inputs_half(v))
  1504. return inputs.__class__(new_inputs)
  1505. elif isinstance(inputs, dict):
  1506. new_inputs = {}
  1507. for k, v in inputs.items():
  1508. new_inputs[k] = self._cast_inputs_half(v)
  1509. return new_inputs
  1510. elif hasattr(inputs, 'half'):
  1511. return inputs.half()
  1512. else:
  1513. return inputs
  1514. def print_forward_breakdown(self, fwd_time):
  1515. gate_time = 0.0
  1516. moe_time = 0.0
  1517. falltoall = 0.0
  1518. salltoall = 0.0
  1519. for gate in self.gate_modules:
  1520. #logger.info(f"Individual TopK gate time: {gate.gate_time:.2f} ms")
  1521. gate_time += gate.gate_time
  1522. for l in self.moe_layers:
  1523. #logger.info(f"MoE layer; total: {l.time_moe:.2f} ms, first alltoall: {l.time_falltoall:.2f}, second alltoall: {l.time_salltoall:.2f}")
  1524. moe_time += l.time_moe
  1525. falltoall += l.time_falltoall
  1526. salltoall += l.time_salltoall
  1527. # TODO: Allreduce/average them across ranks for more accurate timing.
  1528. # if deepspeed.comm.get_rank() == 0:
  1529. log_dist(
  1530. f"time (ms) | fwd: {fwd_time:.2f} (fwd_moe: {moe_time:.2f}, 1st_a2a: {falltoall:.2f}, 2nd_a2a: {salltoall:.2f}, top_k: {gate_time:.2f})",
  1531. ranks=[0])
  1532. @instrument_w_nvtx
  1533. def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
  1534. assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \
  1535. f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled'
  1536. # Pass (PP) gas boundary flag to optimizer (required for zero)
  1537. self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
  1538. # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
  1539. if self.zero_optimization_partition_gradients():
  1540. self.optimizer.overlapping_partition_gradients_reduce_epilogue()
  1541. # Communicate only at gradient accumulation boundaries
  1542. elif self.is_gradient_accumulation_boundary():
  1543. if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr(
  1544. self.optimizer, 'reduce_gradients'):
  1545. self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
  1546. else:
  1547. self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
  1548. @instrument_w_nvtx
  1549. def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):
  1550. r"""Execute backward pass on the loss
  1551. Arguments:
  1552. loss: Torch tensor on which to execute backward propagation
  1553. allreduce_gradients: is deprecated, ignored, and will soon be removed'
  1554. retain_graph: bool, default: false
  1555. forward on user defined choice of retain_graph
  1556. """
  1557. see_memory_usage("Engine before backward", force=self.memory_breakdown())
  1558. if self.scale_wrt_gas is not None:
  1559. scale_wrt_gas = self.scale_wrt_gas
  1560. if not allreduce_gradients:
  1561. logger.warning(f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed")
  1562. # scale loss w.r.t. gradient accumulation if needed
  1563. if self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
  1564. loss = self._scale_loss_by_gas(loss.float())
  1565. # Log training loss
  1566. self.losses += loss.mean().item()
  1567. if self.monitor.enabled:
  1568. if self.is_gradient_accumulation_boundary():
  1569. if self.global_rank == 0:
  1570. self.summary_events = [(
  1571. f"Train/Samples/train_loss",
  1572. self.losses,
  1573. self.global_samples,
  1574. )]
  1575. self.monitor.write_events(self.summary_events)
  1576. self._start_timers(self.engine_timers.backward_timers)
  1577. assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
  1578. "must provide optimizer during init in order to use backward"
  1579. self._start_timers(self.engine_timers.backward_inner_timers)
  1580. if self.zero_optimization():
  1581. self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
  1582. self.optimizer.backward(loss, retain_graph=retain_graph)
  1583. elif self.amp_enabled():
  1584. # AMP requires delaying unscale when inside gradient accumulation boundaries
  1585. # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
  1586. delay_unscale = not self.is_gradient_accumulation_boundary()
  1587. with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
  1588. scaled_loss.backward(retain_graph=retain_graph)
  1589. elif self.fp16_enabled():
  1590. if self.eigenvalue_enabled():
  1591. self.optimizer.backward(loss, create_graph=True, retain_graph=True)
  1592. else:
  1593. self.optimizer.backward(loss, retain_graph=retain_graph)
  1594. elif self.bfloat16_enabled():
  1595. self.optimizer.backward(loss)
  1596. else:
  1597. if self.eigenvalue_enabled():
  1598. loss.backward(create_graph=True, retain_graph=True)
  1599. else:
  1600. loss.backward(retain_graph=retain_graph)
  1601. self._stop_timers(self.engine_timers.backward_inner_timers)
  1602. self._start_timers(self.engine_timers.backward_reduce_timers)
  1603. if allreduce_gradients and self.enable_backward_allreduce:
  1604. # Traditional code path that allreduces the module parameter grads
  1605. self.allreduce_gradients()
  1606. self._stop_timers(self.engine_timers.backward_reduce_timers)
  1607. self._stop_timers(self.engine_timers.backward_timers)
  1608. if release_loss:
  1609. # loss.data = None
  1610. pass
  1611. see_memory_usage("Engine after backward", force=self.memory_breakdown())
  1612. return loss
  1613. def is_gradient_accumulation_boundary(self):
  1614. """
  1615. Query whether the current micro-batch is at the boundary of
  1616. gradient accumulation, and thus will trigger gradient reductions and
  1617. an optimizer step.
  1618. Returns:
  1619. bool: if the current step is a gradient accumulation boundary.
  1620. """
  1621. if self._is_gradient_accumulation_boundary is None:
  1622. return (self.micro_steps + 1) % \
  1623. self.gradient_accumulation_steps() == 0
  1624. else:
  1625. return self._is_gradient_accumulation_boundary
  1626. def set_gradient_accumulation_boundary(self, is_boundary):
  1627. """
  1628. Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional
  1629. feature and should be used with care. The state should be set before to the intended
  1630. value before each forward/backward. The final forward/backward should have the
  1631. boundary state set to True. This style allows client code to only call engine.step() once after all
  1632. the gradient accumulation passes are complete. See example below:
  1633. .. code-block:: python
  1634. engine.set_gradient_accumulation_boundary(False)
  1635. for _ in range(gradient_accumulation_steps - 1):
  1636. micro_batch = next(data_loader)
  1637. loss = engine(micro_batch)
  1638. engine.backward(loss)
  1639. engine.set_gradient_accumulation_boundary(True)
  1640. micro_batch = next(data_loader)
  1641. loss = engine(micro_batch)
  1642. engine.backward(loss)
  1643. engine.step()
  1644. Arguments:
  1645. is_boundary (bool): are we at a gradient accumulation boundary or not?
  1646. """
  1647. self._is_gradient_accumulation_boundary = is_boundary
  1648. self.optimizer.is_gradient_accumulation_boundary = is_boundary
  1649. def zero_grad(self):
  1650. """
  1651. Zero parameter grads.
  1652. """
  1653. for param_name, param in self.module.named_parameters():
  1654. param.grad = None
  1655. def clip_fp32_gradients(self):
  1656. clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu)
  1657. def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
  1658. if self.gradient_clipping() > 0.0:
  1659. if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()):
  1660. self.clip_fp32_gradients()
  1661. elif self.amp_enabled():
  1662. # AMP's recommended way of doing clipping
  1663. # https://nvidia.github.io/apex/advanced.html#gradient-clipping
  1664. master_params = amp.master_params(self.optimizer)
  1665. clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
  1666. self.optimizer.step()
  1667. if hasattr(self.optimizer, '_global_grad_norm'):
  1668. self._global_grad_norm = self.optimizer._global_grad_norm
  1669. # Quantize the updated parameter if there is no overflow
  1670. if self.quantizer:
  1671. tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage(
  1672. ) == 2 else self.optimizer.fp16_groups
  1673. if self.compression_scheduler.weight_quantization_enabled:
  1674. self.quantizer.quantize(
  1675. tensor_to_quantize,
  1676. (self.optimizer.overflow if self.fp16_enabled() else False),
  1677. self.eigenvalue_enabled(),
  1678. block_eigenvalue,
  1679. )
  1680. # zero grad in basic optimizer could be unreliable and may not exhibit
  1681. # the behavior that we want
  1682. if self.bfloat16_enabled():
  1683. # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated
  1684. if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"):
  1685. self.optimizer.zero_grad()
  1686. else:
  1687. pass
  1688. elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled():
  1689. self.optimizer.zero_grad()
  1690. else:
  1691. self.zero_grad()
  1692. report_progress = self.global_rank == 0 if self.global_rank else True
  1693. # Check overflow here since in DS fp16 optimizer, the overflow is updated in above step() function.
  1694. overflow = False
  1695. if hasattr(self.optimizer, "overflow"):
  1696. overflow = self.optimizer.overflow
  1697. self._step_applied = not overflow
  1698. if overflow:
  1699. self.skipped_steps += 1
  1700. else:
  1701. self.compression_scheduler.step()
  1702. if self.lr_scheduler is not None:
  1703. try:
  1704. self.lr_scheduler.step(**(lr_kwargs or {}))
  1705. except TypeError:
  1706. # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines.
  1707. # We don't currently have a way to specify lr_kwargs from
  1708. # pipe_engine.train_batch()
  1709. self.lr_scheduler.step(self.train_batch_size())
  1710. if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
  1711. self._report_progress(self.global_steps + 1)
  1712. self.losses = 0.0
  1713. self.global_steps += 1
  1714. self.global_samples += self.train_batch_size()
  1715. def step(self, lr_kwargs=None):
  1716. r"""Execute the weight update step after forward and backward propagation
  1717. on effective_train_batch.
  1718. """
  1719. see_memory_usage("Engine before step", force=self.memory_breakdown())
  1720. # Check early because self.global_steps is incremented at some point here.
  1721. # TODO: Delay self.global_steps increment until very end of this function.
  1722. flops_profiler_active = self.flops_profiler_enabled(
  1723. ) and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0
  1724. self._start_timers(self.engine_timers.step_timers)
  1725. assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
  1726. "must provide optimizer during init in order to use step"
  1727. report_progress = False
  1728. self._step_applied = False # assume False, will flip to True
  1729. # Update the model when we reach gradient accumulation boundaries
  1730. if self.is_gradient_accumulation_boundary():
  1731. self.gas_boundary_ctr += 1
  1732. if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)
  1733. and self.quantizer.any_precision_switch()):
  1734. log_dist(f"computing eigenvalue...", ranks=[0])
  1735. self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device,
  1736. self.optimizer.cur_scale)
  1737. if self.progressive_layer_drop:
  1738. self.progressive_layer_drop.update_state(self.global_steps)
  1739. if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()
  1740. and self.quantizer.any_precision_switch()):
  1741. self._take_model_step(lr_kwargs, self.block_eigenvalue)
  1742. else:
  1743. self._take_model_step(lr_kwargs)
  1744. report_progress = self.global_rank == 0 if self.global_rank else True
  1745. self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress)
  1746. self._stop_timers(self.engine_timers.step_timers)
  1747. # Log learning rate
  1748. if self.monitor.enabled:
  1749. if self.is_gradient_accumulation_boundary():
  1750. if self.global_rank == 0:
  1751. self.summary_events = [(f"Train/Samples/lr", self.get_lr()[0], self.global_samples)]
  1752. if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"):
  1753. self.summary_events.append((
  1754. f"Train/Samples/loss_scale",
  1755. self.optimizer.cur_scale,
  1756. self.global_samples,
  1757. ))
  1758. if (self.eigenvalue_enabled()
  1759. and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()):
  1760. ev_values = self.block_eigenvalue.values()
  1761. for i in range(len(ev_values)):
  1762. self.summary_events.append((
  1763. f"Train/Eigenvalues/ModelBlockParam_{i}",
  1764. self.ev_values[i][0],
  1765. self.global_samples,
  1766. ))
  1767. self.monitor.write_events(self.summary_events)
  1768. # Check flops profiling
  1769. if flops_profiler_active:
  1770. if self.autotuning_enabled():
  1771. self.flops = self.flops_profiler.get_total_flops() * 3
  1772. self.fwd_duration = self.flops_profiler.get_total_duration()
  1773. else:
  1774. self.flops_profiler.print_model_profile(
  1775. profile_step=self.global_steps,
  1776. module_depth=self.flops_profiler_module_depth(),
  1777. top_modules=self.flops_profiler_top_modules(),
  1778. detailed=self.flops_profiler_detailed(),
  1779. output_file=self.flops_profiler_output_file(),
  1780. )
  1781. self.flops_profiler.end_profile()
  1782. if self.autotuning_enabled() and self.global_steps == (self.autotuning_end_profile_step() + 1):
  1783. self._autotuning_exit()
  1784. if self.wall_clock_breakdown():
  1785. # Log micro timing and reset
  1786. self.timers.log(names=self.engine_timers.micro_timers, memory_breakdown=self.memory_breakdown())
  1787. if self.wall_clock_breakdown() or self.flops_profiler_enabled():
  1788. # Log global timing and reset
  1789. if self.is_gradient_accumulation_boundary():
  1790. if self.monitor.enabled:
  1791. self._write_monitor()
  1792. if self.has_moe_layers:
  1793. fwd_time = self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False)
  1794. self.print_forward_breakdown(fwd_time=fwd_time)
  1795. self.timers.log(self.engine_timers.global_timers)
  1796. self.micro_steps += 1
  1797. see_memory_usage("Engine after step", force=self.memory_breakdown())
  1798. def _start_timers(self, timer_names):
  1799. for name in timer_names:
  1800. self.timers(name).start()
  1801. def _stop_timers(self, timer_names):
  1802. record = self.is_gradient_accumulation_boundary() and \
  1803. self.flops_profiler_enabled() and \
  1804. (self.global_steps >= self.flops_profiler_profile_step())
  1805. for name in timer_names:
  1806. self.timers(name).stop(record=record)
  1807. def _autotuning_exit(self):
  1808. if self.global_rank == 0:
  1809. msg = self.timers.get_mean([
  1810. FORWARD_GLOBAL_TIMER,
  1811. BACKWARD_GLOBAL_TIMER,
  1812. STEP_GLOBAL_TIMER,
  1813. ], reset=False)
  1814. titer = 0.0
  1815. titer += msg[FORWARD_GLOBAL_TIMER] if FORWARD_GLOBAL_TIMER in msg else 0
  1816. titer += msg[BACKWARD_GLOBAL_TIMER] if BACKWARD_GLOBAL_TIMER in msg else 0
  1817. titer += msg[STEP_GLOBAL_TIMER] if STEP_GLOBAL_TIMER in msg else 0
  1818. titer *= self.gradient_accumulation_steps()
  1819. msg["latency"] = titer
  1820. msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer
  1821. msg["throughput"] = self.train_batch_size() * 1_000_000 / \
  1822. msg["latency"]
  1823. print_json_dist(msg, [0], path=self.autotuning_metric_path())
  1824. log_dist(
  1825. f"Wrote metrics to {self.autotuning_metric_path()}, {os.path.abspath(self.autotuning_metric_path())}",
  1826. ranks=[0])
  1827. import atexit
  1828. atexit.register(print, "Autotuning: done with running current ds config.")
  1829. exit()
  1830. def _write_monitor(self):
  1831. if self.global_rank == 0:
  1832. self.summary_events = [
  1833. (
  1834. f"Train/Samples/elapsed_time_ms_forward",
  1835. self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False),
  1836. self.global_samples,
  1837. ),
  1838. (
  1839. f"Train/Samples/elapsed_time_ms_backward",
  1840. self.timers(BACKWARD_GLOBAL_TIMER).elapsed(reset=False),
  1841. self.global_samples,
  1842. ),
  1843. (
  1844. f"Train/Samples/elapsed_time_ms_backward_inner",
  1845. self.timers(BACKWARD_INNER_GLOBAL_TIMER).elapsed(reset=False),
  1846. self.global_samples,
  1847. ),
  1848. (
  1849. f"Train/Samples/elapsed_time_ms_backward_allreduce",
  1850. self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).elapsed(reset=False),
  1851. self.global_samples,
  1852. ),
  1853. (
  1854. f"Train/Samples/elapsed_time_ms_step",
  1855. self.timers(STEP_GLOBAL_TIMER).elapsed(reset=False),
  1856. self.global_samples,
  1857. ),
  1858. ]
  1859. self.monitor.write_events(self.summary_events)
  1860. def _get_optimizer_param(self, param_name):
  1861. result = []
  1862. if not self.optimizer:
  1863. return result
  1864. for group in self.optimizer.param_groups:
  1865. if param_name in group:
  1866. result.append(group[param_name])
  1867. else:
  1868. result.append(0.0)
  1869. return result
  1870. def get_lr(self):
  1871. return self._get_optimizer_param("lr")
  1872. def get_type(self):
  1873. return self._get_optimizer_param("type")
  1874. def get_mom(self):
  1875. if self.optimizer_name() in ["SGD", "RMSprop"]:
  1876. return self._get_optimizer_param("momentum")
  1877. else:
  1878. return self._get_optimizer_param("betas")
  1879. def get_pld_theta(self):
  1880. if self.progressive_layer_drop:
  1881. return self.progressive_layer_drop.get_theta()
  1882. else:
  1883. return None
  1884. def _report_progress(self, step):
  1885. lr = self.get_lr()
  1886. mom = self.get_mom()
  1887. log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0])
  1888. def allreduce_bucket(self, bucket, dp_group):
  1889. tensor = self.flatten(bucket)
  1890. tensor_to_allreduce = tensor
  1891. if self.communication_data_type != tensor.dtype:
  1892. tensor_to_allreduce = tensor.to(self.communication_data_type)
  1893. if self.postscale_gradients():
  1894. if self.gradient_predivide_factor() != 1.0:
  1895. tensor_to_allreduce.mul_(1.0 / self.gradient_predivide_factor())
  1896. dist.all_reduce(tensor_to_allreduce, group=dp_group)
  1897. if self.gradient_average:
  1898. if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group):
  1899. tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
  1900. else:
  1901. tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group))
  1902. dist.all_reduce(tensor_to_allreduce, group=dp_group)
  1903. if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
  1904. tensor.copy_(tensor_to_allreduce)
  1905. return tensor
  1906. def allreduce_and_copy(self, small_bucket, dp_group):
  1907. allreduced = self.allreduce_bucket(small_bucket, dp_group)
  1908. for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
  1909. buf.copy_(synced)
  1910. def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000):
  1911. small_bucket = []
  1912. numel = 0
  1913. for tensor in bucket:
  1914. small_bucket.append(tensor)
  1915. numel = numel + tensor.numel()
  1916. if numel > numel_per_bucket:
  1917. self.allreduce_and_copy(small_bucket, dp_group)
  1918. small_bucket = []
  1919. numel = 0
  1920. if len(small_bucket) > 0:
  1921. self.allreduce_and_copy(small_bucket, dp_group)
  1922. def _get_gradients_for_reduction(self):
  1923. non_expert_grads = []
  1924. expert_grads = {}
  1925. if self.has_moe_layers:
  1926. for key in self.expert_data_parallel_group.keys():
  1927. expert_grads[key] = []
  1928. for param_name, param in self.module.named_parameters():
  1929. if not param.requires_grad:
  1930. continue
  1931. if param.grad is None:
  1932. # In cases where there is an imbalance of empty grads across
  1933. # ranks we must create empty grads, this will ensure that every
  1934. # rank is reducing the same size. In some cases it may make
  1935. # sense in the future to support the ability to average not
  1936. # w.r.t. world size but with a different value.
  1937. param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device)
  1938. grad_data = param.grad.data
  1939. if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:
  1940. # Call param.grad without data to avoid problem with setting of updated grads
  1941. grad_data = SparseTensor(param.grad)
  1942. if is_moe_param(param):
  1943. expert_grads[param.group_name].append(grad_data)
  1944. else:
  1945. non_expert_grads.append(grad_data)
  1946. return non_expert_grads, expert_grads
  1947. def _reduce_non_expert_gradients(self, grads, elements_per_buffer):
  1948. split_buckets = split_half_float_double_sparse(grads)
  1949. for _, bucket_tuple in enumerate(split_buckets):
  1950. bucket_type, bucket = bucket_tuple
  1951. if self.pipeline_parallelism:
  1952. dp_group = self.mpu.get_data_parallel_group()
  1953. else:
  1954. dp_group = groups._get_data_parallel_group()
  1955. if bucket_type == SparseTensor.type():
  1956. self.sparse_allreduce_no_retain(bucket, dp_group=dp_group)
  1957. else:
  1958. self.allreduce_no_retain(bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)
  1959. def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):
  1960. for ep_name, expert_grads_group in expert_grads.items():
  1961. expert_split_buckets = split_half_float_double_sparse(expert_grads_group)
  1962. for i, bucket_tuple in enumerate(expert_split_buckets):
  1963. bucket_type, bucket = bucket_tuple
  1964. if bucket_type == SparseTensor.type():
  1965. self.sparse_allreduce_no_retain(bucket, groups._get_expert_data_parallel_group(ep_name))
  1966. else:
  1967. # Separate between diff groups
  1968. self.allreduce_no_retain(bucket,
  1969. dp_group=groups._get_expert_data_parallel_group(ep_name),
  1970. numel_per_bucket=elements_per_buffer)
  1971. def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
  1972. if grads is None:
  1973. non_expert_grads, expert_grads = self._get_gradients_for_reduction()
  1974. else:
  1975. assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE"
  1976. non_expert_grads = grads
  1977. self._reduce_non_expert_gradients(non_expert_grads, elements_per_buffer)
  1978. if self.has_moe_layers:
  1979. self._reduce_expert_gradients(expert_grads, elements_per_buffer)
  1980. def sparse_allreduce_no_retain(self, bucket, dp_group):
  1981. allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group)
  1982. # Densify sparse tensor and copy back to original location
  1983. for tensor in allreduced_sparses:
  1984. if tensor.is_sparse:
  1985. tensor.orig_dense_tensor.data = tensor.to_coo_tensor()
  1986. else:
  1987. tensor.orig_dense_tensor.copy_(tensor.to_dense())
  1988. def sparse_allreduce_bucket(self, bucket, dp_group):
  1989. sparse_list = []
  1990. for sparse in bucket:
  1991. sparse_list.append(self.sparse_allreduce(sparse, dp_group))
  1992. return sparse_list
  1993. def sparse_allreduce(self, sparse, dp_group):
  1994. original_data_type = sparse.values.dtype
  1995. if self.communication_data_type != sparse.values.dtype:
  1996. if self.communication_data_type in (torch.float16, torch.bfloat16):
  1997. indices = sparse.indices.to(torch.int32)
  1998. else:
  1999. indices = sparse.indices
  2000. values = sparse.values.to(self.communication_data_type)
  2001. else:
  2002. indices = sparse.indices
  2003. values = sparse.values
  2004. if self.postscale_gradients():
  2005. if self.gradient_average:
  2006. values.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
  2007. else:
  2008. values.mul_(1. / dist.get_world_size(group=dp_group))
  2009. indices_device_list = self.sparse_all_gather(indices, dp_group)
  2010. values_device_list = self.sparse_all_gather(values, dp_group)
  2011. sparse.indices = torch.cat(indices_device_list).to(torch.long)
  2012. sparse.values = torch.cat(values_device_list).to(original_data_type)
  2013. return sparse
  2014. def sparse_all_gather(self, value, dp_group):
  2015. my_size = torch.LongTensor([value.size()[0]]).to(self.device)
  2016. all_sizes = self.all_gather_scalar(my_size, dp_group)
  2017. max_size = torch.cat(all_sizes).max()
  2018. fill_size = max_size - my_size
  2019. assert value.dim() in [1, 2]
  2020. if value.dim() == 1:
  2021. if fill_size > 0:
  2022. value = torch.cat([value, value.new_empty(fill_size)])
  2023. tensor_list = [value.new_empty(max_size) for _ in range(dist.get_world_size(group=dp_group))]
  2024. else:
  2025. if fill_size > 0:
  2026. value = torch.cat([value, value.new_empty(fill_size, value.size()[1])])
  2027. tensor_list = [
  2028. value.new_empty(max_size,
  2029. value.size()[1]) for _ in range(dist.get_world_size(group=dp_group))
  2030. ]
  2031. dist.all_gather(tensor_list, value, group=dp_group)
  2032. tensors = []
  2033. for dev_idx, t in enumerate(tensor_list):
  2034. size = all_sizes[dev_idx][0]
  2035. tensors.append(t.index_select(0, torch.arange(size, dtype=torch.long, device=self.device)))
  2036. return tensors
  2037. def all_gather_scalar(self, value, dp_group):
  2038. tensor_list = [value.new_zeros(value.size()) for _ in range(dist.get_world_size(group=dp_group))]
  2039. dist.all_gather(tensor_list, value, group=dp_group)
  2040. return tensor_list
  2041. def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False):
  2042. sd = self.module.state_dict(destination, prefix, keep_vars)
  2043. # Remove frozen parameter weights from state_dict if specified
  2044. if exclude_frozen_parameters:
  2045. for n, p in self.module.named_parameters():
  2046. if not p.requires_grad:
  2047. del sd[n]
  2048. if self.random_ltd_enabled():
  2049. sd = remove_random_ltd_state_dict(sd)
  2050. return sd
  2051. @staticmethod
  2052. def load_moe_state_dict(checkpoint_path,
  2053. tag,
  2054. state_dict,
  2055. old_moe_load,
  2056. model=None,
  2057. mpu=None,
  2058. num_experts=1,
  2059. checkpoint_engine=TorchCheckpointEngine()):
  2060. if old_moe_load:
  2061. expp_rank = groups._get_expert_data_parallel_rank(groups._get_max_expert_size_name())
  2062. num_local_experts = max(num_experts) // groups._get_expert_parallel_world_size(
  2063. groups._get_max_expert_size_name())
  2064. for local_expert_id in range(num_local_experts):
  2065. global_expert_id = expp_rank * num_local_experts + local_expert_id
  2066. expert_state_dict = checkpoint_engine.load(
  2067. DeepSpeedEngine._get_expert_ckpt_name(
  2068. checkpoint_path,
  2069. -1, # -1 means ignore layer_id
  2070. global_expert_id,
  2071. tag,
  2072. mpu),
  2073. map_location=torch.device('cpu'))
  2074. # Updating global -> local expert ids
  2075. moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
  2076. for key in list(expert_state_dict.keys()):
  2077. local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
  2078. f'{moe_str_prefix}{local_expert_id}')
  2079. expert_state_dict[local_key] = expert_state_dict.pop(key)
  2080. state_dict.update(expert_state_dict)
  2081. else:
  2082. moe_layer_id = 0
  2083. for n_module, module in model.named_modules():
  2084. if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
  2085. group_name = module.expert_group_name
  2086. num_local_experts = module.num_local_experts
  2087. expp_rank = groups._get_expert_parallel_rank(group_name)
  2088. # loop all local_experts
  2089. for local_expert_id in range(num_local_experts):
  2090. global_expert_id = expp_rank * num_local_experts + local_expert_id
  2091. expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name(
  2092. checkpoint_path, moe_layer_id, global_expert_id, tag, mpu),
  2093. map_location=torch.device('cpu'))
  2094. # print(expert_state_dict.keys())
  2095. # Updating global -> local expert ids
  2096. moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
  2097. for key in list(expert_state_dict.keys()):
  2098. local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
  2099. f'{moe_str_prefix}{local_expert_id}')
  2100. expert_state_dict[local_key] = expert_state_dict.pop(key)
  2101. state_dict.update(expert_state_dict)
  2102. moe_layer_id += 1
  2103. def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False):
  2104. if fetch_z3_params:
  2105. params_to_fetch = [
  2106. p for p in self.module.parameters()
  2107. if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
  2108. ]
  2109. else:
  2110. params_to_fetch = []
  2111. with deepspeed.zero.GatheredParameters(params_to_fetch, modifier_rank=0):
  2112. module_state_dict = checkpoint['module']
  2113. if custom_load_fn:
  2114. custom_load_fn(src=module_state_dict, dst=self.module)
  2115. else:
  2116. self.module.load_state_dict(
  2117. module_state_dict, # TODO
  2118. strict=strict)
  2119. if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None:
  2120. saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS]
  2121. for param in self.module.parameters():
  2122. if param.requires_grad:
  2123. continue
  2124. if param not in self.param_names:
  2125. raise ValueError(f"failed to find frozen {param} in named params")
  2126. name = self.param_names[param]
  2127. if hasattr(param, 'ds_id'):
  2128. param.ds_tensor.data.copy_(saved_frozen_params[name].data)
  2129. else:
  2130. param.data.copy_(saved_frozen_params[name].data)
  2131. def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode):
  2132. return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}'
  2133. def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode):
  2134. file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode)
  2135. zero_ckpt_name = os.path.join(
  2136. checkpoints_path,
  2137. str(tag),
  2138. f"{file_prefix}_mp_rank_{mp_rank:02d}_optim_states.pt",
  2139. )
  2140. return zero_ckpt_name
  2141. def _get_zero_ckpt_name(self, checkpoints_path, tag):
  2142. mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
  2143. pp_rank = dist.get_rank(group=self.optimizer.dp_process_group)
  2144. bf16_mode = self.bfloat16_enabled()
  2145. return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode)
  2146. def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
  2147. if mp_placeholder is not None:
  2148. mp_rank_str = mp_placeholder
  2149. else:
  2150. mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
  2151. mp_rank_str = f"{mp_rank:02d}"
  2152. if self.zero_optimization_partition_weights():
  2153. filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
  2154. ckpt_name = os.path.join(
  2155. checkpoints_path,
  2156. str(tag),
  2157. f"{filename}_mp_rank_{mp_rank_str}_model_states.pt",
  2158. )
  2159. else:
  2160. ckpt_name = os.path.join(
  2161. checkpoints_path,
  2162. str(tag),
  2163. "mp_rank_" + mp_rank_str + "_model_states.pt",
  2164. )
  2165. return ckpt_name
  2166. def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank):
  2167. mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
  2168. ckpt_name = os.path.join(checkpoints_path, str(tag),
  2169. f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt')
  2170. return ckpt_name
  2171. @staticmethod
  2172. def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu=None):
  2173. mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank()
  2174. if layer_id <= -1:
  2175. # Used to support old checkpoint loading
  2176. ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),
  2177. f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
  2178. else:
  2179. # Used to support new checkpoint loading
  2180. ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),
  2181. f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
  2182. return ckpt_name
  2183. def _get_all_ckpt_names(self, checkpoints_path, tag):
  2184. # It is required that (checkpoints_path, tag) are consistent among all ranks.
  2185. ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
  2186. import glob
  2187. ckpt_files = glob.glob(ckpt_file_pattern)
  2188. ckpt_files.sort()
  2189. return ckpt_files
  2190. def load_checkpoint(self,
  2191. load_dir,
  2192. tag=None,
  2193. load_module_strict=True,
  2194. load_optimizer_states=True,
  2195. load_lr_scheduler_states=True,
  2196. load_module_only=False,
  2197. custom_load_fn=None):
  2198. """
  2199. Load training checkpoint
  2200. Arguments:
  2201. load_dir: Required. Directory to load the checkpoint from
  2202. tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
  2203. load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
  2204. load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
  2205. load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
  2206. load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
  2207. custom_load_fn: Optional. Custom model load function.
  2208. Returns:
  2209. A tuple of ``load_path`` and ``client_state``.
  2210. *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
  2211. *``client_state``: State dictionary used for loading required training states in the client code.
  2212. Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
  2213. after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and
  2214. ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine
  2215. before ``load_checkpoint()``.
  2216. """
  2217. if tag is None:
  2218. latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest"
  2219. latest_path = os.path.join(load_dir, latest_tag)
  2220. if os.path.isfile(latest_path):
  2221. with open(latest_path, "r") as fd:
  2222. tag = fd.read().strip()
  2223. else:
  2224. if self.load_universal_checkpoint():
  2225. raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist')
  2226. else:
  2227. logger.warning(
  2228. f"Unable to find latest file at {latest_path}, if trying to load latest "
  2229. "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint."
  2230. )
  2231. return None, None
  2232. if self._optimizer_has_ckpt_event_prologue():
  2233. # Prepare for checkpoint load by ensuring all parameters are partitioned
  2234. self.optimizer.checkpoint_event_prologue()
  2235. load_path, client_states = self._load_checkpoint(load_dir,
  2236. tag,
  2237. load_module_strict=load_module_strict,
  2238. load_optimizer_states=load_optimizer_states,
  2239. load_lr_scheduler_states=load_lr_scheduler_states,
  2240. load_module_only=load_module_only,
  2241. custom_load_fn=custom_load_fn)
  2242. load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization()
  2243. or self.bfloat16_enabled())
  2244. if load_zero_checkpoint:
  2245. success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
  2246. if not success:
  2247. self.optimizer._restore_from_bit16_weights()
  2248. if self._optimizer_has_ckpt_event_epilogue():
  2249. self.optimizer.checkpoint_event_epilogue()
  2250. return load_path, client_states
  2251. def _load_checkpoint(self,
  2252. load_dir,
  2253. tag,
  2254. load_module_strict=True,
  2255. load_optimizer_states=True,
  2256. load_lr_scheduler_states=True,
  2257. load_module_only=False,
  2258. custom_load_fn=None):
  2259. from deepspeed.runtime.state_dict_factory import SDLoaderFactory
  2260. ckpt_list = self._get_all_ckpt_names(load_dir, tag)
  2261. sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)
  2262. is_pipe_parallel = isinstance(self.module, PipelineModule)
  2263. mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
  2264. load_path, checkpoint, _ = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel)
  2265. if checkpoint is None:
  2266. return None, None
  2267. fetch_z3_params = False
  2268. if self.zero_optimization_partition_weights() and not load_optimizer_states:
  2269. checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir)
  2270. fetch_z3_params = True
  2271. if is_pipe_parallel:
  2272. # Pipeline parallelism uses this to load its own checkpoint files.
  2273. self._curr_ckpt_path = os.path.join(load_dir, tag)
  2274. if self.has_moe_layers:
  2275. # print(checkpoint.keys())
  2276. old_moe_load = False
  2277. if not isinstance(checkpoint['num_experts'], list):
  2278. old_moe_load = True
  2279. DeepSpeedEngine.load_moe_state_dict(load_dir,
  2280. tag,
  2281. state_dict=checkpoint['module'],
  2282. old_moe_load=old_moe_load,
  2283. model=self.module,
  2284. mpu=self.mpu,
  2285. num_experts=self.num_experts,
  2286. checkpoint_engine=self.checkpoint_engine)
  2287. if not self.load_universal_checkpoint():
  2288. self.load_module_state_dict(checkpoint=checkpoint,
  2289. strict=load_module_strict,
  2290. custom_load_fn=custom_load_fn,
  2291. fetch_z3_params=fetch_z3_params)
  2292. self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
  2293. optim_checkpoint = None
  2294. if load_module_only:
  2295. deepspeed_states = ['module']
  2296. if self.optimizer is not None and self.fp16_enabled():
  2297. self.optimizer.refresh_fp32_params()
  2298. else:
  2299. has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
  2300. if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state:
  2301. if self.has_moe_layers:
  2302. largest_group_name = groups._get_max_expert_size_name()
  2303. expp_rank = groups._get_expert_parallel_rank(largest_group_name)
  2304. optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)
  2305. optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu'))
  2306. else:
  2307. optim_checkpoint = checkpoint
  2308. if self.fp16_enabled() or self.bfloat16_enabled():
  2309. self.optimizer.load_state_dict(optim_checkpoint['optimizer'],
  2310. load_optimizer_states=load_optimizer_states)
  2311. else:
  2312. optim_checkpoint = checkpoint
  2313. self.optimizer.load_state_dict(optim_checkpoint['optimizer'])
  2314. if load_lr_scheduler_states and self.lr_scheduler is not None:
  2315. self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  2316. if self.random_ltd_enabled() and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint:
  2317. self.random_ltd_scheduler.load_state_dict(checkpoint['random_ltd'])
  2318. if self.training_dataloader is not None and self.curriculum_learning_enabled(
  2319. ) and 'data_sampler' in checkpoint:
  2320. self.training_dataloader.data_sampler.load_state_dict(checkpoint['data_sampler'])
  2321. def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters, loaded_parameters):
  2322. result = set()
  2323. for name in original_set:
  2324. if name in loaded_parameters and name not in loaded_set:
  2325. continue # parameter existed in previous model and was not sparse
  2326. result.add(name)
  2327. for name in loaded_set:
  2328. if name in original_parameters:
  2329. result.add(name) # parameter exists in both configs and it was sparse
  2330. return result
  2331. if 'sparse_tensor_module_names' in checkpoint:
  2332. sparse_tensor_module_names = checkpoint['sparse_tensor_module_names']
  2333. elif 'csr_tensor_module_names' in checkpoint:
  2334. sparse_tensor_module_names = checkpoint['csr_tensor_module_names']
  2335. else:
  2336. sparse_tensor_module_names = None
  2337. if sparse_tensor_module_names is not None:
  2338. if load_module_strict:
  2339. self.sparse_tensor_module_names = sparse_tensor_module_names
  2340. else:
  2341. self.sparse_tensor_module_names = get_sparse_tensor_module_names(
  2342. self.sparse_tensor_module_names, sparse_tensor_module_names,
  2343. dict(self.module.named_parameters()), checkpoint["module"])
  2344. self.global_steps = checkpoint['global_steps']
  2345. self.global_samples = checkpoint.get('global_samples', self.global_steps * self.train_batch_size())
  2346. self.skipped_steps = checkpoint['skipped_steps']
  2347. self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
  2348. deepspeed_states = [
  2349. 'module', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'dp_world_size',
  2350. 'mp_world_size', 'data_sampler', 'random_ltd'
  2351. ]
  2352. client_state = {}
  2353. if load_lr_scheduler_states:
  2354. deepspeed_states.append('lr_scheduler')
  2355. if load_optimizer_states:
  2356. deepspeed_states.append('optimizer')
  2357. client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states}
  2358. if optim_checkpoint is not None:
  2359. client_state['optimizer'] = optim_checkpoint['optimizer']
  2360. return load_path, client_state
  2361. def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
  2362. load_serial = None
  2363. # When use loading checkpoint serial, checkpoint loading start from local rank 0,
  2364. # all other local rank would be paused, waiting for its rank-1 peer ready and its notification.
  2365. if self._config.zero_config.pipeline_loading_checkpoint:
  2366. assert self.zero_optimization_stage(
  2367. ) == ZeroStageEnum.weights, "Only stage3 support for pipeline checkpoint loading"
  2368. load_serial = torch.zeros(1).to(self.device)
  2369. if dist.get_local_rank() != 0:
  2370. dist.recv(tensor=load_serial, src=dist.get_rank() - 1)
  2371. if self.load_universal_checkpoint():
  2372. zero_sd_list = None
  2373. checkpoint_folder = f'{os.path.join(load_dir, tag)}'
  2374. else:
  2375. if load_optimizer_states and self.seq_dp_world_size != self.loaded_checkpoint_dp_world_size:
  2376. raise ZeRORuntimeException("The checkpoint being loaded used a DP " \
  2377. f"world size of {self.loaded_checkpoint_dp_world_size} but the " \
  2378. f"current world size is {self.seq_dp_world_size}. Automatic adjustment " \
  2379. "of ZeRO's optimizer state partitioning with a new world size is not " \
  2380. "currently supported.")
  2381. checkpoint_folder = None
  2382. zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
  2383. if zero_sd_list is None:
  2384. return False
  2385. self.optimizer.load_state_dict(state_dict_list=zero_sd_list,
  2386. load_optimizer_states=load_optimizer_states,
  2387. load_from_fp32_weights=self.zero_load_from_fp32_weights(),
  2388. checkpoint_folder=checkpoint_folder,
  2389. load_serial=load_serial)
  2390. if self.load_universal_checkpoint():
  2391. logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}')
  2392. else:
  2393. logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
  2394. return True
  2395. def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
  2396. zero_ckpt_names = []
  2397. for dp_rank in range(dp_world_size):
  2398. ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir,
  2399. tag=tag,
  2400. mp_rank=mp_rank,
  2401. dp_rank=dp_rank,
  2402. bf16_mode=bf16_mode)
  2403. zero_ckpt_names.append(ckpt_name)
  2404. return zero_ckpt_names
  2405. def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode):
  2406. mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
  2407. zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(load_dir=load_dir,
  2408. tag=tag,
  2409. mp_rank=mp_rank,
  2410. dp_world_size=self.loaded_checkpoint_dp_world_size,
  2411. bf16_mode=bf16_mode)
  2412. for i, ckpt_name in enumerate(zero_ckpt_names):
  2413. if not os.path.exists(ckpt_name):
  2414. # transparently handle the old file pattern for optim_states
  2415. if "optim_states.pt" in ckpt_name:
  2416. ckpt_name_try = ckpt_name.replace("_optim_states.pt", "optim_states.pt")
  2417. if os.path.exists(ckpt_name_try):
  2418. zero_ckpt_names[i] = ckpt_name_try
  2419. continue
  2420. return zero_ckpt_names
  2421. def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names):
  2422. zero_sd_list = []
  2423. for i, ckpt_name in enumerate(zero_ckpt_names):
  2424. _state = None
  2425. if ckpt_name is None:
  2426. _state = {OPTIMIZER_STATE_DICT: None}
  2427. # Fully load state for current rank
  2428. elif self.zero_elastic_checkpoint() or dist.get_rank(group=self.optimizer.dp_process_group) == i:
  2429. _state = self.checkpoint_engine.load(
  2430. ckpt_name,
  2431. map_location='cpu',
  2432. )
  2433. else:
  2434. _state = {OPTIMIZER_STATE_DICT: None}
  2435. zero_sd_list.append(_state)
  2436. zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list]
  2437. logger.info(f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}")
  2438. return zero_optimizer_sd
  2439. def _get_all_zero_checkpoints(self, load_dir, tag):
  2440. for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]:
  2441. zero_ckpt_names = self._get_all_zero_checkpoint_names(load_dir, tag, bf16_mode)
  2442. if zero_ckpt_names is not None:
  2443. # Warn if loading checkpoint of different bit16 type
  2444. if bf16_mode is not self.bfloat16_enabled():
  2445. checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16
  2446. engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16
  2447. logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
  2448. return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)
  2449. return None
  2450. def _checkpoint_tag_validation(self, tag):
  2451. if self.checkpoint_tag_validation_enabled():
  2452. s_hash = hashlib.sha1(tag.encode())
  2453. bhash = torch.ByteTensor([s_hash.digest()]).flatten().to(self.device)
  2454. max_bhash = bhash.clone()
  2455. min_bhash = bhash.clone()
  2456. dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX)
  2457. dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN)
  2458. valid = all(min_bhash == bhash) and all(max_bhash == bhash)
  2459. msg = (f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across "
  2460. "all ranks. Including rank unique information in checkpoint tag could cause issues when "
  2461. "restoring with different world sizes.")
  2462. if self.checkpoint_tag_validation_fail():
  2463. assert valid, msg
  2464. elif not valid:
  2465. logger.warning(msg)
  2466. def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False):
  2467. """Save training checkpoint
  2468. Arguments:
  2469. save_dir: Required. Directory for saving the checkpoint
  2470. tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
  2471. used if not provided. Tag name must be the same across all ranks.
  2472. client_state: Optional. State dictionary used for saving required training states in the client code.
  2473. save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
  2474. exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state.
  2475. Important: all processes must call this method and not just the process with rank 0. It is
  2476. because each process needs to save its master weights and scheduler+optimizer states. This
  2477. method will hang waiting to synchronize with other processes if it's called just for the
  2478. process with rank 0.
  2479. """
  2480. if self._optimizer_has_ckpt_event_prologue():
  2481. # Custom preparation for checkpoint save, if applicable
  2482. self.optimizer.checkpoint_event_prologue()
  2483. rank = self.local_rank if self.use_node_local_storage() else self.global_rank
  2484. # This is to make sure the checkpoint names are created without collision
  2485. # There seems to be issue creating them in parallel
  2486. # Ensure save_dir directory exists
  2487. self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
  2488. dist.barrier()
  2489. if tag is None:
  2490. tag = f"global_step{self.global_steps}"
  2491. # Ensure tag is a string
  2492. tag = str(tag)
  2493. self.checkpoint_engine.create(tag)
  2494. # Ensure checkpoint tag is consistent across ranks
  2495. self._checkpoint_tag_validation(tag)
  2496. if self.has_moe_layers:
  2497. self.save_non_zero_checkpoint = False
  2498. self._create_checkpoint_file(save_dir, tag, False)
  2499. self._save_moe_checkpoint(save_dir,
  2500. tag,
  2501. client_state=client_state,
  2502. exclude_frozen_parameters=exclude_frozen_parameters)
  2503. # We distribute the task of saving layer checkpoint files among
  2504. # data parallel instances, so all procs should call _save_checkpoint.
  2505. # All procs then call module_state_dict(), but only procs of data
  2506. # parallel rank 0 save the general model params.
  2507. if not self.has_moe_layers:
  2508. self._create_checkpoint_file(save_dir, tag, False)
  2509. self._save_checkpoint(save_dir,
  2510. tag,
  2511. client_state=client_state,
  2512. exclude_frozen_parameters=exclude_frozen_parameters)
  2513. if self.save_zero_checkpoint:
  2514. self._create_zero_checkpoint_files(save_dir, tag)
  2515. self._save_zero_checkpoint(save_dir, tag)
  2516. if self._optimizer_has_ckpt_event_epilogue():
  2517. self.optimizer.checkpoint_event_epilogue()
  2518. # Save latest checkpoint tag
  2519. self.checkpoint_engine.commit(tag)
  2520. if save_latest and rank == 0:
  2521. with open(os.path.join(save_dir, 'latest'), 'w') as fd:
  2522. fd.write(tag)
  2523. dist.barrier()
  2524. return True
  2525. def _get_non_moe_state_dict(self, full_state_dict):
  2526. """
  2527. Get the state dict of the non-moe layers
  2528. """
  2529. for key in list(full_state_dict.keys()):
  2530. if 'expert' in key and 'moe.gate.wg.weight' not in key:
  2531. full_state_dict.pop(key)
  2532. return full_state_dict
  2533. def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):
  2534. save_path = self._get_ckpt_name(save_dir, tag)
  2535. # A hack to save the checkpointing directory. Pipeline parallelism overrides
  2536. # module_state_dict() and uses this path to save the model. module_state_dict()
  2537. # then instead just returns None.
  2538. # Using layer_#_export_# to save the model's expert state_dict
  2539. moe_layer_id = 0
  2540. for n_module, module in self.module.named_modules():
  2541. if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
  2542. group_name = module.expert_group_name
  2543. num_local_experts = module.num_local_experts
  2544. expp_rank = groups._get_expert_parallel_rank(group_name)
  2545. exp_dp_rank = groups._get_expert_data_parallel_rank(group_name)
  2546. # print(expp_rank, exp_dp_rank)
  2547. if exp_dp_rank != 0:
  2548. moe_layer_id += 1
  2549. continue
  2550. # get all moe parameters
  2551. moe_state_dict = {}
  2552. for n, p in module.state_dict().items():
  2553. if 'expert' in n and 'moe.gate.wg.weight' not in n:
  2554. moe_state_dict[n_module + '.' + n] = p
  2555. moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
  2556. # print(moe_state_dict.keys()) # until now, everything is fine. So the bug happens at next few lines
  2557. # Reorder the moe name rank, so that each checkpoint only has one expert
  2558. experts_state_dict = defaultdict(dict)
  2559. for key in list(moe_state_dict.keys()):
  2560. m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
  2561. local_expert_id = None
  2562. if not m:
  2563. logger.warn(f'No expert found in key {key}.')
  2564. else:
  2565. local_expert_id = m.group(1)
  2566. global_expert_id = expp_rank * \
  2567. num_local_experts + int(local_expert_id)
  2568. expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}',
  2569. f'{moe_str_prefix}{global_expert_id}')
  2570. # truncating extra tensor (shared) storage
  2571. truncated = moe_state_dict.pop(key).clone().detach()
  2572. experts_state_dict[str(global_expert_id)][expert_key] = truncated
  2573. # let save the moe parameters
  2574. for global_expert_id, expert_state_dict in experts_state_dict.items():
  2575. # save the moe parameters
  2576. moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
  2577. if self.random_ltd_enabled():
  2578. expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
  2579. self.checkpoint_engine.save(expert_state_dict, moe_save_path)
  2580. moe_layer_id += 1
  2581. self._curr_ckpt_path = os.path.join(save_dir, tag)
  2582. largest_group_name = groups._get_max_expert_size_name()
  2583. expp_rank = groups._get_expert_parallel_rank(largest_group_name)
  2584. exp_dp_rank = groups._get_expert_data_parallel_rank(largest_group_name)
  2585. # In the case of E + D parallelism, only the
  2586. # first expert parallel group should save the expert weights
  2587. # since each expert parallel group is a copy of the model's experts
  2588. if exp_dp_rank != 0:
  2589. return
  2590. # Save optimizer states. They are different across each exp parallel rank.
  2591. optimizer_state = {
  2592. 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None
  2593. }
  2594. # TODO: why use BufferedWriter not the path
  2595. file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
  2596. self.checkpoint_engine.save(optimizer_state, file_path)
  2597. # get non-moe parameters
  2598. model_state_dict = self._get_non_moe_state_dict(
  2599. self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))
  2600. if expp_rank == 0:
  2601. # TODO: update num experts info,.. in checkpoint
  2602. state = {
  2603. 'module':
  2604. model_state_dict,
  2605. 'lr_scheduler':
  2606. self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
  2607. 'data_sampler':
  2608. self.training_dataloader.data_sampler.state_dict() if
  2609. (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
  2610. 'random_ltd':
  2611. self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
  2612. 'sparse_tensor_module_names':
  2613. self.sparse_tensor_module_names,
  2614. 'skipped_steps':
  2615. self.skipped_steps,
  2616. 'global_steps':
  2617. self.global_steps,
  2618. 'global_samples':
  2619. self.global_samples,
  2620. 'dp_world_size':
  2621. self.dp_world_size,
  2622. 'mp_world_size':
  2623. self.mp_world_size,
  2624. 'num_experts':
  2625. self.num_experts
  2626. }
  2627. state.update(client_state)
  2628. logger.info(f'Saving model checkpoint: {save_path}')
  2629. self.checkpoint_engine.save(state, save_path)
  2630. self._curr_save_path = None
  2631. def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
  2632. name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name)
  2633. try:
  2634. checkpoint_name = name_function(save_dir, tag)
  2635. path = os.path.dirname(checkpoint_name)
  2636. self.checkpoint_engine.makedirs(path, exist_ok=True)
  2637. except:
  2638. logger.error(f"Failed saving model checkpoint to {save_dir} with tag {tag}")
  2639. return False
  2640. return True
  2641. def _create_zero_checkpoint_files(self, save_dir, tag):
  2642. success = True
  2643. # zero checkpoint files are created sequentially
  2644. for rank in range(dist.get_world_size(self.optimizer.dp_process_group)):
  2645. if rank == self.global_rank:
  2646. success = self._create_checkpoint_file(save_dir, tag, True)
  2647. dist.barrier(group=self.optimizer.dp_process_group)
  2648. return success
  2649. def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):
  2650. save_path = self._get_ckpt_name(save_dir, tag)
  2651. zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
  2652. save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters
  2653. # A hack to save the checkpointing directory. Pipeline parallelism overrides
  2654. # module_state_dict() and uses this path to save the model. module_state_dict()
  2655. # then instead just returns None. The module_state_dict() implementation in
  2656. # PipelineEngine expects the save path to be set in self._curr_ckpt_path.
  2657. self._curr_ckpt_path = os.path.join(save_dir, tag)
  2658. module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)
  2659. self._curr_ckpt_path = None
  2660. state = dict(module=module,
  2661. buffer_names=self._get_buffer_names(),
  2662. optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None,
  2663. param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None,
  2664. frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func)
  2665. if save_frozen_param else None,
  2666. shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None,
  2667. frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func)
  2668. if save_frozen_param else None,
  2669. lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
  2670. data_sampler=self.training_dataloader.data_sampler.state_dict() if
  2671. (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
  2672. random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
  2673. sparse_tensor_module_names=self.sparse_tensor_module_names,
  2674. skipped_steps=self.skipped_steps,
  2675. global_steps=self.global_steps,
  2676. global_samples=self.global_samples,
  2677. dp_world_size=self.seq_dp_world_size,
  2678. mp_world_size=self.mp_world_size,
  2679. ds_config=self.config,
  2680. ds_version=version)
  2681. state.update(client_state)
  2682. if self.save_non_zero_checkpoint:
  2683. log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
  2684. self.checkpoint_engine.save(state, save_path)
  2685. def _get_buffer_names(self):
  2686. buffer_names = []
  2687. # we save buffer names so that we could extract later the real buffers from the saved
  2688. # state_dict["module"] in the non-zero checkpoint - the buffers are already there but they
  2689. # are intermixed with param placeholders
  2690. # have to traverse the tree to be able to skip non-persistent buffers
  2691. def get_layer_named_buffers(module, prefix=""):
  2692. for name, buf in module.named_buffers(recurse=False):
  2693. if buf is not None and name not in module._non_persistent_buffers_set:
  2694. buffer_names.append(prefix + name)
  2695. for name, child in module.named_children():
  2696. if child is not None:
  2697. get_layer_named_buffers(child, prefix + name + ".")
  2698. get_layer_named_buffers(self.module, prefix="")
  2699. return buffer_names
  2700. def _get_param_shape_func(self, param):
  2701. return param.ds_shape if hasattr(param, 'ds_id') else param.shape
  2702. def _get_param_fragment_func(self, param):
  2703. return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu()
  2704. def _get_zero_frozen_param_attributes(self, attr_func):
  2705. frozen_param_fragments = OrderedDict()
  2706. for param in self.module.parameters():
  2707. if param.requires_grad:
  2708. continue
  2709. if param not in self.param_names:
  2710. raise ValueError(f"failed to find frozen {param} in named params")
  2711. name = self.param_names[param]
  2712. frozen_param_fragments[name] = attr_func(param)
  2713. return frozen_param_fragments
  2714. def _get_zero_param_shapes(self):
  2715. """Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the
  2716. optimizer. the names are exactly as in state_dict. The order is absolutely important, since
  2717. the saved data is just flattened data with no identifiers and requires reconstruction in the
  2718. same order it was saved.
  2719. We can't rely on self.module.named_parameters() to get the saved tensors, as some params
  2720. will be missing and others unsaved and then it'd be impossible to reconstruct state_dict
  2721. from the flattened weights.
  2722. optimizer.bit16_groups seems to be the easiest to use as it's in all zeroX versions.
  2723. """
  2724. param_group_shapes = []
  2725. cnt = 0
  2726. numel = 0
  2727. # zero2 started using a round_robin_bit16_groups which is a shuffled version of bit16_groups -
  2728. # if we don't use it, we get parameters ordered incorrectly
  2729. if hasattr(self.optimizer, "round_robin_bit16_groups"):
  2730. bit16_groups = self.optimizer.round_robin_bit16_groups
  2731. elif self.bfloat16_enabled() and not self.zero_optimization():
  2732. bit16_groups = self.optimizer.bf16_groups
  2733. else:
  2734. bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage(
  2735. ) == 2 else self.optimizer.fp16_groups
  2736. for bit16_group in bit16_groups:
  2737. param_shapes = OrderedDict()
  2738. for param in bit16_group:
  2739. cnt += 1
  2740. numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel()
  2741. shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape
  2742. if param not in self.param_names:
  2743. raise ValueError(f"failed to find optimizer param in named params")
  2744. name = self.param_names[param]
  2745. param_shapes[name] = shape
  2746. # uncomment to debug zero_to_fp32.py problems
  2747. # if self.global_rank == 0: print(f"saving param {name} {shape} (numel={shape.numel()})")
  2748. param_group_shapes.append(param_shapes)
  2749. # if self.global_rank == 0: print(f"Total saved {numel} numels in {cnt} params")
  2750. return param_group_shapes
  2751. def _get_shared_params(self):
  2752. """
  2753. Returns a dict of shared params, which can later be used to reconstruct the original state dict,
  2754. e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name
  2755. of the variable that isn't stored and the value is the actual param holding data.
  2756. """
  2757. shared_index = {}
  2758. shared_params_by_full_name = {}
  2759. is_zero3_model = (self.zero_optimization_partition_weights()
  2760. and any(hasattr(param, "ds_id") for param in self.module.parameters()))
  2761. def get_layer_state_dict(module, prefix=""):
  2762. # handle params
  2763. for name, param in module.named_parameters(recurse=False):
  2764. if param is None or (is_zero3_model and not hasattr(param, "ds_id")):
  2765. continue
  2766. key = prefix + name
  2767. # When weights are manged by stage 3, we can't rely on param.data_ptr() as it will be reused
  2768. # as weights get gathered and reduced, but param.ds_id is unique across all zero weights
  2769. # (and shared params will have the same param.ds_id)
  2770. param_id = param.ds_id if is_zero3_model else param.data_ptr()
  2771. if param_id in shared_index:
  2772. # shared weights
  2773. #print(f"`{key}` is shared with `{shared_index[param_id]}`")
  2774. shared_params_by_full_name[key] = shared_index[param_id]
  2775. else:
  2776. shared_index[param_id] = key
  2777. for name, child in module.named_children():
  2778. if child is not None:
  2779. get_layer_state_dict(child, prefix + name + ".")
  2780. if dist.get_rank() == 0:
  2781. get_layer_state_dict(self.module, prefix="")
  2782. return shared_params_by_full_name
  2783. def _copy_recovery_script(self, save_path):
  2784. base_dir = os.path.dirname(os.path.dirname(__file__))
  2785. script = "zero_to_fp32.py"
  2786. src = os.path.join(base_dir, "utils", script)
  2787. dst = os.path.join(save_path, script)
  2788. #logger.info(f"creating recovery script {dst}")
  2789. copyfile(src, dst)
  2790. self._change_recovery_script_permissions(dst)
  2791. def _change_recovery_script_permissions(self, dst):
  2792. # make executable (safeguard for file shares - Azure as example)
  2793. try:
  2794. os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC)
  2795. except (FileNotFoundError, PermissionError) as e:
  2796. #this message is used in unit test TestZeRONonDistributed
  2797. logger.info(
  2798. f'Warning: Could not change permissions for {dst} due to error: {e}. Continuing without changing permissions.'
  2799. )
  2800. def _save_zero_checkpoint(self, save_path, tag):
  2801. zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
  2802. zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version)
  2803. self.checkpoint_engine.save(zero_sd, zero_checkpoint_name)
  2804. if self.global_rank == 0:
  2805. self._copy_recovery_script(save_path)
  2806. ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero'
  2807. logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}')
  2808. def _zero3_consolidated_16bit_state_dict(self):
  2809. """
  2810. Get a full non-partitioned state_dict with fp16 weights on cpu.
  2811. Important: this function must be called on all ranks and not just rank 0.
  2812. This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:
  2813. 1. consolidates the weights from different partitions on gpu0
  2814. 2. works on one layer at a time to require as little gpu0 memory as possible, by
  2815. moving the already consolidated weights to cpu
  2816. 3. takes care to keep the shared params shared when gradually copying the params to cpu
  2817. Returns:
  2818. a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks
  2819. """
  2820. if not self.zero_optimization_partition_weights():
  2821. raise ValueError("this function requires ZeRO-3 mode")
  2822. state_dict = OrderedDict() if dist.get_rank() == 0 else None
  2823. shared_params = {}
  2824. def get_layer_state_dict(module, prefix=""):
  2825. # gather one layer at a time to be memory-efficient
  2826. # must use modifier_rank=0 to release GPU memory after each layer gathered
  2827. #see_memory_usage("before GatheredParameters", force=True)
  2828. with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
  2829. if dist.get_rank() == 0:
  2830. # handle params
  2831. for name, param in module.named_parameters(recurse=False):
  2832. if param is None:
  2833. continue
  2834. key = prefix + name
  2835. # can't rely on param.data_ptr() as it will be reused as weights gets
  2836. # gathered and reduced, but param.ds_id is unique across all zero weights
  2837. # (and shared params will have the same param.ds_id)
  2838. if param.ds_id in shared_params:
  2839. # shared weights
  2840. #print(f"`{key}` is shared with `{shared_params[param.ds_id]}`")
  2841. state_dict[key] = state_dict[shared_params[param.ds_id]]
  2842. else:
  2843. state_dict[key] = param.detach().cpu()
  2844. shared_params[param.ds_id] = key
  2845. #print(f"param {param.ds_id} {param.shape} {key} ")
  2846. # now buffers - not sure if need to take care of potentially shared weights here
  2847. for name, buf in module.named_buffers(recurse=False):
  2848. if (buf is not None and name not in module._non_persistent_buffers_set):
  2849. state_dict[prefix + name] = buf.detach().cpu()
  2850. #see_memory_usage("after GatheredParameters", force=True)
  2851. for name, child in module.named_children():
  2852. if child is not None:
  2853. get_layer_state_dict(child, prefix + name + ".")
  2854. # Prepare for checkpoint save by ensuring all parameters are partitioned
  2855. if self._optimizer_has_ckpt_event_prologue():
  2856. self.optimizer.checkpoint_event_prologue()
  2857. see_memory_usage("before get_layer_state_dict", force=False)
  2858. get_layer_state_dict(self.module, prefix="")
  2859. see_memory_usage("after get_layer_state_dict", force=False)
  2860. if self._optimizer_has_ckpt_event_epilogue():
  2861. self.optimizer.checkpoint_event_epilogue()
  2862. return state_dict
  2863. def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
  2864. """has been renamed to save_16bit_model, keeping this around for backwards
  2865. compatibility"""
  2866. return self.save_16bit_model(save_dir, save_filename)
  2867. def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"):
  2868. """
  2869. Save 16bit model weights
  2870. This method saves the 16bit model weights at the desired destination.
  2871. Arguments:
  2872. save_dir: Required. Directory for saving the model
  2873. save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``
  2874. Returns:
  2875. ``True`` when a model has been saved, ``False`` otherwise. It will not be saved if
  2876. stage3_gather_16bit_weights_on_model_save is ``False``.
  2877. Important: all processes must call this method and not just the process with rank 0. It is
  2878. because the processes need to work in sync to gather the weights. This method will hang
  2879. waiting to synchronize with other processes if it's called just for the process with rank 0.
  2880. """
  2881. path = os.path.join(save_dir, save_filename)
  2882. if self.zero_optimization_partition_weights():
  2883. if self.zero_gather_16bit_weights_on_model_save():
  2884. # consolidation is expensive in time and memory and therefore isn't a default
  2885. state_dict = self._zero3_consolidated_16bit_state_dict()
  2886. else:
  2887. # the model will be bogus if not consolidated so don't confuse the user by saving it
  2888. logger.info(
  2889. f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False")
  2890. return False
  2891. else:
  2892. state_dict = self.module.state_dict()
  2893. tag = f"global_step{self.global_steps}"
  2894. tag = str(tag)
  2895. self.checkpoint_engine.create(tag)
  2896. if dist.get_rank() == 0:
  2897. self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
  2898. logger.info(f"Saving model weights to {path}, tag: {tag}")
  2899. self.checkpoint_engine.save(state_dict, path)
  2900. self.checkpoint_engine.commit(tag)
  2901. return True
  2902. def empty_partition_cache(self):
  2903. """
  2904. Release GPU memory consumed by offloaded model parameters.
  2905. """
  2906. if hasattr(self.optimizer, 'empty_partition_cache'):
  2907. self.optimizer.empty_partition_cache()
  2908. gc.collect()
  2909. get_accelerator().empty_cache()