123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121 |
- #include "custom_cuda_layers.h"
- namespace cg = cooperative_groups;
- /*
- Fused bias add, residual (elementwise) add, and normalization layer.
- For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
- __half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
- For specific launch constraints, see the launch functions.
- */
- #define NORM_REG (MAX_REGISTERS / 4)
- __global__ void fused_bias_residual_layer_norm(float* vals,
- const float* residual,
- const float* gamma,
- const float* beta,
- float epsilon,
- bool preLayerNorm,
- bool training,
- float* vars,
- float* means,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int gid = id / WARP_SIZE;
- float vals_arr[NORM_REG];
- __shared__ float shr[MAX_WARP_NUM];
- residual += (row * row_stride);
- vals += (row * row_stride);
- float sum = 0.f;
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] = residual[i * iteration_stride + id];
- sum += vals_arr[i];
- }
- if (high_index < row_stride) {
- vals_arr[iterations] = residual[high_index];
- sum += vals_arr[iterations];
- iterations++;
- }
- for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) shr[gid] = sum;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
- #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- sum += g.shfl_down(sum, i);
- }
- sum = g.shfl(sum, 0);
- float mean = sum / row_stride;
- if (training)
- if (threadIdx.x == 0) means[row] = mean;
- float variance = 0.f;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] -= mean;
- variance += vals_arr[i] * vals_arr[i];
- }
- for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
- if (g.thread_rank() == 0) shr[gid] = variance;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- variance += g.shfl_down(variance, i);
- }
- variance = g.shfl(variance, 0);
- variance /= row_stride;
- variance += epsilon;
- if (training)
- if (threadIdx.x == 0) vars[row] = variance;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] = vals_arr[i] * rsqrtf(variance);
- vals_arr[i] =
- vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
- vals[i * iteration_stride + id] = vals_arr[i];
- }
- if ((high_index) < row_stride) {
- vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
- vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
- vals[high_index] = vals_arr[iterations];
- }
- }
- __global__ void fused_bias_residual_layer_norm(__half* vals,
- const __half* residual,
- const __half* gamma,
- const __half* beta,
- float epsilon,
- bool preLayerNorm,
- bool training,
- __half* vars,
- __half* means,
- int row_stride)
- {
- #ifdef HALF_PRECISION_AVAILABLE
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int gid = id >> WARP_SIZE_BITS;
- float2 vals_f[NORM_REG];
- __shared__ float shr[MAX_WARP_NUM];
- __half2* vals_cast = reinterpret_cast<__half2*>(vals);
- const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
- residual_cast += (row * row_stride);
- vals_cast += (row * row_stride);
- float sum = 0.f;
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
- sum += vals_f[i].x;
- sum += vals_f[i].y;
- }
- if ((high_index) < row_stride) {
- vals_f[iterations] = __half22float2(residual_cast[high_index]);
- sum += vals_f[iterations].x;
- sum += vals_f[iterations].y;
- iterations++;
- }
- for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) shr[gid] = sum;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- sum += g.shfl_down(sum, i);
- }
- sum = g.shfl(sum, 0);
- float mean = sum / (row_stride * 2);
- float variance = 0.f;
- for (int i = 0; i < iterations; i++) {
- vals_f[i].x -= mean;
- vals_f[i].y -= mean;
- variance += vals_f[i].x * vals_f[i].x;
- variance += vals_f[i].y * vals_f[i].y;
- }
- for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
- if (g.thread_rank() == 0) shr[gid] = variance;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- variance += g.shfl_down(variance, i);
- }
- variance = g.shfl(variance, 0);
- variance /= (row_stride * 2);
- variance += epsilon;
- __half2 variance_h = __float2half2_rn(variance);
- const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
- const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
- if (training && threadIdx.x == 0) {
- vars[row] = __float2half(variance);
- means[row] = __float2half(mean);
- }
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- __half2 vals_arr = __float22half2_rn(vals_f[i]);
- vals_arr = vals_arr * h2rsqrt(variance_h);
- vals_arr =
- vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
- vals_cast[i * iteration_stride + id] = vals_arr;
- }
- if ((high_index) < row_stride) {
- __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
- vals_arr = vals_arr * h2rsqrt(variance_h);
- vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
- vals_cast[high_index] = vals_arr;
- }
- #endif
- }
- template <typename T>
- void launch_bias_residual_layer_norm(T* vals,
- const T* residual,
- const T* gamma,
- const T* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- T* vars,
- T* means);
- template <>
- void launch_bias_residual_layer_norm<float>(float* vals,
- const float* residual,
- const float* gamma,
- const float* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- float* vars,
- float* means)
- {
- int threads = THREADS;
- dim3 grid_dim(batch_size);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim(threads);
- fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
- }
- template <>
- void launch_bias_residual_layer_norm<__half>(__half* vals,
- const __half* residual,
- const __half* gamma,
- const __half* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- __half* vars,
- __half* means)
- {
- int threads = 128;
- dim3 grid_dim(batch_size);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim(threads);
- fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
- }
- __global__ void fused_bias_residual_layer_norm(float* vals,
- const float* residual,
- const float* gamma,
- const float* beta,
- float epsilon,
- bool preLayerNorm,
- bool training,
- float* vars,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int gid = id / 32;
- float vals_arr[NORM_REG];
- __shared__ float shr[MAX_WARP_NUM];
- residual += (row * row_stride);
- vals += (row * row_stride);
- float sum = 0.f;
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] = residual[i * iteration_stride + id];
- sum += vals_arr[i];
- }
- if ((high_index) < row_stride) {
- vals_arr[iterations] = residual[high_index];
- sum += vals_arr[iterations];
- iterations++;
- }
- for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) shr[gid] = sum;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
- #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- sum += g.shfl_down(sum, i);
- }
- sum = g.shfl(sum, 0);
- float mean = sum / row_stride;
- float variance = 0.f;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] -= mean;
- variance += vals_arr[i] * vals_arr[i];
- }
- for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
- if (g.thread_rank() == 0) shr[gid] = variance;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- variance += g.shfl_down(variance, i);
- }
- variance = g.shfl(variance, 0);
- variance /= row_stride;
- variance += epsilon;
- if (training)
- if (threadIdx.x == 0) vars[row] = variance;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] = vals_arr[i] * rsqrtf(variance);
- vals_arr[i] =
- vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
- vals[i * iteration_stride + id] = vals_arr[i];
- }
- if ((high_index) < row_stride) {
- vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
- vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
- vals[high_index] = vals_arr[iterations];
- }
- }
- __global__ void fused_bias_residual_layer_norm(__half* vals,
- const __half* residual,
- const __half* gamma,
- const __half* beta,
- float epsilon,
- bool preLayerNorm,
- bool training,
- __half* vars,
- int row_stride)
- {
- #ifdef HALF_PRECISION_AVAILABLE
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int gid = id >> WARP_SIZE_BITS;
- float2 vals_f[NORM_REG];
- __shared__ float shr[MAX_WARP_NUM];
- __half2* vals_cast = reinterpret_cast<__half2*>(vals);
- const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
- residual_cast += (row * row_stride);
- vals_cast += (row * row_stride);
- float sum = 0.f;
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
- sum += vals_f[i].x;
- sum += vals_f[i].y;
- }
- if ((high_index) < row_stride) {
- vals_f[iterations] = __half22float2(residual_cast[high_index]);
- sum += vals_f[iterations].x;
- sum += vals_f[iterations].y;
- iterations++;
- }
- for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) shr[gid] = sum;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- sum += g.shfl_down(sum, i);
- }
- sum = g.shfl(sum, 0);
- float mean = sum / (row_stride * 2);
- float variance = 0.f;
- for (int i = 0; i < iterations; i++) {
- vals_f[i].x -= mean;
- vals_f[i].y -= mean;
- variance += vals_f[i].x * vals_f[i].x;
- variance += vals_f[i].y * vals_f[i].y;
- }
- for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
- if (g.thread_rank() == 0) shr[gid] = variance;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- variance += g.shfl_down(variance, i);
- }
- variance = g.shfl(variance, 0);
- variance /= (row_stride * 2);
- variance += epsilon;
- __half2 variance_h = __float2half2_rn(variance);
- const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
- const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
- if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- __half2 vals_arr = __float22half2_rn(vals_f[i]);
- vals_arr = vals_arr * h2rsqrt(variance_h);
- vals_arr =
- vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
- vals_cast[i * iteration_stride + id] = vals_arr;
- }
- if ((high_index) < row_stride) {
- __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
- vals_arr = vals_arr * h2rsqrt(variance_h);
- vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
- vals_cast[high_index] = vals_arr;
- }
- #endif
- }
- template <typename T>
- void launch_bias_residual_layer_norm(T* vals,
- const T* residual,
- const T* gamma,
- const T* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- T* vars);
- /*
- To tune this launch the following restrictions must be met:
- For float:
- row_stride == hidden_size
- threads * iterations == row_stride
- threads is in [32, 64, 128, 256, 512, 1024]
- For half:
- row_stride == hidden_size / 2
- threads * iterations == row_stride
- threads is in [32, 64, 128, 256, 512, 1024]
- */
- template <>
- void launch_bias_residual_layer_norm<float>(float* vals,
- const float* residual,
- const float* gamma,
- const float* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- float* vars)
- {
- int threads = THREADS;
- dim3 grid_dim(batch_size);
- // There are some limitations to call below functions, now just enumerate the situations.
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim(threads);
- fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
- }
- template <>
- void launch_bias_residual_layer_norm<__half>(__half* vals,
- const __half* residual,
- const __half* gamma,
- const __half* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- __half* vars)
- {
- int threads = 128;
- dim3 grid_dim(batch_size);
- // There are some limitations to call below functions, now just enumerate the situations.
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim(threads);
- fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
- }
- /* Normalize Gamma & Betta gradients
- * Compute gradients using either X_hat or
- * normalize input (invertible).
- * Combine transpose with gradients computation.
- */
- template <typename T>
- __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
- const T* __restrict__ vals_hat,
- const T* __restrict__ gamma,
- const T* __restrict__ betta,
- T* __restrict__ gamma_grad,
- T* __restrict__ betta_grad,
- int rows,
- int width,
- bool invertible)
- {
- __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
- __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
- int idx = blockDim.x * blockIdx.x + threadIdx.x;
- int offset = threadIdx.y * width + idx;
- int y_stride = width * TILE_DIM;
- float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
- float gamma_reg = (float)gamma[idx];
- // Loop across matrix height
- float betta_tmp = 0;
- float gamma_tmp = 0;
- for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
- float grad = (float)out_grad[offset];
- float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
- : (float)vals_hat[offset]);
- betta_tmp += grad;
- gamma_tmp += (val * grad);
- offset += y_stride;
- }
- betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
- gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
- __syncthreads();
- // Sum the shared buffer.
- float s1 = betta_buffer[threadIdx.y][threadIdx.x];
- float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < TILE_DIM; i <<= 1) {
- s1 += g.shfl_down(s1, i);
- s2 += g.shfl_down(s2, i);
- }
- if (threadIdx.x == 0) {
- int pos = blockIdx.x * TILE_DIM + threadIdx.y;
- betta_grad[pos] = s1;
- gamma_grad[pos] = s2;
- }
- }
- /* Normalize Gamma & Betta gradients
- * Compute gradients using the input to
- * the normalize.
- * Combine transpose with gradients computation.
- */
- template <typename T>
- __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
- const T* __restrict__ X_data,
- const T* __restrict__ vars,
- const T* __restrict__ means,
- T* __restrict__ gamma_grad,
- T* __restrict__ betta_grad,
- int rows,
- int width)
- {
- __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
- __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
- int idx = blockDim.x * blockIdx.x + threadIdx.x;
- int offset = threadIdx.y * width + idx;
- int y_stride = width * TILE_DIM;
- int pos = blockIdx.x * TILE_DIM + threadIdx.y;
- // Loop across matrix height
- float betta_tmp = 0;
- float gamma_tmp = 0;
- for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
- float grad = (float)out_grad[offset];
- float val = (float)X_data[offset];
- val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
- betta_tmp += grad;
- gamma_tmp += (val * grad);
- offset += y_stride;
- }
- betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
- gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
- __syncthreads();
- // Sum the shared buffer.
- float s1 = betta_buffer[threadIdx.y][threadIdx.x];
- float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < TILE_DIM; i <<= 1) {
- s1 += g.shfl_down(s1, i);
- s2 += g.shfl_down(s2, i);
- }
- if (threadIdx.x == 0) {
- betta_grad[pos] = s1;
- gamma_grad[pos] = s2;
- }
- }
- /*
- /* Backward Normalize (Input-Gradient)
- * Using the means and variances from the input
- * This type of backward is invertible!
- * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
- */
- __global__ void LayerNormBackward2(const float* out_grad,
- const float* vals_hat,
- const float* gamma,
- const float* betta,
- const float* vars,
- float* inp_grad,
- bool invertible,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- out_grad += (row * row_stride);
- vals_hat += (row * row_stride);
- inp_grad += (row * row_stride);
- float vals_arr[NORM_REG];
- float vals_hat_arr[NORM_REG];
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- float gamma_reg = gamma[i * iteration_stride + id];
- vals_arr[i] = out_grad[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- vals_hat_arr[i] =
- (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
- gamma_reg
- : vals_hat[i * iteration_stride + id]);
- }
- if ((high_index) < row_stride) {
- float gamma_reg = gamma[high_index];
- vals_arr[iterations] = out_grad[high_index];
- vals_arr[iterations] *= gamma_reg;
- vals_hat_arr[iterations] =
- (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
- : vals_hat[high_index]);
- iterations++;
- }
- float var_reg = vars[row];
- float sum = 0;
- for (int i = 0; i < iterations; i++) {
- sum += vals_hat_arr[i] * vals_arr[i] *
- sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
- vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
- sum = 0;
- for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
- if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
- }
- __global__ void LayerNormBackward2(const __half* out_grad,
- const __half* vals_hat,
- const __half* gamma,
- const __half* betta,
- const __half* vars,
- __half* inp_grad,
- bool invertible,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- __half2 vals_arr[NORM_REG];
- float2 vals_arr_f[NORM_REG];
- __half2 vals_hat_arr[NORM_REG];
- __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
- const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
- const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
- inp_grad_h += (row * row_stride);
- out_grad_h += (row * row_stride);
- vals_hat_h += (row * row_stride);
- const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
- const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- __half2 gamma_reg = gamma_h[i * iteration_stride + id];
- vals_arr[i] = out_grad_h[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- vals_hat_arr[i] =
- (invertible
- ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
- gamma_reg
- : vals_hat_h[i * iteration_stride + id]);
- }
- if ((high_index) < row_stride) {
- __half2 gamma_reg = gamma_h[high_index];
- vals_arr[iterations] = out_grad_h[high_index];
- vals_arr[iterations] *= gamma_reg;
- vals_hat_arr[iterations] =
- (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
- : vals_hat_h[high_index]);
- iterations++;
- }
- __half var_h = vars[row];
- __half2 var_reg = __halves2half2(var_h, var_h);
- float sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
- float2 result_f = __half22float2(result_h);
- sum += result_f.x;
- sum += result_f.y;
- vals_arr[i] *= h2rsqrt(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- __half2 sum_h = __float2half2_rn(sum);
- for (int i = 0; i < iterations; i++) {
- __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
- vals_arr_f[i] = __half22float2(vals_arr[i]);
- float2 temp_f = __half22float2(temp);
- vals_arr_f[i].x += temp_f.x;
- vals_arr_f[i].y += temp_f.y;
- }
- sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- sum += (vals_arr_f[i].x);
- sum += (vals_arr_f[i].y);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr_f[i].x -= sum;
- vals_arr_f[i].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[i]);
- inp_grad_h[i * iteration_stride + id] = temp;
- }
- if ((high_index) < row_stride) {
- vals_arr_f[iterations].x -= sum;
- vals_arr_f[iterations].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
- inp_grad_h[high_index] = temp;
- }
- }
- template <>
- void launch_layerNorm_backward<float>(const float* out_grad,
- const float* vals_hat,
- const float* vars,
- const float* gamma,
- float* gamma_grad,
- float* betta_grad,
- float* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible,
- const float* betta)
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
- dim3 grid_dim2(batch);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads);
- LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
- }
- template <>
- void launch_layerNorm_backward<__half>(const __half* out_grad,
- const __half* vals_hat,
- const __half* vars,
- const __half* gamma,
- __half* gamma_grad,
- __half* betta_grad,
- __half* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible,
- const __half* betta)
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- // LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
- // out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
- dim3 grid_dim2(batch);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads / 2);
- LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
- }
- /* Backward Normalize (Input-Gradient)
- * Using the means and variances from the input
- * This type of backward is not invertible!
- * We do the backward using the input (X)
- */
- __global__ void LayerNormBackward2(const float* out_grad,
- const float* X_vals,
- const float* gamma,
- const float* vars,
- const float* means,
- float* inp_grad,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id >> WARP_SIZE_BITS;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- out_grad += (row * row_stride);
- X_vals += (row * row_stride);
- inp_grad += (row * row_stride);
- float vals_arr[NORM_REG];
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- float gamma_reg = gamma[i * iteration_stride + id];
- vals_arr[i] = out_grad[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- }
- if ((high_index) < row_stride) {
- float gamma_reg = gamma[high_index];
- vals_arr[iterations] = out_grad[high_index];
- vals_arr[iterations] *= gamma_reg;
- iterations++;
- }
- float var_reg = vars[row];
- float mean_reg = means[row];
- float sum = 0;
- float xu[NORM_REG];
- for (int i = 0; i < iterations; i++) {
- xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
- sum += vals_arr[i] * xu[i];
- vals_arr[i] *= rsqrtf(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
- }
- sum = 0;
- for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
- if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
- }
- __global__ void LayerNormBackward2(const __half* out_grad,
- const __half* X_vals,
- const __half* gamma,
- const __half* vars,
- const __half* means,
- __half* inp_grad,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id >> WARP_SIZE_BITS;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- __half2 vals_arr[NORM_REG];
- float2 vals_arr_f[NORM_REG];
- __half2 xu[NORM_REG];
- __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
- const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
- const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
- inp_grad_h += (row * row_stride);
- out_grad_h += (row * row_stride);
- vals_hat_h += (row * row_stride);
- const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
- int high_index = iterations * iteration_stride + id;
- __half mean_h = means[row];
- __half2 mean_reg = __halves2half2(mean_h, mean_h);
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- __half2 gamma_reg = gamma_h[i * iteration_stride + id];
- vals_arr[i] = out_grad_h[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg; // out_grad * gamma
- xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
- }
- if ((high_index) < row_stride) {
- __half2 gamma_reg = gamma_h[high_index];
- vals_arr[iterations] = out_grad_h[high_index];
- vals_arr[iterations] *= gamma_reg; // out_grad * gamma
- xu[iterations] = (vals_hat_h[high_index] - mean_reg);
- iterations++;
- }
- __half var_h = vars[row];
- __half2 var_reg = __halves2half2(var_h, var_h);
- float sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- __half2 result_h = (xu[i] * vals_arr[i]);
- float2 result_f = __half22float2(result_h);
- sum += result_f.x;
- sum += result_f.y;
- vals_arr[i] *= h2rsqrt(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- __half2 sum_h = __float2half2_rn(sum);
- for (int i = 0; i < iterations; i++) {
- __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
- vals_arr_f[i] = __half22float2(vals_arr[i]);
- float2 xu_grad_f = __half22float2(xu_grad);
- vals_arr_f[i].x += xu_grad_f.x;
- vals_arr_f[i].y += xu_grad_f.y;
- }
- sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- sum += (vals_arr_f[i].x);
- sum += (vals_arr_f[i].y);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr_f[i].x -= sum;
- vals_arr_f[i].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[i]);
- inp_grad_h[i * iteration_stride + id] = temp;
- }
- if ((high_index) < row_stride) {
- vals_arr_f[iterations].x -= sum;
- vals_arr_f[iterations].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
- inp_grad_h[high_index] = temp;
- }
- }
- template <>
- void launch_layerNorm_backward<float>(const float* out_grad,
- const float* X_data,
- const float* vars,
- const float* means,
- const float* gamma,
- float* gamma_grad,
- float* betta_grad,
- float* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2])
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
- dim3 grid_dim2(batch);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads);
- LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
- }
- template <>
- void launch_layerNorm_backward<__half>(const __half* out_grad,
- const __half* X_data,
- const __half* vars,
- const __half* means,
- const __half* gamma,
- __half* gamma_grad,
- __half* betta_grad,
- __half* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2])
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
- dim3 grid_dim2(batch);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads / 2);
- LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
- }
- template <typename T>
- __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
- const T* __restrict__ out_grad2,
- const T* __restrict__ vals_hat,
- const T* __restrict__ gamma,
- const T* __restrict__ betta,
- T* __restrict__ gamma_grad,
- T* __restrict__ betta_grad,
- int rows,
- int width,
- bool invertible)
- {
- __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
- __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
- int idx = blockDim.x * blockIdx.x + threadIdx.x;
- int offset = threadIdx.y * width + idx;
- int y_stride = width * TILE_DIM;
- float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
- float gamma_reg = (float)gamma[idx];
- // Loop across matrix height
- float betta_tmp = 0;
- float gamma_tmp = 0;
- for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
- float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
- float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
- : (float)vals_hat[offset]);
- betta_tmp += grad;
- gamma_tmp += (val * grad);
- offset += y_stride;
- }
- betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
- gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
- __syncthreads();
- // Sum the shared buffer.
- float s1 = betta_buffer[threadIdx.y][threadIdx.x];
- float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < TILE_DIM; i <<= 1) {
- s1 += g.shfl_down(s1, i);
- s2 += g.shfl_down(s2, i);
- }
- if (threadIdx.x == 0) {
- int pos = blockIdx.x * TILE_DIM + threadIdx.y;
- betta_grad[pos] = s1;
- gamma_grad[pos] = s2;
- }
- }
- template <typename T>
- __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
- const T* __restrict__ out_grad2,
- const T* __restrict__ X_data,
- const T* __restrict__ vars,
- const T* __restrict__ means,
- T* __restrict__ gamma_grad,
- T* __restrict__ betta_grad,
- int rows,
- int width)
- {
- __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
- __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
- int idx = blockDim.x * blockIdx.x + threadIdx.x;
- int offset = threadIdx.y * width + idx;
- int y_stride = width * TILE_DIM;
- int pos = blockIdx.x * TILE_DIM + threadIdx.y;
- // Loop across matrix height
- float betta_tmp = 0;
- float gamma_tmp = 0;
- for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
- float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
- float val = (float)X_data[offset];
- val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
- betta_tmp += grad;
- gamma_tmp += (val * grad);
- offset += y_stride;
- }
- betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
- gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
- __syncthreads();
- // Sum the shared buffer.
- float s1 = betta_buffer[threadIdx.y][threadIdx.x];
- float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < TILE_DIM; i <<= 1) {
- s1 += g.shfl_down(s1, i);
- s2 += g.shfl_down(s2, i);
- }
- if (threadIdx.x == 0) {
- betta_grad[pos] = s1;
- gamma_grad[pos] = s2;
- }
- }
- __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
- const float* out_grad2,
- const float* vals_hat,
- const float* gamma,
- const float* betta,
- const float* vars,
- float* inp_grad,
- bool invertible,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- out_grad1 += (row * row_stride);
- out_grad2 += (row * row_stride);
- vals_hat += (row * row_stride);
- inp_grad += (row * row_stride);
- float vals_arr[NORM_REG];
- float vals_hat_arr[NORM_REG];
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- float gamma_reg = gamma[i * iteration_stride + id];
- vals_arr[i] = out_grad1[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- vals_hat_arr[i] =
- (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
- gamma_reg
- : vals_hat[i * iteration_stride + id]);
- }
- if ((high_index) < row_stride) {
- float gamma_reg = gamma[high_index];
- vals_arr[iterations] = out_grad1[high_index];
- vals_arr[iterations] *= gamma_reg;
- vals_hat_arr[iterations] =
- (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
- : vals_hat[high_index]);
- iterations++;
- }
- float var_reg = vars[row];
- float sum = 0;
- for (int i = 0; i < iterations; i++) {
- sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
- vals_arr[i] *= rsqrtf(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
- sum = 0;
- for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++)
- inp_grad[i * iteration_stride + id] =
- (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
- if ((high_index) < row_stride)
- inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
- }
- __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
- const __half* out_grad2,
- const __half* vals_hat,
- const __half* gamma,
- const __half* betta,
- const __half* vars,
- __half* inp_grad,
- bool invertible,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- __half2 vals_arr[NORM_REG];
- float2 vals_arr_f[NORM_REG];
- __half2 vals_hat_arr[NORM_REG];
- // float2 result[iterations];
- __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
- const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
- const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
- const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
- inp_grad_h += (row * row_stride);
- out_grad_h1 += (row * row_stride);
- out_grad_h2 += (row * row_stride);
- vals_hat_h += (row * row_stride);
- const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
- const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- __half2 gamma_reg = gamma_h[i * iteration_stride + id];
- vals_arr[i] = out_grad_h1[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg; // out_grad * gamma
- vals_hat_arr[i] =
- (invertible
- ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
- gamma_reg
- : vals_hat_h[i * iteration_stride + id]);
- }
- if ((high_index) < row_stride) {
- __half2 gamma_reg = gamma_h[high_index];
- vals_arr[iterations] = out_grad_h1[high_index];
- vals_arr[iterations] *= gamma_reg; // out_grad * gamma
- vals_hat_arr[iterations] =
- (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
- : vals_hat_h[high_index]);
- iterations++;
- }
- __half var_h = vars[row];
- __half2 var_reg = __halves2half2(var_h, var_h);
- float sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
- float2 result_f = __half22float2(result_h);
- sum += result_f.x;
- sum += result_f.y;
- vals_arr[i] *= h2rsqrt(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- __half2 sum_h = __float2half2_rn(sum);
- for (int i = 0; i < iterations; i++) {
- __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
- vals_arr_f[i] = __half22float2(vals_arr[i]);
- float2 temp_f = __half22float2(temp);
- vals_arr_f[i].x += temp_f.x;
- vals_arr_f[i].y += temp_f.y;
- }
- sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- sum += (vals_arr_f[i].x);
- sum += (vals_arr_f[i].y);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr_f[i].x -= sum;
- vals_arr_f[i].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[i]);
- inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
- }
- if ((high_index) < row_stride) {
- vals_arr_f[iterations].x -= sum;
- vals_arr_f[iterations].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
- inp_grad_h[high_index] = temp + out_grad_h2[high_index];
- }
- }
- template <>
- void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
- const float* out_grad2,
- const float* vals_hat,
- const float* vars,
- const float* gamma,
- float* gamma_grad,
- float* betta_grad,
- float* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible,
- const float* betta)
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
- dim3 grid_dim2(batch);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads);
- LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
- }
- template <>
- void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
- const __half* out_grad2,
- const __half* vals_hat,
- const __half* vars,
- const __half* gamma,
- __half* gamma_grad,
- __half* betta_grad,
- __half* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible,
- const __half* betta)
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
- dim3 grid_dim2(batch);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads / 2);
- LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
- }
- /* Backward Normalize (Input-Gradient)
- * Using the means and variances from the input
- * This type of backward is not invertible!
- * We do the backward using the input (X)
- */
- __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
- const float* out_grad2,
- const float* X_vals,
- const float* gamma,
- const float* vars,
- const float* means,
- float* inp_grad,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- float vals_arr[NORM_REG];
- float vals_hat_arr[NORM_REG];
- out_grad1 += (row * row_stride);
- out_grad2 += (row * row_stride);
- X_vals += (row * row_stride);
- inp_grad += (row * row_stride);
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- float gamma_reg = gamma[i * iteration_stride + id];
- vals_arr[i] = out_grad1[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- vals_hat_arr[i] = X_vals[i * iteration_stride + id];
- }
- if ((high_index) < row_stride) {
- float gamma_reg = gamma[high_index];
- vals_arr[iterations] = out_grad1[high_index];
- vals_arr[iterations] *= gamma_reg;
- vals_hat_arr[iterations] = X_vals[high_index];
- iterations++;
- }
- float var_reg = vars[row];
- float mean_reg = means[row];
- float sum = 0;
- float xu[NORM_REG];
- for (int i = 0; i < iterations; i++) {
- xu[i] = (vals_hat_arr[i] - mean_reg);
- sum += vals_arr[i] * xu[i];
- vals_arr[i] *= rsqrtf(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
- }
- sum = 0;
- for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++)
- inp_grad[i * iteration_stride + id] =
- (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
- if ((high_index) < row_stride)
- inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
- }
- __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
- const __half* out_grad2,
- const __half* X_vals,
- const __half* gamma,
- const __half* vars,
- const __half* means,
- __half* inp_grad,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- __half2 vals_arr[NORM_REG];
- float2 vals_arr_f[NORM_REG];
- __half2 vals_hat_arr[NORM_REG];
- __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
- const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
- const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
- const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
- out_grad_h1 += (row * row_stride);
- out_grad_h2 += (row * row_stride);
- inp_grad_h += (row * row_stride);
- vals_hat_h += (row * row_stride);
- const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- __half2 gamma_reg = gamma_h[i * iteration_stride + id];
- vals_arr[i] = out_grad_h1[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg; // out_grad * gamma
- vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
- }
- if ((high_index) < row_stride) {
- __half2 gamma_reg = gamma_h[high_index];
- vals_arr[iterations] = out_grad_h1[high_index];
- vals_arr[iterations] *= gamma_reg; // out_grad * gamma
- vals_hat_arr[iterations] = vals_hat_h[high_index];
- iterations++;
- }
- __half mean_h = means[row];
- __half var_h = vars[row];
- __half2 var_reg = __halves2half2(var_h, var_h);
- __half2 mean_reg = __halves2half2(mean_h, mean_h);
- __half2 xu[NORM_REG];
- float sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- xu[i] = (vals_hat_arr[i] - mean_reg);
- __half2 result_h = (xu[i] * vals_arr[i]);
- float2 result_f = __half22float2(result_h);
- sum += result_f.x;
- sum += result_f.y;
- vals_arr[i] *= h2rsqrt(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- __half2 sum_h = __float2half2_rn(sum);
- for (int i = 0; i < iterations; i++) {
- __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
- vals_arr_f[i] = __half22float2(vals_arr[i]);
- float2 xu_grad_f = __half22float2(xu_grad);
- vals_arr_f[i].x += xu_grad_f.x;
- vals_arr_f[i].y += xu_grad_f.y;
- }
- sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- sum += (vals_arr_f[i].x);
- sum += (vals_arr_f[i].y);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr_f[i].x -= sum;
- vals_arr_f[i].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[i]);
- inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
- }
- if ((high_index) < row_stride) {
- vals_arr_f[iterations].x -= sum;
- vals_arr_f[iterations].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
- inp_grad_h[high_index] = temp + out_grad_h2[high_index];
- }
- }
- template <>
- void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
- const float* out_grad2,
- const float* X_data,
- const float* vars,
- const float* means,
- const float* gamma,
- float* gamma_grad,
- float* betta_grad,
- float* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2])
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
- dim3 grid_dim2(batch);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads);
- LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
- }
- template <>
- void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
- const __half* out_grad2,
- const __half* X_data,
- const __half* vars,
- const __half* means,
- const __half* gamma,
- __half* gamma_grad,
- __half* betta_grad,
- __half* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2])
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
- dim3 grid_dim2(batch);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads / 2);
- LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
- }
|