normalize_kernels.cu 73 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129
  1. #include "custom_cuda_layers.h"
  2. namespace cg = cooperative_groups;
  3. /*
  4. Fused bias add, residual (elementwise) add, and normalization layer.
  5. For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
  6. __half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
  7. For specific launch constraints, see the launch functions.
  8. */
  9. #define NORM_REG (MAX_REGISTERS / 4)
  10. __global__ void fused_bias_residual_layer_norm(float* vals,
  11. const float* residual,
  12. const float* gamma,
  13. const float* beta,
  14. float epsilon,
  15. bool preLayerNorm,
  16. bool training,
  17. float* vars,
  18. float* means,
  19. int row_stride)
  20. {
  21. int iteration_stride = blockDim.x;
  22. int iterations = row_stride / iteration_stride;
  23. cg::thread_block b = cg::this_thread_block();
  24. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  25. int row = blockIdx.x;
  26. int id = threadIdx.x;
  27. int gid = id / WARP_SIZE;
  28. float vals_arr[NORM_REG];
  29. __shared__ float shr[MAX_WARP_NUM];
  30. residual += (row * row_stride);
  31. vals += (row * row_stride);
  32. float sum = 0.f;
  33. int high_index = iterations * iteration_stride + id;
  34. #pragma unroll
  35. for (int i = 0; i < iterations; i++) {
  36. vals_arr[i] = residual[i * iteration_stride + id];
  37. sum += vals_arr[i];
  38. }
  39. if (high_index < row_stride) {
  40. vals_arr[iterations] = residual[high_index];
  41. sum += vals_arr[iterations];
  42. iterations++;
  43. }
  44. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  45. if (g.thread_rank() == 0) shr[gid] = sum;
  46. b.sync();
  47. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  48. #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
  49. b.sync();
  50. #endif
  51. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  52. sum += g.shfl_down(sum, i);
  53. }
  54. sum = g.shfl(sum, 0);
  55. float mean = sum / row_stride;
  56. if (training)
  57. if (threadIdx.x == 0) means[row] = mean;
  58. float variance = 0.f;
  59. for (int i = 0; i < iterations; i++) {
  60. vals_arr[i] -= mean;
  61. variance += vals_arr[i] * vals_arr[i];
  62. }
  63. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  64. if (g.thread_rank() == 0) shr[gid] = variance;
  65. b.sync();
  66. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  67. #ifndef __STOCHASTIC_MODE__
  68. b.sync();
  69. #endif
  70. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  71. variance += g.shfl_down(variance, i);
  72. }
  73. variance = g.shfl(variance, 0);
  74. variance /= row_stride;
  75. variance += epsilon;
  76. if (training)
  77. if (threadIdx.x == 0) vars[row] = variance;
  78. iterations = row_stride / iteration_stride;
  79. for (int i = 0; i < iterations; i++) {
  80. vals_arr[i] = vals_arr[i] * rsqrtf(variance);
  81. vals_arr[i] =
  82. vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
  83. vals[i * iteration_stride + id] = vals_arr[i];
  84. }
  85. if ((high_index) < row_stride) {
  86. vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
  87. vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
  88. vals[high_index] = vals_arr[iterations];
  89. }
  90. }
  91. __global__ void fused_bias_residual_layer_norm(__half* vals,
  92. const __half* residual,
  93. const __half* gamma,
  94. const __half* beta,
  95. float epsilon,
  96. bool preLayerNorm,
  97. bool training,
  98. __half* vars,
  99. __half* means,
  100. int row_stride)
  101. {
  102. #ifdef HALF_PRECISION_AVAILABLE
  103. int iteration_stride = blockDim.x;
  104. int iterations = row_stride / iteration_stride;
  105. cg::thread_block b = cg::this_thread_block();
  106. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  107. int row = blockIdx.x;
  108. int id = threadIdx.x;
  109. int gid = id >> WARP_SIZE_BITS;
  110. float2 vals_f[NORM_REG];
  111. __shared__ float shr[MAX_WARP_NUM];
  112. __half2* vals_cast = reinterpret_cast<__half2*>(vals);
  113. const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
  114. residual_cast += (row * row_stride);
  115. vals_cast += (row * row_stride);
  116. float sum = 0.f;
  117. int high_index = iterations * iteration_stride + id;
  118. #pragma unroll
  119. for (int i = 0; i < iterations; i++) {
  120. vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
  121. sum += vals_f[i].x;
  122. sum += vals_f[i].y;
  123. }
  124. if ((high_index) < row_stride) {
  125. vals_f[iterations] = __half22float2(residual_cast[high_index]);
  126. sum += vals_f[iterations].x;
  127. sum += vals_f[iterations].y;
  128. iterations++;
  129. }
  130. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  131. if (g.thread_rank() == 0) shr[gid] = sum;
  132. b.sync();
  133. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  134. #ifndef __STOCHASTIC_MODE__
  135. b.sync();
  136. #endif
  137. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  138. sum += g.shfl_down(sum, i);
  139. }
  140. sum = g.shfl(sum, 0);
  141. float mean = sum / (row_stride * 2);
  142. float variance = 0.f;
  143. for (int i = 0; i < iterations; i++) {
  144. vals_f[i].x -= mean;
  145. vals_f[i].y -= mean;
  146. variance += vals_f[i].x * vals_f[i].x;
  147. variance += vals_f[i].y * vals_f[i].y;
  148. }
  149. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  150. if (g.thread_rank() == 0) shr[gid] = variance;
  151. b.sync();
  152. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  153. #ifndef __STOCHASTIC_MODE__
  154. b.sync();
  155. #endif
  156. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  157. variance += g.shfl_down(variance, i);
  158. }
  159. variance = g.shfl(variance, 0);
  160. variance /= (row_stride * 2);
  161. variance += epsilon;
  162. __half2 variance_h = __float2half2_rn(variance);
  163. const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
  164. const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
  165. if (training && threadIdx.x == 0) {
  166. vars[row] = __float2half(variance);
  167. means[row] = __float2half(mean);
  168. }
  169. iterations = row_stride / iteration_stride;
  170. for (int i = 0; i < iterations; i++) {
  171. __half2 vals_arr = __float22half2_rn(vals_f[i]);
  172. vals_arr = vals_arr * h2rsqrt(variance_h);
  173. vals_arr =
  174. vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
  175. vals_cast[i * iteration_stride + id] = vals_arr;
  176. }
  177. if ((high_index) < row_stride) {
  178. __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
  179. vals_arr = vals_arr * h2rsqrt(variance_h);
  180. vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
  181. vals_cast[high_index] = vals_arr;
  182. }
  183. #endif
  184. }
  185. template <typename T>
  186. void launch_bias_residual_layer_norm(T* vals,
  187. const T* residual,
  188. const T* gamma,
  189. const T* beta,
  190. float epsilon,
  191. int batch_size,
  192. int hidden_dim,
  193. cudaStream_t stream,
  194. bool preLayerNorm,
  195. bool training,
  196. T* vars,
  197. T* means);
  198. template <>
  199. void launch_bias_residual_layer_norm<float>(float* vals,
  200. const float* residual,
  201. const float* gamma,
  202. const float* beta,
  203. float epsilon,
  204. int batch_size,
  205. int hidden_dim,
  206. cudaStream_t stream,
  207. bool preLayerNorm,
  208. bool training,
  209. float* vars,
  210. float* means)
  211. {
  212. int threads = THREADS;
  213. dim3 grid_dim(batch_size);
  214. if (hidden_dim > 16384 && hidden_dim <= 32768)
  215. threads <<= 1;
  216. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  217. threads <<= 2;
  218. else if (hidden_dim > 65536)
  219. throw std::runtime_error("Unsupport hidden_dim.");
  220. dim3 block_dim(threads);
  221. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  222. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
  223. }
  224. template <>
  225. void launch_bias_residual_layer_norm<__half>(__half* vals,
  226. const __half* residual,
  227. const __half* gamma,
  228. const __half* beta,
  229. float epsilon,
  230. int batch_size,
  231. int hidden_dim,
  232. cudaStream_t stream,
  233. bool preLayerNorm,
  234. bool training,
  235. __half* vars,
  236. __half* means)
  237. {
  238. int threads = 128;
  239. dim3 grid_dim(batch_size);
  240. if (hidden_dim > 8192 && hidden_dim <= 16384)
  241. threads <<= 1;
  242. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  243. threads <<= 2;
  244. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  245. threads <<= 3;
  246. else if (hidden_dim > 65536)
  247. throw std::runtime_error("Unsupport hidden_dim.");
  248. dim3 block_dim(threads);
  249. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  250. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
  251. }
  252. __global__ void fused_bias_residual_layer_norm(float* vals,
  253. const float* residual,
  254. const float* gamma,
  255. const float* beta,
  256. float epsilon,
  257. bool preLayerNorm,
  258. bool training,
  259. float* vars,
  260. int row_stride)
  261. {
  262. int iteration_stride = blockDim.x;
  263. int iterations = row_stride / iteration_stride;
  264. cg::thread_block b = cg::this_thread_block();
  265. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  266. int row = blockIdx.x;
  267. int id = threadIdx.x;
  268. int gid = id / 32;
  269. float vals_arr[NORM_REG];
  270. __shared__ float shr[MAX_WARP_NUM];
  271. residual += (row * row_stride);
  272. vals += (row * row_stride);
  273. float sum = 0.f;
  274. int high_index = iterations * iteration_stride + id;
  275. #pragma unroll
  276. for (int i = 0; i < iterations; i++) {
  277. vals_arr[i] = residual[i * iteration_stride + id];
  278. sum += vals_arr[i];
  279. }
  280. if ((high_index) < row_stride) {
  281. vals_arr[iterations] = residual[high_index];
  282. sum += vals_arr[iterations];
  283. iterations++;
  284. }
  285. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  286. if (g.thread_rank() == 0) shr[gid] = sum;
  287. b.sync();
  288. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  289. #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
  290. b.sync();
  291. #endif
  292. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  293. sum += g.shfl_down(sum, i);
  294. }
  295. sum = g.shfl(sum, 0);
  296. float mean = sum / row_stride;
  297. float variance = 0.f;
  298. for (int i = 0; i < iterations; i++) {
  299. vals_arr[i] -= mean;
  300. variance += vals_arr[i] * vals_arr[i];
  301. }
  302. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  303. if (g.thread_rank() == 0) shr[gid] = variance;
  304. b.sync();
  305. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  306. #ifndef __STOCHASTIC_MODE__
  307. b.sync();
  308. #endif
  309. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  310. variance += g.shfl_down(variance, i);
  311. }
  312. variance = g.shfl(variance, 0);
  313. variance /= row_stride;
  314. variance += epsilon;
  315. if (training)
  316. if (threadIdx.x == 0) vars[row] = variance;
  317. iterations = row_stride / iteration_stride;
  318. for (int i = 0; i < iterations; i++) {
  319. vals_arr[i] = vals_arr[i] * rsqrtf(variance);
  320. vals_arr[i] =
  321. vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
  322. vals[i * iteration_stride + id] = vals_arr[i];
  323. }
  324. if ((high_index) < row_stride) {
  325. vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
  326. vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
  327. vals[high_index] = vals_arr[iterations];
  328. }
  329. }
  330. __global__ void fused_bias_residual_layer_norm(__half* vals,
  331. const __half* residual,
  332. const __half* gamma,
  333. const __half* beta,
  334. float epsilon,
  335. bool preLayerNorm,
  336. bool training,
  337. __half* vars,
  338. int row_stride)
  339. {
  340. #ifdef HALF_PRECISION_AVAILABLE
  341. int iteration_stride = blockDim.x;
  342. int iterations = row_stride / iteration_stride;
  343. cg::thread_block b = cg::this_thread_block();
  344. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  345. int row = blockIdx.x;
  346. int id = threadIdx.x;
  347. int gid = id >> WARP_SIZE_BITS;
  348. float2 vals_f[NORM_REG];
  349. __shared__ float shr[MAX_WARP_NUM];
  350. __half2* vals_cast = reinterpret_cast<__half2*>(vals);
  351. const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
  352. residual_cast += (row * row_stride);
  353. vals_cast += (row * row_stride);
  354. float sum = 0.f;
  355. int high_index = iterations * iteration_stride + id;
  356. #pragma unroll
  357. for (int i = 0; i < iterations; i++) {
  358. vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
  359. sum += vals_f[i].x;
  360. sum += vals_f[i].y;
  361. }
  362. if ((high_index) < row_stride) {
  363. vals_f[iterations] = __half22float2(residual_cast[high_index]);
  364. sum += vals_f[iterations].x;
  365. sum += vals_f[iterations].y;
  366. iterations++;
  367. }
  368. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  369. if (g.thread_rank() == 0) shr[gid] = sum;
  370. b.sync();
  371. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  372. #ifndef __STOCHASTIC_MODE__
  373. b.sync();
  374. #endif
  375. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  376. sum += g.shfl_down(sum, i);
  377. }
  378. sum = g.shfl(sum, 0);
  379. float mean = sum / (row_stride * 2);
  380. float variance = 0.f;
  381. for (int i = 0; i < iterations; i++) {
  382. vals_f[i].x -= mean;
  383. vals_f[i].y -= mean;
  384. variance += vals_f[i].x * vals_f[i].x;
  385. variance += vals_f[i].y * vals_f[i].y;
  386. }
  387. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  388. if (g.thread_rank() == 0) shr[gid] = variance;
  389. b.sync();
  390. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  391. #ifndef __STOCHASTIC_MODE__
  392. b.sync();
  393. #endif
  394. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  395. variance += g.shfl_down(variance, i);
  396. }
  397. variance = g.shfl(variance, 0);
  398. variance /= (row_stride * 2);
  399. variance += epsilon;
  400. __half2 variance_h = __float2half2_rn(variance);
  401. const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
  402. const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
  403. if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
  404. iterations = row_stride / iteration_stride;
  405. for (int i = 0; i < iterations; i++) {
  406. __half2 vals_arr = __float22half2_rn(vals_f[i]);
  407. vals_arr = vals_arr * h2rsqrt(variance_h);
  408. vals_arr =
  409. vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
  410. vals_cast[i * iteration_stride + id] = vals_arr;
  411. }
  412. if ((high_index) < row_stride) {
  413. __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
  414. vals_arr = vals_arr * h2rsqrt(variance_h);
  415. vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
  416. vals_cast[high_index] = vals_arr;
  417. }
  418. #endif
  419. }
  420. template <typename T>
  421. void launch_bias_residual_layer_norm(T* vals,
  422. const T* residual,
  423. const T* gamma,
  424. const T* beta,
  425. float epsilon,
  426. int batch_size,
  427. int hidden_dim,
  428. cudaStream_t stream,
  429. bool preLayerNorm,
  430. bool training,
  431. T* vars);
  432. /*
  433. To tune this launch the following restrictions must be met:
  434. For float:
  435. row_stride == hidden_size
  436. threads * iterations == row_stride
  437. threads is in [32, 64, 128, 256, 512, 1024]
  438. For half:
  439. row_stride == hidden_size / 2
  440. threads * iterations == row_stride
  441. threads is in [32, 64, 128, 256, 512, 1024]
  442. */
  443. template <>
  444. void launch_bias_residual_layer_norm<float>(float* vals,
  445. const float* residual,
  446. const float* gamma,
  447. const float* beta,
  448. float epsilon,
  449. int batch_size,
  450. int hidden_dim,
  451. cudaStream_t stream,
  452. bool preLayerNorm,
  453. bool training,
  454. float* vars)
  455. {
  456. int threads = THREADS;
  457. dim3 grid_dim(batch_size);
  458. // There are some limitations to call below functions, now just enumerate the situations.
  459. if (hidden_dim > 16384 && hidden_dim <= 32768)
  460. threads <<= 1;
  461. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  462. threads <<= 2;
  463. else if (hidden_dim > 65536)
  464. throw std::runtime_error("Unsupport hidden_dim.");
  465. dim3 block_dim(threads);
  466. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  467. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
  468. }
  469. template <>
  470. void launch_bias_residual_layer_norm<__half>(__half* vals,
  471. const __half* residual,
  472. const __half* gamma,
  473. const __half* beta,
  474. float epsilon,
  475. int batch_size,
  476. int hidden_dim,
  477. cudaStream_t stream,
  478. bool preLayerNorm,
  479. bool training,
  480. __half* vars)
  481. {
  482. int threads = 128;
  483. dim3 grid_dim(batch_size);
  484. // There are some limitations to call below functions, now just enumerate the situations.
  485. if (hidden_dim > 8192 && hidden_dim <= 16384)
  486. threads <<= 1;
  487. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  488. threads <<= 2;
  489. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  490. threads <<= 3;
  491. else if (hidden_dim > 65536)
  492. throw std::runtime_error("Unsupport hidden_dim.");
  493. dim3 block_dim(threads);
  494. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  495. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
  496. }
  497. /* Normalize Gamma & Betta gradients
  498. * Compute gradients using either X_hat or
  499. * normalize input (invertible).
  500. * Combine transpose with gradients computation.
  501. */
  502. template <typename T>
  503. __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
  504. const T* __restrict__ vals_hat,
  505. const T* __restrict__ gamma,
  506. const T* __restrict__ betta,
  507. T* __restrict__ gamma_grad,
  508. T* __restrict__ betta_grad,
  509. int rows,
  510. int width,
  511. bool invertible)
  512. {
  513. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  514. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  515. cg::thread_block b = cg::this_thread_block();
  516. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  517. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  518. int offset = threadIdx.y * width + idx;
  519. int y_stride = width * TILE_DIM;
  520. float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
  521. float gamma_reg = (float)gamma[idx];
  522. // Loop across matrix height
  523. float betta_tmp = 0;
  524. float gamma_tmp = 0;
  525. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  526. float grad = (float)out_grad[offset];
  527. float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
  528. : (float)vals_hat[offset]);
  529. betta_tmp += grad;
  530. gamma_tmp += (val * grad);
  531. offset += y_stride;
  532. }
  533. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  534. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  535. __syncthreads();
  536. // Sum the shared buffer.
  537. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  538. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  539. #ifndef __STOCHASTIC_MODE__
  540. __syncthreads();
  541. #endif
  542. for (int i = 1; i < TILE_DIM; i <<= 1) {
  543. s1 += g.shfl_down(s1, i);
  544. s2 += g.shfl_down(s2, i);
  545. }
  546. if (threadIdx.x == 0) {
  547. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  548. betta_grad[pos] = s1;
  549. gamma_grad[pos] = s2;
  550. }
  551. }
  552. /* Normalize Gamma & Betta gradients
  553. * Compute gradients using the input to
  554. * the normalize.
  555. * Combine transpose with gradients computation.
  556. */
  557. template <typename T>
  558. __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
  559. const T* __restrict__ X_data,
  560. const T* __restrict__ vars,
  561. const T* __restrict__ means,
  562. T* __restrict__ gamma_grad,
  563. T* __restrict__ betta_grad,
  564. int rows,
  565. int width)
  566. {
  567. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  568. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  569. cg::thread_block b = cg::this_thread_block();
  570. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  571. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  572. int offset = threadIdx.y * width + idx;
  573. int y_stride = width * TILE_DIM;
  574. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  575. // Loop across matrix height
  576. float betta_tmp = 0;
  577. float gamma_tmp = 0;
  578. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  579. float grad = (float)out_grad[offset];
  580. float val = (float)X_data[offset];
  581. val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
  582. betta_tmp += grad;
  583. gamma_tmp += (val * grad);
  584. offset += y_stride;
  585. }
  586. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  587. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  588. __syncthreads();
  589. // Sum the shared buffer.
  590. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  591. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  592. #ifndef __STOCHASTIC_MODE__
  593. __syncthreads();
  594. #endif
  595. for (int i = 1; i < TILE_DIM; i <<= 1) {
  596. s1 += g.shfl_down(s1, i);
  597. s2 += g.shfl_down(s2, i);
  598. }
  599. if (threadIdx.x == 0) {
  600. betta_grad[pos] = s1;
  601. gamma_grad[pos] = s2;
  602. }
  603. }
  604. /*
  605. /* Backward Normalize (Input-Gradient)
  606. * Using the means and variances from the input
  607. * This type of backward is invertible!
  608. * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
  609. */
  610. __global__ void LayerNormBackward2(const float* out_grad,
  611. const float* vals_hat,
  612. const float* gamma,
  613. const float* betta,
  614. const float* vars,
  615. float* inp_grad,
  616. bool invertible,
  617. int row_stride)
  618. {
  619. int iteration_stride = blockDim.x;
  620. int iterations = row_stride / iteration_stride;
  621. cg::thread_block b = cg::this_thread_block();
  622. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  623. int row = blockIdx.x;
  624. int id = threadIdx.x;
  625. int wid = id / WARP_SIZE;
  626. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  627. __shared__ float partialSum[MAX_WARP_NUM];
  628. out_grad += (row * row_stride);
  629. vals_hat += (row * row_stride);
  630. inp_grad += (row * row_stride);
  631. float vals_arr[NORM_REG];
  632. float vals_hat_arr[NORM_REG];
  633. int high_index = iterations * iteration_stride + id;
  634. #pragma unroll
  635. for (int i = 0; i < iterations; i++) {
  636. float gamma_reg = gamma[i * iteration_stride + id];
  637. vals_arr[i] = out_grad[i * iteration_stride + id];
  638. vals_arr[i] *= gamma_reg;
  639. vals_hat_arr[i] =
  640. (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
  641. gamma_reg
  642. : vals_hat[i * iteration_stride + id]);
  643. }
  644. if ((high_index) < row_stride) {
  645. float gamma_reg = gamma[high_index];
  646. vals_arr[iterations] = out_grad[high_index];
  647. vals_arr[iterations] *= gamma_reg;
  648. vals_hat_arr[iterations] =
  649. (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
  650. : vals_hat[high_index]);
  651. iterations++;
  652. }
  653. float var_reg = vars[row];
  654. float sum = 0;
  655. for (int i = 0; i < iterations; i++) {
  656. sum += vals_hat_arr[i] * vals_arr[i] *
  657. sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
  658. vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
  659. }
  660. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  661. if (g.thread_rank() == 0) partialSum[wid] = sum;
  662. __syncthreads();
  663. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  664. #ifndef __STOCHASTIC_MODE__
  665. __syncthreads();
  666. #endif
  667. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  668. sum = g.shfl(sum, 0);
  669. sum /= row_stride;
  670. for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
  671. sum = 0;
  672. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  673. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  674. if (g.thread_rank() == 0) partialSum[wid] = sum;
  675. __syncthreads();
  676. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  677. #ifndef __STOCHASTIC_MODE__
  678. __syncthreads();
  679. #endif
  680. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  681. sum = g.shfl(sum, 0);
  682. sum /= row_stride;
  683. iterations = row_stride / iteration_stride;
  684. for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
  685. if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
  686. }
  687. __global__ void LayerNormBackward2(const __half* out_grad,
  688. const __half* vals_hat,
  689. const __half* gamma,
  690. const __half* betta,
  691. const __half* vars,
  692. __half* inp_grad,
  693. bool invertible,
  694. int row_stride)
  695. {
  696. #ifdef HALF_PRECISION_AVAILABLE
  697. int iteration_stride = blockDim.x;
  698. int iterations = row_stride / iteration_stride;
  699. cg::thread_block b = cg::this_thread_block();
  700. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  701. int row = blockIdx.x;
  702. int id = threadIdx.x;
  703. int wid = id / WARP_SIZE;
  704. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  705. __shared__ float partialSum[MAX_WARP_NUM];
  706. __half2 vals_arr[NORM_REG];
  707. float2 vals_arr_f[NORM_REG];
  708. __half2 vals_hat_arr[NORM_REG];
  709. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  710. const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
  711. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
  712. inp_grad_h += (row * row_stride);
  713. out_grad_h += (row * row_stride);
  714. vals_hat_h += (row * row_stride);
  715. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  716. const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
  717. int high_index = iterations * iteration_stride + id;
  718. #pragma unroll
  719. for (int i = 0; i < iterations; i++) {
  720. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  721. vals_arr[i] = out_grad_h[i * iteration_stride + id];
  722. vals_arr[i] *= gamma_reg;
  723. vals_hat_arr[i] =
  724. (invertible
  725. ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
  726. gamma_reg
  727. : vals_hat_h[i * iteration_stride + id]);
  728. }
  729. if ((high_index) < row_stride) {
  730. __half2 gamma_reg = gamma_h[high_index];
  731. vals_arr[iterations] = out_grad_h[high_index];
  732. vals_arr[iterations] *= gamma_reg;
  733. vals_hat_arr[iterations] =
  734. (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
  735. : vals_hat_h[high_index]);
  736. iterations++;
  737. }
  738. __half var_h = vars[row];
  739. __half2 var_reg = __halves2half2(var_h, var_h);
  740. float sum = 0.f;
  741. for (int i = 0; i < iterations; i++) {
  742. __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
  743. float2 result_f = __half22float2(result_h);
  744. sum += result_f.x;
  745. sum += result_f.y;
  746. vals_arr[i] *= h2rsqrt(var_reg);
  747. }
  748. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  749. if (g.thread_rank() == 0) partialSum[wid] = sum;
  750. __syncthreads();
  751. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  752. #ifndef __STOCHASTIC_MODE__
  753. __syncthreads();
  754. #endif
  755. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  756. sum = g.shfl(sum, 0);
  757. sum /= (2 * row_stride);
  758. __half2 sum_h = __float2half2_rn(sum);
  759. for (int i = 0; i < iterations; i++) {
  760. __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
  761. vals_arr_f[i] = __half22float2(vals_arr[i]);
  762. float2 temp_f = __half22float2(temp);
  763. vals_arr_f[i].x += temp_f.x;
  764. vals_arr_f[i].y += temp_f.y;
  765. }
  766. sum = 0.f;
  767. for (int i = 0; i < iterations; i++) {
  768. sum += (vals_arr_f[i].x);
  769. sum += (vals_arr_f[i].y);
  770. }
  771. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  772. if (g.thread_rank() == 0) partialSum[wid] = sum;
  773. __syncthreads();
  774. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  775. #ifndef __STOCHASTIC_MODE__
  776. __syncthreads();
  777. #endif
  778. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  779. sum = g.shfl(sum, 0);
  780. sum /= (2 * row_stride);
  781. iterations = row_stride / iteration_stride;
  782. for (int i = 0; i < iterations; i++) {
  783. vals_arr_f[i].x -= sum;
  784. vals_arr_f[i].y -= sum;
  785. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  786. inp_grad_h[i * iteration_stride + id] = temp;
  787. }
  788. if ((high_index) < row_stride) {
  789. vals_arr_f[iterations].x -= sum;
  790. vals_arr_f[iterations].y -= sum;
  791. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  792. inp_grad_h[high_index] = temp;
  793. }
  794. #endif
  795. }
  796. template <>
  797. void launch_layerNorm_backward<float>(const float* out_grad,
  798. const float* vals_hat,
  799. const float* vars,
  800. const float* gamma,
  801. float* gamma_grad,
  802. float* betta_grad,
  803. float* inp_grad,
  804. int batch,
  805. int hidden_dim,
  806. cudaStream_t stream[2],
  807. bool invertible,
  808. const float* betta)
  809. {
  810. int threads = THREADS;
  811. dim3 grid_dim(hidden_dim / TILE_DIM);
  812. dim3 block_dim(TILE_DIM, TILE_DIM);
  813. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  814. out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  815. dim3 grid_dim2(batch);
  816. if (hidden_dim > 16384 && hidden_dim <= 32768)
  817. threads <<= 1;
  818. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  819. threads <<= 2;
  820. else if (hidden_dim > 65536)
  821. throw std::runtime_error("Unsupport hidden_dim.");
  822. dim3 block_dim2(threads);
  823. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  824. out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
  825. }
  826. template <>
  827. void launch_layerNorm_backward<__half>(const __half* out_grad,
  828. const __half* vals_hat,
  829. const __half* vars,
  830. const __half* gamma,
  831. __half* gamma_grad,
  832. __half* betta_grad,
  833. __half* inp_grad,
  834. int batch,
  835. int hidden_dim,
  836. cudaStream_t stream[2],
  837. bool invertible,
  838. const __half* betta)
  839. {
  840. int threads = THREADS;
  841. dim3 grid_dim(hidden_dim / TILE_DIM);
  842. dim3 block_dim(TILE_DIM, TILE_DIM);
  843. // LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  844. // out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  845. dim3 grid_dim2(batch);
  846. if (hidden_dim > 8192 && hidden_dim <= 16384)
  847. threads <<= 1;
  848. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  849. threads <<= 2;
  850. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  851. threads <<= 3;
  852. else if (hidden_dim > 65536)
  853. throw std::runtime_error("Unsupport hidden_dim.");
  854. dim3 block_dim2(threads / 2);
  855. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  856. out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
  857. }
  858. /* Backward Normalize (Input-Gradient)
  859. * Using the means and variances from the input
  860. * This type of backward is not invertible!
  861. * We do the backward using the input (X)
  862. */
  863. __global__ void LayerNormBackward2(const float* out_grad,
  864. const float* X_vals,
  865. const float* gamma,
  866. const float* vars,
  867. const float* means,
  868. float* inp_grad,
  869. int row_stride)
  870. {
  871. int iteration_stride = blockDim.x;
  872. int iterations = row_stride / iteration_stride;
  873. cg::thread_block b = cg::this_thread_block();
  874. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  875. int row = blockIdx.x;
  876. int id = threadIdx.x;
  877. int wid = id >> WARP_SIZE_BITS;
  878. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  879. __shared__ float partialSum[MAX_WARP_NUM];
  880. out_grad += (row * row_stride);
  881. X_vals += (row * row_stride);
  882. inp_grad += (row * row_stride);
  883. float vals_arr[NORM_REG];
  884. int high_index = iterations * iteration_stride + id;
  885. #pragma unroll
  886. for (int i = 0; i < iterations; i++) {
  887. float gamma_reg = gamma[i * iteration_stride + id];
  888. vals_arr[i] = out_grad[i * iteration_stride + id];
  889. vals_arr[i] *= gamma_reg;
  890. }
  891. if ((high_index) < row_stride) {
  892. float gamma_reg = gamma[high_index];
  893. vals_arr[iterations] = out_grad[high_index];
  894. vals_arr[iterations] *= gamma_reg;
  895. iterations++;
  896. }
  897. float var_reg = vars[row];
  898. float mean_reg = means[row];
  899. float sum = 0;
  900. float xu[NORM_REG];
  901. for (int i = 0; i < iterations; i++) {
  902. xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
  903. sum += vals_arr[i] * xu[i];
  904. vals_arr[i] *= rsqrtf(var_reg);
  905. }
  906. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  907. if (g.thread_rank() == 0) partialSum[wid] = sum;
  908. __syncthreads();
  909. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  910. #ifndef __STOCHASTIC_MODE__
  911. __syncthreads();
  912. #endif
  913. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  914. sum = g.shfl(sum, 0);
  915. sum /= row_stride;
  916. for (int i = 0; i < iterations; i++) {
  917. vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
  918. }
  919. sum = 0;
  920. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  921. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  922. if (g.thread_rank() == 0) partialSum[wid] = sum;
  923. __syncthreads();
  924. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  925. #ifndef __STOCHASTIC_MODE__
  926. __syncthreads();
  927. #endif
  928. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  929. sum = g.shfl(sum, 0);
  930. sum /= row_stride;
  931. iterations = row_stride / iteration_stride;
  932. for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
  933. if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
  934. }
  935. __global__ void LayerNormBackward2(const __half* out_grad,
  936. const __half* X_vals,
  937. const __half* gamma,
  938. const __half* vars,
  939. const __half* means,
  940. __half* inp_grad,
  941. int row_stride)
  942. {
  943. #ifdef HALF_PRECISION_AVAILABLE
  944. int iteration_stride = blockDim.x;
  945. int iterations = row_stride / iteration_stride;
  946. cg::thread_block b = cg::this_thread_block();
  947. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  948. int row = blockIdx.x;
  949. int id = threadIdx.x;
  950. int wid = id >> WARP_SIZE_BITS;
  951. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  952. __shared__ float partialSum[MAX_WARP_NUM];
  953. __half2 vals_arr[NORM_REG];
  954. float2 vals_arr_f[NORM_REG];
  955. __half2 xu[NORM_REG];
  956. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  957. const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
  958. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
  959. inp_grad_h += (row * row_stride);
  960. out_grad_h += (row * row_stride);
  961. vals_hat_h += (row * row_stride);
  962. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  963. int high_index = iterations * iteration_stride + id;
  964. __half mean_h = means[row];
  965. __half2 mean_reg = __halves2half2(mean_h, mean_h);
  966. #pragma unroll
  967. for (int i = 0; i < iterations; i++) {
  968. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  969. vals_arr[i] = out_grad_h[i * iteration_stride + id];
  970. vals_arr[i] *= gamma_reg; // out_grad * gamma
  971. xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
  972. }
  973. if ((high_index) < row_stride) {
  974. __half2 gamma_reg = gamma_h[high_index];
  975. vals_arr[iterations] = out_grad_h[high_index];
  976. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  977. xu[iterations] = (vals_hat_h[high_index] - mean_reg);
  978. iterations++;
  979. }
  980. __half var_h = vars[row];
  981. __half2 var_reg = __halves2half2(var_h, var_h);
  982. float sum = 0.f;
  983. for (int i = 0; i < iterations; i++) {
  984. __half2 result_h = (xu[i] * vals_arr[i]);
  985. float2 result_f = __half22float2(result_h);
  986. sum += result_f.x;
  987. sum += result_f.y;
  988. vals_arr[i] *= h2rsqrt(var_reg);
  989. }
  990. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  991. if (g.thread_rank() == 0) partialSum[wid] = sum;
  992. __syncthreads();
  993. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  994. #ifndef __STOCHASTIC_MODE__
  995. __syncthreads();
  996. #endif
  997. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  998. sum = g.shfl(sum, 0);
  999. sum /= (2 * row_stride);
  1000. __half2 sum_h = __float2half2_rn(sum);
  1001. for (int i = 0; i < iterations; i++) {
  1002. __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
  1003. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1004. float2 xu_grad_f = __half22float2(xu_grad);
  1005. vals_arr_f[i].x += xu_grad_f.x;
  1006. vals_arr_f[i].y += xu_grad_f.y;
  1007. }
  1008. sum = 0.f;
  1009. for (int i = 0; i < iterations; i++) {
  1010. sum += (vals_arr_f[i].x);
  1011. sum += (vals_arr_f[i].y);
  1012. }
  1013. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1014. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1015. __syncthreads();
  1016. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1017. #ifndef __STOCHASTIC_MODE__
  1018. __syncthreads();
  1019. #endif
  1020. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1021. sum = g.shfl(sum, 0);
  1022. sum /= (2 * row_stride);
  1023. iterations = row_stride / iteration_stride;
  1024. for (int i = 0; i < iterations; i++) {
  1025. vals_arr_f[i].x -= sum;
  1026. vals_arr_f[i].y -= sum;
  1027. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1028. inp_grad_h[i * iteration_stride + id] = temp;
  1029. }
  1030. if ((high_index) < row_stride) {
  1031. vals_arr_f[iterations].x -= sum;
  1032. vals_arr_f[iterations].y -= sum;
  1033. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1034. inp_grad_h[high_index] = temp;
  1035. }
  1036. #endif
  1037. }
  1038. template <>
  1039. void launch_layerNorm_backward<float>(const float* out_grad,
  1040. const float* X_data,
  1041. const float* vars,
  1042. const float* means,
  1043. const float* gamma,
  1044. float* gamma_grad,
  1045. float* betta_grad,
  1046. float* inp_grad,
  1047. int batch,
  1048. int hidden_dim,
  1049. cudaStream_t stream[2])
  1050. {
  1051. int threads = THREADS;
  1052. dim3 grid_dim(hidden_dim / TILE_DIM);
  1053. dim3 block_dim(TILE_DIM, TILE_DIM);
  1054. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1055. out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1056. dim3 grid_dim2(batch);
  1057. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1058. threads <<= 1;
  1059. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1060. threads <<= 2;
  1061. else if (hidden_dim > 65536)
  1062. throw std::runtime_error("Unsupport hidden_dim.");
  1063. dim3 block_dim2(threads);
  1064. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1065. out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
  1066. }
  1067. template <>
  1068. void launch_layerNorm_backward<__half>(const __half* out_grad,
  1069. const __half* X_data,
  1070. const __half* vars,
  1071. const __half* means,
  1072. const __half* gamma,
  1073. __half* gamma_grad,
  1074. __half* betta_grad,
  1075. __half* inp_grad,
  1076. int batch,
  1077. int hidden_dim,
  1078. cudaStream_t stream[2])
  1079. {
  1080. int threads = THREADS;
  1081. dim3 grid_dim(hidden_dim / TILE_DIM);
  1082. dim3 block_dim(TILE_DIM, TILE_DIM);
  1083. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1084. out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1085. dim3 grid_dim2(batch);
  1086. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1087. threads <<= 1;
  1088. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1089. threads <<= 2;
  1090. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1091. threads <<= 3;
  1092. else if (hidden_dim > 65536)
  1093. throw std::runtime_error("Unsupport hidden_dim.");
  1094. dim3 block_dim2(threads / 2);
  1095. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1096. out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
  1097. }
  1098. template <typename T>
  1099. __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
  1100. const T* __restrict__ out_grad2,
  1101. const T* __restrict__ vals_hat,
  1102. const T* __restrict__ gamma,
  1103. const T* __restrict__ betta,
  1104. T* __restrict__ gamma_grad,
  1105. T* __restrict__ betta_grad,
  1106. int rows,
  1107. int width,
  1108. bool invertible)
  1109. {
  1110. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  1111. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  1112. cg::thread_block b = cg::this_thread_block();
  1113. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  1114. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  1115. int offset = threadIdx.y * width + idx;
  1116. int y_stride = width * TILE_DIM;
  1117. float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
  1118. float gamma_reg = (float)gamma[idx];
  1119. // Loop across matrix height
  1120. float betta_tmp = 0;
  1121. float gamma_tmp = 0;
  1122. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  1123. float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
  1124. float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
  1125. : (float)vals_hat[offset]);
  1126. betta_tmp += grad;
  1127. gamma_tmp += (val * grad);
  1128. offset += y_stride;
  1129. }
  1130. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  1131. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  1132. __syncthreads();
  1133. // Sum the shared buffer.
  1134. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  1135. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  1136. #ifndef __STOCHASTIC_MODE__
  1137. __syncthreads();
  1138. #endif
  1139. for (int i = 1; i < TILE_DIM; i <<= 1) {
  1140. s1 += g.shfl_down(s1, i);
  1141. s2 += g.shfl_down(s2, i);
  1142. }
  1143. if (threadIdx.x == 0) {
  1144. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  1145. betta_grad[pos] = s1;
  1146. gamma_grad[pos] = s2;
  1147. }
  1148. }
  1149. template <typename T>
  1150. __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
  1151. const T* __restrict__ out_grad2,
  1152. const T* __restrict__ X_data,
  1153. const T* __restrict__ vars,
  1154. const T* __restrict__ means,
  1155. T* __restrict__ gamma_grad,
  1156. T* __restrict__ betta_grad,
  1157. int rows,
  1158. int width)
  1159. {
  1160. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  1161. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  1162. cg::thread_block b = cg::this_thread_block();
  1163. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  1164. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  1165. int offset = threadIdx.y * width + idx;
  1166. int y_stride = width * TILE_DIM;
  1167. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  1168. // Loop across matrix height
  1169. float betta_tmp = 0;
  1170. float gamma_tmp = 0;
  1171. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  1172. float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
  1173. float val = (float)X_data[offset];
  1174. val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
  1175. betta_tmp += grad;
  1176. gamma_tmp += (val * grad);
  1177. offset += y_stride;
  1178. }
  1179. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  1180. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  1181. __syncthreads();
  1182. // Sum the shared buffer.
  1183. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  1184. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  1185. #ifndef __STOCHASTIC_MODE__
  1186. __syncthreads();
  1187. #endif
  1188. for (int i = 1; i < TILE_DIM; i <<= 1) {
  1189. s1 += g.shfl_down(s1, i);
  1190. s2 += g.shfl_down(s2, i);
  1191. }
  1192. if (threadIdx.x == 0) {
  1193. betta_grad[pos] = s1;
  1194. gamma_grad[pos] = s2;
  1195. }
  1196. }
  1197. __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
  1198. const float* out_grad2,
  1199. const float* vals_hat,
  1200. const float* gamma,
  1201. const float* betta,
  1202. const float* vars,
  1203. float* inp_grad,
  1204. bool invertible,
  1205. int row_stride)
  1206. {
  1207. int iteration_stride = blockDim.x;
  1208. int iterations = row_stride / iteration_stride;
  1209. cg::thread_block b = cg::this_thread_block();
  1210. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1211. int row = blockIdx.x;
  1212. int id = threadIdx.x;
  1213. int wid = id / WARP_SIZE;
  1214. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1215. __shared__ float partialSum[MAX_WARP_NUM];
  1216. out_grad1 += (row * row_stride);
  1217. out_grad2 += (row * row_stride);
  1218. vals_hat += (row * row_stride);
  1219. inp_grad += (row * row_stride);
  1220. float vals_arr[NORM_REG];
  1221. float vals_hat_arr[NORM_REG];
  1222. int high_index = iterations * iteration_stride + id;
  1223. #pragma unroll
  1224. for (int i = 0; i < iterations; i++) {
  1225. float gamma_reg = gamma[i * iteration_stride + id];
  1226. vals_arr[i] = out_grad1[i * iteration_stride + id];
  1227. vals_arr[i] *= gamma_reg;
  1228. vals_hat_arr[i] =
  1229. (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
  1230. gamma_reg
  1231. : vals_hat[i * iteration_stride + id]);
  1232. }
  1233. if ((high_index) < row_stride) {
  1234. float gamma_reg = gamma[high_index];
  1235. vals_arr[iterations] = out_grad1[high_index];
  1236. vals_arr[iterations] *= gamma_reg;
  1237. vals_hat_arr[iterations] =
  1238. (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
  1239. : vals_hat[high_index]);
  1240. iterations++;
  1241. }
  1242. float var_reg = vars[row];
  1243. float sum = 0;
  1244. for (int i = 0; i < iterations; i++) {
  1245. sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
  1246. vals_arr[i] *= rsqrtf(var_reg);
  1247. }
  1248. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1249. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1250. __syncthreads();
  1251. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1252. #ifndef __STOCHASTIC_MODE__
  1253. __syncthreads();
  1254. #endif
  1255. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1256. sum = g.shfl(sum, 0);
  1257. sum /= row_stride;
  1258. for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
  1259. sum = 0;
  1260. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  1261. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1262. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1263. __syncthreads();
  1264. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1265. #ifndef __STOCHASTIC_MODE__
  1266. __syncthreads();
  1267. #endif
  1268. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1269. sum = g.shfl(sum, 0);
  1270. sum /= row_stride;
  1271. iterations = row_stride / iteration_stride;
  1272. for (int i = 0; i < iterations; i++)
  1273. inp_grad[i * iteration_stride + id] =
  1274. (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
  1275. if ((high_index) < row_stride)
  1276. inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
  1277. }
  1278. __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
  1279. const __half* out_grad2,
  1280. const __half* vals_hat,
  1281. const __half* gamma,
  1282. const __half* betta,
  1283. const __half* vars,
  1284. __half* inp_grad,
  1285. bool invertible,
  1286. int row_stride)
  1287. {
  1288. #ifdef HALF_PRECISION_AVAILABLE
  1289. int iteration_stride = blockDim.x;
  1290. int iterations = row_stride / iteration_stride;
  1291. cg::thread_block b = cg::this_thread_block();
  1292. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1293. int row = blockIdx.x;
  1294. int id = threadIdx.x;
  1295. int wid = id / WARP_SIZE;
  1296. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1297. __shared__ float partialSum[MAX_WARP_NUM];
  1298. __half2 vals_arr[NORM_REG];
  1299. float2 vals_arr_f[NORM_REG];
  1300. __half2 vals_hat_arr[NORM_REG];
  1301. // float2 result[iterations];
  1302. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  1303. const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
  1304. const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
  1305. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
  1306. inp_grad_h += (row * row_stride);
  1307. out_grad_h1 += (row * row_stride);
  1308. out_grad_h2 += (row * row_stride);
  1309. vals_hat_h += (row * row_stride);
  1310. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  1311. const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
  1312. int high_index = iterations * iteration_stride + id;
  1313. #pragma unroll
  1314. for (int i = 0; i < iterations; i++) {
  1315. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  1316. vals_arr[i] = out_grad_h1[i * iteration_stride + id];
  1317. vals_arr[i] *= gamma_reg; // out_grad * gamma
  1318. vals_hat_arr[i] =
  1319. (invertible
  1320. ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
  1321. gamma_reg
  1322. : vals_hat_h[i * iteration_stride + id]);
  1323. }
  1324. if ((high_index) < row_stride) {
  1325. __half2 gamma_reg = gamma_h[high_index];
  1326. vals_arr[iterations] = out_grad_h1[high_index];
  1327. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  1328. vals_hat_arr[iterations] =
  1329. (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
  1330. : vals_hat_h[high_index]);
  1331. iterations++;
  1332. }
  1333. __half var_h = vars[row];
  1334. __half2 var_reg = __halves2half2(var_h, var_h);
  1335. float sum = 0.f;
  1336. for (int i = 0; i < iterations; i++) {
  1337. __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
  1338. float2 result_f = __half22float2(result_h);
  1339. sum += result_f.x;
  1340. sum += result_f.y;
  1341. vals_arr[i] *= h2rsqrt(var_reg);
  1342. }
  1343. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1344. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1345. __syncthreads();
  1346. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1347. #ifndef __STOCHASTIC_MODE__
  1348. __syncthreads();
  1349. #endif
  1350. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1351. sum = g.shfl(sum, 0);
  1352. sum /= (2 * row_stride);
  1353. __half2 sum_h = __float2half2_rn(sum);
  1354. for (int i = 0; i < iterations; i++) {
  1355. __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
  1356. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1357. float2 temp_f = __half22float2(temp);
  1358. vals_arr_f[i].x += temp_f.x;
  1359. vals_arr_f[i].y += temp_f.y;
  1360. }
  1361. sum = 0.f;
  1362. for (int i = 0; i < iterations; i++) {
  1363. sum += (vals_arr_f[i].x);
  1364. sum += (vals_arr_f[i].y);
  1365. }
  1366. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1367. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1368. __syncthreads();
  1369. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1370. #ifndef __STOCHASTIC_MODE__
  1371. __syncthreads();
  1372. #endif
  1373. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1374. sum = g.shfl(sum, 0);
  1375. sum /= (2 * row_stride);
  1376. iterations = row_stride / iteration_stride;
  1377. for (int i = 0; i < iterations; i++) {
  1378. vals_arr_f[i].x -= sum;
  1379. vals_arr_f[i].y -= sum;
  1380. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1381. inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
  1382. }
  1383. if ((high_index) < row_stride) {
  1384. vals_arr_f[iterations].x -= sum;
  1385. vals_arr_f[iterations].y -= sum;
  1386. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1387. inp_grad_h[high_index] = temp + out_grad_h2[high_index];
  1388. }
  1389. #endif
  1390. }
  1391. template <>
  1392. void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
  1393. const float* out_grad2,
  1394. const float* vals_hat,
  1395. const float* vars,
  1396. const float* gamma,
  1397. float* gamma_grad,
  1398. float* betta_grad,
  1399. float* inp_grad,
  1400. int batch,
  1401. int hidden_dim,
  1402. cudaStream_t stream[2],
  1403. bool invertible,
  1404. const float* betta)
  1405. {
  1406. int threads = THREADS;
  1407. dim3 grid_dim(hidden_dim / TILE_DIM);
  1408. dim3 block_dim(TILE_DIM, TILE_DIM);
  1409. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1410. out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  1411. dim3 grid_dim2(batch);
  1412. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1413. threads <<= 1;
  1414. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1415. threads <<= 2;
  1416. else if (hidden_dim > 65536)
  1417. throw std::runtime_error("Unsupport hidden_dim.");
  1418. dim3 block_dim2(threads);
  1419. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1420. out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
  1421. }
  1422. template <>
  1423. void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
  1424. const __half* out_grad2,
  1425. const __half* vals_hat,
  1426. const __half* vars,
  1427. const __half* gamma,
  1428. __half* gamma_grad,
  1429. __half* betta_grad,
  1430. __half* inp_grad,
  1431. int batch,
  1432. int hidden_dim,
  1433. cudaStream_t stream[2],
  1434. bool invertible,
  1435. const __half* betta)
  1436. {
  1437. int threads = THREADS;
  1438. dim3 grid_dim(hidden_dim / TILE_DIM);
  1439. dim3 block_dim(TILE_DIM, TILE_DIM);
  1440. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1441. out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  1442. dim3 grid_dim2(batch);
  1443. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1444. threads <<= 1;
  1445. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1446. threads <<= 2;
  1447. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1448. threads <<= 3;
  1449. else if (hidden_dim > 65536)
  1450. throw std::runtime_error("Unsupport hidden_dim.");
  1451. dim3 block_dim2(threads / 2);
  1452. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1453. out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
  1454. }
  1455. /* Backward Normalize (Input-Gradient)
  1456. * Using the means and variances from the input
  1457. * This type of backward is not invertible!
  1458. * We do the backward using the input (X)
  1459. */
  1460. __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
  1461. const float* out_grad2,
  1462. const float* X_vals,
  1463. const float* gamma,
  1464. const float* vars,
  1465. const float* means,
  1466. float* inp_grad,
  1467. int row_stride)
  1468. {
  1469. int iteration_stride = blockDim.x;
  1470. int iterations = row_stride / iteration_stride;
  1471. cg::thread_block b = cg::this_thread_block();
  1472. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1473. int row = blockIdx.x;
  1474. int id = threadIdx.x;
  1475. int wid = id / WARP_SIZE;
  1476. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1477. __shared__ float partialSum[MAX_WARP_NUM];
  1478. float vals_arr[NORM_REG];
  1479. float vals_hat_arr[NORM_REG];
  1480. out_grad1 += (row * row_stride);
  1481. out_grad2 += (row * row_stride);
  1482. X_vals += (row * row_stride);
  1483. inp_grad += (row * row_stride);
  1484. int high_index = iterations * iteration_stride + id;
  1485. #pragma unroll
  1486. for (int i = 0; i < iterations; i++) {
  1487. float gamma_reg = gamma[i * iteration_stride + id];
  1488. vals_arr[i] = out_grad1[i * iteration_stride + id];
  1489. vals_arr[i] *= gamma_reg;
  1490. vals_hat_arr[i] = X_vals[i * iteration_stride + id];
  1491. }
  1492. if ((high_index) < row_stride) {
  1493. float gamma_reg = gamma[high_index];
  1494. vals_arr[iterations] = out_grad1[high_index];
  1495. vals_arr[iterations] *= gamma_reg;
  1496. vals_hat_arr[iterations] = X_vals[high_index];
  1497. iterations++;
  1498. }
  1499. float var_reg = vars[row];
  1500. float mean_reg = means[row];
  1501. float sum = 0;
  1502. float xu[NORM_REG];
  1503. for (int i = 0; i < iterations; i++) {
  1504. xu[i] = (vals_hat_arr[i] - mean_reg);
  1505. sum += vals_arr[i] * xu[i];
  1506. vals_arr[i] *= rsqrtf(var_reg);
  1507. }
  1508. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1509. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1510. __syncthreads();
  1511. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1512. #ifndef __STOCHASTIC_MODE__
  1513. __syncthreads();
  1514. #endif
  1515. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1516. sum = g.shfl(sum, 0);
  1517. sum /= row_stride;
  1518. for (int i = 0; i < iterations; i++) {
  1519. vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
  1520. }
  1521. sum = 0;
  1522. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  1523. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1524. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1525. __syncthreads();
  1526. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1527. #ifndef __STOCHASTIC_MODE__
  1528. __syncthreads();
  1529. #endif
  1530. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1531. sum = g.shfl(sum, 0);
  1532. sum /= row_stride;
  1533. iterations = row_stride / iteration_stride;
  1534. for (int i = 0; i < iterations; i++)
  1535. inp_grad[i * iteration_stride + id] =
  1536. (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
  1537. if ((high_index) < row_stride)
  1538. inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
  1539. }
  1540. __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
  1541. const __half* out_grad2,
  1542. const __half* X_vals,
  1543. const __half* gamma,
  1544. const __half* vars,
  1545. const __half* means,
  1546. __half* inp_grad,
  1547. int row_stride)
  1548. {
  1549. #ifdef HALF_PRECISION_AVAILABLE
  1550. int iteration_stride = blockDim.x;
  1551. int iterations = row_stride / iteration_stride;
  1552. cg::thread_block b = cg::this_thread_block();
  1553. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1554. int row = blockIdx.x;
  1555. int id = threadIdx.x;
  1556. int wid = id / WARP_SIZE;
  1557. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1558. __shared__ float partialSum[MAX_WARP_NUM];
  1559. __half2 vals_arr[NORM_REG];
  1560. float2 vals_arr_f[NORM_REG];
  1561. __half2 vals_hat_arr[NORM_REG];
  1562. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  1563. const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
  1564. const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
  1565. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
  1566. out_grad_h1 += (row * row_stride);
  1567. out_grad_h2 += (row * row_stride);
  1568. inp_grad_h += (row * row_stride);
  1569. vals_hat_h += (row * row_stride);
  1570. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  1571. int high_index = iterations * iteration_stride + id;
  1572. #pragma unroll
  1573. for (int i = 0; i < iterations; i++) {
  1574. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  1575. vals_arr[i] = out_grad_h1[i * iteration_stride + id];
  1576. vals_arr[i] *= gamma_reg; // out_grad * gamma
  1577. vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
  1578. }
  1579. if ((high_index) < row_stride) {
  1580. __half2 gamma_reg = gamma_h[high_index];
  1581. vals_arr[iterations] = out_grad_h1[high_index];
  1582. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  1583. vals_hat_arr[iterations] = vals_hat_h[high_index];
  1584. iterations++;
  1585. }
  1586. __half mean_h = means[row];
  1587. __half var_h = vars[row];
  1588. __half2 var_reg = __halves2half2(var_h, var_h);
  1589. __half2 mean_reg = __halves2half2(mean_h, mean_h);
  1590. __half2 xu[NORM_REG];
  1591. float sum = 0.f;
  1592. for (int i = 0; i < iterations; i++) {
  1593. xu[i] = (vals_hat_arr[i] - mean_reg);
  1594. __half2 result_h = (xu[i] * vals_arr[i]);
  1595. float2 result_f = __half22float2(result_h);
  1596. sum += result_f.x;
  1597. sum += result_f.y;
  1598. vals_arr[i] *= h2rsqrt(var_reg);
  1599. }
  1600. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1601. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1602. __syncthreads();
  1603. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1604. #ifndef __STOCHASTIC_MODE__
  1605. __syncthreads();
  1606. #endif
  1607. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1608. sum = g.shfl(sum, 0);
  1609. sum /= (2 * row_stride);
  1610. __half2 sum_h = __float2half2_rn(sum);
  1611. for (int i = 0; i < iterations; i++) {
  1612. __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
  1613. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1614. float2 xu_grad_f = __half22float2(xu_grad);
  1615. vals_arr_f[i].x += xu_grad_f.x;
  1616. vals_arr_f[i].y += xu_grad_f.y;
  1617. }
  1618. sum = 0.f;
  1619. for (int i = 0; i < iterations; i++) {
  1620. sum += (vals_arr_f[i].x);
  1621. sum += (vals_arr_f[i].y);
  1622. }
  1623. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1624. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1625. __syncthreads();
  1626. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1627. #ifndef __STOCHASTIC_MODE__
  1628. __syncthreads();
  1629. #endif
  1630. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1631. sum = g.shfl(sum, 0);
  1632. sum /= (2 * row_stride);
  1633. iterations = row_stride / iteration_stride;
  1634. for (int i = 0; i < iterations; i++) {
  1635. vals_arr_f[i].x -= sum;
  1636. vals_arr_f[i].y -= sum;
  1637. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1638. inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
  1639. }
  1640. if ((high_index) < row_stride) {
  1641. vals_arr_f[iterations].x -= sum;
  1642. vals_arr_f[iterations].y -= sum;
  1643. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1644. inp_grad_h[high_index] = temp + out_grad_h2[high_index];
  1645. }
  1646. #endif
  1647. }
  1648. template <>
  1649. void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
  1650. const float* out_grad2,
  1651. const float* X_data,
  1652. const float* vars,
  1653. const float* means,
  1654. const float* gamma,
  1655. float* gamma_grad,
  1656. float* betta_grad,
  1657. float* inp_grad,
  1658. int batch,
  1659. int hidden_dim,
  1660. cudaStream_t stream[2])
  1661. {
  1662. int threads = THREADS;
  1663. dim3 grid_dim(hidden_dim / TILE_DIM);
  1664. dim3 block_dim(TILE_DIM, TILE_DIM);
  1665. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1666. out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1667. dim3 grid_dim2(batch);
  1668. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1669. threads <<= 1;
  1670. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1671. threads <<= 2;
  1672. else if (hidden_dim > 65536)
  1673. throw std::runtime_error("Unsupport hidden_dim.");
  1674. dim3 block_dim2(threads);
  1675. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1676. out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
  1677. }
  1678. template <>
  1679. void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
  1680. const __half* out_grad2,
  1681. const __half* X_data,
  1682. const __half* vars,
  1683. const __half* means,
  1684. const __half* gamma,
  1685. __half* gamma_grad,
  1686. __half* betta_grad,
  1687. __half* inp_grad,
  1688. int batch,
  1689. int hidden_dim,
  1690. cudaStream_t stream[2])
  1691. {
  1692. int threads = THREADS;
  1693. dim3 grid_dim(hidden_dim / TILE_DIM);
  1694. dim3 block_dim(TILE_DIM, TILE_DIM);
  1695. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1696. out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1697. dim3 grid_dim2(batch);
  1698. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1699. threads <<= 1;
  1700. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1701. threads <<= 2;
  1702. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1703. threads <<= 3;
  1704. else if (hidden_dim > 65536)
  1705. throw std::runtime_error("Unsupport hidden_dim.");
  1706. dim3 block_dim2(threads / 2);
  1707. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1708. out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
  1709. }