normalize_kernels.cu 73 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "custom_cuda_layers.h"
  5. namespace cg = cooperative_groups;
  6. /*
  7. Fused bias add, residual (elementwise) add, and normalization layer.
  8. For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
  9. __half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
  10. For specific launch constraints, see the launch functions.
  11. */
  12. #define NORM_REG (MAX_REGISTERS / 4)
  13. __global__ void fused_bias_residual_layer_norm(float* vals,
  14. const float* residual,
  15. const float* gamma,
  16. const float* beta,
  17. float epsilon,
  18. bool preLayerNorm,
  19. bool training,
  20. float* vars,
  21. float* means,
  22. int row_stride)
  23. {
  24. int iteration_stride = blockDim.x;
  25. int iterations = row_stride / iteration_stride;
  26. cg::thread_block b = cg::this_thread_block();
  27. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  28. int row = blockIdx.x;
  29. int id = threadIdx.x;
  30. int gid = id / WARP_SIZE;
  31. float vals_arr[NORM_REG];
  32. __shared__ float shr[MAX_WARP_NUM];
  33. residual += (row * row_stride);
  34. vals += (row * row_stride);
  35. float sum = 0.f;
  36. int high_index = iterations * iteration_stride + id;
  37. #pragma unroll
  38. for (int i = 0; i < iterations; i++) {
  39. vals_arr[i] = residual[i * iteration_stride + id];
  40. sum += vals_arr[i];
  41. }
  42. if (high_index < row_stride) {
  43. vals_arr[iterations] = residual[high_index];
  44. sum += vals_arr[iterations];
  45. iterations++;
  46. }
  47. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  48. if (g.thread_rank() == 0) shr[gid] = sum;
  49. b.sync();
  50. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  51. #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
  52. b.sync();
  53. #endif
  54. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  55. sum += g.shfl_down(sum, i);
  56. }
  57. sum = g.shfl(sum, 0);
  58. float mean = sum / row_stride;
  59. if (training)
  60. if (threadIdx.x == 0) means[row] = mean;
  61. float variance = 0.f;
  62. for (int i = 0; i < iterations; i++) {
  63. vals_arr[i] -= mean;
  64. variance += vals_arr[i] * vals_arr[i];
  65. }
  66. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  67. if (g.thread_rank() == 0) shr[gid] = variance;
  68. b.sync();
  69. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  70. #ifndef __STOCHASTIC_MODE__
  71. b.sync();
  72. #endif
  73. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  74. variance += g.shfl_down(variance, i);
  75. }
  76. variance = g.shfl(variance, 0);
  77. variance /= row_stride;
  78. variance += epsilon;
  79. if (training)
  80. if (threadIdx.x == 0) vars[row] = variance;
  81. iterations = row_stride / iteration_stride;
  82. for (int i = 0; i < iterations; i++) {
  83. vals_arr[i] = vals_arr[i] * rsqrtf(variance);
  84. vals_arr[i] =
  85. vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
  86. vals[i * iteration_stride + id] = vals_arr[i];
  87. }
  88. if ((high_index) < row_stride) {
  89. vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
  90. vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
  91. vals[high_index] = vals_arr[iterations];
  92. }
  93. }
  94. __global__ void fused_bias_residual_layer_norm(__half* vals,
  95. const __half* residual,
  96. const __half* gamma,
  97. const __half* beta,
  98. float epsilon,
  99. bool preLayerNorm,
  100. bool training,
  101. __half* vars,
  102. __half* means,
  103. int row_stride)
  104. {
  105. #ifdef HALF_PRECISION_AVAILABLE
  106. int iteration_stride = blockDim.x;
  107. int iterations = row_stride / iteration_stride;
  108. cg::thread_block b = cg::this_thread_block();
  109. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  110. int row = blockIdx.x;
  111. int id = threadIdx.x;
  112. int gid = id >> WARP_SIZE_BITS;
  113. float2 vals_f[NORM_REG];
  114. __shared__ float shr[MAX_WARP_NUM];
  115. __half2* vals_cast = reinterpret_cast<__half2*>(vals);
  116. const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
  117. residual_cast += (row * row_stride);
  118. vals_cast += (row * row_stride);
  119. float sum = 0.f;
  120. int high_index = iterations * iteration_stride + id;
  121. #pragma unroll
  122. for (int i = 0; i < iterations; i++) {
  123. vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
  124. sum += vals_f[i].x;
  125. sum += vals_f[i].y;
  126. }
  127. if ((high_index) < row_stride) {
  128. vals_f[iterations] = __half22float2(residual_cast[high_index]);
  129. sum += vals_f[iterations].x;
  130. sum += vals_f[iterations].y;
  131. iterations++;
  132. }
  133. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  134. if (g.thread_rank() == 0) shr[gid] = sum;
  135. b.sync();
  136. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  137. #ifndef __STOCHASTIC_MODE__
  138. b.sync();
  139. #endif
  140. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  141. sum += g.shfl_down(sum, i);
  142. }
  143. sum = g.shfl(sum, 0);
  144. float mean = sum / (row_stride * 2);
  145. float variance = 0.f;
  146. for (int i = 0; i < iterations; i++) {
  147. vals_f[i].x -= mean;
  148. vals_f[i].y -= mean;
  149. variance += vals_f[i].x * vals_f[i].x;
  150. variance += vals_f[i].y * vals_f[i].y;
  151. }
  152. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  153. if (g.thread_rank() == 0) shr[gid] = variance;
  154. b.sync();
  155. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  156. #ifndef __STOCHASTIC_MODE__
  157. b.sync();
  158. #endif
  159. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  160. variance += g.shfl_down(variance, i);
  161. }
  162. variance = g.shfl(variance, 0);
  163. variance /= (row_stride * 2);
  164. variance += epsilon;
  165. __half2 variance_h = __float2half2_rn(variance);
  166. const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
  167. const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
  168. if (training && threadIdx.x == 0) {
  169. vars[row] = __float2half(variance);
  170. means[row] = __float2half(mean);
  171. }
  172. iterations = row_stride / iteration_stride;
  173. for (int i = 0; i < iterations; i++) {
  174. __half2 vals_arr = __float22half2_rn(vals_f[i]);
  175. vals_arr = vals_arr * h2rsqrt(variance_h);
  176. vals_arr =
  177. vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
  178. vals_cast[i * iteration_stride + id] = vals_arr;
  179. }
  180. if ((high_index) < row_stride) {
  181. __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
  182. vals_arr = vals_arr * h2rsqrt(variance_h);
  183. vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
  184. vals_cast[high_index] = vals_arr;
  185. }
  186. #endif
  187. }
  188. template <typename T>
  189. void launch_bias_residual_layer_norm(T* vals,
  190. const T* residual,
  191. const T* gamma,
  192. const T* beta,
  193. float epsilon,
  194. int batch_size,
  195. int hidden_dim,
  196. cudaStream_t stream,
  197. bool preLayerNorm,
  198. bool training,
  199. T* vars,
  200. T* means);
  201. template <>
  202. void launch_bias_residual_layer_norm<float>(float* vals,
  203. const float* residual,
  204. const float* gamma,
  205. const float* beta,
  206. float epsilon,
  207. int batch_size,
  208. int hidden_dim,
  209. cudaStream_t stream,
  210. bool preLayerNorm,
  211. bool training,
  212. float* vars,
  213. float* means)
  214. {
  215. int threads = THREADS;
  216. dim3 grid_dim(batch_size);
  217. if (hidden_dim > 16384 && hidden_dim <= 32768)
  218. threads <<= 1;
  219. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  220. threads <<= 2;
  221. else if (hidden_dim > 65536)
  222. throw std::runtime_error("Unsupport hidden_dim.");
  223. dim3 block_dim(threads);
  224. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  225. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
  226. }
  227. template <>
  228. void launch_bias_residual_layer_norm<__half>(__half* vals,
  229. const __half* residual,
  230. const __half* gamma,
  231. const __half* beta,
  232. float epsilon,
  233. int batch_size,
  234. int hidden_dim,
  235. cudaStream_t stream,
  236. bool preLayerNorm,
  237. bool training,
  238. __half* vars,
  239. __half* means)
  240. {
  241. int threads = 128;
  242. dim3 grid_dim(batch_size);
  243. if (hidden_dim > 8192 && hidden_dim <= 16384)
  244. threads <<= 1;
  245. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  246. threads <<= 2;
  247. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  248. threads <<= 3;
  249. else if (hidden_dim > 65536)
  250. throw std::runtime_error("Unsupport hidden_dim.");
  251. dim3 block_dim(threads);
  252. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  253. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
  254. }
  255. __global__ void fused_bias_residual_layer_norm(float* vals,
  256. const float* residual,
  257. const float* gamma,
  258. const float* beta,
  259. float epsilon,
  260. bool preLayerNorm,
  261. bool training,
  262. float* vars,
  263. int row_stride)
  264. {
  265. int iteration_stride = blockDim.x;
  266. int iterations = row_stride / iteration_stride;
  267. cg::thread_block b = cg::this_thread_block();
  268. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  269. int row = blockIdx.x;
  270. int id = threadIdx.x;
  271. int gid = id / 32;
  272. float vals_arr[NORM_REG];
  273. __shared__ float shr[MAX_WARP_NUM];
  274. residual += (row * row_stride);
  275. vals += (row * row_stride);
  276. float sum = 0.f;
  277. int high_index = iterations * iteration_stride + id;
  278. #pragma unroll
  279. for (int i = 0; i < iterations; i++) {
  280. vals_arr[i] = residual[i * iteration_stride + id];
  281. sum += vals_arr[i];
  282. }
  283. if ((high_index) < row_stride) {
  284. vals_arr[iterations] = residual[high_index];
  285. sum += vals_arr[iterations];
  286. iterations++;
  287. }
  288. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  289. if (g.thread_rank() == 0) shr[gid] = sum;
  290. b.sync();
  291. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  292. #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
  293. b.sync();
  294. #endif
  295. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  296. sum += g.shfl_down(sum, i);
  297. }
  298. sum = g.shfl(sum, 0);
  299. float mean = sum / row_stride;
  300. float variance = 0.f;
  301. for (int i = 0; i < iterations; i++) {
  302. vals_arr[i] -= mean;
  303. variance += vals_arr[i] * vals_arr[i];
  304. }
  305. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  306. if (g.thread_rank() == 0) shr[gid] = variance;
  307. b.sync();
  308. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  309. #ifndef __STOCHASTIC_MODE__
  310. b.sync();
  311. #endif
  312. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  313. variance += g.shfl_down(variance, i);
  314. }
  315. variance = g.shfl(variance, 0);
  316. variance /= row_stride;
  317. variance += epsilon;
  318. if (training)
  319. if (threadIdx.x == 0) vars[row] = variance;
  320. iterations = row_stride / iteration_stride;
  321. for (int i = 0; i < iterations; i++) {
  322. vals_arr[i] = vals_arr[i] * rsqrtf(variance);
  323. vals_arr[i] =
  324. vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
  325. vals[i * iteration_stride + id] = vals_arr[i];
  326. }
  327. if ((high_index) < row_stride) {
  328. vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
  329. vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
  330. vals[high_index] = vals_arr[iterations];
  331. }
  332. }
  333. __global__ void fused_bias_residual_layer_norm(__half* vals,
  334. const __half* residual,
  335. const __half* gamma,
  336. const __half* beta,
  337. float epsilon,
  338. bool preLayerNorm,
  339. bool training,
  340. __half* vars,
  341. int row_stride)
  342. {
  343. #ifdef HALF_PRECISION_AVAILABLE
  344. int iteration_stride = blockDim.x;
  345. int iterations = row_stride / iteration_stride;
  346. cg::thread_block b = cg::this_thread_block();
  347. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  348. int row = blockIdx.x;
  349. int id = threadIdx.x;
  350. int gid = id >> WARP_SIZE_BITS;
  351. float2 vals_f[NORM_REG];
  352. __shared__ float shr[MAX_WARP_NUM];
  353. __half2* vals_cast = reinterpret_cast<__half2*>(vals);
  354. const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
  355. residual_cast += (row * row_stride);
  356. vals_cast += (row * row_stride);
  357. float sum = 0.f;
  358. int high_index = iterations * iteration_stride + id;
  359. #pragma unroll
  360. for (int i = 0; i < iterations; i++) {
  361. vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
  362. sum += vals_f[i].x;
  363. sum += vals_f[i].y;
  364. }
  365. if ((high_index) < row_stride) {
  366. vals_f[iterations] = __half22float2(residual_cast[high_index]);
  367. sum += vals_f[iterations].x;
  368. sum += vals_f[iterations].y;
  369. iterations++;
  370. }
  371. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  372. if (g.thread_rank() == 0) shr[gid] = sum;
  373. b.sync();
  374. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  375. #ifndef __STOCHASTIC_MODE__
  376. b.sync();
  377. #endif
  378. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  379. sum += g.shfl_down(sum, i);
  380. }
  381. sum = g.shfl(sum, 0);
  382. float mean = sum / (row_stride * 2);
  383. float variance = 0.f;
  384. for (int i = 0; i < iterations; i++) {
  385. vals_f[i].x -= mean;
  386. vals_f[i].y -= mean;
  387. variance += vals_f[i].x * vals_f[i].x;
  388. variance += vals_f[i].y * vals_f[i].y;
  389. }
  390. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  391. if (g.thread_rank() == 0) shr[gid] = variance;
  392. b.sync();
  393. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  394. #ifndef __STOCHASTIC_MODE__
  395. b.sync();
  396. #endif
  397. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  398. variance += g.shfl_down(variance, i);
  399. }
  400. variance = g.shfl(variance, 0);
  401. variance /= (row_stride * 2);
  402. variance += epsilon;
  403. __half2 variance_h = __float2half2_rn(variance);
  404. const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
  405. const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
  406. if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
  407. iterations = row_stride / iteration_stride;
  408. for (int i = 0; i < iterations; i++) {
  409. __half2 vals_arr = __float22half2_rn(vals_f[i]);
  410. vals_arr = vals_arr * h2rsqrt(variance_h);
  411. vals_arr =
  412. vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
  413. vals_cast[i * iteration_stride + id] = vals_arr;
  414. }
  415. if ((high_index) < row_stride) {
  416. __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
  417. vals_arr = vals_arr * h2rsqrt(variance_h);
  418. vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
  419. vals_cast[high_index] = vals_arr;
  420. }
  421. #endif
  422. }
  423. template <typename T>
  424. void launch_bias_residual_layer_norm(T* vals,
  425. const T* residual,
  426. const T* gamma,
  427. const T* beta,
  428. float epsilon,
  429. int batch_size,
  430. int hidden_dim,
  431. cudaStream_t stream,
  432. bool preLayerNorm,
  433. bool training,
  434. T* vars);
  435. /*
  436. To tune this launch the following restrictions must be met:
  437. For float:
  438. row_stride == hidden_size
  439. threads * iterations == row_stride
  440. threads is in [32, 64, 128, 256, 512, 1024]
  441. For half:
  442. row_stride == hidden_size / 2
  443. threads * iterations == row_stride
  444. threads is in [32, 64, 128, 256, 512, 1024]
  445. */
  446. template <>
  447. void launch_bias_residual_layer_norm<float>(float* vals,
  448. const float* residual,
  449. const float* gamma,
  450. const float* beta,
  451. float epsilon,
  452. int batch_size,
  453. int hidden_dim,
  454. cudaStream_t stream,
  455. bool preLayerNorm,
  456. bool training,
  457. float* vars)
  458. {
  459. int threads = THREADS;
  460. dim3 grid_dim(batch_size);
  461. // There are some limitations to call below functions, now just enumerate the situations.
  462. if (hidden_dim > 16384 && hidden_dim <= 32768)
  463. threads <<= 1;
  464. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  465. threads <<= 2;
  466. else if (hidden_dim > 65536)
  467. throw std::runtime_error("Unsupport hidden_dim.");
  468. dim3 block_dim(threads);
  469. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  470. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
  471. }
  472. template <>
  473. void launch_bias_residual_layer_norm<__half>(__half* vals,
  474. const __half* residual,
  475. const __half* gamma,
  476. const __half* beta,
  477. float epsilon,
  478. int batch_size,
  479. int hidden_dim,
  480. cudaStream_t stream,
  481. bool preLayerNorm,
  482. bool training,
  483. __half* vars)
  484. {
  485. int threads = 128;
  486. dim3 grid_dim(batch_size);
  487. // There are some limitations to call below functions, now just enumerate the situations.
  488. if (hidden_dim > 8192 && hidden_dim <= 16384)
  489. threads <<= 1;
  490. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  491. threads <<= 2;
  492. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  493. threads <<= 3;
  494. else if (hidden_dim > 65536)
  495. throw std::runtime_error("Unsupport hidden_dim.");
  496. dim3 block_dim(threads);
  497. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  498. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
  499. }
  500. /* Normalize Gamma & Betta gradients
  501. * Compute gradients using either X_hat or
  502. * normalize input (invertible).
  503. * Combine transpose with gradients computation.
  504. */
  505. template <typename T>
  506. __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
  507. const T* __restrict__ vals_hat,
  508. const T* __restrict__ gamma,
  509. const T* __restrict__ betta,
  510. T* __restrict__ gamma_grad,
  511. T* __restrict__ betta_grad,
  512. int rows,
  513. int width,
  514. bool invertible)
  515. {
  516. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  517. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  518. cg::thread_block b = cg::this_thread_block();
  519. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  520. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  521. int offset = threadIdx.y * width + idx;
  522. int y_stride = width * TILE_DIM;
  523. float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
  524. float gamma_reg = (float)gamma[idx];
  525. // Loop across matrix height
  526. float betta_tmp = 0;
  527. float gamma_tmp = 0;
  528. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  529. float grad = (float)out_grad[offset];
  530. float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
  531. : (float)vals_hat[offset]);
  532. betta_tmp += grad;
  533. gamma_tmp += (val * grad);
  534. offset += y_stride;
  535. }
  536. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  537. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  538. __syncthreads();
  539. // Sum the shared buffer.
  540. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  541. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  542. #ifndef __STOCHASTIC_MODE__
  543. __syncthreads();
  544. #endif
  545. for (int i = 1; i < TILE_DIM; i <<= 1) {
  546. s1 += g.shfl_down(s1, i);
  547. s2 += g.shfl_down(s2, i);
  548. }
  549. if (threadIdx.x == 0) {
  550. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  551. betta_grad[pos] = s1;
  552. gamma_grad[pos] = s2;
  553. }
  554. }
  555. /* Normalize Gamma & Betta gradients
  556. * Compute gradients using the input to
  557. * the normalize.
  558. * Combine transpose with gradients computation.
  559. */
  560. template <typename T>
  561. __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
  562. const T* __restrict__ X_data,
  563. const T* __restrict__ vars,
  564. const T* __restrict__ means,
  565. T* __restrict__ gamma_grad,
  566. T* __restrict__ betta_grad,
  567. int rows,
  568. int width)
  569. {
  570. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  571. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  572. cg::thread_block b = cg::this_thread_block();
  573. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  574. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  575. int offset = threadIdx.y * width + idx;
  576. int y_stride = width * TILE_DIM;
  577. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  578. // Loop across matrix height
  579. float betta_tmp = 0;
  580. float gamma_tmp = 0;
  581. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  582. float grad = (float)out_grad[offset];
  583. float val = (float)X_data[offset];
  584. val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
  585. betta_tmp += grad;
  586. gamma_tmp += (val * grad);
  587. offset += y_stride;
  588. }
  589. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  590. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  591. __syncthreads();
  592. // Sum the shared buffer.
  593. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  594. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  595. #ifndef __STOCHASTIC_MODE__
  596. __syncthreads();
  597. #endif
  598. for (int i = 1; i < TILE_DIM; i <<= 1) {
  599. s1 += g.shfl_down(s1, i);
  600. s2 += g.shfl_down(s2, i);
  601. }
  602. if (threadIdx.x == 0) {
  603. betta_grad[pos] = s1;
  604. gamma_grad[pos] = s2;
  605. }
  606. }
  607. /*
  608. /* Backward Normalize (Input-Gradient)
  609. * Using the means and variances from the input
  610. * This type of backward is invertible!
  611. * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
  612. */
  613. __global__ void LayerNormBackward2(const float* out_grad,
  614. const float* vals_hat,
  615. const float* gamma,
  616. const float* betta,
  617. const float* vars,
  618. float* inp_grad,
  619. bool invertible,
  620. int row_stride)
  621. {
  622. int iteration_stride = blockDim.x;
  623. int iterations = row_stride / iteration_stride;
  624. cg::thread_block b = cg::this_thread_block();
  625. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  626. int row = blockIdx.x;
  627. int id = threadIdx.x;
  628. int wid = id / WARP_SIZE;
  629. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  630. __shared__ float partialSum[MAX_WARP_NUM];
  631. out_grad += (row * row_stride);
  632. vals_hat += (row * row_stride);
  633. inp_grad += (row * row_stride);
  634. float vals_arr[NORM_REG];
  635. float vals_hat_arr[NORM_REG];
  636. int high_index = iterations * iteration_stride + id;
  637. #pragma unroll
  638. for (int i = 0; i < iterations; i++) {
  639. float gamma_reg = gamma[i * iteration_stride + id];
  640. vals_arr[i] = out_grad[i * iteration_stride + id];
  641. vals_arr[i] *= gamma_reg;
  642. vals_hat_arr[i] =
  643. (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
  644. gamma_reg
  645. : vals_hat[i * iteration_stride + id]);
  646. }
  647. if ((high_index) < row_stride) {
  648. float gamma_reg = gamma[high_index];
  649. vals_arr[iterations] = out_grad[high_index];
  650. vals_arr[iterations] *= gamma_reg;
  651. vals_hat_arr[iterations] =
  652. (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
  653. : vals_hat[high_index]);
  654. iterations++;
  655. }
  656. float var_reg = vars[row];
  657. float sum = 0;
  658. for (int i = 0; i < iterations; i++) {
  659. sum += vals_hat_arr[i] * vals_arr[i] *
  660. sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
  661. vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
  662. }
  663. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  664. if (g.thread_rank() == 0) partialSum[wid] = sum;
  665. __syncthreads();
  666. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  667. #ifndef __STOCHASTIC_MODE__
  668. __syncthreads();
  669. #endif
  670. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  671. sum = g.shfl(sum, 0);
  672. sum /= row_stride;
  673. for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
  674. sum = 0;
  675. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  676. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  677. if (g.thread_rank() == 0) partialSum[wid] = sum;
  678. __syncthreads();
  679. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  680. #ifndef __STOCHASTIC_MODE__
  681. __syncthreads();
  682. #endif
  683. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  684. sum = g.shfl(sum, 0);
  685. sum /= row_stride;
  686. iterations = row_stride / iteration_stride;
  687. for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
  688. if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
  689. }
  690. __global__ void LayerNormBackward2(const __half* out_grad,
  691. const __half* vals_hat,
  692. const __half* gamma,
  693. const __half* betta,
  694. const __half* vars,
  695. __half* inp_grad,
  696. bool invertible,
  697. int row_stride)
  698. {
  699. #ifdef HALF_PRECISION_AVAILABLE
  700. int iteration_stride = blockDim.x;
  701. int iterations = row_stride / iteration_stride;
  702. cg::thread_block b = cg::this_thread_block();
  703. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  704. int row = blockIdx.x;
  705. int id = threadIdx.x;
  706. int wid = id / WARP_SIZE;
  707. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  708. __shared__ float partialSum[MAX_WARP_NUM];
  709. __half2 vals_arr[NORM_REG];
  710. float2 vals_arr_f[NORM_REG];
  711. __half2 vals_hat_arr[NORM_REG];
  712. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  713. const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
  714. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
  715. inp_grad_h += (row * row_stride);
  716. out_grad_h += (row * row_stride);
  717. vals_hat_h += (row * row_stride);
  718. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  719. const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
  720. int high_index = iterations * iteration_stride + id;
  721. #pragma unroll
  722. for (int i = 0; i < iterations; i++) {
  723. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  724. vals_arr[i] = out_grad_h[i * iteration_stride + id];
  725. vals_arr[i] *= gamma_reg;
  726. vals_hat_arr[i] =
  727. (invertible
  728. ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
  729. gamma_reg
  730. : vals_hat_h[i * iteration_stride + id]);
  731. }
  732. if ((high_index) < row_stride) {
  733. __half2 gamma_reg = gamma_h[high_index];
  734. vals_arr[iterations] = out_grad_h[high_index];
  735. vals_arr[iterations] *= gamma_reg;
  736. vals_hat_arr[iterations] =
  737. (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
  738. : vals_hat_h[high_index]);
  739. iterations++;
  740. }
  741. __half var_h = vars[row];
  742. __half2 var_reg = __halves2half2(var_h, var_h);
  743. float sum = 0.f;
  744. for (int i = 0; i < iterations; i++) {
  745. __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
  746. float2 result_f = __half22float2(result_h);
  747. sum += result_f.x;
  748. sum += result_f.y;
  749. vals_arr[i] *= h2rsqrt(var_reg);
  750. }
  751. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  752. if (g.thread_rank() == 0) partialSum[wid] = sum;
  753. __syncthreads();
  754. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  755. #ifndef __STOCHASTIC_MODE__
  756. __syncthreads();
  757. #endif
  758. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  759. sum = g.shfl(sum, 0);
  760. sum /= (2 * row_stride);
  761. __half2 sum_h = __float2half2_rn(sum);
  762. for (int i = 0; i < iterations; i++) {
  763. __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
  764. vals_arr_f[i] = __half22float2(vals_arr[i]);
  765. float2 temp_f = __half22float2(temp);
  766. vals_arr_f[i].x += temp_f.x;
  767. vals_arr_f[i].y += temp_f.y;
  768. }
  769. sum = 0.f;
  770. for (int i = 0; i < iterations; i++) {
  771. sum += (vals_arr_f[i].x);
  772. sum += (vals_arr_f[i].y);
  773. }
  774. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  775. if (g.thread_rank() == 0) partialSum[wid] = sum;
  776. __syncthreads();
  777. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  778. #ifndef __STOCHASTIC_MODE__
  779. __syncthreads();
  780. #endif
  781. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  782. sum = g.shfl(sum, 0);
  783. sum /= (2 * row_stride);
  784. iterations = row_stride / iteration_stride;
  785. for (int i = 0; i < iterations; i++) {
  786. vals_arr_f[i].x -= sum;
  787. vals_arr_f[i].y -= sum;
  788. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  789. inp_grad_h[i * iteration_stride + id] = temp;
  790. }
  791. if ((high_index) < row_stride) {
  792. vals_arr_f[iterations].x -= sum;
  793. vals_arr_f[iterations].y -= sum;
  794. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  795. inp_grad_h[high_index] = temp;
  796. }
  797. #endif
  798. }
  799. template <>
  800. void launch_layerNorm_backward<float>(const float* out_grad,
  801. const float* vals_hat,
  802. const float* vars,
  803. const float* gamma,
  804. float* gamma_grad,
  805. float* betta_grad,
  806. float* inp_grad,
  807. int batch,
  808. int hidden_dim,
  809. cudaStream_t stream[2],
  810. bool invertible,
  811. const float* betta)
  812. {
  813. int threads = THREADS;
  814. dim3 grid_dim(hidden_dim / TILE_DIM);
  815. dim3 block_dim(TILE_DIM, TILE_DIM);
  816. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  817. out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  818. dim3 grid_dim2(batch);
  819. if (hidden_dim > 16384 && hidden_dim <= 32768)
  820. threads <<= 1;
  821. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  822. threads <<= 2;
  823. else if (hidden_dim > 65536)
  824. throw std::runtime_error("Unsupport hidden_dim.");
  825. dim3 block_dim2(threads);
  826. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  827. out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
  828. }
  829. template <>
  830. void launch_layerNorm_backward<__half>(const __half* out_grad,
  831. const __half* vals_hat,
  832. const __half* vars,
  833. const __half* gamma,
  834. __half* gamma_grad,
  835. __half* betta_grad,
  836. __half* inp_grad,
  837. int batch,
  838. int hidden_dim,
  839. cudaStream_t stream[2],
  840. bool invertible,
  841. const __half* betta)
  842. {
  843. int threads = THREADS;
  844. dim3 grid_dim(hidden_dim / TILE_DIM);
  845. dim3 block_dim(TILE_DIM, TILE_DIM);
  846. // LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  847. // out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  848. dim3 grid_dim2(batch);
  849. if (hidden_dim > 8192 && hidden_dim <= 16384)
  850. threads <<= 1;
  851. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  852. threads <<= 2;
  853. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  854. threads <<= 3;
  855. else if (hidden_dim > 65536)
  856. throw std::runtime_error("Unsupport hidden_dim.");
  857. dim3 block_dim2(threads / 2);
  858. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  859. out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
  860. }
  861. /* Backward Normalize (Input-Gradient)
  862. * Using the means and variances from the input
  863. * This type of backward is not invertible!
  864. * We do the backward using the input (X)
  865. */
  866. __global__ void LayerNormBackward2(const float* out_grad,
  867. const float* X_vals,
  868. const float* gamma,
  869. const float* vars,
  870. const float* means,
  871. float* inp_grad,
  872. int row_stride)
  873. {
  874. int iteration_stride = blockDim.x;
  875. int iterations = row_stride / iteration_stride;
  876. cg::thread_block b = cg::this_thread_block();
  877. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  878. int row = blockIdx.x;
  879. int id = threadIdx.x;
  880. int wid = id >> WARP_SIZE_BITS;
  881. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  882. __shared__ float partialSum[MAX_WARP_NUM];
  883. out_grad += (row * row_stride);
  884. X_vals += (row * row_stride);
  885. inp_grad += (row * row_stride);
  886. float vals_arr[NORM_REG];
  887. int high_index = iterations * iteration_stride + id;
  888. #pragma unroll
  889. for (int i = 0; i < iterations; i++) {
  890. float gamma_reg = gamma[i * iteration_stride + id];
  891. vals_arr[i] = out_grad[i * iteration_stride + id];
  892. vals_arr[i] *= gamma_reg;
  893. }
  894. if ((high_index) < row_stride) {
  895. float gamma_reg = gamma[high_index];
  896. vals_arr[iterations] = out_grad[high_index];
  897. vals_arr[iterations] *= gamma_reg;
  898. iterations++;
  899. }
  900. float var_reg = vars[row];
  901. float mean_reg = means[row];
  902. float sum = 0;
  903. float xu[NORM_REG];
  904. for (int i = 0; i < iterations; i++) {
  905. xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
  906. sum += vals_arr[i] * xu[i];
  907. vals_arr[i] *= rsqrtf(var_reg);
  908. }
  909. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  910. if (g.thread_rank() == 0) partialSum[wid] = sum;
  911. __syncthreads();
  912. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  913. #ifndef __STOCHASTIC_MODE__
  914. __syncthreads();
  915. #endif
  916. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  917. sum = g.shfl(sum, 0);
  918. sum /= row_stride;
  919. for (int i = 0; i < iterations; i++) {
  920. vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
  921. }
  922. sum = 0;
  923. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  924. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  925. if (g.thread_rank() == 0) partialSum[wid] = sum;
  926. __syncthreads();
  927. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  928. #ifndef __STOCHASTIC_MODE__
  929. __syncthreads();
  930. #endif
  931. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  932. sum = g.shfl(sum, 0);
  933. sum /= row_stride;
  934. iterations = row_stride / iteration_stride;
  935. for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
  936. if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
  937. }
  938. __global__ void LayerNormBackward2(const __half* out_grad,
  939. const __half* X_vals,
  940. const __half* gamma,
  941. const __half* vars,
  942. const __half* means,
  943. __half* inp_grad,
  944. int row_stride)
  945. {
  946. #ifdef HALF_PRECISION_AVAILABLE
  947. int iteration_stride = blockDim.x;
  948. int iterations = row_stride / iteration_stride;
  949. cg::thread_block b = cg::this_thread_block();
  950. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  951. int row = blockIdx.x;
  952. int id = threadIdx.x;
  953. int wid = id >> WARP_SIZE_BITS;
  954. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  955. __shared__ float partialSum[MAX_WARP_NUM];
  956. __half2 vals_arr[NORM_REG];
  957. float2 vals_arr_f[NORM_REG];
  958. __half2 xu[NORM_REG];
  959. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  960. const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
  961. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
  962. inp_grad_h += (row * row_stride);
  963. out_grad_h += (row * row_stride);
  964. vals_hat_h += (row * row_stride);
  965. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  966. int high_index = iterations * iteration_stride + id;
  967. __half mean_h = means[row];
  968. __half2 mean_reg = __halves2half2(mean_h, mean_h);
  969. #pragma unroll
  970. for (int i = 0; i < iterations; i++) {
  971. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  972. vals_arr[i] = out_grad_h[i * iteration_stride + id];
  973. vals_arr[i] *= gamma_reg; // out_grad * gamma
  974. xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
  975. }
  976. if ((high_index) < row_stride) {
  977. __half2 gamma_reg = gamma_h[high_index];
  978. vals_arr[iterations] = out_grad_h[high_index];
  979. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  980. xu[iterations] = (vals_hat_h[high_index] - mean_reg);
  981. iterations++;
  982. }
  983. __half var_h = vars[row];
  984. __half2 var_reg = __halves2half2(var_h, var_h);
  985. float sum = 0.f;
  986. for (int i = 0; i < iterations; i++) {
  987. __half2 result_h = (xu[i] * vals_arr[i]);
  988. float2 result_f = __half22float2(result_h);
  989. sum += result_f.x;
  990. sum += result_f.y;
  991. vals_arr[i] *= h2rsqrt(var_reg);
  992. }
  993. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  994. if (g.thread_rank() == 0) partialSum[wid] = sum;
  995. __syncthreads();
  996. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  997. #ifndef __STOCHASTIC_MODE__
  998. __syncthreads();
  999. #endif
  1000. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1001. sum = g.shfl(sum, 0);
  1002. sum /= (2 * row_stride);
  1003. __half2 sum_h = __float2half2_rn(sum);
  1004. for (int i = 0; i < iterations; i++) {
  1005. __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
  1006. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1007. float2 xu_grad_f = __half22float2(xu_grad);
  1008. vals_arr_f[i].x += xu_grad_f.x;
  1009. vals_arr_f[i].y += xu_grad_f.y;
  1010. }
  1011. sum = 0.f;
  1012. for (int i = 0; i < iterations; i++) {
  1013. sum += (vals_arr_f[i].x);
  1014. sum += (vals_arr_f[i].y);
  1015. }
  1016. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1017. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1018. __syncthreads();
  1019. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1020. #ifndef __STOCHASTIC_MODE__
  1021. __syncthreads();
  1022. #endif
  1023. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1024. sum = g.shfl(sum, 0);
  1025. sum /= (2 * row_stride);
  1026. iterations = row_stride / iteration_stride;
  1027. for (int i = 0; i < iterations; i++) {
  1028. vals_arr_f[i].x -= sum;
  1029. vals_arr_f[i].y -= sum;
  1030. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1031. inp_grad_h[i * iteration_stride + id] = temp;
  1032. }
  1033. if ((high_index) < row_stride) {
  1034. vals_arr_f[iterations].x -= sum;
  1035. vals_arr_f[iterations].y -= sum;
  1036. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1037. inp_grad_h[high_index] = temp;
  1038. }
  1039. #endif
  1040. }
  1041. template <>
  1042. void launch_layerNorm_backward<float>(const float* out_grad,
  1043. const float* X_data,
  1044. const float* vars,
  1045. const float* means,
  1046. const float* gamma,
  1047. float* gamma_grad,
  1048. float* betta_grad,
  1049. float* inp_grad,
  1050. int batch,
  1051. int hidden_dim,
  1052. cudaStream_t stream[2])
  1053. {
  1054. int threads = THREADS;
  1055. dim3 grid_dim(hidden_dim / TILE_DIM);
  1056. dim3 block_dim(TILE_DIM, TILE_DIM);
  1057. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1058. out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1059. dim3 grid_dim2(batch);
  1060. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1061. threads <<= 1;
  1062. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1063. threads <<= 2;
  1064. else if (hidden_dim > 65536)
  1065. throw std::runtime_error("Unsupport hidden_dim.");
  1066. dim3 block_dim2(threads);
  1067. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1068. out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
  1069. }
  1070. template <>
  1071. void launch_layerNorm_backward<__half>(const __half* out_grad,
  1072. const __half* X_data,
  1073. const __half* vars,
  1074. const __half* means,
  1075. const __half* gamma,
  1076. __half* gamma_grad,
  1077. __half* betta_grad,
  1078. __half* inp_grad,
  1079. int batch,
  1080. int hidden_dim,
  1081. cudaStream_t stream[2])
  1082. {
  1083. int threads = THREADS;
  1084. dim3 grid_dim(hidden_dim / TILE_DIM);
  1085. dim3 block_dim(TILE_DIM, TILE_DIM);
  1086. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1087. out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1088. dim3 grid_dim2(batch);
  1089. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1090. threads <<= 1;
  1091. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1092. threads <<= 2;
  1093. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1094. threads <<= 3;
  1095. else if (hidden_dim > 65536)
  1096. throw std::runtime_error("Unsupport hidden_dim.");
  1097. dim3 block_dim2(threads / 2);
  1098. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1099. out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
  1100. }
  1101. template <typename T>
  1102. __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
  1103. const T* __restrict__ out_grad2,
  1104. const T* __restrict__ vals_hat,
  1105. const T* __restrict__ gamma,
  1106. const T* __restrict__ betta,
  1107. T* __restrict__ gamma_grad,
  1108. T* __restrict__ betta_grad,
  1109. int rows,
  1110. int width,
  1111. bool invertible)
  1112. {
  1113. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  1114. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  1115. cg::thread_block b = cg::this_thread_block();
  1116. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  1117. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  1118. int offset = threadIdx.y * width + idx;
  1119. int y_stride = width * TILE_DIM;
  1120. float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
  1121. float gamma_reg = (float)gamma[idx];
  1122. // Loop across matrix height
  1123. float betta_tmp = 0;
  1124. float gamma_tmp = 0;
  1125. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  1126. float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
  1127. float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
  1128. : (float)vals_hat[offset]);
  1129. betta_tmp += grad;
  1130. gamma_tmp += (val * grad);
  1131. offset += y_stride;
  1132. }
  1133. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  1134. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  1135. __syncthreads();
  1136. // Sum the shared buffer.
  1137. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  1138. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  1139. #ifndef __STOCHASTIC_MODE__
  1140. __syncthreads();
  1141. #endif
  1142. for (int i = 1; i < TILE_DIM; i <<= 1) {
  1143. s1 += g.shfl_down(s1, i);
  1144. s2 += g.shfl_down(s2, i);
  1145. }
  1146. if (threadIdx.x == 0) {
  1147. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  1148. betta_grad[pos] = s1;
  1149. gamma_grad[pos] = s2;
  1150. }
  1151. }
  1152. template <typename T>
  1153. __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
  1154. const T* __restrict__ out_grad2,
  1155. const T* __restrict__ X_data,
  1156. const T* __restrict__ vars,
  1157. const T* __restrict__ means,
  1158. T* __restrict__ gamma_grad,
  1159. T* __restrict__ betta_grad,
  1160. int rows,
  1161. int width)
  1162. {
  1163. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  1164. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  1165. cg::thread_block b = cg::this_thread_block();
  1166. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  1167. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  1168. int offset = threadIdx.y * width + idx;
  1169. int y_stride = width * TILE_DIM;
  1170. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  1171. // Loop across matrix height
  1172. float betta_tmp = 0;
  1173. float gamma_tmp = 0;
  1174. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  1175. float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
  1176. float val = (float)X_data[offset];
  1177. val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
  1178. betta_tmp += grad;
  1179. gamma_tmp += (val * grad);
  1180. offset += y_stride;
  1181. }
  1182. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  1183. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  1184. __syncthreads();
  1185. // Sum the shared buffer.
  1186. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  1187. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  1188. #ifndef __STOCHASTIC_MODE__
  1189. __syncthreads();
  1190. #endif
  1191. for (int i = 1; i < TILE_DIM; i <<= 1) {
  1192. s1 += g.shfl_down(s1, i);
  1193. s2 += g.shfl_down(s2, i);
  1194. }
  1195. if (threadIdx.x == 0) {
  1196. betta_grad[pos] = s1;
  1197. gamma_grad[pos] = s2;
  1198. }
  1199. }
  1200. __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
  1201. const float* out_grad2,
  1202. const float* vals_hat,
  1203. const float* gamma,
  1204. const float* betta,
  1205. const float* vars,
  1206. float* inp_grad,
  1207. bool invertible,
  1208. int row_stride)
  1209. {
  1210. int iteration_stride = blockDim.x;
  1211. int iterations = row_stride / iteration_stride;
  1212. cg::thread_block b = cg::this_thread_block();
  1213. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1214. int row = blockIdx.x;
  1215. int id = threadIdx.x;
  1216. int wid = id / WARP_SIZE;
  1217. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1218. __shared__ float partialSum[MAX_WARP_NUM];
  1219. out_grad1 += (row * row_stride);
  1220. out_grad2 += (row * row_stride);
  1221. vals_hat += (row * row_stride);
  1222. inp_grad += (row * row_stride);
  1223. float vals_arr[NORM_REG];
  1224. float vals_hat_arr[NORM_REG];
  1225. int high_index = iterations * iteration_stride + id;
  1226. #pragma unroll
  1227. for (int i = 0; i < iterations; i++) {
  1228. float gamma_reg = gamma[i * iteration_stride + id];
  1229. vals_arr[i] = out_grad1[i * iteration_stride + id];
  1230. vals_arr[i] *= gamma_reg;
  1231. vals_hat_arr[i] =
  1232. (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
  1233. gamma_reg
  1234. : vals_hat[i * iteration_stride + id]);
  1235. }
  1236. if ((high_index) < row_stride) {
  1237. float gamma_reg = gamma[high_index];
  1238. vals_arr[iterations] = out_grad1[high_index];
  1239. vals_arr[iterations] *= gamma_reg;
  1240. vals_hat_arr[iterations] =
  1241. (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
  1242. : vals_hat[high_index]);
  1243. iterations++;
  1244. }
  1245. float var_reg = vars[row];
  1246. float sum = 0;
  1247. for (int i = 0; i < iterations; i++) {
  1248. sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
  1249. vals_arr[i] *= rsqrtf(var_reg);
  1250. }
  1251. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1252. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1253. __syncthreads();
  1254. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1255. #ifndef __STOCHASTIC_MODE__
  1256. __syncthreads();
  1257. #endif
  1258. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1259. sum = g.shfl(sum, 0);
  1260. sum /= row_stride;
  1261. for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
  1262. sum = 0;
  1263. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  1264. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1265. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1266. __syncthreads();
  1267. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1268. #ifndef __STOCHASTIC_MODE__
  1269. __syncthreads();
  1270. #endif
  1271. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1272. sum = g.shfl(sum, 0);
  1273. sum /= row_stride;
  1274. iterations = row_stride / iteration_stride;
  1275. for (int i = 0; i < iterations; i++)
  1276. inp_grad[i * iteration_stride + id] =
  1277. (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
  1278. if ((high_index) < row_stride)
  1279. inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
  1280. }
  1281. __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
  1282. const __half* out_grad2,
  1283. const __half* vals_hat,
  1284. const __half* gamma,
  1285. const __half* betta,
  1286. const __half* vars,
  1287. __half* inp_grad,
  1288. bool invertible,
  1289. int row_stride)
  1290. {
  1291. #ifdef HALF_PRECISION_AVAILABLE
  1292. int iteration_stride = blockDim.x;
  1293. int iterations = row_stride / iteration_stride;
  1294. cg::thread_block b = cg::this_thread_block();
  1295. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1296. int row = blockIdx.x;
  1297. int id = threadIdx.x;
  1298. int wid = id / WARP_SIZE;
  1299. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1300. __shared__ float partialSum[MAX_WARP_NUM];
  1301. __half2 vals_arr[NORM_REG];
  1302. float2 vals_arr_f[NORM_REG];
  1303. __half2 vals_hat_arr[NORM_REG];
  1304. // float2 result[iterations];
  1305. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  1306. const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
  1307. const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
  1308. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
  1309. inp_grad_h += (row * row_stride);
  1310. out_grad_h1 += (row * row_stride);
  1311. out_grad_h2 += (row * row_stride);
  1312. vals_hat_h += (row * row_stride);
  1313. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  1314. const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
  1315. int high_index = iterations * iteration_stride + id;
  1316. #pragma unroll
  1317. for (int i = 0; i < iterations; i++) {
  1318. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  1319. vals_arr[i] = out_grad_h1[i * iteration_stride + id];
  1320. vals_arr[i] *= gamma_reg; // out_grad * gamma
  1321. vals_hat_arr[i] =
  1322. (invertible
  1323. ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
  1324. gamma_reg
  1325. : vals_hat_h[i * iteration_stride + id]);
  1326. }
  1327. if ((high_index) < row_stride) {
  1328. __half2 gamma_reg = gamma_h[high_index];
  1329. vals_arr[iterations] = out_grad_h1[high_index];
  1330. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  1331. vals_hat_arr[iterations] =
  1332. (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
  1333. : vals_hat_h[high_index]);
  1334. iterations++;
  1335. }
  1336. __half var_h = vars[row];
  1337. __half2 var_reg = __halves2half2(var_h, var_h);
  1338. float sum = 0.f;
  1339. for (int i = 0; i < iterations; i++) {
  1340. __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
  1341. float2 result_f = __half22float2(result_h);
  1342. sum += result_f.x;
  1343. sum += result_f.y;
  1344. vals_arr[i] *= h2rsqrt(var_reg);
  1345. }
  1346. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1347. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1348. __syncthreads();
  1349. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1350. #ifndef __STOCHASTIC_MODE__
  1351. __syncthreads();
  1352. #endif
  1353. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1354. sum = g.shfl(sum, 0);
  1355. sum /= (2 * row_stride);
  1356. __half2 sum_h = __float2half2_rn(sum);
  1357. for (int i = 0; i < iterations; i++) {
  1358. __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
  1359. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1360. float2 temp_f = __half22float2(temp);
  1361. vals_arr_f[i].x += temp_f.x;
  1362. vals_arr_f[i].y += temp_f.y;
  1363. }
  1364. sum = 0.f;
  1365. for (int i = 0; i < iterations; i++) {
  1366. sum += (vals_arr_f[i].x);
  1367. sum += (vals_arr_f[i].y);
  1368. }
  1369. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1370. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1371. __syncthreads();
  1372. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1373. #ifndef __STOCHASTIC_MODE__
  1374. __syncthreads();
  1375. #endif
  1376. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1377. sum = g.shfl(sum, 0);
  1378. sum /= (2 * row_stride);
  1379. iterations = row_stride / iteration_stride;
  1380. for (int i = 0; i < iterations; i++) {
  1381. vals_arr_f[i].x -= sum;
  1382. vals_arr_f[i].y -= sum;
  1383. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1384. inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
  1385. }
  1386. if ((high_index) < row_stride) {
  1387. vals_arr_f[iterations].x -= sum;
  1388. vals_arr_f[iterations].y -= sum;
  1389. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1390. inp_grad_h[high_index] = temp + out_grad_h2[high_index];
  1391. }
  1392. #endif
  1393. }
  1394. template <>
  1395. void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
  1396. const float* out_grad2,
  1397. const float* vals_hat,
  1398. const float* vars,
  1399. const float* gamma,
  1400. float* gamma_grad,
  1401. float* betta_grad,
  1402. float* inp_grad,
  1403. int batch,
  1404. int hidden_dim,
  1405. cudaStream_t stream[2],
  1406. bool invertible,
  1407. const float* betta)
  1408. {
  1409. int threads = THREADS;
  1410. dim3 grid_dim(hidden_dim / TILE_DIM);
  1411. dim3 block_dim(TILE_DIM, TILE_DIM);
  1412. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1413. out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  1414. dim3 grid_dim2(batch);
  1415. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1416. threads <<= 1;
  1417. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1418. threads <<= 2;
  1419. else if (hidden_dim > 65536)
  1420. throw std::runtime_error("Unsupport hidden_dim.");
  1421. dim3 block_dim2(threads);
  1422. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1423. out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
  1424. }
  1425. template <>
  1426. void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
  1427. const __half* out_grad2,
  1428. const __half* vals_hat,
  1429. const __half* vars,
  1430. const __half* gamma,
  1431. __half* gamma_grad,
  1432. __half* betta_grad,
  1433. __half* inp_grad,
  1434. int batch,
  1435. int hidden_dim,
  1436. cudaStream_t stream[2],
  1437. bool invertible,
  1438. const __half* betta)
  1439. {
  1440. int threads = THREADS;
  1441. dim3 grid_dim(hidden_dim / TILE_DIM);
  1442. dim3 block_dim(TILE_DIM, TILE_DIM);
  1443. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1444. out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  1445. dim3 grid_dim2(batch);
  1446. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1447. threads <<= 1;
  1448. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1449. threads <<= 2;
  1450. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1451. threads <<= 3;
  1452. else if (hidden_dim > 65536)
  1453. throw std::runtime_error("Unsupport hidden_dim.");
  1454. dim3 block_dim2(threads / 2);
  1455. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1456. out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
  1457. }
  1458. /* Backward Normalize (Input-Gradient)
  1459. * Using the means and variances from the input
  1460. * This type of backward is not invertible!
  1461. * We do the backward using the input (X)
  1462. */
  1463. __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
  1464. const float* out_grad2,
  1465. const float* X_vals,
  1466. const float* gamma,
  1467. const float* vars,
  1468. const float* means,
  1469. float* inp_grad,
  1470. int row_stride)
  1471. {
  1472. int iteration_stride = blockDim.x;
  1473. int iterations = row_stride / iteration_stride;
  1474. cg::thread_block b = cg::this_thread_block();
  1475. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1476. int row = blockIdx.x;
  1477. int id = threadIdx.x;
  1478. int wid = id / WARP_SIZE;
  1479. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1480. __shared__ float partialSum[MAX_WARP_NUM];
  1481. float vals_arr[NORM_REG];
  1482. float vals_hat_arr[NORM_REG];
  1483. out_grad1 += (row * row_stride);
  1484. out_grad2 += (row * row_stride);
  1485. X_vals += (row * row_stride);
  1486. inp_grad += (row * row_stride);
  1487. int high_index = iterations * iteration_stride + id;
  1488. #pragma unroll
  1489. for (int i = 0; i < iterations; i++) {
  1490. float gamma_reg = gamma[i * iteration_stride + id];
  1491. vals_arr[i] = out_grad1[i * iteration_stride + id];
  1492. vals_arr[i] *= gamma_reg;
  1493. vals_hat_arr[i] = X_vals[i * iteration_stride + id];
  1494. }
  1495. if ((high_index) < row_stride) {
  1496. float gamma_reg = gamma[high_index];
  1497. vals_arr[iterations] = out_grad1[high_index];
  1498. vals_arr[iterations] *= gamma_reg;
  1499. vals_hat_arr[iterations] = X_vals[high_index];
  1500. iterations++;
  1501. }
  1502. float var_reg = vars[row];
  1503. float mean_reg = means[row];
  1504. float sum = 0;
  1505. float xu[NORM_REG];
  1506. for (int i = 0; i < iterations; i++) {
  1507. xu[i] = (vals_hat_arr[i] - mean_reg);
  1508. sum += vals_arr[i] * xu[i];
  1509. vals_arr[i] *= rsqrtf(var_reg);
  1510. }
  1511. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1512. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1513. __syncthreads();
  1514. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1515. #ifndef __STOCHASTIC_MODE__
  1516. __syncthreads();
  1517. #endif
  1518. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1519. sum = g.shfl(sum, 0);
  1520. sum /= row_stride;
  1521. for (int i = 0; i < iterations; i++) {
  1522. vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
  1523. }
  1524. sum = 0;
  1525. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  1526. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1527. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1528. __syncthreads();
  1529. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1530. #ifndef __STOCHASTIC_MODE__
  1531. __syncthreads();
  1532. #endif
  1533. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1534. sum = g.shfl(sum, 0);
  1535. sum /= row_stride;
  1536. iterations = row_stride / iteration_stride;
  1537. for (int i = 0; i < iterations; i++)
  1538. inp_grad[i * iteration_stride + id] =
  1539. (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
  1540. if ((high_index) < row_stride)
  1541. inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
  1542. }
  1543. __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
  1544. const __half* out_grad2,
  1545. const __half* X_vals,
  1546. const __half* gamma,
  1547. const __half* vars,
  1548. const __half* means,
  1549. __half* inp_grad,
  1550. int row_stride)
  1551. {
  1552. #ifdef HALF_PRECISION_AVAILABLE
  1553. int iteration_stride = blockDim.x;
  1554. int iterations = row_stride / iteration_stride;
  1555. cg::thread_block b = cg::this_thread_block();
  1556. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1557. int row = blockIdx.x;
  1558. int id = threadIdx.x;
  1559. int wid = id / WARP_SIZE;
  1560. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1561. __shared__ float partialSum[MAX_WARP_NUM];
  1562. __half2 vals_arr[NORM_REG];
  1563. float2 vals_arr_f[NORM_REG];
  1564. __half2 vals_hat_arr[NORM_REG];
  1565. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  1566. const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
  1567. const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
  1568. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
  1569. out_grad_h1 += (row * row_stride);
  1570. out_grad_h2 += (row * row_stride);
  1571. inp_grad_h += (row * row_stride);
  1572. vals_hat_h += (row * row_stride);
  1573. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  1574. int high_index = iterations * iteration_stride + id;
  1575. #pragma unroll
  1576. for (int i = 0; i < iterations; i++) {
  1577. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  1578. vals_arr[i] = out_grad_h1[i * iteration_stride + id];
  1579. vals_arr[i] *= gamma_reg; // out_grad * gamma
  1580. vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
  1581. }
  1582. if ((high_index) < row_stride) {
  1583. __half2 gamma_reg = gamma_h[high_index];
  1584. vals_arr[iterations] = out_grad_h1[high_index];
  1585. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  1586. vals_hat_arr[iterations] = vals_hat_h[high_index];
  1587. iterations++;
  1588. }
  1589. __half mean_h = means[row];
  1590. __half var_h = vars[row];
  1591. __half2 var_reg = __halves2half2(var_h, var_h);
  1592. __half2 mean_reg = __halves2half2(mean_h, mean_h);
  1593. __half2 xu[NORM_REG];
  1594. float sum = 0.f;
  1595. for (int i = 0; i < iterations; i++) {
  1596. xu[i] = (vals_hat_arr[i] - mean_reg);
  1597. __half2 result_h = (xu[i] * vals_arr[i]);
  1598. float2 result_f = __half22float2(result_h);
  1599. sum += result_f.x;
  1600. sum += result_f.y;
  1601. vals_arr[i] *= h2rsqrt(var_reg);
  1602. }
  1603. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1604. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1605. __syncthreads();
  1606. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1607. #ifndef __STOCHASTIC_MODE__
  1608. __syncthreads();
  1609. #endif
  1610. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1611. sum = g.shfl(sum, 0);
  1612. sum /= (2 * row_stride);
  1613. __half2 sum_h = __float2half2_rn(sum);
  1614. for (int i = 0; i < iterations; i++) {
  1615. __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
  1616. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1617. float2 xu_grad_f = __half22float2(xu_grad);
  1618. vals_arr_f[i].x += xu_grad_f.x;
  1619. vals_arr_f[i].y += xu_grad_f.y;
  1620. }
  1621. sum = 0.f;
  1622. for (int i = 0; i < iterations; i++) {
  1623. sum += (vals_arr_f[i].x);
  1624. sum += (vals_arr_f[i].y);
  1625. }
  1626. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1627. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1628. __syncthreads();
  1629. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1630. #ifndef __STOCHASTIC_MODE__
  1631. __syncthreads();
  1632. #endif
  1633. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1634. sum = g.shfl(sum, 0);
  1635. sum /= (2 * row_stride);
  1636. iterations = row_stride / iteration_stride;
  1637. for (int i = 0; i < iterations; i++) {
  1638. vals_arr_f[i].x -= sum;
  1639. vals_arr_f[i].y -= sum;
  1640. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1641. inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
  1642. }
  1643. if ((high_index) < row_stride) {
  1644. vals_arr_f[iterations].x -= sum;
  1645. vals_arr_f[iterations].y -= sum;
  1646. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1647. inp_grad_h[high_index] = temp + out_grad_h2[high_index];
  1648. }
  1649. #endif
  1650. }
  1651. template <>
  1652. void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
  1653. const float* out_grad2,
  1654. const float* X_data,
  1655. const float* vars,
  1656. const float* means,
  1657. const float* gamma,
  1658. float* gamma_grad,
  1659. float* betta_grad,
  1660. float* inp_grad,
  1661. int batch,
  1662. int hidden_dim,
  1663. cudaStream_t stream[2])
  1664. {
  1665. int threads = THREADS;
  1666. dim3 grid_dim(hidden_dim / TILE_DIM);
  1667. dim3 block_dim(TILE_DIM, TILE_DIM);
  1668. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1669. out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1670. dim3 grid_dim2(batch);
  1671. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1672. threads <<= 1;
  1673. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1674. threads <<= 2;
  1675. else if (hidden_dim > 65536)
  1676. throw std::runtime_error("Unsupport hidden_dim.");
  1677. dim3 block_dim2(threads);
  1678. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1679. out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
  1680. }
  1681. template <>
  1682. void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
  1683. const __half* out_grad2,
  1684. const __half* X_data,
  1685. const __half* vars,
  1686. const __half* means,
  1687. const __half* gamma,
  1688. __half* gamma_grad,
  1689. __half* betta_grad,
  1690. __half* inp_grad,
  1691. int batch,
  1692. int hidden_dim,
  1693. cudaStream_t stream[2])
  1694. {
  1695. int threads = THREADS;
  1696. dim3 grid_dim(hidden_dim / TILE_DIM);
  1697. dim3 block_dim(TILE_DIM, TILE_DIM);
  1698. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1699. out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1700. dim3 grid_dim2(batch);
  1701. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1702. threads <<= 1;
  1703. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1704. threads <<= 2;
  1705. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1706. threads <<= 3;
  1707. else if (hidden_dim > 65536)
  1708. throw std::runtime_error("Unsupport hidden_dim.");
  1709. dim3 block_dim2(threads / 2);
  1710. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1711. out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
  1712. }