normalize_kernels.cu 73 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121
  1. #include "custom_cuda_layers.h"
  2. namespace cg = cooperative_groups;
  3. /*
  4. Fused bias add, residual (elementwise) add, and normalization layer.
  5. For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
  6. __half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
  7. For specific launch constraints, see the launch functions.
  8. */
  9. #define NORM_REG (MAX_REGISTERS / 4)
  10. __global__ void fused_bias_residual_layer_norm(float* vals,
  11. const float* residual,
  12. const float* gamma,
  13. const float* beta,
  14. float epsilon,
  15. bool preLayerNorm,
  16. bool training,
  17. float* vars,
  18. float* means,
  19. int row_stride)
  20. {
  21. int iteration_stride = blockDim.x;
  22. int iterations = row_stride / iteration_stride;
  23. cg::thread_block b = cg::this_thread_block();
  24. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  25. int row = blockIdx.x;
  26. int id = threadIdx.x;
  27. int gid = id / WARP_SIZE;
  28. float vals_arr[NORM_REG];
  29. __shared__ float shr[MAX_WARP_NUM];
  30. residual += (row * row_stride);
  31. vals += (row * row_stride);
  32. float sum = 0.f;
  33. int high_index = iterations * iteration_stride + id;
  34. #pragma unroll
  35. for (int i = 0; i < iterations; i++) {
  36. vals_arr[i] = residual[i * iteration_stride + id];
  37. sum += vals_arr[i];
  38. }
  39. if (high_index < row_stride) {
  40. vals_arr[iterations] = residual[high_index];
  41. sum += vals_arr[iterations];
  42. iterations++;
  43. }
  44. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  45. if (g.thread_rank() == 0) shr[gid] = sum;
  46. b.sync();
  47. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  48. #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
  49. b.sync();
  50. #endif
  51. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  52. sum += g.shfl_down(sum, i);
  53. }
  54. sum = g.shfl(sum, 0);
  55. float mean = sum / row_stride;
  56. if (training)
  57. if (threadIdx.x == 0) means[row] = mean;
  58. float variance = 0.f;
  59. for (int i = 0; i < iterations; i++) {
  60. vals_arr[i] -= mean;
  61. variance += vals_arr[i] * vals_arr[i];
  62. }
  63. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  64. if (g.thread_rank() == 0) shr[gid] = variance;
  65. b.sync();
  66. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  67. #ifndef __STOCHASTIC_MODE__
  68. b.sync();
  69. #endif
  70. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  71. variance += g.shfl_down(variance, i);
  72. }
  73. variance = g.shfl(variance, 0);
  74. variance /= row_stride;
  75. variance += epsilon;
  76. if (training)
  77. if (threadIdx.x == 0) vars[row] = variance;
  78. iterations = row_stride / iteration_stride;
  79. for (int i = 0; i < iterations; i++) {
  80. vals_arr[i] = vals_arr[i] * rsqrtf(variance);
  81. vals_arr[i] =
  82. vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
  83. vals[i * iteration_stride + id] = vals_arr[i];
  84. }
  85. if ((high_index) < row_stride) {
  86. vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
  87. vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
  88. vals[high_index] = vals_arr[iterations];
  89. }
  90. }
  91. __global__ void fused_bias_residual_layer_norm(__half* vals,
  92. const __half* residual,
  93. const __half* gamma,
  94. const __half* beta,
  95. float epsilon,
  96. bool preLayerNorm,
  97. bool training,
  98. __half* vars,
  99. __half* means,
  100. int row_stride)
  101. {
  102. #ifdef HALF_PRECISION_AVAILABLE
  103. int iteration_stride = blockDim.x;
  104. int iterations = row_stride / iteration_stride;
  105. cg::thread_block b = cg::this_thread_block();
  106. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  107. int row = blockIdx.x;
  108. int id = threadIdx.x;
  109. int gid = id >> WARP_SIZE_BITS;
  110. float2 vals_f[NORM_REG];
  111. __shared__ float shr[MAX_WARP_NUM];
  112. __half2* vals_cast = reinterpret_cast<__half2*>(vals);
  113. const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
  114. residual_cast += (row * row_stride);
  115. vals_cast += (row * row_stride);
  116. float sum = 0.f;
  117. int high_index = iterations * iteration_stride + id;
  118. #pragma unroll
  119. for (int i = 0; i < iterations; i++) {
  120. vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
  121. sum += vals_f[i].x;
  122. sum += vals_f[i].y;
  123. }
  124. if ((high_index) < row_stride) {
  125. vals_f[iterations] = __half22float2(residual_cast[high_index]);
  126. sum += vals_f[iterations].x;
  127. sum += vals_f[iterations].y;
  128. iterations++;
  129. }
  130. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  131. if (g.thread_rank() == 0) shr[gid] = sum;
  132. b.sync();
  133. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  134. #ifndef __STOCHASTIC_MODE__
  135. b.sync();
  136. #endif
  137. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  138. sum += g.shfl_down(sum, i);
  139. }
  140. sum = g.shfl(sum, 0);
  141. float mean = sum / (row_stride * 2);
  142. float variance = 0.f;
  143. for (int i = 0; i < iterations; i++) {
  144. vals_f[i].x -= mean;
  145. vals_f[i].y -= mean;
  146. variance += vals_f[i].x * vals_f[i].x;
  147. variance += vals_f[i].y * vals_f[i].y;
  148. }
  149. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  150. if (g.thread_rank() == 0) shr[gid] = variance;
  151. b.sync();
  152. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  153. #ifndef __STOCHASTIC_MODE__
  154. b.sync();
  155. #endif
  156. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  157. variance += g.shfl_down(variance, i);
  158. }
  159. variance = g.shfl(variance, 0);
  160. variance /= (row_stride * 2);
  161. variance += epsilon;
  162. __half2 variance_h = __float2half2_rn(variance);
  163. const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
  164. const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
  165. if (training && threadIdx.x == 0) {
  166. vars[row] = __float2half(variance);
  167. means[row] = __float2half(mean);
  168. }
  169. iterations = row_stride / iteration_stride;
  170. for (int i = 0; i < iterations; i++) {
  171. __half2 vals_arr = __float22half2_rn(vals_f[i]);
  172. vals_arr = vals_arr * h2rsqrt(variance_h);
  173. vals_arr =
  174. vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
  175. vals_cast[i * iteration_stride + id] = vals_arr;
  176. }
  177. if ((high_index) < row_stride) {
  178. __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
  179. vals_arr = vals_arr * h2rsqrt(variance_h);
  180. vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
  181. vals_cast[high_index] = vals_arr;
  182. }
  183. #endif
  184. }
  185. template <typename T>
  186. void launch_bias_residual_layer_norm(T* vals,
  187. const T* residual,
  188. const T* gamma,
  189. const T* beta,
  190. float epsilon,
  191. int batch_size,
  192. int hidden_dim,
  193. cudaStream_t stream,
  194. bool preLayerNorm,
  195. bool training,
  196. T* vars,
  197. T* means);
  198. template <>
  199. void launch_bias_residual_layer_norm<float>(float* vals,
  200. const float* residual,
  201. const float* gamma,
  202. const float* beta,
  203. float epsilon,
  204. int batch_size,
  205. int hidden_dim,
  206. cudaStream_t stream,
  207. bool preLayerNorm,
  208. bool training,
  209. float* vars,
  210. float* means)
  211. {
  212. int threads = THREADS;
  213. dim3 grid_dim(batch_size);
  214. if (hidden_dim > 16384 && hidden_dim <= 32768)
  215. threads <<= 1;
  216. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  217. threads <<= 2;
  218. else if (hidden_dim > 65536)
  219. throw std::runtime_error("Unsupport hidden_dim.");
  220. dim3 block_dim(threads);
  221. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  222. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
  223. }
  224. template <>
  225. void launch_bias_residual_layer_norm<__half>(__half* vals,
  226. const __half* residual,
  227. const __half* gamma,
  228. const __half* beta,
  229. float epsilon,
  230. int batch_size,
  231. int hidden_dim,
  232. cudaStream_t stream,
  233. bool preLayerNorm,
  234. bool training,
  235. __half* vars,
  236. __half* means)
  237. {
  238. int threads = 128;
  239. dim3 grid_dim(batch_size);
  240. if (hidden_dim > 8192 && hidden_dim <= 16384)
  241. threads <<= 1;
  242. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  243. threads <<= 2;
  244. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  245. threads <<= 3;
  246. else if (hidden_dim > 65536)
  247. throw std::runtime_error("Unsupport hidden_dim.");
  248. dim3 block_dim(threads);
  249. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  250. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
  251. }
  252. __global__ void fused_bias_residual_layer_norm(float* vals,
  253. const float* residual,
  254. const float* gamma,
  255. const float* beta,
  256. float epsilon,
  257. bool preLayerNorm,
  258. bool training,
  259. float* vars,
  260. int row_stride)
  261. {
  262. int iteration_stride = blockDim.x;
  263. int iterations = row_stride / iteration_stride;
  264. cg::thread_block b = cg::this_thread_block();
  265. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  266. int row = blockIdx.x;
  267. int id = threadIdx.x;
  268. int gid = id / 32;
  269. float vals_arr[NORM_REG];
  270. __shared__ float shr[MAX_WARP_NUM];
  271. residual += (row * row_stride);
  272. vals += (row * row_stride);
  273. float sum = 0.f;
  274. int high_index = iterations * iteration_stride + id;
  275. #pragma unroll
  276. for (int i = 0; i < iterations; i++) {
  277. vals_arr[i] = residual[i * iteration_stride + id];
  278. sum += vals_arr[i];
  279. }
  280. if ((high_index) < row_stride) {
  281. vals_arr[iterations] = residual[high_index];
  282. sum += vals_arr[iterations];
  283. iterations++;
  284. }
  285. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  286. if (g.thread_rank() == 0) shr[gid] = sum;
  287. b.sync();
  288. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  289. #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
  290. b.sync();
  291. #endif
  292. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  293. sum += g.shfl_down(sum, i);
  294. }
  295. sum = g.shfl(sum, 0);
  296. float mean = sum / row_stride;
  297. float variance = 0.f;
  298. for (int i = 0; i < iterations; i++) {
  299. vals_arr[i] -= mean;
  300. variance += vals_arr[i] * vals_arr[i];
  301. }
  302. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  303. if (g.thread_rank() == 0) shr[gid] = variance;
  304. b.sync();
  305. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  306. #ifndef __STOCHASTIC_MODE__
  307. b.sync();
  308. #endif
  309. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  310. variance += g.shfl_down(variance, i);
  311. }
  312. variance = g.shfl(variance, 0);
  313. variance /= row_stride;
  314. variance += epsilon;
  315. if (training)
  316. if (threadIdx.x == 0) vars[row] = variance;
  317. iterations = row_stride / iteration_stride;
  318. for (int i = 0; i < iterations; i++) {
  319. vals_arr[i] = vals_arr[i] * rsqrtf(variance);
  320. vals_arr[i] =
  321. vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
  322. vals[i * iteration_stride + id] = vals_arr[i];
  323. }
  324. if ((high_index) < row_stride) {
  325. vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
  326. vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
  327. vals[high_index] = vals_arr[iterations];
  328. }
  329. }
  330. __global__ void fused_bias_residual_layer_norm(__half* vals,
  331. const __half* residual,
  332. const __half* gamma,
  333. const __half* beta,
  334. float epsilon,
  335. bool preLayerNorm,
  336. bool training,
  337. __half* vars,
  338. int row_stride)
  339. {
  340. #ifdef HALF_PRECISION_AVAILABLE
  341. int iteration_stride = blockDim.x;
  342. int iterations = row_stride / iteration_stride;
  343. cg::thread_block b = cg::this_thread_block();
  344. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  345. int row = blockIdx.x;
  346. int id = threadIdx.x;
  347. int gid = id >> WARP_SIZE_BITS;
  348. float2 vals_f[NORM_REG];
  349. __shared__ float shr[MAX_WARP_NUM];
  350. __half2* vals_cast = reinterpret_cast<__half2*>(vals);
  351. const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
  352. residual_cast += (row * row_stride);
  353. vals_cast += (row * row_stride);
  354. float sum = 0.f;
  355. int high_index = iterations * iteration_stride + id;
  356. #pragma unroll
  357. for (int i = 0; i < iterations; i++) {
  358. vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
  359. sum += vals_f[i].x;
  360. sum += vals_f[i].y;
  361. }
  362. if ((high_index) < row_stride) {
  363. vals_f[iterations] = __half22float2(residual_cast[high_index]);
  364. sum += vals_f[iterations].x;
  365. sum += vals_f[iterations].y;
  366. iterations++;
  367. }
  368. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  369. if (g.thread_rank() == 0) shr[gid] = sum;
  370. b.sync();
  371. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
  372. #ifndef __STOCHASTIC_MODE__
  373. b.sync();
  374. #endif
  375. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  376. sum += g.shfl_down(sum, i);
  377. }
  378. sum = g.shfl(sum, 0);
  379. float mean = sum / (row_stride * 2);
  380. float variance = 0.f;
  381. for (int i = 0; i < iterations; i++) {
  382. vals_f[i].x -= mean;
  383. vals_f[i].y -= mean;
  384. variance += vals_f[i].x * vals_f[i].x;
  385. variance += vals_f[i].y * vals_f[i].y;
  386. }
  387. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  388. if (g.thread_rank() == 0) shr[gid] = variance;
  389. b.sync();
  390. if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
  391. #ifndef __STOCHASTIC_MODE__
  392. b.sync();
  393. #endif
  394. for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
  395. variance += g.shfl_down(variance, i);
  396. }
  397. variance = g.shfl(variance, 0);
  398. variance /= (row_stride * 2);
  399. variance += epsilon;
  400. __half2 variance_h = __float2half2_rn(variance);
  401. const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
  402. const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
  403. if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
  404. iterations = row_stride / iteration_stride;
  405. for (int i = 0; i < iterations; i++) {
  406. __half2 vals_arr = __float22half2_rn(vals_f[i]);
  407. vals_arr = vals_arr * h2rsqrt(variance_h);
  408. vals_arr =
  409. vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
  410. vals_cast[i * iteration_stride + id] = vals_arr;
  411. }
  412. if ((high_index) < row_stride) {
  413. __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
  414. vals_arr = vals_arr * h2rsqrt(variance_h);
  415. vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
  416. vals_cast[high_index] = vals_arr;
  417. }
  418. #endif
  419. }
  420. template <typename T>
  421. void launch_bias_residual_layer_norm(T* vals,
  422. const T* residual,
  423. const T* gamma,
  424. const T* beta,
  425. float epsilon,
  426. int batch_size,
  427. int hidden_dim,
  428. cudaStream_t stream,
  429. bool preLayerNorm,
  430. bool training,
  431. T* vars);
  432. /*
  433. To tune this launch the following restrictions must be met:
  434. For float:
  435. row_stride == hidden_size
  436. threads * iterations == row_stride
  437. threads is in [32, 64, 128, 256, 512, 1024]
  438. For half:
  439. row_stride == hidden_size / 2
  440. threads * iterations == row_stride
  441. threads is in [32, 64, 128, 256, 512, 1024]
  442. */
  443. template <>
  444. void launch_bias_residual_layer_norm<float>(float* vals,
  445. const float* residual,
  446. const float* gamma,
  447. const float* beta,
  448. float epsilon,
  449. int batch_size,
  450. int hidden_dim,
  451. cudaStream_t stream,
  452. bool preLayerNorm,
  453. bool training,
  454. float* vars)
  455. {
  456. int threads = THREADS;
  457. dim3 grid_dim(batch_size);
  458. // There are some limitations to call below functions, now just enumerate the situations.
  459. if (hidden_dim > 16384 && hidden_dim <= 32768)
  460. threads <<= 1;
  461. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  462. threads <<= 2;
  463. else if (hidden_dim > 65536)
  464. throw std::runtime_error("Unsupport hidden_dim.");
  465. dim3 block_dim(threads);
  466. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  467. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
  468. }
  469. template <>
  470. void launch_bias_residual_layer_norm<__half>(__half* vals,
  471. const __half* residual,
  472. const __half* gamma,
  473. const __half* beta,
  474. float epsilon,
  475. int batch_size,
  476. int hidden_dim,
  477. cudaStream_t stream,
  478. bool preLayerNorm,
  479. bool training,
  480. __half* vars)
  481. {
  482. int threads = 128;
  483. dim3 grid_dim(batch_size);
  484. // There are some limitations to call below functions, now just enumerate the situations.
  485. if (hidden_dim > 8192 && hidden_dim <= 16384)
  486. threads <<= 1;
  487. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  488. threads <<= 2;
  489. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  490. threads <<= 3;
  491. else if (hidden_dim > 65536)
  492. throw std::runtime_error("Unsupport hidden_dim.");
  493. dim3 block_dim(threads);
  494. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  495. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
  496. }
  497. /* Normalize Gamma & Betta gradients
  498. * Compute gradients using either X_hat or
  499. * normalize input (invertible).
  500. * Combine transpose with gradients computation.
  501. */
  502. template <typename T>
  503. __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
  504. const T* __restrict__ vals_hat,
  505. const T* __restrict__ gamma,
  506. const T* __restrict__ betta,
  507. T* __restrict__ gamma_grad,
  508. T* __restrict__ betta_grad,
  509. int rows,
  510. int width,
  511. bool invertible)
  512. {
  513. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  514. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  515. cg::thread_block b = cg::this_thread_block();
  516. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  517. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  518. int offset = threadIdx.y * width + idx;
  519. int y_stride = width * TILE_DIM;
  520. float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
  521. float gamma_reg = (float)gamma[idx];
  522. // Loop across matrix height
  523. float betta_tmp = 0;
  524. float gamma_tmp = 0;
  525. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  526. float grad = (float)out_grad[offset];
  527. float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
  528. : (float)vals_hat[offset]);
  529. betta_tmp += grad;
  530. gamma_tmp += (val * grad);
  531. offset += y_stride;
  532. }
  533. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  534. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  535. __syncthreads();
  536. // Sum the shared buffer.
  537. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  538. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  539. #ifndef __STOCHASTIC_MODE__
  540. __syncthreads();
  541. #endif
  542. for (int i = 1; i < TILE_DIM; i <<= 1) {
  543. s1 += g.shfl_down(s1, i);
  544. s2 += g.shfl_down(s2, i);
  545. }
  546. if (threadIdx.x == 0) {
  547. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  548. betta_grad[pos] = s1;
  549. gamma_grad[pos] = s2;
  550. }
  551. }
  552. /* Normalize Gamma & Betta gradients
  553. * Compute gradients using the input to
  554. * the normalize.
  555. * Combine transpose with gradients computation.
  556. */
  557. template <typename T>
  558. __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
  559. const T* __restrict__ X_data,
  560. const T* __restrict__ vars,
  561. const T* __restrict__ means,
  562. T* __restrict__ gamma_grad,
  563. T* __restrict__ betta_grad,
  564. int rows,
  565. int width)
  566. {
  567. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  568. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  569. cg::thread_block b = cg::this_thread_block();
  570. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  571. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  572. int offset = threadIdx.y * width + idx;
  573. int y_stride = width * TILE_DIM;
  574. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  575. // Loop across matrix height
  576. float betta_tmp = 0;
  577. float gamma_tmp = 0;
  578. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  579. float grad = (float)out_grad[offset];
  580. float val = (float)X_data[offset];
  581. val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
  582. betta_tmp += grad;
  583. gamma_tmp += (val * grad);
  584. offset += y_stride;
  585. }
  586. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  587. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  588. __syncthreads();
  589. // Sum the shared buffer.
  590. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  591. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  592. #ifndef __STOCHASTIC_MODE__
  593. __syncthreads();
  594. #endif
  595. for (int i = 1; i < TILE_DIM; i <<= 1) {
  596. s1 += g.shfl_down(s1, i);
  597. s2 += g.shfl_down(s2, i);
  598. }
  599. if (threadIdx.x == 0) {
  600. betta_grad[pos] = s1;
  601. gamma_grad[pos] = s2;
  602. }
  603. }
  604. /*
  605. /* Backward Normalize (Input-Gradient)
  606. * Using the means and variances from the input
  607. * This type of backward is invertible!
  608. * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
  609. */
  610. __global__ void LayerNormBackward2(const float* out_grad,
  611. const float* vals_hat,
  612. const float* gamma,
  613. const float* betta,
  614. const float* vars,
  615. float* inp_grad,
  616. bool invertible,
  617. int row_stride)
  618. {
  619. int iteration_stride = blockDim.x;
  620. int iterations = row_stride / iteration_stride;
  621. cg::thread_block b = cg::this_thread_block();
  622. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  623. int row = blockIdx.x;
  624. int id = threadIdx.x;
  625. int wid = id / WARP_SIZE;
  626. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  627. __shared__ float partialSum[MAX_WARP_NUM];
  628. out_grad += (row * row_stride);
  629. vals_hat += (row * row_stride);
  630. inp_grad += (row * row_stride);
  631. float vals_arr[NORM_REG];
  632. float vals_hat_arr[NORM_REG];
  633. int high_index = iterations * iteration_stride + id;
  634. #pragma unroll
  635. for (int i = 0; i < iterations; i++) {
  636. float gamma_reg = gamma[i * iteration_stride + id];
  637. vals_arr[i] = out_grad[i * iteration_stride + id];
  638. vals_arr[i] *= gamma_reg;
  639. vals_hat_arr[i] =
  640. (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
  641. gamma_reg
  642. : vals_hat[i * iteration_stride + id]);
  643. }
  644. if ((high_index) < row_stride) {
  645. float gamma_reg = gamma[high_index];
  646. vals_arr[iterations] = out_grad[high_index];
  647. vals_arr[iterations] *= gamma_reg;
  648. vals_hat_arr[iterations] =
  649. (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
  650. : vals_hat[high_index]);
  651. iterations++;
  652. }
  653. float var_reg = vars[row];
  654. float sum = 0;
  655. for (int i = 0; i < iterations; i++) {
  656. sum += vals_hat_arr[i] * vals_arr[i] *
  657. sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
  658. vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
  659. }
  660. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  661. if (g.thread_rank() == 0) partialSum[wid] = sum;
  662. __syncthreads();
  663. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  664. #ifndef __STOCHASTIC_MODE__
  665. __syncthreads();
  666. #endif
  667. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  668. sum = g.shfl(sum, 0);
  669. sum /= row_stride;
  670. for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
  671. sum = 0;
  672. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  673. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  674. if (g.thread_rank() == 0) partialSum[wid] = sum;
  675. __syncthreads();
  676. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  677. #ifndef __STOCHASTIC_MODE__
  678. __syncthreads();
  679. #endif
  680. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  681. sum = g.shfl(sum, 0);
  682. sum /= row_stride;
  683. iterations = row_stride / iteration_stride;
  684. for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
  685. if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
  686. }
  687. __global__ void LayerNormBackward2(const __half* out_grad,
  688. const __half* vals_hat,
  689. const __half* gamma,
  690. const __half* betta,
  691. const __half* vars,
  692. __half* inp_grad,
  693. bool invertible,
  694. int row_stride)
  695. {
  696. int iteration_stride = blockDim.x;
  697. int iterations = row_stride / iteration_stride;
  698. cg::thread_block b = cg::this_thread_block();
  699. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  700. int row = blockIdx.x;
  701. int id = threadIdx.x;
  702. int wid = id / WARP_SIZE;
  703. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  704. __shared__ float partialSum[MAX_WARP_NUM];
  705. __half2 vals_arr[NORM_REG];
  706. float2 vals_arr_f[NORM_REG];
  707. __half2 vals_hat_arr[NORM_REG];
  708. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  709. const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
  710. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
  711. inp_grad_h += (row * row_stride);
  712. out_grad_h += (row * row_stride);
  713. vals_hat_h += (row * row_stride);
  714. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  715. const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
  716. int high_index = iterations * iteration_stride + id;
  717. #pragma unroll
  718. for (int i = 0; i < iterations; i++) {
  719. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  720. vals_arr[i] = out_grad_h[i * iteration_stride + id];
  721. vals_arr[i] *= gamma_reg;
  722. vals_hat_arr[i] =
  723. (invertible
  724. ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
  725. gamma_reg
  726. : vals_hat_h[i * iteration_stride + id]);
  727. }
  728. if ((high_index) < row_stride) {
  729. __half2 gamma_reg = gamma_h[high_index];
  730. vals_arr[iterations] = out_grad_h[high_index];
  731. vals_arr[iterations] *= gamma_reg;
  732. vals_hat_arr[iterations] =
  733. (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
  734. : vals_hat_h[high_index]);
  735. iterations++;
  736. }
  737. __half var_h = vars[row];
  738. __half2 var_reg = __halves2half2(var_h, var_h);
  739. float sum = 0.f;
  740. for (int i = 0; i < iterations; i++) {
  741. __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
  742. float2 result_f = __half22float2(result_h);
  743. sum += result_f.x;
  744. sum += result_f.y;
  745. vals_arr[i] *= h2rsqrt(var_reg);
  746. }
  747. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  748. if (g.thread_rank() == 0) partialSum[wid] = sum;
  749. __syncthreads();
  750. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  751. #ifndef __STOCHASTIC_MODE__
  752. __syncthreads();
  753. #endif
  754. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  755. sum = g.shfl(sum, 0);
  756. sum /= (2 * row_stride);
  757. __half2 sum_h = __float2half2_rn(sum);
  758. for (int i = 0; i < iterations; i++) {
  759. __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
  760. vals_arr_f[i] = __half22float2(vals_arr[i]);
  761. float2 temp_f = __half22float2(temp);
  762. vals_arr_f[i].x += temp_f.x;
  763. vals_arr_f[i].y += temp_f.y;
  764. }
  765. sum = 0.f;
  766. for (int i = 0; i < iterations; i++) {
  767. sum += (vals_arr_f[i].x);
  768. sum += (vals_arr_f[i].y);
  769. }
  770. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  771. if (g.thread_rank() == 0) partialSum[wid] = sum;
  772. __syncthreads();
  773. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  774. #ifndef __STOCHASTIC_MODE__
  775. __syncthreads();
  776. #endif
  777. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  778. sum = g.shfl(sum, 0);
  779. sum /= (2 * row_stride);
  780. iterations = row_stride / iteration_stride;
  781. for (int i = 0; i < iterations; i++) {
  782. vals_arr_f[i].x -= sum;
  783. vals_arr_f[i].y -= sum;
  784. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  785. inp_grad_h[i * iteration_stride + id] = temp;
  786. }
  787. if ((high_index) < row_stride) {
  788. vals_arr_f[iterations].x -= sum;
  789. vals_arr_f[iterations].y -= sum;
  790. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  791. inp_grad_h[high_index] = temp;
  792. }
  793. }
  794. template <>
  795. void launch_layerNorm_backward<float>(const float* out_grad,
  796. const float* vals_hat,
  797. const float* vars,
  798. const float* gamma,
  799. float* gamma_grad,
  800. float* betta_grad,
  801. float* inp_grad,
  802. int batch,
  803. int hidden_dim,
  804. cudaStream_t stream[2],
  805. bool invertible,
  806. const float* betta)
  807. {
  808. int threads = THREADS;
  809. dim3 grid_dim(hidden_dim / TILE_DIM);
  810. dim3 block_dim(TILE_DIM, TILE_DIM);
  811. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  812. out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  813. dim3 grid_dim2(batch);
  814. if (hidden_dim > 16384 && hidden_dim <= 32768)
  815. threads <<= 1;
  816. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  817. threads <<= 2;
  818. else if (hidden_dim > 65536)
  819. throw std::runtime_error("Unsupport hidden_dim.");
  820. dim3 block_dim2(threads);
  821. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  822. out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
  823. }
  824. template <>
  825. void launch_layerNorm_backward<__half>(const __half* out_grad,
  826. const __half* vals_hat,
  827. const __half* vars,
  828. const __half* gamma,
  829. __half* gamma_grad,
  830. __half* betta_grad,
  831. __half* inp_grad,
  832. int batch,
  833. int hidden_dim,
  834. cudaStream_t stream[2],
  835. bool invertible,
  836. const __half* betta)
  837. {
  838. int threads = THREADS;
  839. dim3 grid_dim(hidden_dim / TILE_DIM);
  840. dim3 block_dim(TILE_DIM, TILE_DIM);
  841. // LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  842. // out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  843. dim3 grid_dim2(batch);
  844. if (hidden_dim > 8192 && hidden_dim <= 16384)
  845. threads <<= 1;
  846. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  847. threads <<= 2;
  848. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  849. threads <<= 3;
  850. else if (hidden_dim > 65536)
  851. throw std::runtime_error("Unsupport hidden_dim.");
  852. dim3 block_dim2(threads / 2);
  853. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  854. out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
  855. }
  856. /* Backward Normalize (Input-Gradient)
  857. * Using the means and variances from the input
  858. * This type of backward is not invertible!
  859. * We do the backward using the input (X)
  860. */
  861. __global__ void LayerNormBackward2(const float* out_grad,
  862. const float* X_vals,
  863. const float* gamma,
  864. const float* vars,
  865. const float* means,
  866. float* inp_grad,
  867. int row_stride)
  868. {
  869. int iteration_stride = blockDim.x;
  870. int iterations = row_stride / iteration_stride;
  871. cg::thread_block b = cg::this_thread_block();
  872. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  873. int row = blockIdx.x;
  874. int id = threadIdx.x;
  875. int wid = id >> WARP_SIZE_BITS;
  876. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  877. __shared__ float partialSum[MAX_WARP_NUM];
  878. out_grad += (row * row_stride);
  879. X_vals += (row * row_stride);
  880. inp_grad += (row * row_stride);
  881. float vals_arr[NORM_REG];
  882. int high_index = iterations * iteration_stride + id;
  883. #pragma unroll
  884. for (int i = 0; i < iterations; i++) {
  885. float gamma_reg = gamma[i * iteration_stride + id];
  886. vals_arr[i] = out_grad[i * iteration_stride + id];
  887. vals_arr[i] *= gamma_reg;
  888. }
  889. if ((high_index) < row_stride) {
  890. float gamma_reg = gamma[high_index];
  891. vals_arr[iterations] = out_grad[high_index];
  892. vals_arr[iterations] *= gamma_reg;
  893. iterations++;
  894. }
  895. float var_reg = vars[row];
  896. float mean_reg = means[row];
  897. float sum = 0;
  898. float xu[NORM_REG];
  899. for (int i = 0; i < iterations; i++) {
  900. xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
  901. sum += vals_arr[i] * xu[i];
  902. vals_arr[i] *= rsqrtf(var_reg);
  903. }
  904. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  905. if (g.thread_rank() == 0) partialSum[wid] = sum;
  906. __syncthreads();
  907. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  908. #ifndef __STOCHASTIC_MODE__
  909. __syncthreads();
  910. #endif
  911. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  912. sum = g.shfl(sum, 0);
  913. sum /= row_stride;
  914. for (int i = 0; i < iterations; i++) {
  915. vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
  916. }
  917. sum = 0;
  918. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  919. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  920. if (g.thread_rank() == 0) partialSum[wid] = sum;
  921. __syncthreads();
  922. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  923. #ifndef __STOCHASTIC_MODE__
  924. __syncthreads();
  925. #endif
  926. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  927. sum = g.shfl(sum, 0);
  928. sum /= row_stride;
  929. iterations = row_stride / iteration_stride;
  930. for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
  931. if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
  932. }
  933. __global__ void LayerNormBackward2(const __half* out_grad,
  934. const __half* X_vals,
  935. const __half* gamma,
  936. const __half* vars,
  937. const __half* means,
  938. __half* inp_grad,
  939. int row_stride)
  940. {
  941. int iteration_stride = blockDim.x;
  942. int iterations = row_stride / iteration_stride;
  943. cg::thread_block b = cg::this_thread_block();
  944. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  945. int row = blockIdx.x;
  946. int id = threadIdx.x;
  947. int wid = id >> WARP_SIZE_BITS;
  948. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  949. __shared__ float partialSum[MAX_WARP_NUM];
  950. __half2 vals_arr[NORM_REG];
  951. float2 vals_arr_f[NORM_REG];
  952. __half2 xu[NORM_REG];
  953. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  954. const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
  955. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
  956. inp_grad_h += (row * row_stride);
  957. out_grad_h += (row * row_stride);
  958. vals_hat_h += (row * row_stride);
  959. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  960. int high_index = iterations * iteration_stride + id;
  961. __half mean_h = means[row];
  962. __half2 mean_reg = __halves2half2(mean_h, mean_h);
  963. #pragma unroll
  964. for (int i = 0; i < iterations; i++) {
  965. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  966. vals_arr[i] = out_grad_h[i * iteration_stride + id];
  967. vals_arr[i] *= gamma_reg; // out_grad * gamma
  968. xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
  969. }
  970. if ((high_index) < row_stride) {
  971. __half2 gamma_reg = gamma_h[high_index];
  972. vals_arr[iterations] = out_grad_h[high_index];
  973. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  974. xu[iterations] = (vals_hat_h[high_index] - mean_reg);
  975. iterations++;
  976. }
  977. __half var_h = vars[row];
  978. __half2 var_reg = __halves2half2(var_h, var_h);
  979. float sum = 0.f;
  980. for (int i = 0; i < iterations; i++) {
  981. __half2 result_h = (xu[i] * vals_arr[i]);
  982. float2 result_f = __half22float2(result_h);
  983. sum += result_f.x;
  984. sum += result_f.y;
  985. vals_arr[i] *= h2rsqrt(var_reg);
  986. }
  987. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  988. if (g.thread_rank() == 0) partialSum[wid] = sum;
  989. __syncthreads();
  990. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  991. #ifndef __STOCHASTIC_MODE__
  992. __syncthreads();
  993. #endif
  994. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  995. sum = g.shfl(sum, 0);
  996. sum /= (2 * row_stride);
  997. __half2 sum_h = __float2half2_rn(sum);
  998. for (int i = 0; i < iterations; i++) {
  999. __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
  1000. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1001. float2 xu_grad_f = __half22float2(xu_grad);
  1002. vals_arr_f[i].x += xu_grad_f.x;
  1003. vals_arr_f[i].y += xu_grad_f.y;
  1004. }
  1005. sum = 0.f;
  1006. for (int i = 0; i < iterations; i++) {
  1007. sum += (vals_arr_f[i].x);
  1008. sum += (vals_arr_f[i].y);
  1009. }
  1010. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1011. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1012. __syncthreads();
  1013. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1014. #ifndef __STOCHASTIC_MODE__
  1015. __syncthreads();
  1016. #endif
  1017. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1018. sum = g.shfl(sum, 0);
  1019. sum /= (2 * row_stride);
  1020. iterations = row_stride / iteration_stride;
  1021. for (int i = 0; i < iterations; i++) {
  1022. vals_arr_f[i].x -= sum;
  1023. vals_arr_f[i].y -= sum;
  1024. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1025. inp_grad_h[i * iteration_stride + id] = temp;
  1026. }
  1027. if ((high_index) < row_stride) {
  1028. vals_arr_f[iterations].x -= sum;
  1029. vals_arr_f[iterations].y -= sum;
  1030. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1031. inp_grad_h[high_index] = temp;
  1032. }
  1033. }
  1034. template <>
  1035. void launch_layerNorm_backward<float>(const float* out_grad,
  1036. const float* X_data,
  1037. const float* vars,
  1038. const float* means,
  1039. const float* gamma,
  1040. float* gamma_grad,
  1041. float* betta_grad,
  1042. float* inp_grad,
  1043. int batch,
  1044. int hidden_dim,
  1045. cudaStream_t stream[2])
  1046. {
  1047. int threads = THREADS;
  1048. dim3 grid_dim(hidden_dim / TILE_DIM);
  1049. dim3 block_dim(TILE_DIM, TILE_DIM);
  1050. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1051. out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1052. dim3 grid_dim2(batch);
  1053. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1054. threads <<= 1;
  1055. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1056. threads <<= 2;
  1057. else if (hidden_dim > 65536)
  1058. throw std::runtime_error("Unsupport hidden_dim.");
  1059. dim3 block_dim2(threads);
  1060. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1061. out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
  1062. }
  1063. template <>
  1064. void launch_layerNorm_backward<__half>(const __half* out_grad,
  1065. const __half* X_data,
  1066. const __half* vars,
  1067. const __half* means,
  1068. const __half* gamma,
  1069. __half* gamma_grad,
  1070. __half* betta_grad,
  1071. __half* inp_grad,
  1072. int batch,
  1073. int hidden_dim,
  1074. cudaStream_t stream[2])
  1075. {
  1076. int threads = THREADS;
  1077. dim3 grid_dim(hidden_dim / TILE_DIM);
  1078. dim3 block_dim(TILE_DIM, TILE_DIM);
  1079. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1080. out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1081. dim3 grid_dim2(batch);
  1082. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1083. threads <<= 1;
  1084. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1085. threads <<= 2;
  1086. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1087. threads <<= 3;
  1088. else if (hidden_dim > 65536)
  1089. throw std::runtime_error("Unsupport hidden_dim.");
  1090. dim3 block_dim2(threads / 2);
  1091. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1092. out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
  1093. }
  1094. template <typename T>
  1095. __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
  1096. const T* __restrict__ out_grad2,
  1097. const T* __restrict__ vals_hat,
  1098. const T* __restrict__ gamma,
  1099. const T* __restrict__ betta,
  1100. T* __restrict__ gamma_grad,
  1101. T* __restrict__ betta_grad,
  1102. int rows,
  1103. int width,
  1104. bool invertible)
  1105. {
  1106. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  1107. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  1108. cg::thread_block b = cg::this_thread_block();
  1109. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  1110. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  1111. int offset = threadIdx.y * width + idx;
  1112. int y_stride = width * TILE_DIM;
  1113. float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
  1114. float gamma_reg = (float)gamma[idx];
  1115. // Loop across matrix height
  1116. float betta_tmp = 0;
  1117. float gamma_tmp = 0;
  1118. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  1119. float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
  1120. float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
  1121. : (float)vals_hat[offset]);
  1122. betta_tmp += grad;
  1123. gamma_tmp += (val * grad);
  1124. offset += y_stride;
  1125. }
  1126. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  1127. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  1128. __syncthreads();
  1129. // Sum the shared buffer.
  1130. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  1131. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  1132. #ifndef __STOCHASTIC_MODE__
  1133. __syncthreads();
  1134. #endif
  1135. for (int i = 1; i < TILE_DIM; i <<= 1) {
  1136. s1 += g.shfl_down(s1, i);
  1137. s2 += g.shfl_down(s2, i);
  1138. }
  1139. if (threadIdx.x == 0) {
  1140. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  1141. betta_grad[pos] = s1;
  1142. gamma_grad[pos] = s2;
  1143. }
  1144. }
  1145. template <typename T>
  1146. __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
  1147. const T* __restrict__ out_grad2,
  1148. const T* __restrict__ X_data,
  1149. const T* __restrict__ vars,
  1150. const T* __restrict__ means,
  1151. T* __restrict__ gamma_grad,
  1152. T* __restrict__ betta_grad,
  1153. int rows,
  1154. int width)
  1155. {
  1156. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  1157. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  1158. cg::thread_block b = cg::this_thread_block();
  1159. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  1160. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  1161. int offset = threadIdx.y * width + idx;
  1162. int y_stride = width * TILE_DIM;
  1163. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  1164. // Loop across matrix height
  1165. float betta_tmp = 0;
  1166. float gamma_tmp = 0;
  1167. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  1168. float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
  1169. float val = (float)X_data[offset];
  1170. val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
  1171. betta_tmp += grad;
  1172. gamma_tmp += (val * grad);
  1173. offset += y_stride;
  1174. }
  1175. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  1176. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  1177. __syncthreads();
  1178. // Sum the shared buffer.
  1179. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  1180. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  1181. #ifndef __STOCHASTIC_MODE__
  1182. __syncthreads();
  1183. #endif
  1184. for (int i = 1; i < TILE_DIM; i <<= 1) {
  1185. s1 += g.shfl_down(s1, i);
  1186. s2 += g.shfl_down(s2, i);
  1187. }
  1188. if (threadIdx.x == 0) {
  1189. betta_grad[pos] = s1;
  1190. gamma_grad[pos] = s2;
  1191. }
  1192. }
  1193. __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
  1194. const float* out_grad2,
  1195. const float* vals_hat,
  1196. const float* gamma,
  1197. const float* betta,
  1198. const float* vars,
  1199. float* inp_grad,
  1200. bool invertible,
  1201. int row_stride)
  1202. {
  1203. int iteration_stride = blockDim.x;
  1204. int iterations = row_stride / iteration_stride;
  1205. cg::thread_block b = cg::this_thread_block();
  1206. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1207. int row = blockIdx.x;
  1208. int id = threadIdx.x;
  1209. int wid = id / WARP_SIZE;
  1210. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1211. __shared__ float partialSum[MAX_WARP_NUM];
  1212. out_grad1 += (row * row_stride);
  1213. out_grad2 += (row * row_stride);
  1214. vals_hat += (row * row_stride);
  1215. inp_grad += (row * row_stride);
  1216. float vals_arr[NORM_REG];
  1217. float vals_hat_arr[NORM_REG];
  1218. int high_index = iterations * iteration_stride + id;
  1219. #pragma unroll
  1220. for (int i = 0; i < iterations; i++) {
  1221. float gamma_reg = gamma[i * iteration_stride + id];
  1222. vals_arr[i] = out_grad1[i * iteration_stride + id];
  1223. vals_arr[i] *= gamma_reg;
  1224. vals_hat_arr[i] =
  1225. (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
  1226. gamma_reg
  1227. : vals_hat[i * iteration_stride + id]);
  1228. }
  1229. if ((high_index) < row_stride) {
  1230. float gamma_reg = gamma[high_index];
  1231. vals_arr[iterations] = out_grad1[high_index];
  1232. vals_arr[iterations] *= gamma_reg;
  1233. vals_hat_arr[iterations] =
  1234. (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
  1235. : vals_hat[high_index]);
  1236. iterations++;
  1237. }
  1238. float var_reg = vars[row];
  1239. float sum = 0;
  1240. for (int i = 0; i < iterations; i++) {
  1241. sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
  1242. vals_arr[i] *= rsqrtf(var_reg);
  1243. }
  1244. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1245. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1246. __syncthreads();
  1247. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1248. #ifndef __STOCHASTIC_MODE__
  1249. __syncthreads();
  1250. #endif
  1251. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1252. sum = g.shfl(sum, 0);
  1253. sum /= row_stride;
  1254. for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
  1255. sum = 0;
  1256. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  1257. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1258. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1259. __syncthreads();
  1260. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1261. #ifndef __STOCHASTIC_MODE__
  1262. __syncthreads();
  1263. #endif
  1264. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1265. sum = g.shfl(sum, 0);
  1266. sum /= row_stride;
  1267. iterations = row_stride / iteration_stride;
  1268. for (int i = 0; i < iterations; i++)
  1269. inp_grad[i * iteration_stride + id] =
  1270. (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
  1271. if ((high_index) < row_stride)
  1272. inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
  1273. }
  1274. __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
  1275. const __half* out_grad2,
  1276. const __half* vals_hat,
  1277. const __half* gamma,
  1278. const __half* betta,
  1279. const __half* vars,
  1280. __half* inp_grad,
  1281. bool invertible,
  1282. int row_stride)
  1283. {
  1284. int iteration_stride = blockDim.x;
  1285. int iterations = row_stride / iteration_stride;
  1286. cg::thread_block b = cg::this_thread_block();
  1287. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1288. int row = blockIdx.x;
  1289. int id = threadIdx.x;
  1290. int wid = id / WARP_SIZE;
  1291. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1292. __shared__ float partialSum[MAX_WARP_NUM];
  1293. __half2 vals_arr[NORM_REG];
  1294. float2 vals_arr_f[NORM_REG];
  1295. __half2 vals_hat_arr[NORM_REG];
  1296. // float2 result[iterations];
  1297. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  1298. const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
  1299. const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
  1300. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
  1301. inp_grad_h += (row * row_stride);
  1302. out_grad_h1 += (row * row_stride);
  1303. out_grad_h2 += (row * row_stride);
  1304. vals_hat_h += (row * row_stride);
  1305. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  1306. const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
  1307. int high_index = iterations * iteration_stride + id;
  1308. #pragma unroll
  1309. for (int i = 0; i < iterations; i++) {
  1310. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  1311. vals_arr[i] = out_grad_h1[i * iteration_stride + id];
  1312. vals_arr[i] *= gamma_reg; // out_grad * gamma
  1313. vals_hat_arr[i] =
  1314. (invertible
  1315. ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
  1316. gamma_reg
  1317. : vals_hat_h[i * iteration_stride + id]);
  1318. }
  1319. if ((high_index) < row_stride) {
  1320. __half2 gamma_reg = gamma_h[high_index];
  1321. vals_arr[iterations] = out_grad_h1[high_index];
  1322. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  1323. vals_hat_arr[iterations] =
  1324. (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
  1325. : vals_hat_h[high_index]);
  1326. iterations++;
  1327. }
  1328. __half var_h = vars[row];
  1329. __half2 var_reg = __halves2half2(var_h, var_h);
  1330. float sum = 0.f;
  1331. for (int i = 0; i < iterations; i++) {
  1332. __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
  1333. float2 result_f = __half22float2(result_h);
  1334. sum += result_f.x;
  1335. sum += result_f.y;
  1336. vals_arr[i] *= h2rsqrt(var_reg);
  1337. }
  1338. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1339. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1340. __syncthreads();
  1341. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1342. #ifndef __STOCHASTIC_MODE__
  1343. __syncthreads();
  1344. #endif
  1345. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1346. sum = g.shfl(sum, 0);
  1347. sum /= (2 * row_stride);
  1348. __half2 sum_h = __float2half2_rn(sum);
  1349. for (int i = 0; i < iterations; i++) {
  1350. __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
  1351. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1352. float2 temp_f = __half22float2(temp);
  1353. vals_arr_f[i].x += temp_f.x;
  1354. vals_arr_f[i].y += temp_f.y;
  1355. }
  1356. sum = 0.f;
  1357. for (int i = 0; i < iterations; i++) {
  1358. sum += (vals_arr_f[i].x);
  1359. sum += (vals_arr_f[i].y);
  1360. }
  1361. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1362. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1363. __syncthreads();
  1364. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1365. #ifndef __STOCHASTIC_MODE__
  1366. __syncthreads();
  1367. #endif
  1368. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1369. sum = g.shfl(sum, 0);
  1370. sum /= (2 * row_stride);
  1371. iterations = row_stride / iteration_stride;
  1372. for (int i = 0; i < iterations; i++) {
  1373. vals_arr_f[i].x -= sum;
  1374. vals_arr_f[i].y -= sum;
  1375. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1376. inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
  1377. }
  1378. if ((high_index) < row_stride) {
  1379. vals_arr_f[iterations].x -= sum;
  1380. vals_arr_f[iterations].y -= sum;
  1381. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1382. inp_grad_h[high_index] = temp + out_grad_h2[high_index];
  1383. }
  1384. }
  1385. template <>
  1386. void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
  1387. const float* out_grad2,
  1388. const float* vals_hat,
  1389. const float* vars,
  1390. const float* gamma,
  1391. float* gamma_grad,
  1392. float* betta_grad,
  1393. float* inp_grad,
  1394. int batch,
  1395. int hidden_dim,
  1396. cudaStream_t stream[2],
  1397. bool invertible,
  1398. const float* betta)
  1399. {
  1400. int threads = THREADS;
  1401. dim3 grid_dim(hidden_dim / TILE_DIM);
  1402. dim3 block_dim(TILE_DIM, TILE_DIM);
  1403. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1404. out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  1405. dim3 grid_dim2(batch);
  1406. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1407. threads <<= 1;
  1408. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1409. threads <<= 2;
  1410. else if (hidden_dim > 65536)
  1411. throw std::runtime_error("Unsupport hidden_dim.");
  1412. dim3 block_dim2(threads);
  1413. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1414. out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
  1415. }
  1416. template <>
  1417. void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
  1418. const __half* out_grad2,
  1419. const __half* vals_hat,
  1420. const __half* vars,
  1421. const __half* gamma,
  1422. __half* gamma_grad,
  1423. __half* betta_grad,
  1424. __half* inp_grad,
  1425. int batch,
  1426. int hidden_dim,
  1427. cudaStream_t stream[2],
  1428. bool invertible,
  1429. const __half* betta)
  1430. {
  1431. int threads = THREADS;
  1432. dim3 grid_dim(hidden_dim / TILE_DIM);
  1433. dim3 block_dim(TILE_DIM, TILE_DIM);
  1434. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1435. out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  1436. dim3 grid_dim2(batch);
  1437. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1438. threads <<= 1;
  1439. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1440. threads <<= 2;
  1441. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1442. threads <<= 3;
  1443. else if (hidden_dim > 65536)
  1444. throw std::runtime_error("Unsupport hidden_dim.");
  1445. dim3 block_dim2(threads / 2);
  1446. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1447. out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
  1448. }
  1449. /* Backward Normalize (Input-Gradient)
  1450. * Using the means and variances from the input
  1451. * This type of backward is not invertible!
  1452. * We do the backward using the input (X)
  1453. */
  1454. __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
  1455. const float* out_grad2,
  1456. const float* X_vals,
  1457. const float* gamma,
  1458. const float* vars,
  1459. const float* means,
  1460. float* inp_grad,
  1461. int row_stride)
  1462. {
  1463. int iteration_stride = blockDim.x;
  1464. int iterations = row_stride / iteration_stride;
  1465. cg::thread_block b = cg::this_thread_block();
  1466. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1467. int row = blockIdx.x;
  1468. int id = threadIdx.x;
  1469. int wid = id / WARP_SIZE;
  1470. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1471. __shared__ float partialSum[MAX_WARP_NUM];
  1472. float vals_arr[NORM_REG];
  1473. float vals_hat_arr[NORM_REG];
  1474. out_grad1 += (row * row_stride);
  1475. out_grad2 += (row * row_stride);
  1476. X_vals += (row * row_stride);
  1477. inp_grad += (row * row_stride);
  1478. int high_index = iterations * iteration_stride + id;
  1479. #pragma unroll
  1480. for (int i = 0; i < iterations; i++) {
  1481. float gamma_reg = gamma[i * iteration_stride + id];
  1482. vals_arr[i] = out_grad1[i * iteration_stride + id];
  1483. vals_arr[i] *= gamma_reg;
  1484. vals_hat_arr[i] = X_vals[i * iteration_stride + id];
  1485. }
  1486. if ((high_index) < row_stride) {
  1487. float gamma_reg = gamma[high_index];
  1488. vals_arr[iterations] = out_grad1[high_index];
  1489. vals_arr[iterations] *= gamma_reg;
  1490. vals_hat_arr[iterations] = X_vals[high_index];
  1491. iterations++;
  1492. }
  1493. float var_reg = vars[row];
  1494. float mean_reg = means[row];
  1495. float sum = 0;
  1496. float xu[NORM_REG];
  1497. for (int i = 0; i < iterations; i++) {
  1498. xu[i] = (vals_hat_arr[i] - mean_reg);
  1499. sum += vals_arr[i] * xu[i];
  1500. vals_arr[i] *= rsqrtf(var_reg);
  1501. }
  1502. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1503. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1504. __syncthreads();
  1505. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1506. #ifndef __STOCHASTIC_MODE__
  1507. __syncthreads();
  1508. #endif
  1509. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1510. sum = g.shfl(sum, 0);
  1511. sum /= row_stride;
  1512. for (int i = 0; i < iterations; i++) {
  1513. vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
  1514. }
  1515. sum = 0;
  1516. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  1517. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1518. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1519. __syncthreads();
  1520. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1521. #ifndef __STOCHASTIC_MODE__
  1522. __syncthreads();
  1523. #endif
  1524. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1525. sum = g.shfl(sum, 0);
  1526. sum /= row_stride;
  1527. iterations = row_stride / iteration_stride;
  1528. for (int i = 0; i < iterations; i++)
  1529. inp_grad[i * iteration_stride + id] =
  1530. (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
  1531. if ((high_index) < row_stride)
  1532. inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
  1533. }
  1534. __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
  1535. const __half* out_grad2,
  1536. const __half* X_vals,
  1537. const __half* gamma,
  1538. const __half* vars,
  1539. const __half* means,
  1540. __half* inp_grad,
  1541. int row_stride)
  1542. {
  1543. int iteration_stride = blockDim.x;
  1544. int iterations = row_stride / iteration_stride;
  1545. cg::thread_block b = cg::this_thread_block();
  1546. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1547. int row = blockIdx.x;
  1548. int id = threadIdx.x;
  1549. int wid = id / WARP_SIZE;
  1550. int warp_num = iteration_stride >> WARP_SIZE_BITS;
  1551. __shared__ float partialSum[MAX_WARP_NUM];
  1552. __half2 vals_arr[NORM_REG];
  1553. float2 vals_arr_f[NORM_REG];
  1554. __half2 vals_hat_arr[NORM_REG];
  1555. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  1556. const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
  1557. const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
  1558. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
  1559. out_grad_h1 += (row * row_stride);
  1560. out_grad_h2 += (row * row_stride);
  1561. inp_grad_h += (row * row_stride);
  1562. vals_hat_h += (row * row_stride);
  1563. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  1564. int high_index = iterations * iteration_stride + id;
  1565. #pragma unroll
  1566. for (int i = 0; i < iterations; i++) {
  1567. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  1568. vals_arr[i] = out_grad_h1[i * iteration_stride + id];
  1569. vals_arr[i] *= gamma_reg; // out_grad * gamma
  1570. vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
  1571. }
  1572. if ((high_index) < row_stride) {
  1573. __half2 gamma_reg = gamma_h[high_index];
  1574. vals_arr[iterations] = out_grad_h1[high_index];
  1575. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  1576. vals_hat_arr[iterations] = vals_hat_h[high_index];
  1577. iterations++;
  1578. }
  1579. __half mean_h = means[row];
  1580. __half var_h = vars[row];
  1581. __half2 var_reg = __halves2half2(var_h, var_h);
  1582. __half2 mean_reg = __halves2half2(mean_h, mean_h);
  1583. __half2 xu[NORM_REG];
  1584. float sum = 0.f;
  1585. for (int i = 0; i < iterations; i++) {
  1586. xu[i] = (vals_hat_arr[i] - mean_reg);
  1587. __half2 result_h = (xu[i] * vals_arr[i]);
  1588. float2 result_f = __half22float2(result_h);
  1589. sum += result_f.x;
  1590. sum += result_f.y;
  1591. vals_arr[i] *= h2rsqrt(var_reg);
  1592. }
  1593. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1594. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1595. __syncthreads();
  1596. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1597. #ifndef __STOCHASTIC_MODE__
  1598. __syncthreads();
  1599. #endif
  1600. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1601. sum = g.shfl(sum, 0);
  1602. sum /= (2 * row_stride);
  1603. __half2 sum_h = __float2half2_rn(sum);
  1604. for (int i = 0; i < iterations; i++) {
  1605. __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
  1606. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1607. float2 xu_grad_f = __half22float2(xu_grad);
  1608. vals_arr_f[i].x += xu_grad_f.x;
  1609. vals_arr_f[i].y += xu_grad_f.y;
  1610. }
  1611. sum = 0.f;
  1612. for (int i = 0; i < iterations; i++) {
  1613. sum += (vals_arr_f[i].x);
  1614. sum += (vals_arr_f[i].y);
  1615. }
  1616. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1617. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1618. __syncthreads();
  1619. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1620. #ifndef __STOCHASTIC_MODE__
  1621. __syncthreads();
  1622. #endif
  1623. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1624. sum = g.shfl(sum, 0);
  1625. sum /= (2 * row_stride);
  1626. iterations = row_stride / iteration_stride;
  1627. for (int i = 0; i < iterations; i++) {
  1628. vals_arr_f[i].x -= sum;
  1629. vals_arr_f[i].y -= sum;
  1630. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1631. inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
  1632. }
  1633. if ((high_index) < row_stride) {
  1634. vals_arr_f[iterations].x -= sum;
  1635. vals_arr_f[iterations].y -= sum;
  1636. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1637. inp_grad_h[high_index] = temp + out_grad_h2[high_index];
  1638. }
  1639. }
  1640. template <>
  1641. void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
  1642. const float* out_grad2,
  1643. const float* X_data,
  1644. const float* vars,
  1645. const float* means,
  1646. const float* gamma,
  1647. float* gamma_grad,
  1648. float* betta_grad,
  1649. float* inp_grad,
  1650. int batch,
  1651. int hidden_dim,
  1652. cudaStream_t stream[2])
  1653. {
  1654. int threads = THREADS;
  1655. dim3 grid_dim(hidden_dim / TILE_DIM);
  1656. dim3 block_dim(TILE_DIM, TILE_DIM);
  1657. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1658. out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1659. dim3 grid_dim2(batch);
  1660. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1661. threads <<= 1;
  1662. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1663. threads <<= 2;
  1664. else if (hidden_dim > 65536)
  1665. throw std::runtime_error("Unsupport hidden_dim.");
  1666. dim3 block_dim2(threads);
  1667. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1668. out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
  1669. }
  1670. template <>
  1671. void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
  1672. const __half* out_grad2,
  1673. const __half* X_data,
  1674. const __half* vars,
  1675. const __half* means,
  1676. const __half* gamma,
  1677. __half* gamma_grad,
  1678. __half* betta_grad,
  1679. __half* inp_grad,
  1680. int batch,
  1681. int hidden_dim,
  1682. cudaStream_t stream[2])
  1683. {
  1684. int threads = THREADS;
  1685. dim3 grid_dim(hidden_dim / TILE_DIM);
  1686. dim3 block_dim(TILE_DIM, TILE_DIM);
  1687. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1688. out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1689. dim3 grid_dim2(batch);
  1690. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1691. threads <<= 1;
  1692. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1693. threads <<= 2;
  1694. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1695. threads <<= 3;
  1696. else if (hidden_dim > 65536)
  1697. throw std::runtime_error("Unsupport hidden_dim.");
  1698. dim3 block_dim2(threads / 2);
  1699. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1700. out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
  1701. }