normalize_kernels.cu 75 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103
  1. #include "custom_cuda_layers.h"
  2. namespace cg = cooperative_groups;
  3. /*
  4. Fused bias add, residual (elementwise) add, and normalization layer.
  5. For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
  6. __half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
  7. For specific launch constraints, see the launch functions.
  8. */
  9. #define NORM_REG (MAX_REGISTERS / 4)
  10. __global__ void fused_bias_residual_layer_norm(float* vals,
  11. const float* residual,
  12. const float* gamma,
  13. const float* beta,
  14. float epsilon,
  15. bool preLayerNorm,
  16. bool training,
  17. float* vars,
  18. float* means,
  19. int row_stride)
  20. {
  21. int iteration_stride = blockDim.x;
  22. int iterations = row_stride / iteration_stride;
  23. cg::thread_block b = cg::this_thread_block();
  24. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  25. int row = blockIdx.x;
  26. int id = threadIdx.x;
  27. int gid = id / WARP_SIZE;
  28. float vals_arr[NORM_REG];
  29. __shared__ float shr[MAX_WARP_NUM];
  30. residual += (row * row_stride);
  31. vals += (row * row_stride);
  32. float sum = 0.f;
  33. int high_index = iterations * iteration_stride + id;
  34. #pragma unroll
  35. for (int i = 0; i < iterations; i++) {
  36. vals_arr[i] = residual[i * iteration_stride + id];
  37. sum += vals_arr[i];
  38. }
  39. if (high_index < row_stride) {
  40. vals_arr[iterations] = residual[high_index];
  41. sum += vals_arr[iterations];
  42. iterations++;
  43. }
  44. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  45. if (g.thread_rank() == 0) shr[gid] = sum;
  46. b.sync();
  47. if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
  48. #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
  49. b.sync();
  50. #endif
  51. for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
  52. sum = g.shfl(sum, 0);
  53. float mean = sum / row_stride;
  54. if (training)
  55. if (threadIdx.x == 0) means[row] = mean;
  56. float variance = 0.f;
  57. for (int i = 0; i < iterations; i++) {
  58. vals_arr[i] -= mean;
  59. variance += vals_arr[i] * vals_arr[i];
  60. }
  61. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  62. if (g.thread_rank() == 0) shr[gid] = variance;
  63. b.sync();
  64. if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
  65. #ifndef __STOCHASTIC_MODE__
  66. b.sync();
  67. #endif
  68. for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
  69. variance = g.shfl(variance, 0);
  70. variance /= row_stride;
  71. variance += epsilon;
  72. if (training)
  73. if (threadIdx.x == 0) vars[row] = variance;
  74. iterations = row_stride / iteration_stride;
  75. for (int i = 0; i < iterations; i++) {
  76. vals_arr[i] = vals_arr[i] * rsqrtf(variance);
  77. vals_arr[i] =
  78. vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
  79. vals[i * iteration_stride + id] = vals_arr[i];
  80. }
  81. if ((high_index) < row_stride) {
  82. vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
  83. vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
  84. vals[high_index] = vals_arr[iterations];
  85. }
  86. }
  87. __global__ void fused_bias_residual_layer_norm(__half* vals,
  88. const __half* residual,
  89. const __half* gamma,
  90. const __half* beta,
  91. float epsilon,
  92. bool preLayerNorm,
  93. bool training,
  94. __half* vars,
  95. __half* means,
  96. int row_stride)
  97. {
  98. #if __CUDA_ARCH__ >= 700
  99. int iteration_stride = blockDim.x;
  100. int iterations = row_stride / iteration_stride;
  101. cg::thread_block b = cg::this_thread_block();
  102. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  103. int row = blockIdx.x;
  104. int id = threadIdx.x;
  105. int gid = id >> 5;
  106. float2 vals_f[NORM_REG];
  107. __shared__ float shr[MAX_WARP_NUM];
  108. __half2* vals_cast = reinterpret_cast<__half2*>(vals);
  109. const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
  110. residual_cast += (row * row_stride);
  111. vals_cast += (row * row_stride);
  112. float sum = 0.f;
  113. int high_index = iterations * iteration_stride + id;
  114. #pragma unroll
  115. for (int i = 0; i < iterations; i++) {
  116. vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
  117. sum += vals_f[i].x;
  118. sum += vals_f[i].y;
  119. }
  120. if ((high_index) < row_stride) {
  121. vals_f[iterations] = __half22float2(residual_cast[high_index]);
  122. sum += vals_f[iterations].x;
  123. sum += vals_f[iterations].y;
  124. iterations++;
  125. }
  126. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  127. if (g.thread_rank() == 0) shr[gid] = sum;
  128. b.sync();
  129. if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
  130. #ifndef __STOCHASTIC_MODE__
  131. b.sync();
  132. #endif
  133. for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
  134. sum = g.shfl(sum, 0);
  135. float mean = sum / (row_stride * 2);
  136. float variance = 0.f;
  137. for (int i = 0; i < iterations; i++) {
  138. vals_f[i].x -= mean;
  139. vals_f[i].y -= mean;
  140. variance += vals_f[i].x * vals_f[i].x;
  141. variance += vals_f[i].y * vals_f[i].y;
  142. }
  143. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  144. if (g.thread_rank() == 0) shr[gid] = variance;
  145. b.sync();
  146. if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
  147. #ifndef __STOCHASTIC_MODE__
  148. b.sync();
  149. #endif
  150. for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
  151. variance = g.shfl(variance, 0);
  152. variance /= (row_stride * 2);
  153. variance += epsilon;
  154. __half2 variance_h = __float2half2_rn(variance);
  155. const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
  156. const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
  157. if (training && threadIdx.x == 0) {
  158. vars[row] = __float2half(variance);
  159. means[row] = __float2half(mean);
  160. }
  161. iterations = row_stride / iteration_stride;
  162. for (int i = 0; i < iterations; i++) {
  163. __half2 vals_arr = __float22half2_rn(vals_f[i]);
  164. vals_arr = vals_arr * h2rsqrt(variance_h);
  165. vals_arr =
  166. vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
  167. vals_cast[i * iteration_stride + id] = vals_arr;
  168. }
  169. if ((high_index) < row_stride) {
  170. __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
  171. vals_arr = vals_arr * h2rsqrt(variance_h);
  172. vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
  173. vals_cast[high_index] = vals_arr;
  174. }
  175. #endif
  176. }
  177. template <typename T>
  178. void launch_bias_residual_layer_norm(T* vals,
  179. const T* residual,
  180. const T* gamma,
  181. const T* beta,
  182. float epsilon,
  183. int batch_size,
  184. int hidden_dim,
  185. cudaStream_t stream,
  186. bool preLayerNorm,
  187. bool training,
  188. T* vars,
  189. T* means);
  190. template <>
  191. void launch_bias_residual_layer_norm<float>(float* vals,
  192. const float* residual,
  193. const float* gamma,
  194. const float* beta,
  195. float epsilon,
  196. int batch_size,
  197. int hidden_dim,
  198. cudaStream_t stream,
  199. bool preLayerNorm,
  200. bool training,
  201. float* vars,
  202. float* means)
  203. {
  204. int threads = THREADS;
  205. dim3 grid_dim(batch_size);
  206. if (hidden_dim > 16384 && hidden_dim <= 32768)
  207. threads <<= 1;
  208. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  209. threads <<= 2;
  210. else if (hidden_dim > 65536)
  211. throw std::runtime_error("Unsupport hidden_dim.");
  212. dim3 block_dim(threads);
  213. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  214. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
  215. }
  216. template <>
  217. void launch_bias_residual_layer_norm<__half>(__half* vals,
  218. const __half* residual,
  219. const __half* gamma,
  220. const __half* beta,
  221. float epsilon,
  222. int batch_size,
  223. int hidden_dim,
  224. cudaStream_t stream,
  225. bool preLayerNorm,
  226. bool training,
  227. __half* vars,
  228. __half* means)
  229. {
  230. int threads = 128;
  231. dim3 grid_dim(batch_size);
  232. if (hidden_dim > 8192 && hidden_dim <= 16384)
  233. threads <<= 1;
  234. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  235. threads <<= 2;
  236. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  237. threads <<= 3;
  238. else if (hidden_dim > 65536)
  239. throw std::runtime_error("Unsupport hidden_dim.");
  240. dim3 block_dim(threads);
  241. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  242. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
  243. }
  244. __global__ void fused_bias_residual_layer_norm(float* vals,
  245. const float* residual,
  246. const float* gamma,
  247. const float* beta,
  248. float epsilon,
  249. bool preLayerNorm,
  250. bool training,
  251. float* vars,
  252. int row_stride)
  253. {
  254. int iteration_stride = blockDim.x;
  255. int iterations = row_stride / iteration_stride;
  256. cg::thread_block b = cg::this_thread_block();
  257. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  258. int row = blockIdx.x;
  259. int id = threadIdx.x;
  260. int gid = id / 32;
  261. float vals_arr[NORM_REG];
  262. __shared__ float shr[MAX_WARP_NUM];
  263. residual += (row * row_stride);
  264. vals += (row * row_stride);
  265. float sum = 0.f;
  266. int high_index = iterations * iteration_stride + id;
  267. #pragma unroll
  268. for (int i = 0; i < iterations; i++) {
  269. vals_arr[i] = residual[i * iteration_stride + id];
  270. sum += vals_arr[i];
  271. }
  272. if ((high_index) < row_stride) {
  273. vals_arr[iterations] = residual[high_index];
  274. sum += vals_arr[iterations];
  275. iterations++;
  276. }
  277. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  278. if (g.thread_rank() == 0) shr[gid] = sum;
  279. b.sync();
  280. if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
  281. #if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
  282. b.sync();
  283. #endif
  284. for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
  285. sum = g.shfl(sum, 0);
  286. float mean = sum / row_stride;
  287. float variance = 0.f;
  288. for (int i = 0; i < iterations; i++) {
  289. vals_arr[i] -= mean;
  290. variance += vals_arr[i] * vals_arr[i];
  291. }
  292. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  293. if (g.thread_rank() == 0) shr[gid] = variance;
  294. b.sync();
  295. if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
  296. #ifndef __STOCHASTIC_MODE__
  297. b.sync();
  298. #endif
  299. for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
  300. variance = g.shfl(variance, 0);
  301. variance /= row_stride;
  302. variance += epsilon;
  303. if (training)
  304. if (threadIdx.x == 0) vars[row] = variance;
  305. iterations = row_stride / iteration_stride;
  306. for (int i = 0; i < iterations; i++) {
  307. vals_arr[i] = vals_arr[i] * rsqrtf(variance);
  308. vals_arr[i] =
  309. vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
  310. vals[i * iteration_stride + id] = vals_arr[i];
  311. }
  312. if ((high_index) < row_stride) {
  313. vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
  314. vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
  315. vals[high_index] = vals_arr[iterations];
  316. }
  317. }
  318. __global__ void fused_bias_residual_layer_norm(__half* vals,
  319. const __half* residual,
  320. const __half* gamma,
  321. const __half* beta,
  322. float epsilon,
  323. bool preLayerNorm,
  324. bool training,
  325. __half* vars,
  326. int row_stride)
  327. {
  328. #if __CUDA_ARCH__ >= 700
  329. int iteration_stride = blockDim.x;
  330. int iterations = row_stride / iteration_stride;
  331. cg::thread_block b = cg::this_thread_block();
  332. cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
  333. int row = blockIdx.x;
  334. int id = threadIdx.x;
  335. int gid = id >> 5;
  336. float2 vals_f[NORM_REG];
  337. __shared__ float shr[MAX_WARP_NUM];
  338. __half2* vals_cast = reinterpret_cast<__half2*>(vals);
  339. const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
  340. residual_cast += (row * row_stride);
  341. vals_cast += (row * row_stride);
  342. float sum = 0.f;
  343. int high_index = iterations * iteration_stride + id;
  344. #pragma unroll
  345. for (int i = 0; i < iterations; i++) {
  346. vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
  347. sum += vals_f[i].x;
  348. sum += vals_f[i].y;
  349. }
  350. if ((high_index) < row_stride) {
  351. vals_f[iterations] = __half22float2(residual_cast[high_index]);
  352. sum += vals_f[iterations].x;
  353. sum += vals_f[iterations].y;
  354. iterations++;
  355. }
  356. for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
  357. if (g.thread_rank() == 0) shr[gid] = sum;
  358. b.sync();
  359. if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
  360. #ifndef __STOCHASTIC_MODE__
  361. b.sync();
  362. #endif
  363. for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
  364. sum = g.shfl(sum, 0);
  365. float mean = sum / (row_stride * 2);
  366. float variance = 0.f;
  367. for (int i = 0; i < iterations; i++) {
  368. vals_f[i].x -= mean;
  369. vals_f[i].y -= mean;
  370. variance += vals_f[i].x * vals_f[i].x;
  371. variance += vals_f[i].y * vals_f[i].y;
  372. }
  373. for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
  374. if (g.thread_rank() == 0) shr[gid] = variance;
  375. b.sync();
  376. if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
  377. #ifndef __STOCHASTIC_MODE__
  378. b.sync();
  379. #endif
  380. for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
  381. variance = g.shfl(variance, 0);
  382. variance /= (row_stride * 2);
  383. variance += epsilon;
  384. __half2 variance_h = __float2half2_rn(variance);
  385. const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
  386. const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
  387. if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
  388. iterations = row_stride / iteration_stride;
  389. for (int i = 0; i < iterations; i++) {
  390. __half2 vals_arr = __float22half2_rn(vals_f[i]);
  391. vals_arr = vals_arr * h2rsqrt(variance_h);
  392. vals_arr =
  393. vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
  394. vals_cast[i * iteration_stride + id] = vals_arr;
  395. }
  396. if ((high_index) < row_stride) {
  397. __half2 vals_arr = __float22half2_rn(vals_f[iterations]);
  398. vals_arr = vals_arr * h2rsqrt(variance_h);
  399. vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
  400. vals_cast[high_index] = vals_arr;
  401. }
  402. #endif
  403. }
  404. template <typename T>
  405. void launch_bias_residual_layer_norm(T* vals,
  406. const T* residual,
  407. const T* gamma,
  408. const T* beta,
  409. float epsilon,
  410. int batch_size,
  411. int hidden_dim,
  412. cudaStream_t stream,
  413. bool preLayerNorm,
  414. bool training,
  415. T* vars);
  416. /*
  417. To tune this launch the following restrictions must be met:
  418. For float:
  419. row_stride == hidden_size
  420. threads * iterations == row_stride
  421. threads is in [32, 64, 128, 256, 512, 1024]
  422. For half:
  423. row_stride == hidden_size / 2
  424. threads * iterations == row_stride
  425. threads is in [32, 64, 128, 256, 512, 1024]
  426. */
  427. template <>
  428. void launch_bias_residual_layer_norm<float>(float* vals,
  429. const float* residual,
  430. const float* gamma,
  431. const float* beta,
  432. float epsilon,
  433. int batch_size,
  434. int hidden_dim,
  435. cudaStream_t stream,
  436. bool preLayerNorm,
  437. bool training,
  438. float* vars)
  439. {
  440. int threads = THREADS;
  441. dim3 grid_dim(batch_size);
  442. // There are some limitations to call below functions, now just enumerate the situations.
  443. if (hidden_dim > 16384 && hidden_dim <= 32768)
  444. threads <<= 1;
  445. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  446. threads <<= 2;
  447. else if (hidden_dim > 65536)
  448. throw std::runtime_error("Unsupport hidden_dim.");
  449. dim3 block_dim(threads);
  450. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  451. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
  452. }
  453. template <>
  454. void launch_bias_residual_layer_norm<__half>(__half* vals,
  455. const __half* residual,
  456. const __half* gamma,
  457. const __half* beta,
  458. float epsilon,
  459. int batch_size,
  460. int hidden_dim,
  461. cudaStream_t stream,
  462. bool preLayerNorm,
  463. bool training,
  464. __half* vars)
  465. {
  466. int threads = 128;
  467. dim3 grid_dim(batch_size);
  468. // There are some limitations to call below functions, now just enumerate the situations.
  469. if (hidden_dim > 8192 && hidden_dim <= 16384)
  470. threads <<= 1;
  471. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  472. threads <<= 2;
  473. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  474. threads <<= 3;
  475. else if (hidden_dim > 65536)
  476. throw std::runtime_error("Unsupport hidden_dim.");
  477. dim3 block_dim(threads);
  478. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
  479. vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
  480. }
  481. /* Normalize Gamma & Betta gradients
  482. * Compute gradients using either X_hat or
  483. * normalize input (invertible).
  484. * Combine transpose with gradients computation.
  485. */
  486. template <typename T>
  487. __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
  488. const T* __restrict__ vals_hat,
  489. const T* __restrict__ gamma,
  490. const T* __restrict__ betta,
  491. T* __restrict__ gamma_grad,
  492. T* __restrict__ betta_grad,
  493. int rows,
  494. int width,
  495. bool invertible)
  496. {
  497. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  498. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  499. cg::thread_block b = cg::this_thread_block();
  500. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  501. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  502. int offset = threadIdx.y * width + idx;
  503. int y_stride = width * TILE_DIM;
  504. float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
  505. float gamma_reg = (float)gamma[idx];
  506. // Loop across matrix height
  507. float betta_tmp = 0;
  508. float gamma_tmp = 0;
  509. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  510. float grad = (float)out_grad[offset];
  511. float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
  512. : (float)vals_hat[offset]);
  513. betta_tmp += grad;
  514. gamma_tmp += (val * grad);
  515. offset += y_stride;
  516. }
  517. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  518. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  519. __syncthreads();
  520. // Sum the shared buffer.
  521. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  522. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  523. #ifndef __STOCHASTIC_MODE__
  524. __syncthreads();
  525. #endif
  526. for (int i = 1; i < TILE_DIM; i <<= 1) {
  527. s1 += g.shfl_down(s1, i);
  528. s2 += g.shfl_down(s2, i);
  529. }
  530. if (threadIdx.x == 0) {
  531. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  532. betta_grad[pos] = s1;
  533. gamma_grad[pos] = s2;
  534. }
  535. }
  536. /* Normalize Gamma & Betta gradients
  537. * Compute gradients using the input to
  538. * the normalize.
  539. * Combine transpose with gradients computation.
  540. */
  541. template <typename T>
  542. __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
  543. const T* __restrict__ X_data,
  544. const T* __restrict__ vars,
  545. const T* __restrict__ means,
  546. T* __restrict__ gamma_grad,
  547. T* __restrict__ betta_grad,
  548. int rows,
  549. int width)
  550. {
  551. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  552. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  553. cg::thread_block b = cg::this_thread_block();
  554. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  555. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  556. int offset = threadIdx.y * width + idx;
  557. int y_stride = width * TILE_DIM;
  558. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  559. // Loop across matrix height
  560. float betta_tmp = 0;
  561. float gamma_tmp = 0;
  562. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  563. float grad = (float)out_grad[offset];
  564. float val = (float)X_data[offset];
  565. val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
  566. betta_tmp += grad;
  567. gamma_tmp += (val * grad);
  568. offset += y_stride;
  569. }
  570. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  571. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  572. __syncthreads();
  573. // Sum the shared buffer.
  574. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  575. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  576. #ifndef __STOCHASTIC_MODE__
  577. __syncthreads();
  578. #endif
  579. for (int i = 1; i < TILE_DIM; i <<= 1) {
  580. s1 += g.shfl_down(s1, i);
  581. s2 += g.shfl_down(s2, i);
  582. }
  583. if (threadIdx.x == 0) {
  584. betta_grad[pos] = s1;
  585. gamma_grad[pos] = s2;
  586. }
  587. }
  588. /*
  589. /* Backward Normalize (Input-Gradient)
  590. * Using the means and variances from the input
  591. * This type of backward is invertible!
  592. * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
  593. */
  594. __global__ void LayerNormBackward2(const float* out_grad,
  595. const float* vals_hat,
  596. const float* gamma,
  597. const float* betta,
  598. const float* vars,
  599. float* inp_grad,
  600. bool invertible,
  601. int row_stride)
  602. {
  603. int iteration_stride = blockDim.x;
  604. int iterations = row_stride / iteration_stride;
  605. cg::thread_block b = cg::this_thread_block();
  606. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  607. int row = blockIdx.x;
  608. int id = threadIdx.x;
  609. int wid = id / WARP_SIZE;
  610. int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
  611. __shared__ float partialSum[MAX_WARP_NUM];
  612. out_grad += (row * row_stride);
  613. vals_hat += (row * row_stride);
  614. inp_grad += (row * row_stride);
  615. float vals_arr[NORM_REG];
  616. float vals_hat_arr[NORM_REG];
  617. int high_index = iterations * iteration_stride + id;
  618. #pragma unroll
  619. for (int i = 0; i < iterations; i++) {
  620. float gamma_reg = gamma[i * iteration_stride + id];
  621. vals_arr[i] = out_grad[i * iteration_stride + id];
  622. vals_arr[i] *= gamma_reg;
  623. vals_hat_arr[i] =
  624. (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
  625. gamma_reg
  626. : vals_hat[i * iteration_stride + id]);
  627. }
  628. if ((high_index) < row_stride) {
  629. float gamma_reg = gamma[high_index];
  630. vals_arr[iterations] = out_grad[high_index];
  631. vals_arr[iterations] *= gamma_reg;
  632. vals_hat_arr[iterations] =
  633. (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
  634. : vals_hat[high_index]);
  635. iterations++;
  636. }
  637. float var_reg = vars[row];
  638. float sum = 0;
  639. for (int i = 0; i < iterations; i++) {
  640. sum += vals_hat_arr[i] * vals_arr[i] *
  641. sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
  642. vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
  643. }
  644. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  645. if (g.thread_rank() == 0) partialSum[wid] = sum;
  646. __syncthreads();
  647. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  648. #ifndef __STOCHASTIC_MODE__
  649. __syncthreads();
  650. #endif
  651. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  652. sum = g.shfl(sum, 0);
  653. sum /= row_stride;
  654. for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
  655. sum = 0;
  656. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  657. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  658. if (g.thread_rank() == 0) partialSum[wid] = sum;
  659. __syncthreads();
  660. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  661. #ifndef __STOCHASTIC_MODE__
  662. __syncthreads();
  663. #endif
  664. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  665. sum = g.shfl(sum, 0);
  666. sum /= row_stride;
  667. iterations = row_stride / iteration_stride;
  668. for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
  669. if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
  670. }
  671. __global__ void LayerNormBackward2(const __half* out_grad,
  672. const __half* vals_hat,
  673. const __half* gamma,
  674. const __half* betta,
  675. const __half* vars,
  676. __half* inp_grad,
  677. bool invertible,
  678. int row_stride)
  679. {
  680. int iteration_stride = blockDim.x;
  681. int iterations = row_stride / iteration_stride;
  682. cg::thread_block b = cg::this_thread_block();
  683. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  684. int row = blockIdx.x;
  685. int id = threadIdx.x;
  686. int wid = id / WARP_SIZE;
  687. int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
  688. __shared__ float partialSum[MAX_WARP_NUM];
  689. __half2 vals_arr[NORM_REG];
  690. float2 vals_arr_f[NORM_REG];
  691. __half2 vals_hat_arr[NORM_REG];
  692. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  693. const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
  694. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
  695. inp_grad_h += (row * row_stride);
  696. out_grad_h += (row * row_stride);
  697. vals_hat_h += (row * row_stride);
  698. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  699. const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
  700. int high_index = iterations * iteration_stride + id;
  701. #pragma unroll
  702. for (int i = 0; i < iterations; i++) {
  703. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  704. vals_arr[i] = out_grad_h[i * iteration_stride + id];
  705. vals_arr[i] *= gamma_reg;
  706. vals_hat_arr[i] =
  707. (invertible
  708. ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
  709. gamma_reg
  710. : vals_hat_h[i * iteration_stride + id]);
  711. }
  712. if ((high_index) < row_stride) {
  713. __half2 gamma_reg = gamma_h[high_index];
  714. vals_arr[iterations] = out_grad_h[high_index];
  715. vals_arr[iterations] *= gamma_reg;
  716. vals_hat_arr[iterations] =
  717. (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
  718. : vals_hat_h[high_index]);
  719. iterations++;
  720. }
  721. __half var_h = vars[row];
  722. __half2 var_reg = __halves2half2(var_h, var_h);
  723. float sum = 0.f;
  724. for (int i = 0; i < iterations; i++) {
  725. __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
  726. float2 result_f = __half22float2(result_h);
  727. sum += result_f.x;
  728. sum += result_f.y;
  729. vals_arr[i] *= h2rsqrt(var_reg);
  730. }
  731. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  732. if (g.thread_rank() == 0) partialSum[wid] = sum;
  733. __syncthreads();
  734. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  735. #ifndef __STOCHASTIC_MODE__
  736. __syncthreads();
  737. #endif
  738. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  739. sum = g.shfl(sum, 0);
  740. sum /= (2 * row_stride);
  741. __half2 sum_h = __float2half2_rn(sum);
  742. for (int i = 0; i < iterations; i++) {
  743. __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
  744. vals_arr_f[i] = __half22float2(vals_arr[i]);
  745. float2 temp_f = __half22float2(temp);
  746. vals_arr_f[i].x += temp_f.x;
  747. vals_arr_f[i].y += temp_f.y;
  748. }
  749. sum = 0.f;
  750. for (int i = 0; i < iterations; i++) {
  751. sum += (vals_arr_f[i].x);
  752. sum += (vals_arr_f[i].y);
  753. }
  754. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  755. if (g.thread_rank() == 0) partialSum[wid] = sum;
  756. __syncthreads();
  757. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  758. #ifndef __STOCHASTIC_MODE__
  759. __syncthreads();
  760. #endif
  761. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  762. sum = g.shfl(sum, 0);
  763. sum /= (2 * row_stride);
  764. iterations = row_stride / iteration_stride;
  765. for (int i = 0; i < iterations; i++) {
  766. vals_arr_f[i].x -= sum;
  767. vals_arr_f[i].y -= sum;
  768. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  769. inp_grad_h[i * iteration_stride + id] = temp;
  770. }
  771. if ((high_index) < row_stride) {
  772. vals_arr_f[iterations].x -= sum;
  773. vals_arr_f[iterations].y -= sum;
  774. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  775. inp_grad_h[high_index] = temp;
  776. }
  777. }
  778. template <>
  779. void launch_layerNorm_backward<float>(const float* out_grad,
  780. const float* vals_hat,
  781. const float* vars,
  782. const float* gamma,
  783. float* gamma_grad,
  784. float* betta_grad,
  785. float* inp_grad,
  786. int batch,
  787. int hidden_dim,
  788. cudaStream_t stream[2],
  789. bool invertible,
  790. const float* betta)
  791. {
  792. int threads = THREADS;
  793. dim3 grid_dim(hidden_dim / TILE_DIM);
  794. dim3 block_dim(TILE_DIM, TILE_DIM);
  795. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  796. out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  797. dim3 grid_dim2(batch);
  798. if (hidden_dim > 16384 && hidden_dim <= 32768)
  799. threads <<= 1;
  800. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  801. threads <<= 2;
  802. else if (hidden_dim > 65536)
  803. throw std::runtime_error("Unsupport hidden_dim.");
  804. dim3 block_dim2(threads);
  805. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  806. out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
  807. }
  808. template <>
  809. void launch_layerNorm_backward<__half>(const __half* out_grad,
  810. const __half* vals_hat,
  811. const __half* vars,
  812. const __half* gamma,
  813. __half* gamma_grad,
  814. __half* betta_grad,
  815. __half* inp_grad,
  816. int batch,
  817. int hidden_dim,
  818. cudaStream_t stream[2],
  819. bool invertible,
  820. const __half* betta)
  821. {
  822. int threads = THREADS;
  823. dim3 grid_dim(hidden_dim / TILE_DIM);
  824. dim3 block_dim(TILE_DIM, TILE_DIM);
  825. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  826. out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  827. dim3 grid_dim2(batch);
  828. if (hidden_dim > 8192 && hidden_dim <= 16384)
  829. threads <<= 1;
  830. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  831. threads <<= 2;
  832. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  833. threads <<= 3;
  834. else if (hidden_dim > 65536)
  835. throw std::runtime_error("Unsupport hidden_dim.");
  836. dim3 block_dim2(threads / 2);
  837. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  838. out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
  839. }
  840. /* Backward Normalize (Input-Gradient)
  841. * Using the means and variances from the input
  842. * This type of backward is not invertible!
  843. * We do the backward using the input (X)
  844. */
  845. __global__ void LayerNormBackward2(const float* out_grad,
  846. const float* X_vals,
  847. const float* gamma,
  848. const float* vars,
  849. const float* means,
  850. float* inp_grad,
  851. int row_stride)
  852. {
  853. int iteration_stride = blockDim.x;
  854. int iterations = row_stride / iteration_stride;
  855. cg::thread_block b = cg::this_thread_block();
  856. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  857. int row = blockIdx.x;
  858. int id = threadIdx.x;
  859. int wid = id / WARP_SIZE;
  860. int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
  861. __shared__ float partialSum[MAX_WARP_NUM];
  862. out_grad += (row * row_stride);
  863. X_vals += (row * row_stride);
  864. inp_grad += (row * row_stride);
  865. float vals_arr[NORM_REG];
  866. int high_index = iterations * iteration_stride + id;
  867. #pragma unroll
  868. for (int i = 0; i < iterations; i++) {
  869. float gamma_reg = gamma[i * iteration_stride + id];
  870. vals_arr[i] = out_grad[i * iteration_stride + id];
  871. vals_arr[i] *= gamma_reg;
  872. }
  873. if ((high_index) < row_stride) {
  874. float gamma_reg = gamma[high_index];
  875. vals_arr[iterations] = out_grad[high_index];
  876. vals_arr[iterations] *= gamma_reg;
  877. iterations++;
  878. }
  879. float var_reg = vars[row];
  880. float mean_reg = means[row];
  881. float sum = 0;
  882. float xu[NORM_REG];
  883. for (int i = 0; i < iterations; i++) {
  884. xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
  885. sum += vals_arr[i] * xu[i];
  886. vals_arr[i] *= rsqrtf(var_reg);
  887. }
  888. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  889. if (g.thread_rank() == 0) partialSum[wid] = sum;
  890. __syncthreads();
  891. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  892. #ifndef __STOCHASTIC_MODE__
  893. __syncthreads();
  894. #endif
  895. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  896. sum = g.shfl(sum, 0);
  897. sum /= row_stride;
  898. for (int i = 0; i < iterations; i++) {
  899. vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
  900. }
  901. sum = 0;
  902. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  903. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  904. if (g.thread_rank() == 0) partialSum[wid] = sum;
  905. __syncthreads();
  906. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  907. #ifndef __STOCHASTIC_MODE__
  908. __syncthreads();
  909. #endif
  910. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  911. sum = g.shfl(sum, 0);
  912. sum /= row_stride;
  913. iterations = row_stride / iteration_stride;
  914. for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
  915. if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
  916. }
  917. __global__ void LayerNormBackward2(const __half* out_grad,
  918. const __half* X_vals,
  919. const __half* gamma,
  920. const __half* vars,
  921. const __half* means,
  922. __half* inp_grad,
  923. int row_stride)
  924. {
  925. int iteration_stride = blockDim.x;
  926. int iterations = row_stride / iteration_stride;
  927. cg::thread_block b = cg::this_thread_block();
  928. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  929. int row = blockIdx.x;
  930. int id = threadIdx.x;
  931. int wid = id / WARP_SIZE;
  932. int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
  933. __shared__ float partialSum[MAX_WARP_NUM];
  934. __half2 vals_arr[NORM_REG];
  935. float2 vals_arr_f[NORM_REG];
  936. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  937. const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
  938. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
  939. inp_grad_h += (row * row_stride);
  940. out_grad_h += (row * row_stride);
  941. vals_hat_h += (row * row_stride);
  942. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  943. int high_index = iterations * iteration_stride + id;
  944. #pragma unroll
  945. for (int i = 0; i < iterations; i++) {
  946. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  947. vals_arr[i] = out_grad_h[i * iteration_stride + id];
  948. vals_arr[i] *= gamma_reg; // out_grad * gamma
  949. }
  950. if ((high_index) < row_stride) {
  951. __half2 gamma_reg = gamma_h[high_index];
  952. vals_arr[iterations] = out_grad_h[high_index];
  953. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  954. iterations++;
  955. }
  956. __half mean_h = means[row];
  957. __half var_h = vars[row];
  958. __half2 var_reg = __halves2half2(var_h, var_h);
  959. __half2 mean_reg = __halves2half2(mean_h, mean_h);
  960. __half2 xu[NORM_REG];
  961. float sum = 0.f;
  962. for (int i = 0; i < iterations; i++) {
  963. xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
  964. __half2 result_h = (xu[i] * vals_arr[i]);
  965. float2 result_f = __half22float2(result_h);
  966. sum += result_f.x;
  967. sum += result_f.y;
  968. vals_arr[i] *= h2rsqrt(var_reg);
  969. }
  970. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  971. if (g.thread_rank() == 0) partialSum[wid] = sum;
  972. __syncthreads();
  973. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  974. #ifndef __STOCHASTIC_MODE__
  975. __syncthreads();
  976. #endif
  977. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  978. sum = g.shfl(sum, 0);
  979. sum /= (2 * row_stride);
  980. __half2 sum_h = __float2half2_rn(sum);
  981. for (int i = 0; i < iterations; i++) {
  982. __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
  983. vals_arr_f[i] = __half22float2(vals_arr[i]);
  984. float2 xu_grad_f = __half22float2(xu_grad);
  985. vals_arr_f[i].x += xu_grad_f.x;
  986. vals_arr_f[i].y += xu_grad_f.y;
  987. }
  988. sum = 0.f;
  989. for (int i = 0; i < iterations; i++) {
  990. sum += (vals_arr_f[i].x);
  991. sum += (vals_arr_f[i].y);
  992. }
  993. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  994. if (g.thread_rank() == 0) partialSum[wid] = sum;
  995. __syncthreads();
  996. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  997. #ifndef __STOCHASTIC_MODE__
  998. __syncthreads();
  999. #endif
  1000. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1001. sum = g.shfl(sum, 0);
  1002. sum /= (2 * row_stride);
  1003. iterations = row_stride / iteration_stride;
  1004. for (int i = 0; i < iterations; i++) {
  1005. vals_arr_f[i].x -= sum;
  1006. vals_arr_f[i].y -= sum;
  1007. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1008. inp_grad_h[i * iteration_stride + id] = temp;
  1009. }
  1010. if ((high_index) < row_stride) {
  1011. vals_arr_f[iterations].x -= sum;
  1012. vals_arr_f[iterations].y -= sum;
  1013. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1014. inp_grad_h[high_index] = temp;
  1015. }
  1016. }
  1017. template <>
  1018. void launch_layerNorm_backward<float>(const float* out_grad,
  1019. const float* X_data,
  1020. const float* vars,
  1021. const float* means,
  1022. const float* gamma,
  1023. float* gamma_grad,
  1024. float* betta_grad,
  1025. float* inp_grad,
  1026. int batch,
  1027. int hidden_dim,
  1028. cudaStream_t stream[2])
  1029. {
  1030. int threads = THREADS;
  1031. dim3 grid_dim(hidden_dim / TILE_DIM);
  1032. dim3 block_dim(TILE_DIM, TILE_DIM);
  1033. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1034. out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1035. dim3 grid_dim2(batch);
  1036. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1037. threads <<= 1;
  1038. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1039. threads <<= 2;
  1040. else if (hidden_dim > 65536)
  1041. throw std::runtime_error("Unsupport hidden_dim.");
  1042. dim3 block_dim2(threads);
  1043. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1044. out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
  1045. }
  1046. template <>
  1047. void launch_layerNorm_backward<__half>(const __half* out_grad,
  1048. const __half* X_data,
  1049. const __half* vars,
  1050. const __half* means,
  1051. const __half* gamma,
  1052. __half* gamma_grad,
  1053. __half* betta_grad,
  1054. __half* inp_grad,
  1055. int batch,
  1056. int hidden_dim,
  1057. cudaStream_t stream[2])
  1058. {
  1059. int threads = THREADS;
  1060. dim3 grid_dim(hidden_dim / TILE_DIM);
  1061. dim3 block_dim(TILE_DIM, TILE_DIM);
  1062. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1063. out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1064. dim3 grid_dim2(batch);
  1065. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1066. threads <<= 1;
  1067. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1068. threads <<= 2;
  1069. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1070. threads <<= 3;
  1071. else if (hidden_dim > 65536)
  1072. throw std::runtime_error("Unsupport hidden_dim.");
  1073. dim3 block_dim2(threads / 2);
  1074. LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1075. out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
  1076. }
  1077. template <typename T>
  1078. __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
  1079. const T* __restrict__ out_grad2,
  1080. const T* __restrict__ vals_hat,
  1081. const T* __restrict__ gamma,
  1082. const T* __restrict__ betta,
  1083. T* __restrict__ gamma_grad,
  1084. T* __restrict__ betta_grad,
  1085. int rows,
  1086. int width,
  1087. bool invertible)
  1088. {
  1089. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  1090. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  1091. cg::thread_block b = cg::this_thread_block();
  1092. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  1093. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  1094. int offset = threadIdx.y * width + idx;
  1095. int y_stride = width * TILE_DIM;
  1096. float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
  1097. float gamma_reg = (float)gamma[idx];
  1098. // Loop across matrix height
  1099. float betta_tmp = 0;
  1100. float gamma_tmp = 0;
  1101. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  1102. float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
  1103. float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
  1104. : (float)vals_hat[offset]);
  1105. betta_tmp += grad;
  1106. gamma_tmp += (val * grad);
  1107. offset += y_stride;
  1108. }
  1109. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  1110. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  1111. __syncthreads();
  1112. // Sum the shared buffer.
  1113. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  1114. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  1115. #ifndef __STOCHASTIC_MODE__
  1116. __syncthreads();
  1117. #endif
  1118. for (int i = 1; i < TILE_DIM; i <<= 1) {
  1119. s1 += g.shfl_down(s1, i);
  1120. s2 += g.shfl_down(s2, i);
  1121. }
  1122. if (threadIdx.x == 0) {
  1123. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  1124. betta_grad[pos] = s1;
  1125. gamma_grad[pos] = s2;
  1126. }
  1127. }
  1128. template <typename T>
  1129. __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
  1130. const T* __restrict__ out_grad2,
  1131. const T* __restrict__ X_data,
  1132. const T* __restrict__ vars,
  1133. const T* __restrict__ means,
  1134. T* __restrict__ gamma_grad,
  1135. T* __restrict__ betta_grad,
  1136. int rows,
  1137. int width)
  1138. {
  1139. __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
  1140. __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
  1141. cg::thread_block b = cg::this_thread_block();
  1142. cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
  1143. int idx = blockDim.x * blockIdx.x + threadIdx.x;
  1144. int offset = threadIdx.y * width + idx;
  1145. int y_stride = width * TILE_DIM;
  1146. int pos = blockIdx.x * TILE_DIM + threadIdx.y;
  1147. // Loop across matrix height
  1148. float betta_tmp = 0;
  1149. float gamma_tmp = 0;
  1150. for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
  1151. float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
  1152. float val = (float)X_data[offset];
  1153. val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
  1154. betta_tmp += grad;
  1155. gamma_tmp += (val * grad);
  1156. offset += y_stride;
  1157. }
  1158. betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
  1159. gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
  1160. __syncthreads();
  1161. // Sum the shared buffer.
  1162. float s1 = betta_buffer[threadIdx.y][threadIdx.x];
  1163. float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
  1164. #ifndef __STOCHASTIC_MODE__
  1165. __syncthreads();
  1166. #endif
  1167. for (int i = 1; i < TILE_DIM; i <<= 1) {
  1168. s1 += g.shfl_down(s1, i);
  1169. s2 += g.shfl_down(s2, i);
  1170. }
  1171. if (threadIdx.x == 0) {
  1172. betta_grad[pos] = s1;
  1173. gamma_grad[pos] = s2;
  1174. }
  1175. }
  1176. __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
  1177. const float* out_grad2,
  1178. const float* vals_hat,
  1179. const float* gamma,
  1180. const float* betta,
  1181. const float* vars,
  1182. float* inp_grad,
  1183. bool invertible,
  1184. int row_stride)
  1185. {
  1186. int iteration_stride = blockDim.x;
  1187. int iterations = row_stride / iteration_stride;
  1188. cg::thread_block b = cg::this_thread_block();
  1189. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1190. int row = blockIdx.x;
  1191. int id = threadIdx.x;
  1192. int wid = id / WARP_SIZE;
  1193. int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
  1194. __shared__ float partialSum[MAX_WARP_NUM];
  1195. out_grad1 += (row * row_stride);
  1196. out_grad2 += (row * row_stride);
  1197. vals_hat += (row * row_stride);
  1198. inp_grad += (row * row_stride);
  1199. float vals_arr[NORM_REG];
  1200. float vals_hat_arr[NORM_REG];
  1201. int high_index = iterations * iteration_stride + id;
  1202. #pragma unroll
  1203. for (int i = 0; i < iterations; i++) {
  1204. float gamma_reg = gamma[i * iteration_stride + id];
  1205. vals_arr[i] = out_grad1[i * iteration_stride + id];
  1206. vals_arr[i] *= gamma_reg;
  1207. vals_hat_arr[i] =
  1208. (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
  1209. gamma_reg
  1210. : vals_hat[i * iteration_stride + id]);
  1211. }
  1212. if ((high_index) < row_stride) {
  1213. float gamma_reg = gamma[high_index];
  1214. vals_arr[iterations] = out_grad1[high_index];
  1215. vals_arr[iterations] *= gamma_reg;
  1216. vals_hat_arr[iterations] =
  1217. (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
  1218. : vals_hat[high_index]);
  1219. iterations++;
  1220. }
  1221. float var_reg = vars[row];
  1222. float sum = 0;
  1223. for (int i = 0; i < iterations; i++) {
  1224. sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
  1225. vals_arr[i] *= rsqrtf(var_reg);
  1226. }
  1227. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1228. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1229. __syncthreads();
  1230. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1231. #ifndef __STOCHASTIC_MODE__
  1232. __syncthreads();
  1233. #endif
  1234. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1235. sum = g.shfl(sum, 0);
  1236. sum /= row_stride;
  1237. for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
  1238. sum = 0;
  1239. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  1240. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1241. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1242. __syncthreads();
  1243. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1244. #ifndef __STOCHASTIC_MODE__
  1245. __syncthreads();
  1246. #endif
  1247. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1248. sum = g.shfl(sum, 0);
  1249. sum /= row_stride;
  1250. iterations = row_stride / iteration_stride;
  1251. for (int i = 0; i < iterations; i++)
  1252. inp_grad[i * iteration_stride + id] =
  1253. (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
  1254. if ((high_index) < row_stride)
  1255. inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
  1256. }
  1257. __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
  1258. const __half* out_grad2,
  1259. const __half* vals_hat,
  1260. const __half* gamma,
  1261. const __half* betta,
  1262. const __half* vars,
  1263. __half* inp_grad,
  1264. bool invertible,
  1265. int row_stride)
  1266. {
  1267. int iteration_stride = blockDim.x;
  1268. int iterations = row_stride / iteration_stride;
  1269. cg::thread_block b = cg::this_thread_block();
  1270. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1271. int row = blockIdx.x;
  1272. int id = threadIdx.x;
  1273. int wid = id / WARP_SIZE;
  1274. int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
  1275. __shared__ float partialSum[MAX_WARP_NUM];
  1276. __half2 vals_arr[NORM_REG];
  1277. float2 vals_arr_f[NORM_REG];
  1278. __half2 vals_hat_arr[NORM_REG];
  1279. // float2 result[iterations];
  1280. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  1281. const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
  1282. const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
  1283. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
  1284. inp_grad_h += (row * row_stride);
  1285. out_grad_h1 += (row * row_stride);
  1286. out_grad_h2 += (row * row_stride);
  1287. vals_hat_h += (row * row_stride);
  1288. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  1289. const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
  1290. int high_index = iterations * iteration_stride + id;
  1291. #pragma unroll
  1292. for (int i = 0; i < iterations; i++) {
  1293. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  1294. vals_arr[i] = out_grad_h1[i * iteration_stride + id];
  1295. vals_arr[i] *= gamma_reg; // out_grad * gamma
  1296. vals_hat_arr[i] =
  1297. (invertible
  1298. ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
  1299. gamma_reg
  1300. : vals_hat_h[i * iteration_stride + id]);
  1301. }
  1302. if ((high_index) < row_stride) {
  1303. __half2 gamma_reg = gamma_h[high_index];
  1304. vals_arr[iterations] = out_grad_h1[high_index];
  1305. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  1306. vals_hat_arr[iterations] =
  1307. (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
  1308. : vals_hat_h[high_index]);
  1309. iterations++;
  1310. }
  1311. __half var_h = vars[row];
  1312. __half2 var_reg = __halves2half2(var_h, var_h);
  1313. float sum = 0.f;
  1314. for (int i = 0; i < iterations; i++) {
  1315. __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
  1316. float2 result_f = __half22float2(result_h);
  1317. sum += result_f.x;
  1318. sum += result_f.y;
  1319. vals_arr[i] *= h2rsqrt(var_reg);
  1320. }
  1321. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1322. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1323. __syncthreads();
  1324. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1325. #ifndef __STOCHASTIC_MODE__
  1326. __syncthreads();
  1327. #endif
  1328. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1329. sum = g.shfl(sum, 0);
  1330. sum /= (2 * row_stride);
  1331. __half2 sum_h = __float2half2_rn(sum);
  1332. for (int i = 0; i < iterations; i++) {
  1333. __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
  1334. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1335. float2 temp_f = __half22float2(temp);
  1336. vals_arr_f[i].x += temp_f.x;
  1337. vals_arr_f[i].y += temp_f.y;
  1338. }
  1339. sum = 0.f;
  1340. for (int i = 0; i < iterations; i++) {
  1341. sum += (vals_arr_f[i].x);
  1342. sum += (vals_arr_f[i].y);
  1343. }
  1344. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1345. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1346. __syncthreads();
  1347. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1348. #ifndef __STOCHASTIC_MODE__
  1349. __syncthreads();
  1350. #endif
  1351. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1352. sum = g.shfl(sum, 0);
  1353. sum /= (2 * row_stride);
  1354. iterations = row_stride / iteration_stride;
  1355. for (int i = 0; i < iterations; i++) {
  1356. vals_arr_f[i].x -= sum;
  1357. vals_arr_f[i].y -= sum;
  1358. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1359. inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
  1360. }
  1361. if ((high_index) < row_stride) {
  1362. vals_arr_f[iterations].x -= sum;
  1363. vals_arr_f[iterations].y -= sum;
  1364. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1365. inp_grad_h[high_index] = temp + out_grad_h2[high_index];
  1366. }
  1367. }
  1368. template <>
  1369. void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
  1370. const float* out_grad2,
  1371. const float* vals_hat,
  1372. const float* vars,
  1373. const float* gamma,
  1374. float* gamma_grad,
  1375. float* betta_grad,
  1376. float* inp_grad,
  1377. int batch,
  1378. int hidden_dim,
  1379. cudaStream_t stream[2],
  1380. bool invertible,
  1381. const float* betta)
  1382. {
  1383. int threads = THREADS;
  1384. dim3 grid_dim(hidden_dim / TILE_DIM);
  1385. dim3 block_dim(TILE_DIM, TILE_DIM);
  1386. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1387. out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  1388. dim3 grid_dim2(batch);
  1389. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1390. threads <<= 1;
  1391. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1392. threads <<= 2;
  1393. else if (hidden_dim > 65536)
  1394. throw std::runtime_error("Unsupport hidden_dim.");
  1395. dim3 block_dim2(threads);
  1396. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1397. out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
  1398. }
  1399. template <>
  1400. void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
  1401. const __half* out_grad2,
  1402. const __half* vals_hat,
  1403. const __half* vars,
  1404. const __half* gamma,
  1405. __half* gamma_grad,
  1406. __half* betta_grad,
  1407. __half* inp_grad,
  1408. int batch,
  1409. int hidden_dim,
  1410. cudaStream_t stream[2],
  1411. bool invertible,
  1412. const __half* betta)
  1413. {
  1414. int threads = THREADS;
  1415. dim3 grid_dim(hidden_dim / TILE_DIM);
  1416. dim3 block_dim(TILE_DIM, TILE_DIM);
  1417. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1418. out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
  1419. dim3 grid_dim2(batch);
  1420. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1421. threads <<= 1;
  1422. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1423. threads <<= 2;
  1424. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1425. threads <<= 3;
  1426. else if (hidden_dim > 65536)
  1427. throw std::runtime_error("Unsupport hidden_dim.");
  1428. dim3 block_dim2(threads / 2);
  1429. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1430. out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
  1431. }
  1432. /* Backward Normalize (Input-Gradient)
  1433. * Using the means and variances from the input
  1434. * This type of backward is not invertible!
  1435. * We do the backward using the input (X)
  1436. */
  1437. __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
  1438. const float* out_grad2,
  1439. const float* X_vals,
  1440. const float* gamma,
  1441. const float* vars,
  1442. const float* means,
  1443. float* inp_grad,
  1444. int row_stride)
  1445. {
  1446. int iteration_stride = blockDim.x;
  1447. int iterations = row_stride / iteration_stride;
  1448. cg::thread_block b = cg::this_thread_block();
  1449. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1450. int row = blockIdx.x;
  1451. int id = threadIdx.x;
  1452. int wid = id / WARP_SIZE;
  1453. int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
  1454. __shared__ float partialSum[MAX_WARP_NUM];
  1455. float vals_arr[NORM_REG];
  1456. float vals_hat_arr[NORM_REG];
  1457. out_grad1 += (row * row_stride);
  1458. out_grad2 += (row * row_stride);
  1459. X_vals += (row * row_stride);
  1460. inp_grad += (row * row_stride);
  1461. int high_index = iterations * iteration_stride + id;
  1462. #pragma unroll
  1463. for (int i = 0; i < iterations; i++) {
  1464. float gamma_reg = gamma[i * iteration_stride + id];
  1465. vals_arr[i] = out_grad1[i * iteration_stride + id];
  1466. vals_arr[i] *= gamma_reg;
  1467. vals_hat_arr[i] = X_vals[i * iteration_stride + id];
  1468. }
  1469. if ((high_index) < row_stride) {
  1470. float gamma_reg = gamma[high_index];
  1471. vals_arr[iterations] = out_grad1[high_index];
  1472. vals_arr[iterations] *= gamma_reg;
  1473. vals_hat_arr[iterations] = X_vals[high_index];
  1474. iterations++;
  1475. }
  1476. float var_reg = vars[row];
  1477. float mean_reg = means[row];
  1478. float sum = 0;
  1479. float xu[NORM_REG];
  1480. for (int i = 0; i < iterations; i++) {
  1481. xu[i] = (vals_hat_arr[i] - mean_reg);
  1482. sum += vals_arr[i] * xu[i];
  1483. vals_arr[i] *= rsqrtf(var_reg);
  1484. }
  1485. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1486. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1487. __syncthreads();
  1488. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1489. #ifndef __STOCHASTIC_MODE__
  1490. __syncthreads();
  1491. #endif
  1492. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1493. sum = g.shfl(sum, 0);
  1494. sum /= row_stride;
  1495. for (int i = 0; i < iterations; i++) {
  1496. vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
  1497. }
  1498. sum = 0;
  1499. for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
  1500. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1501. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1502. __syncthreads();
  1503. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1504. #ifndef __STOCHASTIC_MODE__
  1505. __syncthreads();
  1506. #endif
  1507. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1508. sum = g.shfl(sum, 0);
  1509. sum /= row_stride;
  1510. iterations = row_stride / iteration_stride;
  1511. for (int i = 0; i < iterations; i++)
  1512. inp_grad[i * iteration_stride + id] =
  1513. (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
  1514. if ((high_index) < row_stride)
  1515. inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
  1516. }
  1517. __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
  1518. const __half* out_grad2,
  1519. const __half* X_vals,
  1520. const __half* gamma,
  1521. const __half* vars,
  1522. const __half* means,
  1523. __half* inp_grad,
  1524. int row_stride)
  1525. {
  1526. int iteration_stride = blockDim.x;
  1527. int iterations = row_stride / iteration_stride;
  1528. cg::thread_block b = cg::this_thread_block();
  1529. cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
  1530. int row = blockIdx.x;
  1531. int id = threadIdx.x;
  1532. int wid = id / WARP_SIZE;
  1533. int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
  1534. __shared__ float partialSum[MAX_WARP_NUM];
  1535. __half2 vals_arr[NORM_REG];
  1536. float2 vals_arr_f[NORM_REG];
  1537. __half2 vals_hat_arr[NORM_REG];
  1538. __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
  1539. const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
  1540. const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
  1541. const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
  1542. out_grad_h1 += (row * row_stride);
  1543. out_grad_h2 += (row * row_stride);
  1544. inp_grad_h += (row * row_stride);
  1545. vals_hat_h += (row * row_stride);
  1546. const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
  1547. int high_index = iterations * iteration_stride + id;
  1548. #pragma unroll
  1549. for (int i = 0; i < iterations; i++) {
  1550. __half2 gamma_reg = gamma_h[i * iteration_stride + id];
  1551. vals_arr[i] = out_grad_h1[i * iteration_stride + id];
  1552. vals_arr[i] *= gamma_reg; // out_grad * gamma
  1553. vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
  1554. }
  1555. if ((high_index) < row_stride) {
  1556. __half2 gamma_reg = gamma_h[high_index];
  1557. vals_arr[iterations] = out_grad_h1[high_index];
  1558. vals_arr[iterations] *= gamma_reg; // out_grad * gamma
  1559. vals_hat_arr[iterations] = vals_hat_h[high_index];
  1560. iterations++;
  1561. }
  1562. __half mean_h = means[row];
  1563. __half var_h = vars[row];
  1564. __half2 var_reg = __halves2half2(var_h, var_h);
  1565. __half2 mean_reg = __halves2half2(mean_h, mean_h);
  1566. __half2 xu[NORM_REG];
  1567. float sum = 0.f;
  1568. for (int i = 0; i < iterations; i++) {
  1569. xu[i] = (vals_hat_arr[i] - mean_reg);
  1570. __half2 result_h = (xu[i] * vals_arr[i]);
  1571. float2 result_f = __half22float2(result_h);
  1572. sum += result_f.x;
  1573. sum += result_f.y;
  1574. vals_arr[i] *= h2rsqrt(var_reg);
  1575. }
  1576. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1577. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1578. __syncthreads();
  1579. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1580. #ifndef __STOCHASTIC_MODE__
  1581. __syncthreads();
  1582. #endif
  1583. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1584. sum = g.shfl(sum, 0);
  1585. sum /= (2 * row_stride);
  1586. __half2 sum_h = __float2half2_rn(sum);
  1587. for (int i = 0; i < iterations; i++) {
  1588. __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
  1589. vals_arr_f[i] = __half22float2(vals_arr[i]);
  1590. float2 xu_grad_f = __half22float2(xu_grad);
  1591. vals_arr_f[i].x += xu_grad_f.x;
  1592. vals_arr_f[i].y += xu_grad_f.y;
  1593. }
  1594. sum = 0.f;
  1595. for (int i = 0; i < iterations; i++) {
  1596. sum += (vals_arr_f[i].x);
  1597. sum += (vals_arr_f[i].y);
  1598. }
  1599. for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
  1600. if (g.thread_rank() == 0) partialSum[wid] = sum;
  1601. __syncthreads();
  1602. if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
  1603. #ifndef __STOCHASTIC_MODE__
  1604. __syncthreads();
  1605. #endif
  1606. for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
  1607. sum = g.shfl(sum, 0);
  1608. sum /= (2 * row_stride);
  1609. iterations = row_stride / iteration_stride;
  1610. for (int i = 0; i < iterations; i++) {
  1611. vals_arr_f[i].x -= sum;
  1612. vals_arr_f[i].y -= sum;
  1613. __half2 temp = __float22half2_rn(vals_arr_f[i]);
  1614. inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
  1615. }
  1616. if ((high_index) < row_stride) {
  1617. vals_arr_f[iterations].x -= sum;
  1618. vals_arr_f[iterations].y -= sum;
  1619. __half2 temp = __float22half2_rn(vals_arr_f[iterations]);
  1620. inp_grad_h[high_index] = temp + out_grad_h2[high_index];
  1621. }
  1622. }
  1623. template <>
  1624. void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
  1625. const float* out_grad2,
  1626. const float* X_data,
  1627. const float* vars,
  1628. const float* means,
  1629. const float* gamma,
  1630. float* gamma_grad,
  1631. float* betta_grad,
  1632. float* inp_grad,
  1633. int batch,
  1634. int hidden_dim,
  1635. cudaStream_t stream[2])
  1636. {
  1637. int threads = THREADS;
  1638. dim3 grid_dim(hidden_dim / TILE_DIM);
  1639. dim3 block_dim(TILE_DIM, TILE_DIM);
  1640. LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
  1641. out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1642. dim3 grid_dim2(batch);
  1643. if (hidden_dim > 16384 && hidden_dim <= 32768)
  1644. threads <<= 1;
  1645. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1646. threads <<= 2;
  1647. else if (hidden_dim > 65536)
  1648. throw std::runtime_error("Unsupport hidden_dim.");
  1649. dim3 block_dim2(threads);
  1650. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1651. out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
  1652. }
  1653. template <>
  1654. void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
  1655. const __half* out_grad2,
  1656. const __half* X_data,
  1657. const __half* vars,
  1658. const __half* means,
  1659. const __half* gamma,
  1660. __half* gamma_grad,
  1661. __half* betta_grad,
  1662. __half* inp_grad,
  1663. int batch,
  1664. int hidden_dim,
  1665. cudaStream_t stream[2])
  1666. {
  1667. int threads = THREADS;
  1668. dim3 grid_dim(hidden_dim / TILE_DIM);
  1669. dim3 block_dim(TILE_DIM, TILE_DIM);
  1670. LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
  1671. out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
  1672. dim3 grid_dim2(batch);
  1673. if (hidden_dim > 8192 && hidden_dim <= 16384)
  1674. threads <<= 1;
  1675. else if (hidden_dim > 16384 && hidden_dim <= 32768)
  1676. threads <<= 2;
  1677. else if (hidden_dim > 32768 && hidden_dim <= 65536)
  1678. threads <<= 3;
  1679. else if (hidden_dim > 65536)
  1680. throw std::runtime_error("Unsupport hidden_dim.");
  1681. dim3 block_dim2(threads / 2);
  1682. LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
  1683. out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
  1684. }