1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134 |
- // Copyright (c) Microsoft Corporation.
- // SPDX-License-Identifier: Apache-2.0
- // DeepSpeed Team
- #include "custom_cuda_layers.h"
- namespace cg = cooperative_groups;
- /*
- Fused bias add, residual (elementwise) add, and normalization layer.
- For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
- __half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
- For specific launch constraints, see the launch functions.
- */
- #define NORM_REG (MAX_REGISTERS / 4)
- __global__ void fused_bias_residual_layer_norm(float* vals,
- const float* residual,
- const float* gamma,
- const float* beta,
- float epsilon,
- bool preLayerNorm,
- bool training,
- float* vars,
- float* means,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int gid = id / WARP_SIZE;
- float vals_arr[NORM_REG];
- __shared__ float shr[MAX_WARP_NUM];
- residual += (row * row_stride);
- vals += (row * row_stride);
- float sum = 0.f;
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] = residual[i * iteration_stride + id];
- sum += vals_arr[i];
- }
- if (high_index < row_stride) {
- vals_arr[iterations] = residual[high_index];
- sum += vals_arr[iterations];
- iterations++;
- }
- for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) shr[gid] = sum;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
- #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- sum += g.shfl_down(sum, i);
- }
- sum = g.shfl(sum, 0);
- float mean = sum / row_stride;
- if (training)
- if (threadIdx.x == 0) means[row] = mean;
- float variance = 0.f;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] -= mean;
- variance += vals_arr[i] * vals_arr[i];
- }
- for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
- if (g.thread_rank() == 0) shr[gid] = variance;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- variance += g.shfl_down(variance, i);
- }
- variance = g.shfl(variance, 0);
- variance /= row_stride;
- variance += epsilon;
- if (training)
- if (threadIdx.x == 0) vars[row] = variance;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] = vals_arr[i] * rsqrtf(variance);
- vals_arr[i] =
- vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
- vals[i * iteration_stride + id] = vals_arr[i];
- }
- if ((high_index) < row_stride) {
- vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
- vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
- vals[high_index] = vals_arr[iterations];
- }
- }
- __global__ void fused_bias_residual_layer_norm(__half* vals,
- const __half* residual,
- const __half* gamma,
- const __half* beta,
- float epsilon,
- bool preLayerNorm,
- bool training,
- __half* vars,
- __half* means,
- int row_stride)
- {
- #ifdef HALF_PRECISION_AVAILABLE
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int gid = id >> WARP_SIZE_BITS;
- float2 vals_f[NORM_REG];
- __shared__ float shr[MAX_WARP_NUM];
- __half2* vals_cast = reinterpret_cast<__half2*>(vals);
- const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
- residual_cast += (row * row_stride);
- vals_cast += (row * row_stride);
- float sum = 0.f;
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
- sum += vals_f[i].x;
- sum += vals_f[i].y;
- }
- if ((high_index) < row_stride) {
- vals_f[iterations] = __half22float2(residual_cast[high_index]);
- sum += vals_f[iterations].x;
- sum += vals_f[iterations].y;
- iterations++;
- }
- for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) shr[gid] = sum;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- sum += g.shfl_down(sum, i);
- }
- sum = g.shfl(sum, 0);
- float mean = sum / (row_stride * 2);
- float variance = 0.f;
- for (int i = 0; i < iterations; i++) {
- vals_f[i].x -= mean;
- vals_f[i].y -= mean;
- variance += vals_f[i].x * vals_f[i].x;
- variance += vals_f[i].y * vals_f[i].y;
- }
- for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
- if (g.thread_rank() == 0) shr[gid] = variance;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- variance += g.shfl_down(variance, i);
- }
- variance = g.shfl(variance, 0);
- variance /= (row_stride * 2);
- variance += epsilon;
- __half2 variance_h = __float2half2_rn(variance);
- const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
- const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
- if (training && threadIdx.x == 0) {
- vars[row] = __float2half(variance);
- means[row] = __float2half(mean);
- }
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- __half2 vals_arr = __float22half2_rn(vals_f[i]);
- vals_arr = vals_arr * h2rsqrt(variance_h);
- vals_arr =
- vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
- vals_cast[i * iteration_stride + id] = vals_arr;
- }
- if ((high_index) < row_stride) {
- __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
- vals_arr = vals_arr * h2rsqrt(variance_h);
- vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
- vals_cast[high_index] = vals_arr;
- }
- #endif
- }
- template <typename T>
- void launch_bias_residual_layer_norm(T* vals,
- const T* residual,
- const T* gamma,
- const T* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- T* vars,
- T* means);
- template <>
- void launch_bias_residual_layer_norm<float>(float* vals,
- const float* residual,
- const float* gamma,
- const float* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- float* vars,
- float* means)
- {
- int threads = THREADS;
- dim3 grid_dim(batch_size);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim(threads);
- fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
- }
- template <>
- void launch_bias_residual_layer_norm<__half>(__half* vals,
- const __half* residual,
- const __half* gamma,
- const __half* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- __half* vars,
- __half* means)
- {
- int threads = 128;
- dim3 grid_dim(batch_size);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim(threads);
- fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
- }
- __global__ void fused_bias_residual_layer_norm(float* vals,
- const float* residual,
- const float* gamma,
- const float* beta,
- float epsilon,
- bool preLayerNorm,
- bool training,
- float* vars,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int gid = id / 32;
- float vals_arr[NORM_REG];
- __shared__ float shr[MAX_WARP_NUM];
- residual += (row * row_stride);
- vals += (row * row_stride);
- float sum = 0.f;
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] = residual[i * iteration_stride + id];
- sum += vals_arr[i];
- }
- if ((high_index) < row_stride) {
- vals_arr[iterations] = residual[high_index];
- sum += vals_arr[iterations];
- iterations++;
- }
- for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) shr[gid] = sum;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
- #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- sum += g.shfl_down(sum, i);
- }
- sum = g.shfl(sum, 0);
- float mean = sum / row_stride;
- float variance = 0.f;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] -= mean;
- variance += vals_arr[i] * vals_arr[i];
- }
- for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
- if (g.thread_rank() == 0) shr[gid] = variance;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- variance += g.shfl_down(variance, i);
- }
- variance = g.shfl(variance, 0);
- variance /= row_stride;
- variance += epsilon;
- if (training)
- if (threadIdx.x == 0) vars[row] = variance;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] = vals_arr[i] * rsqrtf(variance);
- vals_arr[i] =
- vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
- vals[i * iteration_stride + id] = vals_arr[i];
- }
- if ((high_index) < row_stride) {
- vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
- vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
- vals[high_index] = vals_arr[iterations];
- }
- }
- __global__ void fused_bias_residual_layer_norm(__half* vals,
- const __half* residual,
- const __half* gamma,
- const __half* beta,
- float epsilon,
- bool preLayerNorm,
- bool training,
- __half* vars,
- int row_stride)
- {
- #ifdef HALF_PRECISION_AVAILABLE
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int gid = id >> WARP_SIZE_BITS;
- float2 vals_f[NORM_REG];
- __shared__ float shr[MAX_WARP_NUM];
- __half2* vals_cast = reinterpret_cast<__half2*>(vals);
- const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
- residual_cast += (row * row_stride);
- vals_cast += (row * row_stride);
- float sum = 0.f;
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
- sum += vals_f[i].x;
- sum += vals_f[i].y;
- }
- if ((high_index) < row_stride) {
- vals_f[iterations] = __half22float2(residual_cast[high_index]);
- sum += vals_f[iterations].x;
- sum += vals_f[iterations].y;
- iterations++;
- }
- for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) shr[gid] = sum;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- sum += g.shfl_down(sum, i);
- }
- sum = g.shfl(sum, 0);
- float mean = sum / (row_stride * 2);
- float variance = 0.f;
- for (int i = 0; i < iterations; i++) {
- vals_f[i].x -= mean;
- vals_f[i].y -= mean;
- variance += vals_f[i].x * vals_f[i].x;
- variance += vals_f[i].y * vals_f[i].y;
- }
- for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
- if (g.thread_rank() == 0) shr[gid] = variance;
- b.sync();
- if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- b.sync();
- #endif
- for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
- variance += g.shfl_down(variance, i);
- }
- variance = g.shfl(variance, 0);
- variance /= (row_stride * 2);
- variance += epsilon;
- __half2 variance_h = __float2half2_rn(variance);
- const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
- const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
- if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- __half2 vals_arr = __float22half2_rn(vals_f[i]);
- vals_arr = vals_arr * h2rsqrt(variance_h);
- vals_arr =
- vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
- vals_cast[i * iteration_stride + id] = vals_arr;
- }
- if ((high_index) < row_stride) {
- __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
- vals_arr = vals_arr * h2rsqrt(variance_h);
- vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
- vals_cast[high_index] = vals_arr;
- }
- #endif
- }
- template <typename T>
- void launch_bias_residual_layer_norm(T* vals,
- const T* residual,
- const T* gamma,
- const T* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- T* vars);
- /*
- To tune this launch the following restrictions must be met:
- For float:
- row_stride == hidden_size
- threads * iterations == row_stride
- threads is in [32, 64, 128, 256, 512, 1024]
- For half:
- row_stride == hidden_size / 2
- threads * iterations == row_stride
- threads is in [32, 64, 128, 256, 512, 1024]
- */
- template <>
- void launch_bias_residual_layer_norm<float>(float* vals,
- const float* residual,
- const float* gamma,
- const float* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- float* vars)
- {
- int threads = THREADS;
- dim3 grid_dim(batch_size);
- // There are some limitations to call below functions, now just enumerate the situations.
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim(threads);
- fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
- }
- template <>
- void launch_bias_residual_layer_norm<__half>(__half* vals,
- const __half* residual,
- const __half* gamma,
- const __half* beta,
- float epsilon,
- int batch_size,
- int hidden_dim,
- cudaStream_t stream,
- bool preLayerNorm,
- bool training,
- __half* vars)
- {
- int threads = 128;
- dim3 grid_dim(batch_size);
- // There are some limitations to call below functions, now just enumerate the situations.
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim(threads);
- fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
- vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
- }
- /* Normalize Gamma & Betta gradients
- * Compute gradients using either X_hat or
- * normalize input (invertible).
- * Combine transpose with gradients computation.
- */
- template <typename T>
- __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
- const T* __restrict__ vals_hat,
- const T* __restrict__ gamma,
- const T* __restrict__ betta,
- T* __restrict__ gamma_grad,
- T* __restrict__ betta_grad,
- int rows,
- int width,
- bool invertible)
- {
- __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
- __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
- int idx = blockDim.x * blockIdx.x + threadIdx.x;
- int offset = threadIdx.y * width + idx;
- int y_stride = width * TILE_DIM;
- float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
- float gamma_reg = (float)gamma[idx];
- // Loop across matrix height
- float betta_tmp = 0;
- float gamma_tmp = 0;
- for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
- float grad = (float)out_grad[offset];
- float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
- : (float)vals_hat[offset]);
- betta_tmp += grad;
- gamma_tmp += (val * grad);
- offset += y_stride;
- }
- betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
- gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
- __syncthreads();
- // Sum the shared buffer.
- float s1 = betta_buffer[threadIdx.y][threadIdx.x];
- float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < TILE_DIM; i <<= 1) {
- s1 += g.shfl_down(s1, i);
- s2 += g.shfl_down(s2, i);
- }
- if (threadIdx.x == 0) {
- int pos = blockIdx.x * TILE_DIM + threadIdx.y;
- betta_grad[pos] = s1;
- gamma_grad[pos] = s2;
- }
- }
- /* Normalize Gamma & Betta gradients
- * Compute gradients using the input to
- * the normalize.
- * Combine transpose with gradients computation.
- */
- template <typename T>
- __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
- const T* __restrict__ X_data,
- const T* __restrict__ vars,
- const T* __restrict__ means,
- T* __restrict__ gamma_grad,
- T* __restrict__ betta_grad,
- int rows,
- int width)
- {
- __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
- __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
- int idx = blockDim.x * blockIdx.x + threadIdx.x;
- int offset = threadIdx.y * width + idx;
- int y_stride = width * TILE_DIM;
- int pos = blockIdx.x * TILE_DIM + threadIdx.y;
- // Loop across matrix height
- float betta_tmp = 0;
- float gamma_tmp = 0;
- for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
- float grad = (float)out_grad[offset];
- float val = (float)X_data[offset];
- val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
- betta_tmp += grad;
- gamma_tmp += (val * grad);
- offset += y_stride;
- }
- betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
- gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
- __syncthreads();
- // Sum the shared buffer.
- float s1 = betta_buffer[threadIdx.y][threadIdx.x];
- float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < TILE_DIM; i <<= 1) {
- s1 += g.shfl_down(s1, i);
- s2 += g.shfl_down(s2, i);
- }
- if (threadIdx.x == 0) {
- betta_grad[pos] = s1;
- gamma_grad[pos] = s2;
- }
- }
- /*
- /* Backward Normalize (Input-Gradient)
- * Using the means and variances from the input
- * This type of backward is invertible!
- * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
- */
- __global__ void LayerNormBackward2(const float* out_grad,
- const float* vals_hat,
- const float* gamma,
- const float* betta,
- const float* vars,
- float* inp_grad,
- bool invertible,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- out_grad += (row * row_stride);
- vals_hat += (row * row_stride);
- inp_grad += (row * row_stride);
- float vals_arr[NORM_REG];
- float vals_hat_arr[NORM_REG];
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- float gamma_reg = gamma[i * iteration_stride + id];
- vals_arr[i] = out_grad[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- vals_hat_arr[i] =
- (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
- gamma_reg
- : vals_hat[i * iteration_stride + id]);
- }
- if ((high_index) < row_stride) {
- float gamma_reg = gamma[high_index];
- vals_arr[iterations] = out_grad[high_index];
- vals_arr[iterations] *= gamma_reg;
- vals_hat_arr[iterations] =
- (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
- : vals_hat[high_index]);
- iterations++;
- }
- float var_reg = vars[row];
- float sum = 0;
- for (int i = 0; i < iterations; i++) {
- sum += vals_hat_arr[i] * vals_arr[i] *
- sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
- vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
- sum = 0;
- for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
- if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
- }
- __global__ void LayerNormBackward2(const __half* out_grad,
- const __half* vals_hat,
- const __half* gamma,
- const __half* betta,
- const __half* vars,
- __half* inp_grad,
- bool invertible,
- int row_stride)
- {
- #ifdef HALF_PRECISION_AVAILABLE
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- __half2 vals_arr[NORM_REG];
- float2 vals_arr_f[NORM_REG];
- __half2 vals_hat_arr[NORM_REG];
- __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
- const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
- const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
- inp_grad_h += (row * row_stride);
- out_grad_h += (row * row_stride);
- vals_hat_h += (row * row_stride);
- const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
- const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- __half2 gamma_reg = gamma_h[i * iteration_stride + id];
- vals_arr[i] = out_grad_h[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- vals_hat_arr[i] =
- (invertible
- ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
- gamma_reg
- : vals_hat_h[i * iteration_stride + id]);
- }
- if ((high_index) < row_stride) {
- __half2 gamma_reg = gamma_h[high_index];
- vals_arr[iterations] = out_grad_h[high_index];
- vals_arr[iterations] *= gamma_reg;
- vals_hat_arr[iterations] =
- (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
- : vals_hat_h[high_index]);
- iterations++;
- }
- __half var_h = vars[row];
- __half2 var_reg = __halves2half2(var_h, var_h);
- float sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
- float2 result_f = __half22float2(result_h);
- sum += result_f.x;
- sum += result_f.y;
- vals_arr[i] *= h2rsqrt(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- __half2 sum_h = __float2half2_rn(sum);
- for (int i = 0; i < iterations; i++) {
- __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
- vals_arr_f[i] = __half22float2(vals_arr[i]);
- float2 temp_f = __half22float2(temp);
- vals_arr_f[i].x += temp_f.x;
- vals_arr_f[i].y += temp_f.y;
- }
- sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- sum += (vals_arr_f[i].x);
- sum += (vals_arr_f[i].y);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr_f[i].x -= sum;
- vals_arr_f[i].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[i]);
- inp_grad_h[i * iteration_stride + id] = temp;
- }
- if ((high_index) < row_stride) {
- vals_arr_f[iterations].x -= sum;
- vals_arr_f[iterations].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
- inp_grad_h[high_index] = temp;
- }
- #endif
- }
- template <>
- void launch_layerNorm_backward<float>(const float* out_grad,
- const float* vals_hat,
- const float* vars,
- const float* gamma,
- float* gamma_grad,
- float* betta_grad,
- float* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible,
- const float* betta)
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
- dim3 grid_dim2(batch);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads);
- LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
- }
- template <>
- void launch_layerNorm_backward<__half>(const __half* out_grad,
- const __half* vals_hat,
- const __half* vars,
- const __half* gamma,
- __half* gamma_grad,
- __half* betta_grad,
- __half* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible,
- const __half* betta)
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- // LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
- // out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
- dim3 grid_dim2(batch);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads / 2);
- LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
- }
- /* Backward Normalize (Input-Gradient)
- * Using the means and variances from the input
- * This type of backward is not invertible!
- * We do the backward using the input (X)
- */
- __global__ void LayerNormBackward2(const float* out_grad,
- const float* X_vals,
- const float* gamma,
- const float* vars,
- const float* means,
- float* inp_grad,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id >> WARP_SIZE_BITS;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- out_grad += (row * row_stride);
- X_vals += (row * row_stride);
- inp_grad += (row * row_stride);
- float vals_arr[NORM_REG];
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- float gamma_reg = gamma[i * iteration_stride + id];
- vals_arr[i] = out_grad[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- }
- if ((high_index) < row_stride) {
- float gamma_reg = gamma[high_index];
- vals_arr[iterations] = out_grad[high_index];
- vals_arr[iterations] *= gamma_reg;
- iterations++;
- }
- float var_reg = vars[row];
- float mean_reg = means[row];
- float sum = 0;
- float xu[NORM_REG];
- for (int i = 0; i < iterations; i++) {
- xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
- sum += vals_arr[i] * xu[i];
- vals_arr[i] *= rsqrtf(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
- }
- sum = 0;
- for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
- if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
- }
- __global__ void LayerNormBackward2(const __half* out_grad,
- const __half* X_vals,
- const __half* gamma,
- const __half* vars,
- const __half* means,
- __half* inp_grad,
- int row_stride)
- {
- #ifdef HALF_PRECISION_AVAILABLE
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id >> WARP_SIZE_BITS;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- __half2 vals_arr[NORM_REG];
- float2 vals_arr_f[NORM_REG];
- __half2 xu[NORM_REG];
- __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
- const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
- const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
- inp_grad_h += (row * row_stride);
- out_grad_h += (row * row_stride);
- vals_hat_h += (row * row_stride);
- const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
- int high_index = iterations * iteration_stride + id;
- __half mean_h = means[row];
- __half2 mean_reg = __halves2half2(mean_h, mean_h);
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- __half2 gamma_reg = gamma_h[i * iteration_stride + id];
- vals_arr[i] = out_grad_h[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg; // out_grad * gamma
- xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
- }
- if ((high_index) < row_stride) {
- __half2 gamma_reg = gamma_h[high_index];
- vals_arr[iterations] = out_grad_h[high_index];
- vals_arr[iterations] *= gamma_reg; // out_grad * gamma
- xu[iterations] = (vals_hat_h[high_index] - mean_reg);
- iterations++;
- }
- __half var_h = vars[row];
- __half2 var_reg = __halves2half2(var_h, var_h);
- float sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- __half2 result_h = (xu[i] * vals_arr[i]);
- float2 result_f = __half22float2(result_h);
- sum += result_f.x;
- sum += result_f.y;
- vals_arr[i] *= h2rsqrt(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- __half2 sum_h = __float2half2_rn(sum);
- for (int i = 0; i < iterations; i++) {
- __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
- vals_arr_f[i] = __half22float2(vals_arr[i]);
- float2 xu_grad_f = __half22float2(xu_grad);
- vals_arr_f[i].x += xu_grad_f.x;
- vals_arr_f[i].y += xu_grad_f.y;
- }
- sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- sum += (vals_arr_f[i].x);
- sum += (vals_arr_f[i].y);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr_f[i].x -= sum;
- vals_arr_f[i].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[i]);
- inp_grad_h[i * iteration_stride + id] = temp;
- }
- if ((high_index) < row_stride) {
- vals_arr_f[iterations].x -= sum;
- vals_arr_f[iterations].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
- inp_grad_h[high_index] = temp;
- }
- #endif
- }
- template <>
- void launch_layerNorm_backward<float>(const float* out_grad,
- const float* X_data,
- const float* vars,
- const float* means,
- const float* gamma,
- float* gamma_grad,
- float* betta_grad,
- float* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2])
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
- dim3 grid_dim2(batch);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads);
- LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
- }
- template <>
- void launch_layerNorm_backward<__half>(const __half* out_grad,
- const __half* X_data,
- const __half* vars,
- const __half* means,
- const __half* gamma,
- __half* gamma_grad,
- __half* betta_grad,
- __half* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2])
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
- dim3 grid_dim2(batch);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads / 2);
- LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
- }
- template <typename T>
- __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
- const T* __restrict__ out_grad2,
- const T* __restrict__ vals_hat,
- const T* __restrict__ gamma,
- const T* __restrict__ betta,
- T* __restrict__ gamma_grad,
- T* __restrict__ betta_grad,
- int rows,
- int width,
- bool invertible)
- {
- __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
- __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
- int idx = blockDim.x * blockIdx.x + threadIdx.x;
- int offset = threadIdx.y * width + idx;
- int y_stride = width * TILE_DIM;
- float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
- float gamma_reg = (float)gamma[idx];
- // Loop across matrix height
- float betta_tmp = 0;
- float gamma_tmp = 0;
- for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
- float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
- float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
- : (float)vals_hat[offset]);
- betta_tmp += grad;
- gamma_tmp += (val * grad);
- offset += y_stride;
- }
- betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
- gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
- __syncthreads();
- // Sum the shared buffer.
- float s1 = betta_buffer[threadIdx.y][threadIdx.x];
- float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < TILE_DIM; i <<= 1) {
- s1 += g.shfl_down(s1, i);
- s2 += g.shfl_down(s2, i);
- }
- if (threadIdx.x == 0) {
- int pos = blockIdx.x * TILE_DIM + threadIdx.y;
- betta_grad[pos] = s1;
- gamma_grad[pos] = s2;
- }
- }
- template <typename T>
- __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
- const T* __restrict__ out_grad2,
- const T* __restrict__ X_data,
- const T* __restrict__ vars,
- const T* __restrict__ means,
- T* __restrict__ gamma_grad,
- T* __restrict__ betta_grad,
- int rows,
- int width)
- {
- __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
- __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
- int idx = blockDim.x * blockIdx.x + threadIdx.x;
- int offset = threadIdx.y * width + idx;
- int y_stride = width * TILE_DIM;
- int pos = blockIdx.x * TILE_DIM + threadIdx.y;
- // Loop across matrix height
- float betta_tmp = 0;
- float gamma_tmp = 0;
- for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
- float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
- float val = (float)X_data[offset];
- val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
- betta_tmp += grad;
- gamma_tmp += (val * grad);
- offset += y_stride;
- }
- betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
- gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
- __syncthreads();
- // Sum the shared buffer.
- float s1 = betta_buffer[threadIdx.y][threadIdx.x];
- float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < TILE_DIM; i <<= 1) {
- s1 += g.shfl_down(s1, i);
- s2 += g.shfl_down(s2, i);
- }
- if (threadIdx.x == 0) {
- betta_grad[pos] = s1;
- gamma_grad[pos] = s2;
- }
- }
- __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
- const float* out_grad2,
- const float* vals_hat,
- const float* gamma,
- const float* betta,
- const float* vars,
- float* inp_grad,
- bool invertible,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- out_grad1 += (row * row_stride);
- out_grad2 += (row * row_stride);
- vals_hat += (row * row_stride);
- inp_grad += (row * row_stride);
- float vals_arr[NORM_REG];
- float vals_hat_arr[NORM_REG];
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- float gamma_reg = gamma[i * iteration_stride + id];
- vals_arr[i] = out_grad1[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- vals_hat_arr[i] =
- (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
- gamma_reg
- : vals_hat[i * iteration_stride + id]);
- }
- if ((high_index) < row_stride) {
- float gamma_reg = gamma[high_index];
- vals_arr[iterations] = out_grad1[high_index];
- vals_arr[iterations] *= gamma_reg;
- vals_hat_arr[iterations] =
- (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
- : vals_hat[high_index]);
- iterations++;
- }
- float var_reg = vars[row];
- float sum = 0;
- for (int i = 0; i < iterations; i++) {
- sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
- vals_arr[i] *= rsqrtf(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
- sum = 0;
- for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++)
- inp_grad[i * iteration_stride + id] =
- (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
- if ((high_index) < row_stride)
- inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
- }
- __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
- const __half* out_grad2,
- const __half* vals_hat,
- const __half* gamma,
- const __half* betta,
- const __half* vars,
- __half* inp_grad,
- bool invertible,
- int row_stride)
- {
- #ifdef HALF_PRECISION_AVAILABLE
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- __half2 vals_arr[NORM_REG];
- float2 vals_arr_f[NORM_REG];
- __half2 vals_hat_arr[NORM_REG];
- // float2 result[iterations];
- __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
- const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
- const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
- const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
- inp_grad_h += (row * row_stride);
- out_grad_h1 += (row * row_stride);
- out_grad_h2 += (row * row_stride);
- vals_hat_h += (row * row_stride);
- const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
- const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- __half2 gamma_reg = gamma_h[i * iteration_stride + id];
- vals_arr[i] = out_grad_h1[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg; // out_grad * gamma
- vals_hat_arr[i] =
- (invertible
- ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
- gamma_reg
- : vals_hat_h[i * iteration_stride + id]);
- }
- if ((high_index) < row_stride) {
- __half2 gamma_reg = gamma_h[high_index];
- vals_arr[iterations] = out_grad_h1[high_index];
- vals_arr[iterations] *= gamma_reg; // out_grad * gamma
- vals_hat_arr[iterations] =
- (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
- : vals_hat_h[high_index]);
- iterations++;
- }
- __half var_h = vars[row];
- __half2 var_reg = __halves2half2(var_h, var_h);
- float sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
- float2 result_f = __half22float2(result_h);
- sum += result_f.x;
- sum += result_f.y;
- vals_arr[i] *= h2rsqrt(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- __half2 sum_h = __float2half2_rn(sum);
- for (int i = 0; i < iterations; i++) {
- __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
- vals_arr_f[i] = __half22float2(vals_arr[i]);
- float2 temp_f = __half22float2(temp);
- vals_arr_f[i].x += temp_f.x;
- vals_arr_f[i].y += temp_f.y;
- }
- sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- sum += (vals_arr_f[i].x);
- sum += (vals_arr_f[i].y);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr_f[i].x -= sum;
- vals_arr_f[i].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[i]);
- inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
- }
- if ((high_index) < row_stride) {
- vals_arr_f[iterations].x -= sum;
- vals_arr_f[iterations].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
- inp_grad_h[high_index] = temp + out_grad_h2[high_index];
- }
- #endif
- }
- template <>
- void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
- const float* out_grad2,
- const float* vals_hat,
- const float* vars,
- const float* gamma,
- float* gamma_grad,
- float* betta_grad,
- float* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible,
- const float* betta)
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
- dim3 grid_dim2(batch);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads);
- LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
- }
- template <>
- void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
- const __half* out_grad2,
- const __half* vals_hat,
- const __half* vars,
- const __half* gamma,
- __half* gamma_grad,
- __half* betta_grad,
- __half* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2],
- bool invertible,
- const __half* betta)
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
- dim3 grid_dim2(batch);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads / 2);
- LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
- }
- /* Backward Normalize (Input-Gradient)
- * Using the means and variances from the input
- * This type of backward is not invertible!
- * We do the backward using the input (X)
- */
- __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
- const float* out_grad2,
- const float* X_vals,
- const float* gamma,
- const float* vars,
- const float* means,
- float* inp_grad,
- int row_stride)
- {
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- float vals_arr[NORM_REG];
- float vals_hat_arr[NORM_REG];
- out_grad1 += (row * row_stride);
- out_grad2 += (row * row_stride);
- X_vals += (row * row_stride);
- inp_grad += (row * row_stride);
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- float gamma_reg = gamma[i * iteration_stride + id];
- vals_arr[i] = out_grad1[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg;
- vals_hat_arr[i] = X_vals[i * iteration_stride + id];
- }
- if ((high_index) < row_stride) {
- float gamma_reg = gamma[high_index];
- vals_arr[iterations] = out_grad1[high_index];
- vals_arr[iterations] *= gamma_reg;
- vals_hat_arr[iterations] = X_vals[high_index];
- iterations++;
- }
- float var_reg = vars[row];
- float mean_reg = means[row];
- float sum = 0;
- float xu[NORM_REG];
- for (int i = 0; i < iterations; i++) {
- xu[i] = (vals_hat_arr[i] - mean_reg);
- sum += vals_arr[i] * xu[i];
- vals_arr[i] *= rsqrtf(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
- }
- sum = 0;
- for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= row_stride;
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++)
- inp_grad[i * iteration_stride + id] =
- (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
- if ((high_index) < row_stride)
- inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
- }
- __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
- const __half* out_grad2,
- const __half* X_vals,
- const __half* gamma,
- const __half* vars,
- const __half* means,
- __half* inp_grad,
- int row_stride)
- {
- #ifdef HALF_PRECISION_AVAILABLE
- int iteration_stride = blockDim.x;
- int iterations = row_stride / iteration_stride;
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
- int row = blockIdx.x;
- int id = threadIdx.x;
- int wid = id / WARP_SIZE;
- int warp_num = iteration_stride >> WARP_SIZE_BITS;
- __shared__ float partialSum[MAX_WARP_NUM];
- __half2 vals_arr[NORM_REG];
- float2 vals_arr_f[NORM_REG];
- __half2 vals_hat_arr[NORM_REG];
- __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
- const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
- const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
- const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
- out_grad_h1 += (row * row_stride);
- out_grad_h2 += (row * row_stride);
- inp_grad_h += (row * row_stride);
- vals_hat_h += (row * row_stride);
- const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
- int high_index = iterations * iteration_stride + id;
- #pragma unroll
- for (int i = 0; i < iterations; i++) {
- __half2 gamma_reg = gamma_h[i * iteration_stride + id];
- vals_arr[i] = out_grad_h1[i * iteration_stride + id];
- vals_arr[i] *= gamma_reg; // out_grad * gamma
- vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
- }
- if ((high_index) < row_stride) {
- __half2 gamma_reg = gamma_h[high_index];
- vals_arr[iterations] = out_grad_h1[high_index];
- vals_arr[iterations] *= gamma_reg; // out_grad * gamma
- vals_hat_arr[iterations] = vals_hat_h[high_index];
- iterations++;
- }
- __half mean_h = means[row];
- __half var_h = vars[row];
- __half2 var_reg = __halves2half2(var_h, var_h);
- __half2 mean_reg = __halves2half2(mean_h, mean_h);
- __half2 xu[NORM_REG];
- float sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- xu[i] = (vals_hat_arr[i] - mean_reg);
- __half2 result_h = (xu[i] * vals_arr[i]);
- float2 result_f = __half22float2(result_h);
- sum += result_f.x;
- sum += result_f.y;
- vals_arr[i] *= h2rsqrt(var_reg);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- __half2 sum_h = __float2half2_rn(sum);
- for (int i = 0; i < iterations; i++) {
- __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
- vals_arr_f[i] = __half22float2(vals_arr[i]);
- float2 xu_grad_f = __half22float2(xu_grad);
- vals_arr_f[i].x += xu_grad_f.x;
- vals_arr_f[i].y += xu_grad_f.y;
- }
- sum = 0.f;
- for (int i = 0; i < iterations; i++) {
- sum += (vals_arr_f[i].x);
- sum += (vals_arr_f[i].y);
- }
- for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
- if (g.thread_rank() == 0) partialSum[wid] = sum;
- __syncthreads();
- if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
- #ifndef __STOCHASTIC_MODE__
- __syncthreads();
- #endif
- for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
- sum = g.shfl(sum, 0);
- sum /= (2 * row_stride);
- iterations = row_stride / iteration_stride;
- for (int i = 0; i < iterations; i++) {
- vals_arr_f[i].x -= sum;
- vals_arr_f[i].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[i]);
- inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
- }
- if ((high_index) < row_stride) {
- vals_arr_f[iterations].x -= sum;
- vals_arr_f[iterations].y -= sum;
- __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
- inp_grad_h[high_index] = temp + out_grad_h2[high_index];
- }
- #endif
- }
- template <>
- void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
- const float* out_grad2,
- const float* X_data,
- const float* vars,
- const float* means,
- const float* gamma,
- float* gamma_grad,
- float* betta_grad,
- float* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2])
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
- dim3 grid_dim2(batch);
- if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 1;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 2;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads);
- LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
- }
- template <>
- void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
- const __half* out_grad2,
- const __half* X_data,
- const __half* vars,
- const __half* means,
- const __half* gamma,
- __half* gamma_grad,
- __half* betta_grad,
- __half* inp_grad,
- int batch,
- int hidden_dim,
- cudaStream_t stream[2])
- {
- int threads = THREADS;
- dim3 grid_dim(hidden_dim / TILE_DIM);
- dim3 block_dim(TILE_DIM, TILE_DIM);
- LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
- out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
- dim3 grid_dim2(batch);
- if (hidden_dim > 8192 && hidden_dim <= 16384)
- threads <<= 1;
- else if (hidden_dim > 16384 && hidden_dim <= 32768)
- threads <<= 2;
- else if (hidden_dim > 32768 && hidden_dim <= 65536)
- threads <<= 3;
- else if (hidden_dim > 65536)
- throw std::runtime_error("Unsupport hidden_dim.");
- dim3 block_dim2(threads / 2);
- LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
- out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
- }
|