kernel_backward.h 93 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965
  1. /***************************************************************************************************
  2. * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. * SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice, this
  9. * list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holdvr nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  22. * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  23. * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  24. * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  25. * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  26. * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  27. * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  28. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *
  30. **************************************************************************************************/
  31. // Copyright (c) Microsoft Corporation.
  32. // SPDX-License-Identifier: Apache-2.0
  33. // DeepSpeed Team
  34. #pragma once
  35. #include <cmath>
  36. #include <type_traits>
  37. #include <vector>
  38. #include <cuda_fp16.h>
  39. #include "cutlass/cutlass.h"
  40. #include "cutlass/epilogue/thread/linear_combination.h"
  41. #include "cutlass/epilogue/thread/scale_type.h"
  42. #include "cutlass/fast_math.h"
  43. #include "cutlass/functional.h"
  44. #include "cutlass/gemm/gemm.h"
  45. #include "cutlass/layout/matrix.h"
  46. #include "cutlass/layout/vector.h"
  47. #include "cutlass/numeric_conversion.h"
  48. #include "cutlass/numeric_types.h"
  49. #include "cutlass/tensor_ref.h"
  50. #include "gemm_kernel_utils.h"
  51. #include "cutlass/epilogue/thread/linear_combination_relu.h"
  52. #include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
  53. #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
  54. #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
  55. #include "cutlass/gemm/device/default_gemm_configuration.h"
  56. #include "cutlass/gemm/kernel/default_gemm.h"
  57. #include "cutlass/gemm/threadblock/default_mma.h"
  58. #include "cutlass/gemm/threadblock/default_mma_core_simt.h"
  59. #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
  60. #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
  61. #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
  62. #include "cutlass/matrix_shape.h"
  63. #include "cutlass/platform/platform.h"
  64. #include "cutlass/transform/threadblock/predicated_tile_iterator.h"
  65. #include "cutlass/transform/threadblock/vector_iterator.h"
  66. #include "epilogue/epilogue_pipelined.h"
  67. #include "iterators/epilogue_predicated_tile_iterator.h"
  68. #include "epilogue/epilogue_grad_bias.h"
  69. #include "gemm/custom_mma.h"
  70. #include "gemm/find_default_mma.h"
  71. #include "gemm/mma_accum_lambda_iterator.h"
  72. #include "gemm/mma_from_smem.h"
  73. #include "transform/bias_broadcast.h"
  74. #include "transform/tile_smem_loader.h"
  75. #include <inttypes.h>
  76. using namespace gemm_kernel_utils;
  77. namespace {
  78. template <typename FragmentType, int32_t kNumThreads>
  79. struct GmemTile {
  80. /*
  81. Helper functions to efficient store/load RF to gmem
  82. GEMM accumulators have a particular format on A100, and
  83. it takes some compute/shared-memory to rearrange them to
  84. a RowMajor or ColumnMajor format in global memory through
  85. an Epilogue. The same complexity goes for loading into RF.
  86. This class loads/stores RF as they are, and can be used for
  87. efficient accumulation across gemms for instance:
  88. ```
  89. GmemTile tile;
  90. for (int i = 0; i < N; ++i) {
  91. // ...
  92. Fragment accum;
  93. if (i == 0) {
  94. accum.clear();
  95. } else {
  96. tile.load(accum);
  97. }
  98. mma(accum, ...);
  99. if (i < N-1) {
  100. // Store for next GEMM
  101. tile.store(accum);
  102. } else {
  103. // Store in tensor (eg RowMajor)
  104. epilogue(accum);
  105. }
  106. // ...
  107. }
  108. ```
  109. */
  110. // 128bits per thread
  111. using AccessType = cutlass::Array<float, 4>;
  112. static constexpr int32_t kBytes = sizeof(AccessType);
  113. static constexpr int32_t kStride = kNumThreads * AccessType::kElements;
  114. static constexpr int32_t kNumIters = FragmentType::kElements / AccessType::kElements;
  115. static constexpr int32_t kElementsStored = kNumThreads * FragmentType::kElements;
  116. static_assert(FragmentType::kElements % AccessType::kElements == 0,
  117. "fragment not aligned on 128 bits");
  118. float* ptr;
  119. CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id)
  120. {
  121. CUTLASS_PRAGMA_UNROLL
  122. for (int i = 0; i < kNumIters; ++i) {
  123. AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
  124. ptr + thread_id * AccessType::kElements + i * kStride);
  125. AccessType sub_fragment;
  126. cutlass::arch::global_load<AccessType, kBytes>(sub_fragment, gmem_ptr, true);
  127. CUTLASS_PRAGMA_UNROLL
  128. for (int j = 0; j < AccessType::kElements; ++j) {
  129. fragment[i * AccessType::kElements + j] = sub_fragment[j];
  130. }
  131. }
  132. }
  133. CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id)
  134. {
  135. CUTLASS_PRAGMA_UNROLL
  136. for (int i = 0; i < kNumIters; ++i) {
  137. AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
  138. ptr + thread_id * AccessType::kElements + i * kStride);
  139. AccessType sub_fragment;
  140. CUTLASS_PRAGMA_UNROLL
  141. for (int j = 0; j < AccessType::kElements; ++j) {
  142. sub_fragment[j] = fragment[i * AccessType::kElements + j];
  143. }
  144. cutlass::arch::global_store<AccessType, kBytes>(sub_fragment, gmem_ptr, true);
  145. }
  146. }
  147. };
  148. template <typename scalar_t, typename Arch>
  149. constexpr int getWarpsPerSm()
  150. {
  151. constexpr bool is_half = !cutlass::platform::is_same<scalar_t, float>::value;
  152. if (Arch::kMinComputeCapability >= 80) { return is_half ? 12 : 8; }
  153. return 8;
  154. }
  155. } // namespace
  156. template <
  157. // which arch we target (eg `cutlass::arch::Sm80`)
  158. typename ArchTag_,
  159. // input/output type
  160. typename scalar_t_,
  161. // run optimized kernel because memory accesses will be aligned
  162. bool kIsAligned_,
  163. // use dropout if enabled
  164. bool kApplyDropout_,
  165. // when doing a GEMM, preload the next one (uses more shmem)
  166. bool kPreload_,
  167. // block dimensions
  168. int kBlockSizeI_,
  169. int kBlockSizeJ_,
  170. // upperbound on `max(value.shape[-1], query.shape[-1])`
  171. int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
  172. template <typename, typename, typename> class Broadcast1_ = BroadcastNoLoad,
  173. template <typename, typename, typename> class Broadcast2_ = BroadcastNoLoad>
  174. struct AttentionBackwardKernel {
  175. using scalar_t = scalar_t_;
  176. using output_t = scalar_t;
  177. using output_accum_t = float;
  178. using lse_scalar_t = float;
  179. using accum_t = float;
  180. using ArchTag = ArchTag_;
  181. static constexpr bool kIsAligned = kIsAligned_;
  182. static constexpr bool kApplyDropout = kApplyDropout_;
  183. static constexpr bool kPreload = kPreload_;
  184. static constexpr int kBlockSizeI = kBlockSizeI_;
  185. static constexpr int kBlockSizeJ = kBlockSizeJ_;
  186. static constexpr int kMaxK = kMaxK_;
  187. struct Params {
  188. // Input tensors
  189. scalar_t* query_ptr; // [Mq, nH, K]
  190. scalar_t* key_ptr; // [Mk, nH, K]
  191. scalar_t* value_ptr; // [Mk, nH, Kv]
  192. lse_scalar_t* logsumexp_ptr; // [nH, Mq]
  193. scalar_t* output_ptr; // [Mq, nH, Kv]
  194. scalar_t* grad_output_ptr; // [Mq, nH, Kv]
  195. accum_t* delta_ptr; // [nH, Mq]
  196. int32_t* cu_seqlens_q_ptr = nullptr;
  197. int32_t* cu_seqlens_k_ptr = nullptr;
  198. // Output tensors
  199. output_t* grad_query_ptr; // [Mq, nH, K]
  200. output_t* grad_key_ptr; // [Mk, nH, K]
  201. output_t* grad_value_ptr; // [Mk, nH, Kv]
  202. accum_t* grad_bias1_ptr = nullptr;
  203. accum_t* grad_bias2_ptr = nullptr;
  204. int32_t B = 0;
  205. int32_t N = 0;
  206. scalar_t* bias1_ptr = nullptr;
  207. scalar_t* bias2_ptr = nullptr;
  208. // Accumulators
  209. union {
  210. output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv]
  211. output_accum_t* workspace_gk;
  212. };
  213. output_accum_t* workspace_gv; // (will be calculated by the kernel)
  214. output_accum_t* workspace_gq; // (will be calculated by the kernel)
  215. // Scale
  216. accum_t scale;
  217. // Dimensions/strides
  218. int32_t head_dim = -1;
  219. int32_t head_dim_value = -1;
  220. int32_t num_queries = -1;
  221. int32_t num_keys = -1;
  222. int32_t num_heads = -1;
  223. int32_t q_strideM;
  224. int32_t k_strideM;
  225. int32_t v_strideM;
  226. int32_t gO_strideM;
  227. int32_t gB_strideM;
  228. int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise
  229. // RNG sequence offset based on batch_id and head_id
  230. unsigned long long dropout_batch_head_rng_offset;
  231. float dropout_prob = 0.0f;
  232. CUTLASS_HOST_DEVICE int32_t o_strideM() const { return head_dim_value * num_heads; }
  233. CUTLASS_HOST_DEVICE int32_t gQ_strideM() const
  234. {
  235. return gQKV_strideM_multiplier * num_heads * head_dim;
  236. }
  237. CUTLASS_HOST_DEVICE int32_t gK_strideM() const
  238. {
  239. return gQKV_strideM_multiplier * num_heads * head_dim;
  240. }
  241. CUTLASS_HOST_DEVICE int32_t gV_strideM() const
  242. {
  243. return gQKV_strideM_multiplier * num_heads * head_dim_value;
  244. }
  245. // Everything below is only used in `advance_to_block`
  246. // and shouldn't use registers
  247. int64_t o_strideH;
  248. int32_t q_strideH;
  249. int32_t k_strideH;
  250. int32_t v_strideH;
  251. int64_t o_strideB;
  252. int64_t q_strideB;
  253. int64_t k_strideB;
  254. int64_t v_strideB;
  255. int64_t lse_strideB;
  256. int64_t lse_strideH;
  257. int64_t delta_strideB;
  258. int64_t delta_strideH;
  259. int32_t num_batches;
  260. int64_t gO_strideB = 0;
  261. int64_t gQ_strideB = 0;
  262. int64_t gK_strideB = 0;
  263. int64_t gV_strideB = 0;
  264. int64_t gB_strideB = 0;
  265. int64_t gO_strideH = 0;
  266. int64_t gQ_strideH = 0;
  267. int64_t gK_strideH = 0;
  268. int64_t gV_strideH = 0;
  269. int64_t gB_strideH = 0;
  270. CUTLASS_DEVICE bool advance_to_block()
  271. {
  272. int64_t batch_id = blockIdx.z;
  273. int32_t head_id = blockIdx.y;
  274. if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) {
  275. assert(workspace_size() == 0 || workspace != nullptr);
  276. workspace += (batch_id * num_heads + head_id) * workspace_strideBH();
  277. workspace = warp_uniform(workspace);
  278. workspace_gv = workspace + workspace_elements_gk();
  279. workspace_gq = workspace_gv + workspace_elements_gv();
  280. } else {
  281. workspace = nullptr;
  282. }
  283. // Advance pointers that depend on the total concatenated
  284. // number of queries, as `num_queries` is modified in the block
  285. // below
  286. dropout_batch_head_rng_offset = batch_id * (num_heads * num_queries * num_keys) +
  287. head_id * (num_queries * num_keys);
  288. logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH;
  289. query_ptr += batch_id * q_strideB + head_id * q_strideH;
  290. key_ptr += batch_id * k_strideB + head_id * k_strideH;
  291. value_ptr += batch_id * v_strideB + head_id * v_strideH;
  292. output_ptr += batch_id * o_strideB + head_id * o_strideH;
  293. grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH;
  294. delta_ptr += batch_id * delta_strideB + head_id * delta_strideH;
  295. grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH;
  296. grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH;
  297. grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH;
  298. using broadcast_1 = Broadcast1_<typename MatmulQK::BiasLoader::ThreadMap,
  299. typename MatmulQK::BiasLoader::Shape,
  300. scalar_t>;
  301. using broadcast_2 = Broadcast2_<typename MatmulQK::BiasLoader::ThreadMap,
  302. typename MatmulQK::BiasLoader::Shape,
  303. scalar_t>;
  304. if (broadcast_1::kEnable && grad_bias1_ptr) {
  305. grad_bias1_ptr += batch_id * num_queries;
  306. }
  307. if (broadcast_2::kEnable && grad_bias2_ptr) {
  308. auto strideB = num_heads * num_queries * num_keys;
  309. auto strideH = num_queries * num_keys;
  310. grad_bias2_ptr += (batch_id / N) * strideB + head_id * strideH;
  311. }
  312. if (broadcast_1::kEnable && bias1_ptr) {
  313. bias1_ptr = broadcast_1::advance(bias1_ptr,
  314. batch_id / N,
  315. batch_id % N,
  316. head_id,
  317. num_queries * N,
  318. num_queries,
  319. 0);
  320. }
  321. if (broadcast_2::kEnable && bias2_ptr) {
  322. auto strideB = num_heads * num_queries * num_keys;
  323. auto strideH = num_queries * num_keys;
  324. bias2_ptr = broadcast_2::advance(
  325. bias2_ptr, batch_id / N, batch_id % N, head_id, strideB, 0, strideH);
  326. }
  327. num_queries = warp_uniform(num_queries);
  328. num_keys = warp_uniform(num_keys);
  329. query_ptr = warp_uniform(query_ptr);
  330. key_ptr = warp_uniform(key_ptr);
  331. value_ptr = warp_uniform(value_ptr);
  332. logsumexp_ptr = warp_uniform(logsumexp_ptr);
  333. output_ptr = warp_uniform(output_ptr);
  334. grad_output_ptr = warp_uniform(grad_output_ptr);
  335. delta_ptr = warp_uniform(delta_ptr);
  336. grad_query_ptr = warp_uniform(grad_query_ptr);
  337. grad_key_ptr = warp_uniform(grad_key_ptr);
  338. grad_value_ptr = warp_uniform(grad_value_ptr);
  339. if (broadcast_1::kEnable) {
  340. grad_bias1_ptr = warp_uniform(grad_bias1_ptr);
  341. bias1_ptr = warp_uniform(bias1_ptr);
  342. }
  343. if (broadcast_2::kEnable) {
  344. grad_bias2_ptr = warp_uniform(grad_bias2_ptr);
  345. bias2_ptr = warp_uniform(bias2_ptr);
  346. }
  347. return true;
  348. }
  349. __host__ dim3 getBlocksGrid() const { return dim3(1, num_heads, num_batches); }
  350. __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); }
  351. CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const
  352. {
  353. if (!kNeedsAccumGradK) { return 0; }
  354. return align_up(num_keys, (int32_t)kBlockSizeJ) *
  355. align_up(head_dim, (int32_t)kBlockSizeI);
  356. }
  357. CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const
  358. {
  359. if (!kNeedsAccumGradV) { return 0; }
  360. return align_up(num_keys, (int32_t)kBlockSizeJ) *
  361. align_up(head_dim_value, (int32_t)kBlockSizeI);
  362. }
  363. CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const
  364. {
  365. if (!kNeedsAccumGradQ) { return 0; }
  366. if (num_keys <= kBlockSizeJ) { return 0; }
  367. return align_up(num_queries, (int32_t)kBlockSizeI) *
  368. align_up(head_dim, (int32_t)kBlockSizeJ);
  369. }
  370. CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const
  371. {
  372. // Aligned on 128bits
  373. return align_up(
  374. workspace_elements_gk() + workspace_elements_gv() + workspace_elements_gq(),
  375. int64_t(4));
  376. }
  377. CUTLASS_HOST_DEVICE int64_t workspace_size() const
  378. {
  379. // Returns size of buffer we need to run this kernel
  380. return num_batches * num_heads * workspace_strideBH() * sizeof(float);
  381. }
  382. };
  383. static constexpr int64_t kWarpSize = 32;
  384. // If this is true, we store and accumulate dK/dV in RF
  385. // rather than going back to gmem every time
  386. static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value <= 16;
  387. static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
  388. static_assert(!kPreload || (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF),
  389. "preload MMA not supported");
  390. static constexpr bool kPrologueQK = kPreload;
  391. static constexpr bool kPrologueGV = kPreload;
  392. static constexpr bool kPrologueDOV = kPreload;
  393. static constexpr bool kPrologueGQ = kPreload;
  394. static constexpr bool kPrologueGK = kPreload;
  395. static constexpr int64_t kNumWarpsPerBlock = (kBlockSizeI * kBlockSizeJ) / (32 * 32);
  396. // Compute delta for the f16 kernels
  397. // TODO: Figure out why it's slower on the f32 kernels
  398. // (something due to RF pressure?)
  399. // TODO: Remove condition on `kOutputInRF` - this is needed to work
  400. // around a compiler bug on V100, not exactly sure why but I spent
  401. // too much time on this already. Reproducible with
  402. // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance
  403. static constexpr bool kKernelComputesDelta =
  404. kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70);
  405. static constexpr bool kNeedsAccumGradQ =
  406. !cutlass::platform::is_same<output_accum_t, output_t>::value;
  407. static constexpr bool kNeedsAccumGradK =
  408. !kOutputInRF && !cutlass::platform::is_same<output_accum_t, output_t>::value;
  409. static constexpr bool kNeedsAccumGradV =
  410. !kOutputInRF && !cutlass::platform::is_same<output_accum_t, output_t>::value;
  411. // Launch bounds
  412. static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock;
  413. static constexpr int64_t kMinBlocksPerSm =
  414. getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
  415. using GemmType = DefaultGemmType<ArchTag, scalar_t>;
  416. using DefaultConfig =
  417. typename cutlass::gemm::device::DefaultGemmConfiguration<typename GemmType::OpClass,
  418. ArchTag,
  419. scalar_t,
  420. scalar_t,
  421. scalar_t, // ElementC
  422. accum_t // ElementAccumulator
  423. >;
  424. static constexpr auto kOptimalAlignement =
  425. cutlass::platform::max(DefaultConfig::kAlignmentA, DefaultConfig::kAlignmentB);
  426. static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment;
  427. struct MatmulQK {
  428. /*
  429. attn_T = k_j @ q_i.transpose(-2, -1) # matmul
  430. attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2,
  431. -1)).exp() # epilogue
  432. with attn_T.shape = (kBlockSizeJ, kBlockSizeI)
  433. */
  434. using ThreadblockShape =
  435. cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
  436. using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
  437. using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma<
  438. scalar_t, // ElementA
  439. cutlass::layout::RowMajor, // LayoutA
  440. kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
  441. scalar_t, // ElementB
  442. cutlass::layout::ColumnMajor, // LayoutB
  443. kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
  444. accum_t, // ElementC
  445. cutlass::layout::RowMajor, // LayoutC
  446. typename GemmType::OpClass,
  447. ArchTag,
  448. ThreadblockShape,
  449. WarpShape,
  450. typename GemmType::InstructionShape,
  451. DefaultConfig::kStages,
  452. typename GemmType::Operator,
  453. false, // AccumulatorsInRowMajor = false,
  454. cutlass::gemm::SharedMemoryClearOption::kNone>;
  455. using MmaCore = typename DefaultMma::MmaCore;
  456. using Mma = typename MakeCustomMma<typename DefaultMma::ThreadblockMma, kMaxK>::Mma;
  457. // used for efficient load of bias tile (Bij) from global memory to shared
  458. // memory
  459. using BiasLoader =
  460. TileSmemLoader<scalar_t,
  461. // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded
  462. // row-major but needs to have transposed shape so we get the same
  463. // elements.
  464. cutlass::MatrixShape<ThreadblockShape::kN, ThreadblockShape::kM>,
  465. MmaCore::kThreads,
  466. // input restriction: kv_len has to be a multiple of this value
  467. 128 / cutlass::sizeof_bits<scalar_t>::value>;
  468. // Epilogue to store to shared-memory in a format that we can use later for
  469. // the second matmul
  470. using B2bGemm =
  471. typename cutlass::gemm::threadblock::B2bGemm<typename Mma::Operator::IteratorC,
  472. typename Mma::Operator,
  473. scalar_t,
  474. WarpShape,
  475. ThreadblockShape>;
  476. using AccumLambdaIterator =
  477. typename DefaultMmaAccumLambdaIterator<typename Mma::Operator::IteratorC,
  478. accum_t,
  479. kWarpSize>::Iterator;
  480. using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
  481. };
  482. struct MatmulGradV {
  483. /*
  484. grad_v[j_start:j_end] += attn_T @ do_i # matmul
  485. Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K)
  486. (we might need to iterate multiple times on K)
  487. */
  488. using ThreadblockShape =
  489. cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
  490. using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
  491. using InstructionShape = typename GemmType::InstructionShape;
  492. using DefaultGemm =
  493. cutlass::gemm::kernel::DefaultGemm<scalar_t, // ElementA,
  494. cutlass::layout::RowMajor, // LayoutA,
  495. DefaultConfig::kAlignmentA,
  496. scalar_t, // ElementB,
  497. cutlass::layout::RowMajor, // LayoutB,
  498. kIsAligned ? DefaultConfig::kAlignmentB
  499. : GemmType::kMinimumAlignment,
  500. output_t,
  501. cutlass::layout::RowMajor, // LayoutC,
  502. accum_t,
  503. typename GemmType::OpClass,
  504. ArchTag,
  505. ThreadblockShape,
  506. WarpShape,
  507. typename GemmType::InstructionShape,
  508. typename DefaultConfig::EpilogueOutputOp,
  509. void, // ThreadblockSwizzle - not used
  510. DefaultConfig::kStages,
  511. false, // SplitKSerial
  512. typename GemmType::Operator>;
  513. // if dropout:
  514. // for computing dVj += (Pij.T * Zij) @ dOi
  515. // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of
  516. // Pij.T are loaded in. The reason we do it this way is because Pij.T and
  517. // Zij are reused in later steps, while Pij_dropped.T is only needed in
  518. // this step. computing Pij_dropped.T on the fly allows us to avoid
  519. // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the
  520. // same time.
  521. // if no dropout:
  522. // for computing dVj += Pij.T @ dOi
  523. using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
  524. typename DefaultGemm::Mma,
  525. typename MatmulQK::AccumulatorSharedStorage,
  526. kApplyDropout>; // kScaleOperandA
  527. using Mma = typename DefaultMmaFromSmem::Mma;
  528. using WarpIteratorA = typename DefaultMmaFromSmem::WarpIteratorA;
  529. using IteratorB = typename Mma::IteratorB;
  530. using WarpCount = typename Mma::WarpCount;
  531. // Epilogue
  532. using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
  533. using DefaultEpilogue = typename DefaultGemm::Epilogue;
  534. using OutputTileIterator =
  535. typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
  536. typename DefaultEpilogue::OutputTileIterator>::Iterator;
  537. using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
  538. };
  539. struct MatmulDOIVJ {
  540. /*
  541. doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul
  542. tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue?
  543. */
  544. using ThreadblockShape =
  545. cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
  546. using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
  547. using ElementC = accum_t; // CSY: Change it for better accuracy
  548. using ElementAccum = accum_t;
  549. // no-op output op - epilogue just stores result to global memory
  550. using BiasGradEpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination<
  551. ElementC,
  552. DefaultConfig::EpilogueOutputOp::kCount,
  553. typename DefaultConfig::EpilogueOutputOp::ElementAccumulator,
  554. typename DefaultConfig::EpilogueOutputOp::ElementCompute,
  555. cutlass::epilogue::thread::ScaleType::Nothing>;
  556. using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm<
  557. scalar_t, // ElementA
  558. cutlass::layout::RowMajor, // LayoutA
  559. kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
  560. scalar_t, // ElementB
  561. cutlass::layout::ColumnMajor, // LayoutB
  562. kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
  563. ElementC, // ElementC
  564. cutlass::layout::RowMajor, // LayoutC
  565. ElementAccum, // ElementAccumulator
  566. typename GemmType::OpClass,
  567. ArchTag,
  568. ThreadblockShape,
  569. WarpShape,
  570. typename GemmType::InstructionShape,
  571. BiasGradEpilogueOutputOp, // EpilogueOutputOp
  572. void, // ThreadblockSwizzle (not used)
  573. // multiple preloads, dropout Zij tile, and 3 stages push us over shared
  574. // memory capacity on A100. set a ceiling on number of stages to save
  575. // shared memory if dropout is in use.
  576. kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64)
  577. ? cutlass::const_min(2, DefaultConfig::kStages)
  578. : DefaultConfig::kStages, // Stages
  579. false, // SplitKSerial
  580. typename GemmType::Operator,
  581. cutlass::gemm::SharedMemoryClearOption::kNone>;
  582. using Mma = typename MakeCustomMma<typename DefaultGemm::Mma, kMaxK>::Mma;
  583. // epilogue used to write bias gradient, which is just the output of this
  584. // matmul with some operations applied to the fragment
  585. using BiasGradEpilogue = typename DefaultGemm::Epilogue;
  586. // Epilogue to store to shared-memory in a format that we can use later for
  587. // the second matmul
  588. using B2bGemm =
  589. typename cutlass::gemm::threadblock::B2bGemm<typename Mma::Operator::IteratorC,
  590. typename Mma::Operator,
  591. scalar_t,
  592. WarpShape,
  593. ThreadblockShape>;
  594. using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
  595. };
  596. struct MatmulGradQ {
  597. // grad_q <- tmp @ k_j
  598. using ThreadblockShape =
  599. cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
  600. using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
  601. using InstructionShape = typename GemmType::InstructionShape;
  602. using DefaultGemm =
  603. cutlass::gemm::kernel::DefaultGemm<scalar_t, // ElementA,
  604. cutlass::layout::RowMajor, // LayoutA,
  605. DefaultConfig::kAlignmentA,
  606. scalar_t, // ElementB,
  607. cutlass::layout::RowMajor, // LayoutB,
  608. kIsAligned ? DefaultConfig::kAlignmentB
  609. : GemmType::kMinimumAlignment,
  610. output_t,
  611. cutlass::layout::RowMajor, // LayoutC,
  612. accum_t,
  613. typename GemmType::OpClass,
  614. ArchTag,
  615. ThreadblockShape,
  616. WarpShape,
  617. typename GemmType::InstructionShape,
  618. typename DefaultConfig::EpilogueOutputOp,
  619. void, // ThreadblockSwizzle - not used
  620. DefaultConfig::kStages,
  621. false, // SplitKSerial
  622. typename GemmType::Operator>;
  623. using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
  624. typename DefaultGemm::Mma,
  625. typename MatmulDOIVJ::AccumulatorSharedStorage,
  626. false>; // kScaleOperandA
  627. using Mma = typename DefaultMmaFromSmem::Mma;
  628. using IteratorB = typename Mma::IteratorB;
  629. using WarpCount = typename Mma::WarpCount;
  630. // Epilogue
  631. using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
  632. using DefaultEpilogue = typename DefaultGemm::Epilogue;
  633. using OutputTileIterator =
  634. typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
  635. typename DefaultEpilogue::OutputTileIterator>::Iterator;
  636. using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
  637. };
  638. struct MatmulGradK {
  639. // grad_k <- tmp.transpose(-2, -1) @ q_i
  640. using ThreadblockShape =
  641. cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
  642. using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
  643. using InstructionShape = typename GemmType::InstructionShape;
  644. using DefaultGemm =
  645. cutlass::gemm::kernel::DefaultGemm<scalar_t, // ElementA,
  646. cutlass::layout::RowMajor, // LayoutA,
  647. DefaultConfig::kAlignmentA,
  648. scalar_t, // ElementB,
  649. cutlass::layout::RowMajor, // LayoutB,
  650. kIsAligned ? DefaultConfig::kAlignmentB
  651. : GemmType::kMinimumAlignment,
  652. output_t,
  653. cutlass::layout::RowMajor, // LayoutC,
  654. accum_t,
  655. typename GemmType::OpClass,
  656. ArchTag,
  657. ThreadblockShape,
  658. WarpShape,
  659. typename GemmType::InstructionShape,
  660. typename DefaultConfig::EpilogueOutputOp,
  661. void, // ThreadblockSwizzle - not used
  662. DefaultConfig::kStages,
  663. false, // SplitKSerial
  664. typename GemmType::Operator>;
  665. using DefaultMmaFromSmemN = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
  666. typename DefaultGemm::Mma,
  667. typename MatmulQK::AccumulatorSharedStorage,
  668. false>; // kScaleOperandA
  669. using DefaultMmaFromSmemT = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
  670. typename DefaultGemm::Mma,
  671. typename MatmulDOIVJ::AccumulatorSharedStorage,
  672. false, // kScaleOperandA
  673. kPreload>; // kTransposeA
  674. using DefaultMmaFromSmem =
  675. typename cutlass::platform::conditional<DefaultMmaFromSmemT::kIsTransposedA,
  676. DefaultMmaFromSmemT,
  677. DefaultMmaFromSmemN>::type;
  678. using Mma = typename DefaultMmaFromSmem::Mma;
  679. using IteratorB = typename Mma::IteratorB;
  680. using WarpCount = typename Mma::WarpCount;
  681. // Epilogue
  682. using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
  683. using DefaultEpilogue = typename DefaultGemm::Epilogue;
  684. using OutputTileIterator =
  685. typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
  686. typename DefaultEpilogue::OutputTileIterator>::Iterator;
  687. using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
  688. };
  689. using broadcast_1 = Broadcast1_<typename MatmulQK::BiasLoader::ThreadMap,
  690. typename MatmulQK::BiasLoader::Shape,
  691. scalar_t>;
  692. using broadcast_2 = Broadcast2_<typename MatmulQK::BiasLoader::ThreadMap,
  693. typename MatmulQK::BiasLoader::Shape,
  694. scalar_t>;
  695. // shared storage for keeping Zij matrix. not needed if we aren't using
  696. // dropout, in which case we use an empty array to save shared memory
  697. using ZijSharedStorage = typename cutlass::platform::conditional<
  698. kApplyDropout,
  699. typename MatmulQK::AccumulatorSharedStorage,
  700. // dummy shared storage object that takes up no space.
  701. typename cutlass::gemm::threadblock::AccumulatorSharedStorage<
  702. #ifdef _WIN32
  703. // windows builds throw the error:
  704. // "type containing an unknown-size array is not allowed"
  705. // if we try to make Zij shared storage zero-sized.
  706. // To get around this just make it sized 1 on windows.
  707. typename cutlass::gemm::GemmShape<1, 1, 0>,
  708. #else
  709. typename cutlass::gemm::GemmShape<0, 0, 0>,
  710. #endif
  711. typename MatmulQK::AccumulatorSharedStorage::Element,
  712. typename MatmulQK::AccumulatorSharedStorage::Layout,
  713. typename cutlass::MatrixShape<0, 0>>>::type;
  714. struct SharedStoragePrologue {
  715. struct {
  716. cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
  717. typename MatmulQK::Mma::SharedStorageA mm_qk_k;
  718. } persistent;
  719. union {
  720. struct {
  721. // part1 - after Q.K / dV / dO.V
  722. union {
  723. // 1. efficient load of bias tile Bij, which is then applied to Pij
  724. // typename MatmulQK::BiasLoader::SmemTile bias;
  725. cutlass::AlignedBuffer<float, MatmulQK::BiasLoader::Shape::kCount> bias;
  726. // 4. store Pij. it is needed:
  727. // - in dVj += (Pij.T * Zij) @ dOi
  728. // - in dSij = Pij * (dPij - Di)
  729. // 6. dVj += (Pij.T * Zij) @ dOi
  730. // 10. write to fragment
  731. typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
  732. };
  733. // 5. store Zij. it is needed:
  734. // - to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij
  735. // are loaded for the computation of dVj.
  736. // - to compute dPij = (dOi @ Vj.T) * Zij
  737. // 6. used in dVj += (Pij.T * Zij) @ dOi
  738. // 9. used in dPij = dPij_dropped * Zij
  739. ZijSharedStorage zij;
  740. union {
  741. // 2. prologue for dVj
  742. // 6. workspace for dVj += (Pij.T * Zij) @ dOi
  743. typename MatmulGradV::Mma::SharedStorage mm_gradV;
  744. // 7. dVj epilogue
  745. typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
  746. };
  747. // 3. prologue for dPij_dropped
  748. // 8. used in dPij_dropped = dOi @ Vj.T
  749. typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
  750. } part1;
  751. struct {
  752. // part2 - dQ
  753. union {
  754. typename MatmulQK::AccumulatorSharedStorage
  755. tmpT_shared_storage; // (from part1)
  756. typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
  757. };
  758. typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
  759. typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload)
  760. union {
  761. // store dB = dSij to global memory
  762. typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
  763. typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
  764. };
  765. } part2;
  766. struct {
  767. // part3 - after last iteration on dQ's epilogue / dK
  768. union {
  769. typename MatmulQK::AccumulatorSharedStorage
  770. tmpT_shared_storage; // (from part1)
  771. typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
  772. };
  773. typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
  774. typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter;
  775. typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
  776. } part3;
  777. struct {
  778. // part4 - after last iteration on dK's epilogue / preload next K.Q_t
  779. typename MatmulQK::Mma::SharedStorageB mm_qk_q;
  780. // If we reach end of current key, dump RF->gmem with "final" epilogues
  781. typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final;
  782. typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final;
  783. } part4;
  784. };
  785. // ===========================================
  786. #define FIELD(INSIDE_STRUCT, FIELDNAME) \
  787. CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; }
  788. FIELD(persistent, di)
  789. FIELD(persistent, mm_qk_k)
  790. FIELD(part1, bias)
  791. FIELD(part1, attn_shared_storage)
  792. FIELD(part1, zij)
  793. FIELD(part1, mm_gradV)
  794. FIELD(part1, gradV_epilogue)
  795. FIELD(part1, mm_doivj)
  796. FIELD(part2, mm_gradK)
  797. FIELD(part2, mm_gradQ)
  798. FIELD(part2, gradB_epilogue)
  799. FIELD(part2, gradQ_epilogue)
  800. FIELD(part2, tmp_shared_storage)
  801. FIELD(part3, tmpT_shared_storage)
  802. FIELD(part3, gradQ_epilogue_lastIter)
  803. FIELD(part3, gradK_epilogue)
  804. FIELD(part4, mm_qk_q)
  805. FIELD(part4, gradK_epilogue_final)
  806. FIELD(part4, gradV_epilogue_final)
  807. };
  808. struct SharedStorageNoPrologue {
  809. struct {
  810. cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
  811. } persistent;
  812. union {
  813. struct {
  814. // part1 - Q.K matmul
  815. typename MatmulQK::Mma::SharedStorageA mm_qk_k;
  816. typename MatmulQK::Mma::SharedStorageB mm_qk_q;
  817. } part1;
  818. struct {
  819. // part2 - compute gradV
  820. union {
  821. // 1. efficient load of bias tile Bij, which is then applied to Pij
  822. cutlass::AlignedBuffer<float, MatmulQK::BiasLoader::Shape::kCount> bias;
  823. // 2. store Pij to shared memory. it is needed:
  824. // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi
  825. // - in next step where it is used in dSij = Pij * (dPij - Di)
  826. typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
  827. };
  828. // 3. store Zij. it is needed:
  829. // - in this step, where it is used to compute Pij_dropped = Pij * Zij
  830. // on the
  831. // fly as fragments of Pij are loaded for the computation of dVj.
  832. // - later to compute dPij = (dOi @ Vj.T) * Zij
  833. ZijSharedStorage zij;
  834. union {
  835. typename MatmulGradV::Mma::SharedStorage mm_gradV;
  836. typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
  837. };
  838. } part2;
  839. struct {
  840. // part3 - DO.V matmul
  841. union {
  842. // first compute dPij = (dOi @ Vj.T) * Zij
  843. // and dSij = Pij * (dPij - Di)
  844. struct {
  845. // (from part2) - Pij for computing dSij = Pij * (dPij - Di)
  846. typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
  847. // (from part2) - Zij for computing dPij = dPij_dropped * Zij
  848. ZijSharedStorage zij;
  849. // matmul to compute dOiVj
  850. typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
  851. };
  852. // then store dB = dSij to global memory
  853. typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
  854. };
  855. } part3;
  856. struct {
  857. // part4 - compute gradQ
  858. typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2)
  859. typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
  860. union {
  861. typename MatmulGradQ::Mma::SharedStorage mm_gradQ;
  862. typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
  863. typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter;
  864. };
  865. } part4;
  866. struct {
  867. // part5 - compute gradK
  868. typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2)
  869. typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
  870. union {
  871. typename MatmulGradK::Mma::SharedStorage mm_gradK;
  872. typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
  873. };
  874. } part5;
  875. struct {
  876. // part6 - store RF accumulated into gmem
  877. typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final;
  878. typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final;
  879. } part6;
  880. };
  881. // ===========================================
  882. #define FIELD(INSIDE_STRUCT, FIELDNAME) \
  883. CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; }
  884. FIELD(persistent, di)
  885. FIELD(part1, mm_qk_k)
  886. FIELD(part1, mm_qk_q)
  887. FIELD(part2, bias)
  888. FIELD(part2, attn_shared_storage)
  889. FIELD(part2, zij)
  890. FIELD(part2, mm_gradV)
  891. FIELD(part2, gradV_epilogue)
  892. FIELD(part3, mm_doivj)
  893. FIELD(part3, gradB_epilogue)
  894. FIELD(part4, tmpT_shared_storage)
  895. FIELD(part4, tmp_shared_storage)
  896. FIELD(part4, mm_gradQ)
  897. FIELD(part4, gradQ_epilogue)
  898. FIELD(part4, gradQ_epilogue_lastIter)
  899. FIELD(part5, mm_gradK)
  900. FIELD(part5, gradK_epilogue)
  901. FIELD(part6, gradK_epilogue_final)
  902. FIELD(part6, gradV_epilogue_final)
  903. };
  904. using SharedStorage = typename cutlass::platform::
  905. conditional<kPreload, SharedStoragePrologue, SharedStorageNoPrologue>::type;
  906. struct OutputFragments {
  907. typename MatmulGradV::Mma::FragmentC gradV;
  908. typename MatmulGradK::Mma::FragmentC gradK;
  909. CUTLASS_DEVICE void clear()
  910. {
  911. gradV.clear();
  912. gradK.clear();
  913. }
  914. };
  915. static bool __host__ check_supported(Params const& p)
  916. {
  917. CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment);
  918. CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment);
  919. CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
  920. CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
  921. CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
  922. EVOFORMER_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned");
  923. EVOFORMER_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned");
  924. EVOFORMER_CHECK(p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0,
  925. "query is not correctly aligned (strideH)");
  926. EVOFORMER_CHECK(p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0,
  927. "key is not correctly aligned (strideH)");
  928. EVOFORMER_CHECK(p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0,
  929. "value is not correctly aligned (strideH)");
  930. EVOFORMER_CHECK(p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0,
  931. "query is not correctly aligned (strideB)");
  932. EVOFORMER_CHECK(p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0,
  933. "key is not correctly aligned (strideB)");
  934. EVOFORMER_CHECK(p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0,
  935. "value is not correctly aligned (strideB)");
  936. EVOFORMER_CHECK(p.q_strideM % kMinimumAlignment == 0,
  937. "query is not correctly aligned (strideM)");
  938. EVOFORMER_CHECK(p.k_strideM % kMinimumAlignment == 0,
  939. "key is not correctly aligned (strideM)");
  940. EVOFORMER_CHECK(p.v_strideM % kMinimumAlignment == 0,
  941. "value is not correctly aligned (strideM)");
  942. EVOFORMER_CHECK(p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f,
  943. "Invalid value for `dropout_prob`");
  944. EVOFORMER_CHECK(kApplyDropout || p.dropout_prob == 0.0f,
  945. "Set `kApplyDropout`=True to support `dropout_prob > 0`");
  946. EVOFORMER_CHECK(p.head_dim > 0, "Invalid value for `head_dim`");
  947. EVOFORMER_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`");
  948. EVOFORMER_CHECK(p.num_queries > 0, "Invalid value for `num_queries`");
  949. EVOFORMER_CHECK(p.num_keys > 0, "Invalid value for `num_keys`");
  950. EVOFORMER_CHECK(p.num_heads > 0, "Invalid value for `num_heads`");
  951. EVOFORMER_CHECK(p.num_batches > 0, "Invalid value for `num_batches`");
  952. EVOFORMER_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`");
  953. EVOFORMER_CHECK(p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`");
  954. return true;
  955. }
  956. static CUTLASS_DEVICE void attention_kernel(Params p)
  957. {
  958. extern __shared__ char smem_buffer[];
  959. SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
  960. uint16_t thread_id = threadIdx.x;
  961. uint8_t warp_id = warp_uniform(thread_id / 32);
  962. uint8_t lane_id = thread_id % 32;
  963. if (kPrologueQK) {
  964. prologueQkNextIteration<true>(shared_storage, p, 0, 0, warp_id, lane_id);
  965. }
  966. // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr`
  967. if (kKernelComputesDelta) {
  968. constexpr int kOptimalElements = 128 / cutlass::sizeof_bits<scalar_t>::value;
  969. if (p.head_dim_value % kOptimalElements == 0) {
  970. for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) {
  971. computeDelta<kOptimalElements>(p, query_start, warp_id, lane_id);
  972. }
  973. } else {
  974. for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) {
  975. computeDelta<1>(p, query_start, warp_id, lane_id);
  976. }
  977. }
  978. __syncthreads();
  979. }
  980. OutputFragments output_frags;
  981. int32_t key_start = 0;
  982. int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ;
  983. for (; key_start < key_end; key_start += kBlockSizeJ) {
  984. output_frags.clear();
  985. int32_t query_start = getQueryStart(p, key_start);
  986. int32_t query_end =
  987. query_start + (p.num_queries - query_start) / kBlockSizeI * kBlockSizeI;
  988. for (; query_start < query_end; query_start += kBlockSizeI) {
  989. processBlockIJ<true>(
  990. shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id);
  991. }
  992. // last (partial) query
  993. if (query_start < p.num_queries) {
  994. processBlockIJ<false>(
  995. shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id);
  996. }
  997. if (kOutputInRF) {
  998. writeFragsToGmem<true>(
  999. shared_storage, output_frags, p, key_start, warp_id, lane_id);
  1000. } else if (getQueryStart(p, key_start) >= p.num_queries) {
  1001. zfillGradKV<true>(p, key_start, warp_id, lane_id);
  1002. }
  1003. __syncthreads();
  1004. }
  1005. // Last (partial) key
  1006. if (key_start != p.num_keys) {
  1007. output_frags.clear();
  1008. int32_t query_start = getQueryStart(p, key_start);
  1009. for (; query_start < p.num_queries; query_start += kBlockSizeI) {
  1010. warp_id = warp_uniform(warp_id);
  1011. processBlockIJ<false>(
  1012. shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id);
  1013. }
  1014. if (kOutputInRF) {
  1015. writeFragsToGmem<false>(
  1016. shared_storage, output_frags, p, key_start, warp_id, lane_id);
  1017. } else if (getQueryStart(p, key_start) >= p.num_queries) {
  1018. zfillGradKV<false>(p, key_start, warp_id, lane_id);
  1019. }
  1020. }
  1021. }
  1022. static CUTLASS_DEVICE void loadDi(cutlass::Array<accum_t, kBlockSizeI>& di,
  1023. Params const& p,
  1024. int32_t query_start)
  1025. {
  1026. int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x;
  1027. if (thread_id < kBlockSizeI) {
  1028. accum_t di_rf = accum_t(0);
  1029. if (query_start + thread_id < p.num_queries) {
  1030. di_rf = p.delta_ptr[query_start + thread_id];
  1031. }
  1032. di[thread_id] = di_rf;
  1033. }
  1034. }
  1035. template <bool skipBoundsChecks>
  1036. static CUTLASS_DEVICE void zfillGradKV(Params const& p,
  1037. int32_t key_start,
  1038. uint8_t warp_id,
  1039. uint8_t lane_id)
  1040. {
  1041. constexpr int kThreadsPerKey = 8;
  1042. constexpr int kParallelKeys = kNumThreads / kThreadsPerKey;
  1043. static_assert(kBlockSizeJ % kParallelKeys == 0, "");
  1044. // This function is not really optimized, but should rarely be used
  1045. // It's only used when some keys are "useless" and don't attend to
  1046. // any query, due to causal masking
  1047. int thread_id = 32 * warp_id + lane_id;
  1048. int k_shift = lane_id % kThreadsPerKey;
  1049. CUTLASS_PRAGMA_UNROLL
  1050. for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) {
  1051. int key = key_start + j + (thread_id / kThreadsPerKey);
  1052. if (!skipBoundsChecks && key >= p.num_keys) { continue; }
  1053. auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM();
  1054. auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM();
  1055. for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) {
  1056. gv_ptr[k] = scalar_t(0);
  1057. }
  1058. for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { gk_ptr[k] = scalar_t(0); }
  1059. }
  1060. }
  1061. template <bool skipBoundsChecks>
  1062. static CUTLASS_DEVICE void processBlockIJ(SharedStorage& shared_storage,
  1063. OutputFragments& output_frags,
  1064. Params& p,
  1065. int32_t query_start,
  1066. int32_t key_start,
  1067. uint8_t warp_id,
  1068. uint8_t lane_id)
  1069. {
  1070. cutlass::MatrixCoord no_offset{0, 0};
  1071. accum_t scale = p.scale;
  1072. int16_t thread_id = 32 * warp_id + lane_id;
  1073. auto rematerializeThreadIds = [&]() {
  1074. // Prevents `nvcc` from keeping values deduced from
  1075. // `thread_id`, `warp_id`, ... in RF - to reduce register pressure
  1076. warp_id = warp_uniform(thread_id / 32);
  1077. lane_id = thread_id % 32;
  1078. thread_id = 32 * warp_id + lane_id;
  1079. };
  1080. bool isFirstQuery = (query_start == getQueryStart(p, key_start));
  1081. int32_t next_query, next_key;
  1082. incrIteration(p, query_start, key_start, next_query, next_key);
  1083. bool isLastQuery = next_key != key_start;
  1084. __syncthreads();
  1085. loadDi(shared_storage.di(), p, query_start);
  1086. int32_t num_queries_in_block =
  1087. skipBoundsChecks ? MatmulQK::Mma::Shape::kN
  1088. : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kN,
  1089. p.num_queries - query_start));
  1090. int32_t num_keys_in_block =
  1091. skipBoundsChecks ? MatmulQK::Mma::Shape::kM
  1092. : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM,
  1093. p.num_keys - key_start));
  1094. auto prologueGradV = [&](int col) {
  1095. typename MatmulGradV::Mma::IteratorB iterator_dO(
  1096. {int32_t(p.gO_strideM)},
  1097. p.grad_output_ptr + query_start * p.gO_strideM + col,
  1098. {num_queries_in_block, p.head_dim_value - col},
  1099. thread_id,
  1100. no_offset);
  1101. MatmulGradV::Mma::prologue(
  1102. shared_storage.mm_gradV(), iterator_dO, thread_id, num_queries_in_block);
  1103. };
  1104. auto prologueGradQ = [&](int col) {
  1105. typename MatmulGradQ::Mma::IteratorB iterator_K(
  1106. {int32_t(p.k_strideM)},
  1107. p.key_ptr + key_start * p.k_strideM + col,
  1108. {num_keys_in_block, p.head_dim - col},
  1109. thread_id,
  1110. no_offset);
  1111. MatmulGradQ::Mma::prologue(
  1112. shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block);
  1113. };
  1114. auto prologueGradK = [&](int col) {
  1115. typename MatmulGradK::Mma::IteratorB iterator_Q(
  1116. {int32_t(p.q_strideM)},
  1117. p.query_ptr + query_start * p.q_strideM + col,
  1118. {num_queries_in_block, p.head_dim - col},
  1119. thread_id,
  1120. no_offset);
  1121. MatmulGradK::Mma::prologue(
  1122. shared_storage.mm_gradK(), iterator_Q, thread_id, num_queries_in_block);
  1123. };
  1124. auto prologueDOV = [&]() {
  1125. typename MatmulDOIVJ::Mma::IteratorA iterator_A(
  1126. {int32_t(p.gO_strideM)},
  1127. p.grad_output_ptr + query_start * p.gO_strideM,
  1128. {num_queries_in_block, p.head_dim_value},
  1129. thread_id,
  1130. no_offset);
  1131. typename MatmulDOIVJ::Mma::IteratorB iterator_B({int32_t(p.v_strideM)},
  1132. p.value_ptr + key_start * p.v_strideM,
  1133. {p.head_dim_value, num_keys_in_block},
  1134. thread_id,
  1135. no_offset);
  1136. MatmulDOIVJ::Mma::prologue(
  1137. shared_storage.mm_doivj(), iterator_A, iterator_B, thread_id, p.head_dim_value);
  1138. };
  1139. /////////////////////////////////////////////////////////////////////////////////////////////////
  1140. // MatmulQK
  1141. /////////////////////////////////////////////////////////////////////////////////////////////////
  1142. {
  1143. using Mma = typename MatmulQK::Mma;
  1144. cutlass::gemm::GemmCoord problem_size(num_keys_in_block,
  1145. num_queries_in_block,
  1146. p.head_dim // k
  1147. );
  1148. // k_j
  1149. typename Mma::IteratorA iterator_A({int32_t(p.k_strideM)},
  1150. p.key_ptr + key_start * p.k_strideM,
  1151. {problem_size.m(), problem_size.k()},
  1152. thread_id,
  1153. no_offset);
  1154. // q_i.transpose(-2, -1)
  1155. typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)},
  1156. p.query_ptr + query_start * p.q_strideM,
  1157. {problem_size.k(), problem_size.n()},
  1158. thread_id,
  1159. no_offset);
  1160. Mma mma(
  1161. shared_storage.mm_qk_k(), shared_storage.mm_qk_q(), thread_id, warp_id, lane_id);
  1162. typename Mma::FragmentC accum;
  1163. accum.clear();
  1164. auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
  1165. // Compute threadblock-scoped matrix multiply-add
  1166. mma.set_prologue_done(kPrologueQK);
  1167. mma.set_zero_outside_bounds(!skipBoundsChecks);
  1168. mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
  1169. // Epilogue: add LSE + exp and store that to our shared memory buffer
  1170. // shmem <- (matmul_result -
  1171. // logsumexp[i_start:i_end].unsqueeze(1)).exp()
  1172. int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
  1173. auto output_tile_coords = cutlass::MatrixCoord{
  1174. warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM};
  1175. if (broadcast_1::kEnable || broadcast_2::kEnable) {
  1176. cutlass::TensorRef<float, cutlass::layout::RowMajor> bias_tensor_ref(
  1177. shared_storage.bias().data(),
  1178. cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM));
  1179. using Shape = cutlass::MatrixShape<MatmulQK::ThreadblockShape::kM,
  1180. MatmulQK::ThreadblockShape::kN>;
  1181. AttentionBiasEpilogue<Shape,
  1182. scalar_t,
  1183. MatmulQK::MmaCore::kThreads,
  1184. Broadcast1_,
  1185. Broadcast2_>
  1186. bias_epilogue;
  1187. bias_epilogue(bias_tensor_ref,
  1188. p.bias1_ptr + key_start,
  1189. p.bias2_ptr + query_start * p.num_keys + key_start,
  1190. thread_id,
  1191. {num_queries_in_block, num_keys_in_block},
  1192. p.num_keys);
  1193. // Pij += Bij, Pij is in register fragment and Bij is in shared memory
  1194. auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
  1195. lane_id, warp_id, output_tile_coords);
  1196. MatmulQK::AccumLambdaIterator::iterateRows(
  1197. lane_offset,
  1198. [&](int accum_n) {},
  1199. [&](int accum_m, int accum_n, int idx) {
  1200. // remember we are transposed
  1201. accum[idx] = accum[idx] * scale + bias_tensor_ref.at({accum_n, accum_m});
  1202. },
  1203. [&](int accum_n) {});
  1204. } else {
  1205. accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
  1206. }
  1207. __syncthreads();
  1208. if (kPrologueGV) { prologueGradV(0); }
  1209. if (kPrologueDOV) { prologueDOV(); }
  1210. MatmulQK::B2bGemm::accumApplyLSEToSmem(shared_storage.attn_shared_storage(),
  1211. accum,
  1212. p.logsumexp_ptr + query_start,
  1213. problem_size.n(),
  1214. thread_id,
  1215. warp_id,
  1216. lane_id,
  1217. output_tile_coords);
  1218. __syncthreads();
  1219. }
  1220. rematerializeThreadIds();
  1221. /////////////////////////////////////////////////////////////////////////////////////////////////
  1222. // GradV matmul
  1223. //
  1224. // grad_v[j_start:j_end] += attn_T @ do_i
  1225. /////////////////////////////////////////////////////////////////////////////////////////////////
  1226. constexpr bool kSingleIterationGradV = kMaxK <= MatmulGradV::ThreadblockShape::kN;
  1227. for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value);
  1228. col += MatmulGradV::ThreadblockShape::kN) {
  1229. using Mma = typename MatmulGradV::Mma;
  1230. using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
  1231. cutlass::gemm::GemmCoord problem_size(
  1232. num_keys_in_block, p.head_dim_value - col, num_queries_in_block);
  1233. auto createEpilogueIter = [&]() {
  1234. return typename MatmulGradV::OutputTileIterator(
  1235. typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
  1236. p.grad_value_ptr + key_start * p.gV_strideM() + col,
  1237. {num_keys_in_block, p.head_dim_value - col},
  1238. thread_id);
  1239. };
  1240. typename Mma::IteratorB iterator_B({int32_t(p.gO_strideM)},
  1241. p.grad_output_ptr + query_start * p.gO_strideM + col,
  1242. {num_queries_in_block, p.head_dim_value - col},
  1243. thread_id,
  1244. no_offset);
  1245. // if dropout: dVj += (Pij.T * Zij) @ dOi
  1246. // otherwise: dVj += Pij.T @ dOi
  1247. Mma mma(shared_storage.mm_gradV(),
  1248. // operand A: Pij
  1249. typename MatmulGradV::WarpIteratorA(
  1250. shared_storage.attn_shared_storage().accum_ref(), lane_id),
  1251. // if we're using dropout, operand A is Pij_dropped = Pij * Zij
  1252. // which is computed on the fly as fragments of Pij are loaded in
  1253. typename Mma::WarpIteratorAScale(shared_storage.zij().accum_ref(), lane_id),
  1254. thread_id,
  1255. warp_id,
  1256. lane_id);
  1257. int storage_id = col / MatmulGradV::ThreadblockShape::kN;
  1258. AccumTileGmem gmem_tile{p.workspace_gv + storage_id * AccumTileGmem::kElementsStored};
  1259. if (!kOutputInRF) {
  1260. if (isFirstQuery || !kNeedsAccumGradV) {
  1261. output_frags.gradV.clear();
  1262. } else {
  1263. gmem_tile.load(output_frags.gradV, thread_id);
  1264. }
  1265. }
  1266. mma.set_prologue_done(kPrologueGV);
  1267. auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
  1268. // Compute threadblock-scoped matrix multiply-add
  1269. __syncthreads();
  1270. mma(gemm_k_iterations, output_frags.gradV, iterator_B, output_frags.gradV);
  1271. __syncthreads();
  1272. if (kPrologueGV && !kSingleIterationGradV &&
  1273. col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) {
  1274. prologueGradV(col + MatmulGradV::ThreadblockShape::kN);
  1275. }
  1276. if (!kOutputInRF) {
  1277. if (kNeedsAccumGradV && !isLastQuery) {
  1278. gmem_tile.store(output_frags.gradV, thread_id);
  1279. } else {
  1280. accumulateInGmem<MatmulGradV>(shared_storage.gradV_epilogue(),
  1281. output_frags.gradV,
  1282. createEpilogueIter(),
  1283. isFirstQuery || kNeedsAccumGradV,
  1284. warp_id,
  1285. lane_id);
  1286. }
  1287. }
  1288. }
  1289. __syncthreads();
  1290. /////////////////////////////////////////////////////////////////////////////////////////////////
  1291. // MatmulDOIVJ
  1292. /////////////////////////////////////////////////////////////////////////////////////////////////
  1293. {
  1294. using Mma = typename MatmulDOIVJ::Mma;
  1295. // do_i
  1296. typename Mma::IteratorA iterator_A({int32_t(p.gO_strideM)},
  1297. p.grad_output_ptr + query_start * p.gO_strideM,
  1298. {num_queries_in_block, p.head_dim_value},
  1299. thread_id,
  1300. no_offset);
  1301. // v_j.transpose(-2, -1)
  1302. typename Mma::IteratorB iterator_B({int32_t(p.v_strideM)},
  1303. p.value_ptr + key_start * p.v_strideM,
  1304. {p.head_dim_value, num_keys_in_block},
  1305. thread_id,
  1306. no_offset);
  1307. Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id);
  1308. mma.set_prologue_done(kPrologueDOV);
  1309. mma.set_zero_outside_bounds(!skipBoundsChecks);
  1310. typename Mma::FragmentC accum;
  1311. accum.clear();
  1312. auto gemm_k_iterations = (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK;
  1313. // Compute threadblock-scoped matrix multiply-add
  1314. mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
  1315. __syncthreads();
  1316. if (kPrologueGQ) { prologueGradQ(0); }
  1317. if (kPrologueGK) { prologueGradK(0); }
  1318. int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
  1319. auto output_tile_coords = cutlass::MatrixCoord{
  1320. warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM};
  1321. // TODO: This must be terribly inefficient. There must be a better way
  1322. // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem]
  1323. // attn_shared_storage [smem] <- tmp.T
  1324. // tmp_shared_storage [smem] <- tmp
  1325. {
  1326. using LambdaIterator =
  1327. typename DefaultMmaAccumLambdaIterator<typename Mma::Operator::IteratorC,
  1328. typename MatmulDOIVJ::ElementAccum,
  1329. kWarpSize>::Iterator;
  1330. auto lane_offset =
  1331. LambdaIterator::get_lane_offset(lane_id, warp_id, output_tile_coords);
  1332. auto attn_T = shared_storage.attn_shared_storage().accum_ref();
  1333. accum_t current_di;
  1334. // dSij = (dPij - Di) * Pij
  1335. LambdaIterator::iterateRows(
  1336. lane_offset,
  1337. [&](int accum_m) { current_di = shared_storage.di()[accum_m]; },
  1338. [&](int accum_m, int accum_n, int idx) {
  1339. if (skipBoundsChecks ||
  1340. (accum_m < num_queries_in_block && accum_n < num_keys_in_block)) {
  1341. accum_t attn = attn_T.at({accum_n, accum_m});
  1342. accum[idx] = (accum[idx] - current_di) * attn;
  1343. } else {
  1344. accum[idx] = 0;
  1345. }
  1346. },
  1347. [&](int accum_m) {
  1348. });
  1349. using DefaultGemm = typename MatmulDOIVJ::DefaultGemm;
  1350. using OutputOp = typename MatmulDOIVJ::BiasGradEpilogueOutputOp;
  1351. if (broadcast_1::kEnable && p.grad_bias1_ptr) {
  1352. using Epilogue =
  1353. typename BiasGradEpilogueAffineRankN<ArchTag,
  1354. 2,
  1355. typename MatmulDOIVJ::ThreadblockShape,
  1356. typename DefaultGemm::Mma::Operator,
  1357. DefaultGemm::kPartitionsK,
  1358. OutputOp,
  1359. OutputOp::kCount>::Epilogue;
  1360. cutlass::layout::AffineRankN<2> layout({0, 1});
  1361. auto dst_ptr = p.grad_bias1_ptr + key_start;
  1362. typename Epilogue::OutputTileIterator output_iter(
  1363. {layout},
  1364. dst_ptr,
  1365. {num_queries_in_block, num_keys_in_block},
  1366. (int)thread_id);
  1367. Epilogue epilogue(shared_storage.gradB_epilogue(),
  1368. (int)thread_id,
  1369. (int)warp_id,
  1370. (int)lane_id);
  1371. epilogue(OutputOp(1), output_iter, accum);
  1372. }
  1373. if (broadcast_2::kEnable && p.grad_bias2_ptr) {
  1374. if (broadcast_1::kEnable) { __syncthreads(); }
  1375. using Epilogue =
  1376. typename BiasGradEpilogue<ArchTag,
  1377. typename MatmulDOIVJ::ThreadblockShape,
  1378. typename DefaultGemm::Mma::Operator,
  1379. DefaultGemm::kPartitionsK,
  1380. OutputOp,
  1381. OutputOp::kCount>::Epilogue;
  1382. typename Epilogue::OutputTileIterator::Params params{p.num_keys};
  1383. auto dst_ptr = p.grad_bias2_ptr + query_start * p.num_keys + key_start;
  1384. typename Epilogue::OutputTileIterator output_iter(
  1385. params, dst_ptr, {num_queries_in_block, num_keys_in_block}, (int)thread_id);
  1386. Epilogue epilogue(shared_storage.gradB_epilogue(),
  1387. (int)thread_id,
  1388. (int)warp_id,
  1389. (int)lane_id);
  1390. epilogue(OutputOp(1), output_iter, accum);
  1391. }
  1392. accum = accum * scale;
  1393. __syncthreads();
  1394. if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) {
  1395. auto tmpT = shared_storage.tmpT_shared_storage().accum_ref();
  1396. // attn <- attn_T.T
  1397. LambdaIterator::iterateRows(
  1398. lane_offset,
  1399. [&](int accum_m) {},
  1400. [&](int accum_m, int accum_n, int idx) {
  1401. tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]);
  1402. },
  1403. [&](int accum_m) {});
  1404. }
  1405. }
  1406. MatmulDOIVJ::B2bGemm::accumToSmem(
  1407. shared_storage.tmp_shared_storage(), accum, lane_id, output_tile_coords);
  1408. __syncthreads();
  1409. }
  1410. p.head_dim = warp_uniform(p.head_dim);
  1411. p.k_strideM = warp_uniform(p.k_strideM);
  1412. rematerializeThreadIds();
  1413. /////////////////////////////////////////////////////////////////////////////////////////////////
  1414. // GradQ matmul
  1415. //
  1416. // grad_q[i_start:i_end] += tmp @ k_j
  1417. /////////////////////////////////////////////////////////////////////////////////////////////////
  1418. // Skip the loop & associated branches if we know at compile time the number
  1419. // of iterations
  1420. constexpr bool kSingleIterationGradQ = kMaxK <= MatmulGradQ::ThreadblockShape::kN;
  1421. for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim);
  1422. col += MatmulGradQ::ThreadblockShape::kN) {
  1423. using Mma = typename MatmulGradQ::Mma;
  1424. using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
  1425. cutlass::gemm::GemmCoord problem_size(
  1426. num_queries_in_block,
  1427. false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col,
  1428. num_keys_in_block);
  1429. // k_j
  1430. typename Mma::IteratorB iterator_B({int32_t(p.k_strideM)},
  1431. p.key_ptr + key_start * p.k_strideM + col,
  1432. {problem_size.k(), problem_size.n()},
  1433. thread_id,
  1434. no_offset);
  1435. auto a = shared_storage.tmp_shared_storage().accum_ref();
  1436. Mma mma(shared_storage.mm_gradQ(),
  1437. shared_storage.tmp_shared_storage(),
  1438. thread_id,
  1439. warp_id,
  1440. lane_id,
  1441. problem_size.k());
  1442. typename Mma::FragmentC accum;
  1443. bool isFirst = key_start == 0;
  1444. int col_id = col / MatmulGradQ::ThreadblockShape::kN;
  1445. int num_cols =
  1446. kSingleIterationGradQ ? 1 : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN);
  1447. int storage_id = (col_id + query_start / kBlockSizeI * num_cols);
  1448. AccumTileGmem gmem_tile{p.workspace_gq + storage_id * AccumTileGmem::kElementsStored};
  1449. if (isFirst || !kNeedsAccumGradQ) {
  1450. accum.clear();
  1451. } else {
  1452. gmem_tile.load(accum, thread_id);
  1453. }
  1454. auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
  1455. // Compute threadblock-scoped matrix multiply-add
  1456. __syncthreads();
  1457. mma.set_prologue_done(kPrologueGQ);
  1458. mma(gemm_k_iterations, accum, iterator_B, accum);
  1459. __syncthreads();
  1460. bool isLastColumn = kSingleIterationGradQ ||
  1461. (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim);
  1462. if (kPrologueGQ && !isLastColumn) {
  1463. prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN);
  1464. }
  1465. // Output results
  1466. int32_t next_query, next_key;
  1467. incrIteration(p, p.num_queries, key_start, next_query, next_key);
  1468. bool isLast = next_query > query_start || next_key >= p.num_keys;
  1469. if (kNeedsAccumGradQ && !isLast) {
  1470. gmem_tile.store(accum, thread_id);
  1471. } else {
  1472. typename MatmulGradQ::OutputTileIterator output_it(
  1473. typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()},
  1474. p.grad_query_ptr + query_start * p.gQ_strideM() + col,
  1475. {problem_size.m(), problem_size.n()},
  1476. thread_id);
  1477. accumulateInGmem<MatmulGradQ>(isLastColumn
  1478. ? shared_storage.gradQ_epilogue_lastIter()
  1479. : shared_storage.gradQ_epilogue(),
  1480. accum,
  1481. output_it,
  1482. isFirst || kNeedsAccumGradQ,
  1483. warp_id,
  1484. lane_id);
  1485. }
  1486. }
  1487. /////////////////////////////////////////////////////////////////////////////////////////////////
  1488. // GradK matmul
  1489. //
  1490. // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i
  1491. /////////////////////////////////////////////////////////////////////////////////////////////////
  1492. rematerializeThreadIds();
  1493. constexpr bool kSingleIterationGradK = kMaxK <= MatmulGradK::ThreadblockShape::kN;
  1494. for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim);
  1495. col += MatmulGradK::ThreadblockShape::kN) {
  1496. using Mma = typename MatmulGradK::Mma;
  1497. using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
  1498. cutlass::gemm::GemmCoord problem_size(
  1499. num_keys_in_block,
  1500. false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col,
  1501. num_queries_in_block);
  1502. auto createEpilogueIter = [&]() {
  1503. return typename MatmulGradK::OutputTileIterator(
  1504. typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
  1505. p.grad_key_ptr + key_start * p.gK_strideM() + col,
  1506. {num_keys_in_block,
  1507. false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col},
  1508. thread_id);
  1509. };
  1510. // q_i
  1511. typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)},
  1512. p.query_ptr + query_start * p.q_strideM + col,
  1513. {problem_size.k(), problem_size.n()},
  1514. thread_id,
  1515. no_offset);
  1516. auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); };
  1517. auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); };
  1518. // this is basically:
  1519. // opA = kIsTransposedA ? getTmp() : getTmpT();
  1520. bool constexpr kIsTransposedA = MatmulGradK::DefaultMmaFromSmem::kIsTransposedA;
  1521. auto& opA =
  1522. *call_conditional<kIsTransposedA, decltype(getTmp), decltype(getTmpT)>::apply(
  1523. getTmp, getTmpT, 0);
  1524. Mma mma(shared_storage.mm_gradK(), opA, thread_id, warp_id, lane_id, problem_size.k());
  1525. int storage_id = col / MatmulGradK::ThreadblockShape::kN;
  1526. AccumTileGmem gmem_tile{p.workspace_gk + storage_id * AccumTileGmem::kElementsStored};
  1527. if (!kOutputInRF) {
  1528. if (isFirstQuery || !kNeedsAccumGradK) {
  1529. output_frags.gradK.clear();
  1530. } else {
  1531. gmem_tile.load(output_frags.gradK, thread_id);
  1532. }
  1533. }
  1534. mma.set_prologue_done(kPrologueGK);
  1535. auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
  1536. // Compute threadblock-scoped matrix multiply-add
  1537. __syncthreads();
  1538. mma(gemm_k_iterations, output_frags.gradK, iterator_B, output_frags.gradK);
  1539. __syncthreads();
  1540. bool isLastColumn = kSingleIterationGradK ||
  1541. col + MatmulGradK::ThreadblockShape::kN >= p.head_dim;
  1542. if (kPrologueGK && !isLastColumn) {
  1543. prologueGradK(col + MatmulGradK::ThreadblockShape::kN);
  1544. }
  1545. if (kPrologueQK && isLastColumn) {
  1546. int32_t next_query, next_key;
  1547. incrIteration(p, query_start, key_start, next_query, next_key);
  1548. DISPATCH_BOOL(next_key != key_start, kForceReloadK, ([&]() {
  1549. prologueQkNextIteration<kForceReloadK>(
  1550. shared_storage, p, next_query, next_key, warp_id, lane_id);
  1551. }));
  1552. }
  1553. // Output results
  1554. if (!kOutputInRF) {
  1555. if (kNeedsAccumGradK && !isLastQuery) {
  1556. gmem_tile.store(output_frags.gradK, thread_id);
  1557. } else {
  1558. accumulateInGmem<MatmulGradK>(isLastColumn
  1559. ? shared_storage.gradK_epilogue_final()
  1560. : shared_storage.gradK_epilogue(),
  1561. output_frags.gradK,
  1562. createEpilogueIter(),
  1563. isFirstQuery || kNeedsAccumGradK,
  1564. warp_id,
  1565. lane_id);
  1566. __syncthreads();
  1567. }
  1568. }
  1569. }
  1570. }
  1571. static CUTLASS_DEVICE int32_t getQueryStart(Params const& p, int32_t key_start) { return 0; };
  1572. static CUTLASS_DEVICE void incrIteration(Params const& p,
  1573. int32_t query_start,
  1574. int32_t key_start,
  1575. int32_t& next_query,
  1576. int32_t& next_key)
  1577. {
  1578. next_query = query_start + kBlockSizeI;
  1579. next_key = key_start;
  1580. if (next_query >= p.num_queries) {
  1581. next_key = key_start + kBlockSizeJ;
  1582. next_query = getQueryStart(p, next_key);
  1583. }
  1584. }
  1585. template <bool kForceReloadK>
  1586. static CUTLASS_DEVICE void prologueQkNextIteration(SharedStorage& shared_storage,
  1587. Params const& p,
  1588. int32_t query_start,
  1589. int32_t key_start,
  1590. uint8_t warp_id,
  1591. uint8_t lane_id)
  1592. {
  1593. if (query_start >= p.num_queries || key_start >= p.num_keys) { return; }
  1594. static constexpr bool kReloadK = kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat;
  1595. int thread_id = 32 * warp_id + lane_id;
  1596. typename MatmulQK::Mma::IteratorA iterator_A({int32_t(p.k_strideM)},
  1597. p.key_ptr + key_start * p.k_strideM,
  1598. {p.num_keys - key_start, p.head_dim},
  1599. thread_id,
  1600. cutlass::MatrixCoord{0, 0});
  1601. typename MatmulQK::Mma::IteratorB iterator_B({int32_t(p.q_strideM)},
  1602. p.query_ptr + query_start * p.q_strideM,
  1603. {p.head_dim, p.num_queries - query_start},
  1604. thread_id,
  1605. cutlass::MatrixCoord{0, 0});
  1606. MatmulQK::Mma::prologue<kReloadK, true>(shared_storage.mm_qk_k(),
  1607. shared_storage.mm_qk_q(),
  1608. iterator_A,
  1609. iterator_B,
  1610. thread_id,
  1611. p.head_dim);
  1612. }
  1613. template <bool skipBoundsChecks>
  1614. static CUTLASS_DEVICE void writeFragsToGmem(SharedStorage& shared_storage,
  1615. OutputFragments& output_frags,
  1616. Params const& p,
  1617. int32_t key_start,
  1618. uint8_t warp_id,
  1619. uint8_t lane_id)
  1620. {
  1621. uint16_t thread_id = 32 * warp_id + lane_id;
  1622. int32_t num_keys_in_block =
  1623. skipBoundsChecks
  1624. ? MatmulQK::Mma::Shape::kM
  1625. : cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start);
  1626. typename MatmulGradV::OutputTileIterator outputV_it(
  1627. typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
  1628. p.grad_value_ptr + key_start * p.gV_strideM(),
  1629. {num_keys_in_block, p.head_dim_value},
  1630. thread_id);
  1631. accumulateInGmem<MatmulGradV>(shared_storage.gradV_epilogue_final(),
  1632. output_frags.gradV,
  1633. outputV_it,
  1634. true,
  1635. warp_id,
  1636. lane_id);
  1637. typename MatmulGradK::OutputTileIterator outputK_it(
  1638. typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
  1639. p.grad_key_ptr + key_start * p.gK_strideM(),
  1640. {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim},
  1641. thread_id);
  1642. accumulateInGmem<MatmulGradK>(shared_storage.gradK_epilogue_final(),
  1643. output_frags.gradK,
  1644. outputK_it,
  1645. true,
  1646. warp_id,
  1647. lane_id);
  1648. }
  1649. template <typename MatmulT>
  1650. static CUTLASS_DEVICE void accumulateInGmem(
  1651. typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem,
  1652. typename MatmulT::Mma::FragmentC const& accum,
  1653. typename MatmulT::OutputTileIterator output_it,
  1654. bool first,
  1655. uint8_t warp_id,
  1656. uint8_t lane_id)
  1657. {
  1658. using DefaultEpilogue = typename MatmulT::DefaultEpilogue;
  1659. using DefaultOutputOp = typename MatmulT::DefaultOutputOp;
  1660. using Mma = typename MatmulT::Mma;
  1661. int thread_id = 32 * warp_id + lane_id;
  1662. DISPATCH_BOOL(
  1663. first, kIsFirst, ([&]() {
  1664. static constexpr auto ScaleType =
  1665. kIsFirst ? cutlass::epilogue::thread::ScaleType::Nothing
  1666. : cutlass::epilogue::thread::ScaleType::NoBetaScaling;
  1667. using EpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination<
  1668. typename DefaultOutputOp::ElementOutput,
  1669. DefaultOutputOp::kCount,
  1670. typename DefaultOutputOp::ElementAccumulator,
  1671. typename DefaultOutputOp::ElementCompute,
  1672. ScaleType>;
  1673. using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined<
  1674. typename DefaultEpilogue::Shape,
  1675. typename Mma::Operator,
  1676. DefaultEpilogue::kPartitionsK,
  1677. typename MatmulT::OutputTileIterator,
  1678. typename DefaultEpilogue::AccumulatorFragmentIterator,
  1679. typename DefaultEpilogue::WarpTileIterator,
  1680. typename DefaultEpilogue::SharedLoadIterator,
  1681. EpilogueOutputOp,
  1682. typename DefaultEpilogue::Padding,
  1683. DefaultEpilogue::kFragmentsPerIteration,
  1684. true // IterationsUnroll
  1685. >;
  1686. EpilogueOutputOp rescale({1, 1});
  1687. Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id);
  1688. epilogue(rescale, output_it, accum, output_it);
  1689. }));
  1690. }
  1691. template <int kElementsPerAccess>
  1692. static CUTLASS_DEVICE void computeDelta(Params const& p,
  1693. int32_t query_start,
  1694. uint8_t warp_id,
  1695. uint8_t lane_id)
  1696. {
  1697. // Each thread computes one value for Delta
  1698. // Depending on warp configuration, we might have multiple
  1699. // threads of the same warp working on the same row
  1700. using AccessType = cutlass::Array<scalar_t, kElementsPerAccess>;
  1701. static_assert(kNumThreads >= kBlockSizeI, "");
  1702. static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI;
  1703. int16_t thread_id = 32 * warp_id + lane_id;
  1704. int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine);
  1705. int16_t laneRow = thread_id / kNumThreadsPerLine;
  1706. bool rowPred = (query_start + laneRow) < p.num_queries;
  1707. bool pred = rowPred;
  1708. // on windows, previous syntax __restrict__ AccessType*
  1709. // resulted in error: "restrict" is not allowed
  1710. const AccessType* __restrict__ grad_output_ptr = reinterpret_cast<const AccessType*>(
  1711. p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + laneFirstCol);
  1712. const AccessType* __restrict__ output_ptr = reinterpret_cast<const AccessType*>(
  1713. p.output_ptr + (query_start + laneRow) * p.o_strideM() + laneFirstCol);
  1714. static constexpr int64_t kMaxIters = kMaxK / (kElementsPerAccess * kNumThreadsPerLine);
  1715. constexpr int kPipelineStages = 2;
  1716. accum_t delta_value = accum_t(0);
  1717. using GlobalLoad = cutlass::arch::global_load<AccessType, sizeof(AccessType)>;
  1718. AccessType frag_grad_output[kPipelineStages];
  1719. AccessType frag_output[kPipelineStages];
  1720. auto loadAndIncrement = [&](int ld_pos, bool is_valid) {
  1721. frag_grad_output[ld_pos].clear();
  1722. frag_output[ld_pos].clear();
  1723. GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid);
  1724. GlobalLoad(frag_output[ld_pos], output_ptr, is_valid);
  1725. grad_output_ptr += kNumThreadsPerLine;
  1726. output_ptr += kNumThreadsPerLine;
  1727. };
  1728. CUTLASS_PRAGMA_UNROLL
  1729. for (int iter = 0; iter < kPipelineStages - 1; ++iter) {
  1730. int ld_pos = iter % kPipelineStages;
  1731. pred = pred && (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) <
  1732. p.head_dim_value;
  1733. loadAndIncrement(ld_pos, pred);
  1734. }
  1735. auto columnIteration = [&](int iter) {
  1736. // Load for next iter
  1737. int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages;
  1738. pred = pred && (laneFirstCol + (iter + kPipelineStages - 1) * kElementsPerAccess *
  1739. kNumThreadsPerLine) < p.head_dim_value;
  1740. loadAndIncrement(ld_pos, pred);
  1741. CUTLASS_PRAGMA_UNROLL
  1742. for (int i = 0; i < AccessType::kElements; ++i) {
  1743. delta_value += accum_t(frag_output[iter % kPipelineStages][i]) *
  1744. accum_t(frag_grad_output[iter % kPipelineStages][i]);
  1745. }
  1746. };
  1747. // If we have a small lower-bound for K, we can unroll the loop
  1748. if (kMaxK <= 256) {
  1749. CUTLASS_PRAGMA_UNROLL
  1750. for (int iter = 0; iter < kMaxIters; ++iter) { columnIteration(iter); }
  1751. } else {
  1752. int num_iters = ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) *
  1753. (kElementsPerAccess * kNumThreadsPerLine);
  1754. for (int iter = 0; iter < num_iters; ++iter) { columnIteration(iter); }
  1755. }
  1756. // Reduce between workers
  1757. static_assert(kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || kNumThreadsPerLine == 4,
  1758. "");
  1759. CUTLASS_PRAGMA_UNROLL
  1760. for (int i = 1; i < kNumThreadsPerLine; i *= 2) {
  1761. delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i);
  1762. }
  1763. // Store in gmem
  1764. if (rowPred) { p.delta_ptr[query_start + laneRow] = delta_value; }
  1765. }
  1766. };
  1767. template <typename AK>
  1768. __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
  1769. attention_kernel_backward_batched_impl(typename AK::Params p)
  1770. {
  1771. if (!p.advance_to_block()) { return; }
  1772. AK::attention_kernel(p);
  1773. }
  1774. template <typename AK>
  1775. __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
  1776. attention_kernel_backward_batched(typename AK::Params params);