123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965 |
- /***************************************************************************************************
- * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
- * SPDX-License-Identifier: BSD-3-Clause
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- *
- * 1. Redistributions of source code must retain the above copyright notice, this
- * list of conditions and the following disclaimer.
- *
- * 2. Redistributions in binary form must reproduce the above copyright notice,
- * this list of conditions and the following disclaimer in the documentation
- * and/or other materials provided with the distribution.
- *
- * 3. Neither the name of the copyright holdvr nor the names of its
- * contributors may be used to endorse or promote products derived from
- * this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- **************************************************************************************************/
- // Copyright (c) Microsoft Corporation.
- // SPDX-License-Identifier: Apache-2.0
- // DeepSpeed Team
- #pragma once
- #include <cmath>
- #include <type_traits>
- #include <vector>
- #include <cuda_fp16.h>
- #include "cutlass/cutlass.h"
- #include "cutlass/epilogue/thread/linear_combination.h"
- #include "cutlass/epilogue/thread/scale_type.h"
- #include "cutlass/fast_math.h"
- #include "cutlass/functional.h"
- #include "cutlass/gemm/gemm.h"
- #include "cutlass/layout/matrix.h"
- #include "cutlass/layout/vector.h"
- #include "cutlass/numeric_conversion.h"
- #include "cutlass/numeric_types.h"
- #include "cutlass/tensor_ref.h"
- #include "gemm_kernel_utils.h"
- #include "cutlass/epilogue/thread/linear_combination_relu.h"
- #include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
- #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
- #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
- #include "cutlass/gemm/device/default_gemm_configuration.h"
- #include "cutlass/gemm/kernel/default_gemm.h"
- #include "cutlass/gemm/threadblock/default_mma.h"
- #include "cutlass/gemm/threadblock/default_mma_core_simt.h"
- #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
- #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
- #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
- #include "cutlass/matrix_shape.h"
- #include "cutlass/platform/platform.h"
- #include "cutlass/transform/threadblock/predicated_tile_iterator.h"
- #include "cutlass/transform/threadblock/vector_iterator.h"
- #include "epilogue/epilogue_pipelined.h"
- #include "iterators/epilogue_predicated_tile_iterator.h"
- #include "epilogue/epilogue_grad_bias.h"
- #include "gemm/custom_mma.h"
- #include "gemm/find_default_mma.h"
- #include "gemm/mma_accum_lambda_iterator.h"
- #include "gemm/mma_from_smem.h"
- #include "transform/bias_broadcast.h"
- #include "transform/tile_smem_loader.h"
- #include <inttypes.h>
- using namespace gemm_kernel_utils;
- namespace {
- template <typename FragmentType, int32_t kNumThreads>
- struct GmemTile {
- /*
- Helper functions to efficient store/load RF to gmem
- GEMM accumulators have a particular format on A100, and
- it takes some compute/shared-memory to rearrange them to
- a RowMajor or ColumnMajor format in global memory through
- an Epilogue. The same complexity goes for loading into RF.
- This class loads/stores RF as they are, and can be used for
- efficient accumulation across gemms for instance:
- ```
- GmemTile tile;
- for (int i = 0; i < N; ++i) {
- // ...
- Fragment accum;
- if (i == 0) {
- accum.clear();
- } else {
- tile.load(accum);
- }
- mma(accum, ...);
- if (i < N-1) {
- // Store for next GEMM
- tile.store(accum);
- } else {
- // Store in tensor (eg RowMajor)
- epilogue(accum);
- }
- // ...
- }
- ```
- */
- // 128bits per thread
- using AccessType = cutlass::Array<float, 4>;
- static constexpr int32_t kBytes = sizeof(AccessType);
- static constexpr int32_t kStride = kNumThreads * AccessType::kElements;
- static constexpr int32_t kNumIters = FragmentType::kElements / AccessType::kElements;
- static constexpr int32_t kElementsStored = kNumThreads * FragmentType::kElements;
- static_assert(FragmentType::kElements % AccessType::kElements == 0,
- "fragment not aligned on 128 bits");
- float* ptr;
- CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id)
- {
- CUTLASS_PRAGMA_UNROLL
- for (int i = 0; i < kNumIters; ++i) {
- AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
- ptr + thread_id * AccessType::kElements + i * kStride);
- AccessType sub_fragment;
- cutlass::arch::global_load<AccessType, kBytes>(sub_fragment, gmem_ptr, true);
- CUTLASS_PRAGMA_UNROLL
- for (int j = 0; j < AccessType::kElements; ++j) {
- fragment[i * AccessType::kElements + j] = sub_fragment[j];
- }
- }
- }
- CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id)
- {
- CUTLASS_PRAGMA_UNROLL
- for (int i = 0; i < kNumIters; ++i) {
- AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
- ptr + thread_id * AccessType::kElements + i * kStride);
- AccessType sub_fragment;
- CUTLASS_PRAGMA_UNROLL
- for (int j = 0; j < AccessType::kElements; ++j) {
- sub_fragment[j] = fragment[i * AccessType::kElements + j];
- }
- cutlass::arch::global_store<AccessType, kBytes>(sub_fragment, gmem_ptr, true);
- }
- }
- };
- template <typename scalar_t, typename Arch>
- constexpr int getWarpsPerSm()
- {
- constexpr bool is_half = !cutlass::platform::is_same<scalar_t, float>::value;
- if (Arch::kMinComputeCapability >= 80) { return is_half ? 12 : 8; }
- return 8;
- }
- } // namespace
- template <
- // which arch we target (eg `cutlass::arch::Sm80`)
- typename ArchTag_,
- // input/output type
- typename scalar_t_,
- // run optimized kernel because memory accesses will be aligned
- bool kIsAligned_,
- // use dropout if enabled
- bool kApplyDropout_,
- // when doing a GEMM, preload the next one (uses more shmem)
- bool kPreload_,
- // block dimensions
- int kBlockSizeI_,
- int kBlockSizeJ_,
- // upperbound on `max(value.shape[-1], query.shape[-1])`
- int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
- template <typename, typename, typename> class Broadcast1_ = BroadcastNoLoad,
- template <typename, typename, typename> class Broadcast2_ = BroadcastNoLoad>
- struct AttentionBackwardKernel {
- using scalar_t = scalar_t_;
- using output_t = scalar_t;
- using output_accum_t = float;
- using lse_scalar_t = float;
- using accum_t = float;
- using ArchTag = ArchTag_;
- static constexpr bool kIsAligned = kIsAligned_;
- static constexpr bool kApplyDropout = kApplyDropout_;
- static constexpr bool kPreload = kPreload_;
- static constexpr int kBlockSizeI = kBlockSizeI_;
- static constexpr int kBlockSizeJ = kBlockSizeJ_;
- static constexpr int kMaxK = kMaxK_;
- struct Params {
- // Input tensors
- scalar_t* query_ptr; // [Mq, nH, K]
- scalar_t* key_ptr; // [Mk, nH, K]
- scalar_t* value_ptr; // [Mk, nH, Kv]
- lse_scalar_t* logsumexp_ptr; // [nH, Mq]
- scalar_t* output_ptr; // [Mq, nH, Kv]
- scalar_t* grad_output_ptr; // [Mq, nH, Kv]
- accum_t* delta_ptr; // [nH, Mq]
- int32_t* cu_seqlens_q_ptr = nullptr;
- int32_t* cu_seqlens_k_ptr = nullptr;
- // Output tensors
- output_t* grad_query_ptr; // [Mq, nH, K]
- output_t* grad_key_ptr; // [Mk, nH, K]
- output_t* grad_value_ptr; // [Mk, nH, Kv]
- accum_t* grad_bias1_ptr = nullptr;
- accum_t* grad_bias2_ptr = nullptr;
- int32_t B = 0;
- int32_t N = 0;
- scalar_t* bias1_ptr = nullptr;
- scalar_t* bias2_ptr = nullptr;
- // Accumulators
- union {
- output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv]
- output_accum_t* workspace_gk;
- };
- output_accum_t* workspace_gv; // (will be calculated by the kernel)
- output_accum_t* workspace_gq; // (will be calculated by the kernel)
- // Scale
- accum_t scale;
- // Dimensions/strides
- int32_t head_dim = -1;
- int32_t head_dim_value = -1;
- int32_t num_queries = -1;
- int32_t num_keys = -1;
- int32_t num_heads = -1;
- int32_t q_strideM;
- int32_t k_strideM;
- int32_t v_strideM;
- int32_t gO_strideM;
- int32_t gB_strideM;
- int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise
- // RNG sequence offset based on batch_id and head_id
- unsigned long long dropout_batch_head_rng_offset;
- float dropout_prob = 0.0f;
- CUTLASS_HOST_DEVICE int32_t o_strideM() const { return head_dim_value * num_heads; }
- CUTLASS_HOST_DEVICE int32_t gQ_strideM() const
- {
- return gQKV_strideM_multiplier * num_heads * head_dim;
- }
- CUTLASS_HOST_DEVICE int32_t gK_strideM() const
- {
- return gQKV_strideM_multiplier * num_heads * head_dim;
- }
- CUTLASS_HOST_DEVICE int32_t gV_strideM() const
- {
- return gQKV_strideM_multiplier * num_heads * head_dim_value;
- }
- // Everything below is only used in `advance_to_block`
- // and shouldn't use registers
- int64_t o_strideH;
- int32_t q_strideH;
- int32_t k_strideH;
- int32_t v_strideH;
- int64_t o_strideB;
- int64_t q_strideB;
- int64_t k_strideB;
- int64_t v_strideB;
- int64_t lse_strideB;
- int64_t lse_strideH;
- int64_t delta_strideB;
- int64_t delta_strideH;
- int32_t num_batches;
- int64_t gO_strideB = 0;
- int64_t gQ_strideB = 0;
- int64_t gK_strideB = 0;
- int64_t gV_strideB = 0;
- int64_t gB_strideB = 0;
- int64_t gO_strideH = 0;
- int64_t gQ_strideH = 0;
- int64_t gK_strideH = 0;
- int64_t gV_strideH = 0;
- int64_t gB_strideH = 0;
- CUTLASS_DEVICE bool advance_to_block()
- {
- int64_t batch_id = blockIdx.z;
- int32_t head_id = blockIdx.y;
- if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) {
- assert(workspace_size() == 0 || workspace != nullptr);
- workspace += (batch_id * num_heads + head_id) * workspace_strideBH();
- workspace = warp_uniform(workspace);
- workspace_gv = workspace + workspace_elements_gk();
- workspace_gq = workspace_gv + workspace_elements_gv();
- } else {
- workspace = nullptr;
- }
- // Advance pointers that depend on the total concatenated
- // number of queries, as `num_queries` is modified in the block
- // below
- dropout_batch_head_rng_offset = batch_id * (num_heads * num_queries * num_keys) +
- head_id * (num_queries * num_keys);
- logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH;
- query_ptr += batch_id * q_strideB + head_id * q_strideH;
- key_ptr += batch_id * k_strideB + head_id * k_strideH;
- value_ptr += batch_id * v_strideB + head_id * v_strideH;
- output_ptr += batch_id * o_strideB + head_id * o_strideH;
- grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH;
- delta_ptr += batch_id * delta_strideB + head_id * delta_strideH;
- grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH;
- grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH;
- grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH;
- using broadcast_1 = Broadcast1_<typename MatmulQK::BiasLoader::ThreadMap,
- typename MatmulQK::BiasLoader::Shape,
- scalar_t>;
- using broadcast_2 = Broadcast2_<typename MatmulQK::BiasLoader::ThreadMap,
- typename MatmulQK::BiasLoader::Shape,
- scalar_t>;
- if (broadcast_1::kEnable && grad_bias1_ptr) {
- grad_bias1_ptr += batch_id * num_queries;
- }
- if (broadcast_2::kEnable && grad_bias2_ptr) {
- auto strideB = num_heads * num_queries * num_keys;
- auto strideH = num_queries * num_keys;
- grad_bias2_ptr += (batch_id / N) * strideB + head_id * strideH;
- }
- if (broadcast_1::kEnable && bias1_ptr) {
- bias1_ptr = broadcast_1::advance(bias1_ptr,
- batch_id / N,
- batch_id % N,
- head_id,
- num_queries * N,
- num_queries,
- 0);
- }
- if (broadcast_2::kEnable && bias2_ptr) {
- auto strideB = num_heads * num_queries * num_keys;
- auto strideH = num_queries * num_keys;
- bias2_ptr = broadcast_2::advance(
- bias2_ptr, batch_id / N, batch_id % N, head_id, strideB, 0, strideH);
- }
- num_queries = warp_uniform(num_queries);
- num_keys = warp_uniform(num_keys);
- query_ptr = warp_uniform(query_ptr);
- key_ptr = warp_uniform(key_ptr);
- value_ptr = warp_uniform(value_ptr);
- logsumexp_ptr = warp_uniform(logsumexp_ptr);
- output_ptr = warp_uniform(output_ptr);
- grad_output_ptr = warp_uniform(grad_output_ptr);
- delta_ptr = warp_uniform(delta_ptr);
- grad_query_ptr = warp_uniform(grad_query_ptr);
- grad_key_ptr = warp_uniform(grad_key_ptr);
- grad_value_ptr = warp_uniform(grad_value_ptr);
- if (broadcast_1::kEnable) {
- grad_bias1_ptr = warp_uniform(grad_bias1_ptr);
- bias1_ptr = warp_uniform(bias1_ptr);
- }
- if (broadcast_2::kEnable) {
- grad_bias2_ptr = warp_uniform(grad_bias2_ptr);
- bias2_ptr = warp_uniform(bias2_ptr);
- }
- return true;
- }
- __host__ dim3 getBlocksGrid() const { return dim3(1, num_heads, num_batches); }
- __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); }
- CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const
- {
- if (!kNeedsAccumGradK) { return 0; }
- return align_up(num_keys, (int32_t)kBlockSizeJ) *
- align_up(head_dim, (int32_t)kBlockSizeI);
- }
- CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const
- {
- if (!kNeedsAccumGradV) { return 0; }
- return align_up(num_keys, (int32_t)kBlockSizeJ) *
- align_up(head_dim_value, (int32_t)kBlockSizeI);
- }
- CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const
- {
- if (!kNeedsAccumGradQ) { return 0; }
- if (num_keys <= kBlockSizeJ) { return 0; }
- return align_up(num_queries, (int32_t)kBlockSizeI) *
- align_up(head_dim, (int32_t)kBlockSizeJ);
- }
- CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const
- {
- // Aligned on 128bits
- return align_up(
- workspace_elements_gk() + workspace_elements_gv() + workspace_elements_gq(),
- int64_t(4));
- }
- CUTLASS_HOST_DEVICE int64_t workspace_size() const
- {
- // Returns size of buffer we need to run this kernel
- return num_batches * num_heads * workspace_strideBH() * sizeof(float);
- }
- };
- static constexpr int64_t kWarpSize = 32;
- // If this is true, we store and accumulate dK/dV in RF
- // rather than going back to gmem every time
- static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value <= 16;
- static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
- static_assert(!kPreload || (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF),
- "preload MMA not supported");
- static constexpr bool kPrologueQK = kPreload;
- static constexpr bool kPrologueGV = kPreload;
- static constexpr bool kPrologueDOV = kPreload;
- static constexpr bool kPrologueGQ = kPreload;
- static constexpr bool kPrologueGK = kPreload;
- static constexpr int64_t kNumWarpsPerBlock = (kBlockSizeI * kBlockSizeJ) / (32 * 32);
- // Compute delta for the f16 kernels
- // TODO: Figure out why it's slower on the f32 kernels
- // (something due to RF pressure?)
- // TODO: Remove condition on `kOutputInRF` - this is needed to work
- // around a compiler bug on V100, not exactly sure why but I spent
- // too much time on this already. Reproducible with
- // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance
- static constexpr bool kKernelComputesDelta =
- kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70);
- static constexpr bool kNeedsAccumGradQ =
- !cutlass::platform::is_same<output_accum_t, output_t>::value;
- static constexpr bool kNeedsAccumGradK =
- !kOutputInRF && !cutlass::platform::is_same<output_accum_t, output_t>::value;
- static constexpr bool kNeedsAccumGradV =
- !kOutputInRF && !cutlass::platform::is_same<output_accum_t, output_t>::value;
- // Launch bounds
- static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock;
- static constexpr int64_t kMinBlocksPerSm =
- getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
- using GemmType = DefaultGemmType<ArchTag, scalar_t>;
- using DefaultConfig =
- typename cutlass::gemm::device::DefaultGemmConfiguration<typename GemmType::OpClass,
- ArchTag,
- scalar_t,
- scalar_t,
- scalar_t, // ElementC
- accum_t // ElementAccumulator
- >;
- static constexpr auto kOptimalAlignement =
- cutlass::platform::max(DefaultConfig::kAlignmentA, DefaultConfig::kAlignmentB);
- static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment;
- struct MatmulQK {
- /*
- attn_T = k_j @ q_i.transpose(-2, -1) # matmul
- attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2,
- -1)).exp() # epilogue
- with attn_T.shape = (kBlockSizeJ, kBlockSizeI)
- */
- using ThreadblockShape =
- cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
- using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
- using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma<
- scalar_t, // ElementA
- cutlass::layout::RowMajor, // LayoutA
- kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
- scalar_t, // ElementB
- cutlass::layout::ColumnMajor, // LayoutB
- kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
- accum_t, // ElementC
- cutlass::layout::RowMajor, // LayoutC
- typename GemmType::OpClass,
- ArchTag,
- ThreadblockShape,
- WarpShape,
- typename GemmType::InstructionShape,
- DefaultConfig::kStages,
- typename GemmType::Operator,
- false, // AccumulatorsInRowMajor = false,
- cutlass::gemm::SharedMemoryClearOption::kNone>;
- using MmaCore = typename DefaultMma::MmaCore;
- using Mma = typename MakeCustomMma<typename DefaultMma::ThreadblockMma, kMaxK>::Mma;
- // used for efficient load of bias tile (Bij) from global memory to shared
- // memory
- using BiasLoader =
- TileSmemLoader<scalar_t,
- // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded
- // row-major but needs to have transposed shape so we get the same
- // elements.
- cutlass::MatrixShape<ThreadblockShape::kN, ThreadblockShape::kM>,
- MmaCore::kThreads,
- // input restriction: kv_len has to be a multiple of this value
- 128 / cutlass::sizeof_bits<scalar_t>::value>;
- // Epilogue to store to shared-memory in a format that we can use later for
- // the second matmul
- using B2bGemm =
- typename cutlass::gemm::threadblock::B2bGemm<typename Mma::Operator::IteratorC,
- typename Mma::Operator,
- scalar_t,
- WarpShape,
- ThreadblockShape>;
- using AccumLambdaIterator =
- typename DefaultMmaAccumLambdaIterator<typename Mma::Operator::IteratorC,
- accum_t,
- kWarpSize>::Iterator;
- using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
- };
- struct MatmulGradV {
- /*
- grad_v[j_start:j_end] += attn_T @ do_i # matmul
- Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K)
- (we might need to iterate multiple times on K)
- */
- using ThreadblockShape =
- cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
- using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
- using InstructionShape = typename GemmType::InstructionShape;
- using DefaultGemm =
- cutlass::gemm::kernel::DefaultGemm<scalar_t, // ElementA,
- cutlass::layout::RowMajor, // LayoutA,
- DefaultConfig::kAlignmentA,
- scalar_t, // ElementB,
- cutlass::layout::RowMajor, // LayoutB,
- kIsAligned ? DefaultConfig::kAlignmentB
- : GemmType::kMinimumAlignment,
- output_t,
- cutlass::layout::RowMajor, // LayoutC,
- accum_t,
- typename GemmType::OpClass,
- ArchTag,
- ThreadblockShape,
- WarpShape,
- typename GemmType::InstructionShape,
- typename DefaultConfig::EpilogueOutputOp,
- void, // ThreadblockSwizzle - not used
- DefaultConfig::kStages,
- false, // SplitKSerial
- typename GemmType::Operator>;
- // if dropout:
- // for computing dVj += (Pij.T * Zij) @ dOi
- // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of
- // Pij.T are loaded in. The reason we do it this way is because Pij.T and
- // Zij are reused in later steps, while Pij_dropped.T is only needed in
- // this step. computing Pij_dropped.T on the fly allows us to avoid
- // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the
- // same time.
- // if no dropout:
- // for computing dVj += Pij.T @ dOi
- using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
- typename DefaultGemm::Mma,
- typename MatmulQK::AccumulatorSharedStorage,
- kApplyDropout>; // kScaleOperandA
- using Mma = typename DefaultMmaFromSmem::Mma;
- using WarpIteratorA = typename DefaultMmaFromSmem::WarpIteratorA;
- using IteratorB = typename Mma::IteratorB;
- using WarpCount = typename Mma::WarpCount;
- // Epilogue
- using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
- using DefaultEpilogue = typename DefaultGemm::Epilogue;
- using OutputTileIterator =
- typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
- typename DefaultEpilogue::OutputTileIterator>::Iterator;
- using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
- };
- struct MatmulDOIVJ {
- /*
- doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul
- tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue?
- */
- using ThreadblockShape =
- cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
- using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
- using ElementC = accum_t; // CSY: Change it for better accuracy
- using ElementAccum = accum_t;
- // no-op output op - epilogue just stores result to global memory
- using BiasGradEpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination<
- ElementC,
- DefaultConfig::EpilogueOutputOp::kCount,
- typename DefaultConfig::EpilogueOutputOp::ElementAccumulator,
- typename DefaultConfig::EpilogueOutputOp::ElementCompute,
- cutlass::epilogue::thread::ScaleType::Nothing>;
- using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm<
- scalar_t, // ElementA
- cutlass::layout::RowMajor, // LayoutA
- kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
- scalar_t, // ElementB
- cutlass::layout::ColumnMajor, // LayoutB
- kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
- ElementC, // ElementC
- cutlass::layout::RowMajor, // LayoutC
- ElementAccum, // ElementAccumulator
- typename GemmType::OpClass,
- ArchTag,
- ThreadblockShape,
- WarpShape,
- typename GemmType::InstructionShape,
- BiasGradEpilogueOutputOp, // EpilogueOutputOp
- void, // ThreadblockSwizzle (not used)
- // multiple preloads, dropout Zij tile, and 3 stages push us over shared
- // memory capacity on A100. set a ceiling on number of stages to save
- // shared memory if dropout is in use.
- kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64)
- ? cutlass::const_min(2, DefaultConfig::kStages)
- : DefaultConfig::kStages, // Stages
- false, // SplitKSerial
- typename GemmType::Operator,
- cutlass::gemm::SharedMemoryClearOption::kNone>;
- using Mma = typename MakeCustomMma<typename DefaultGemm::Mma, kMaxK>::Mma;
- // epilogue used to write bias gradient, which is just the output of this
- // matmul with some operations applied to the fragment
- using BiasGradEpilogue = typename DefaultGemm::Epilogue;
- // Epilogue to store to shared-memory in a format that we can use later for
- // the second matmul
- using B2bGemm =
- typename cutlass::gemm::threadblock::B2bGemm<typename Mma::Operator::IteratorC,
- typename Mma::Operator,
- scalar_t,
- WarpShape,
- ThreadblockShape>;
- using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
- };
- struct MatmulGradQ {
- // grad_q <- tmp @ k_j
- using ThreadblockShape =
- cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
- using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
- using InstructionShape = typename GemmType::InstructionShape;
- using DefaultGemm =
- cutlass::gemm::kernel::DefaultGemm<scalar_t, // ElementA,
- cutlass::layout::RowMajor, // LayoutA,
- DefaultConfig::kAlignmentA,
- scalar_t, // ElementB,
- cutlass::layout::RowMajor, // LayoutB,
- kIsAligned ? DefaultConfig::kAlignmentB
- : GemmType::kMinimumAlignment,
- output_t,
- cutlass::layout::RowMajor, // LayoutC,
- accum_t,
- typename GemmType::OpClass,
- ArchTag,
- ThreadblockShape,
- WarpShape,
- typename GemmType::InstructionShape,
- typename DefaultConfig::EpilogueOutputOp,
- void, // ThreadblockSwizzle - not used
- DefaultConfig::kStages,
- false, // SplitKSerial
- typename GemmType::Operator>;
- using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
- typename DefaultGemm::Mma,
- typename MatmulDOIVJ::AccumulatorSharedStorage,
- false>; // kScaleOperandA
- using Mma = typename DefaultMmaFromSmem::Mma;
- using IteratorB = typename Mma::IteratorB;
- using WarpCount = typename Mma::WarpCount;
- // Epilogue
- using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
- using DefaultEpilogue = typename DefaultGemm::Epilogue;
- using OutputTileIterator =
- typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
- typename DefaultEpilogue::OutputTileIterator>::Iterator;
- using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
- };
- struct MatmulGradK {
- // grad_k <- tmp.transpose(-2, -1) @ q_i
- using ThreadblockShape =
- cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
- using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
- using InstructionShape = typename GemmType::InstructionShape;
- using DefaultGemm =
- cutlass::gemm::kernel::DefaultGemm<scalar_t, // ElementA,
- cutlass::layout::RowMajor, // LayoutA,
- DefaultConfig::kAlignmentA,
- scalar_t, // ElementB,
- cutlass::layout::RowMajor, // LayoutB,
- kIsAligned ? DefaultConfig::kAlignmentB
- : GemmType::kMinimumAlignment,
- output_t,
- cutlass::layout::RowMajor, // LayoutC,
- accum_t,
- typename GemmType::OpClass,
- ArchTag,
- ThreadblockShape,
- WarpShape,
- typename GemmType::InstructionShape,
- typename DefaultConfig::EpilogueOutputOp,
- void, // ThreadblockSwizzle - not used
- DefaultConfig::kStages,
- false, // SplitKSerial
- typename GemmType::Operator>;
- using DefaultMmaFromSmemN = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
- typename DefaultGemm::Mma,
- typename MatmulQK::AccumulatorSharedStorage,
- false>; // kScaleOperandA
- using DefaultMmaFromSmemT = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
- typename DefaultGemm::Mma,
- typename MatmulDOIVJ::AccumulatorSharedStorage,
- false, // kScaleOperandA
- kPreload>; // kTransposeA
- using DefaultMmaFromSmem =
- typename cutlass::platform::conditional<DefaultMmaFromSmemT::kIsTransposedA,
- DefaultMmaFromSmemT,
- DefaultMmaFromSmemN>::type;
- using Mma = typename DefaultMmaFromSmem::Mma;
- using IteratorB = typename Mma::IteratorB;
- using WarpCount = typename Mma::WarpCount;
- // Epilogue
- using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
- using DefaultEpilogue = typename DefaultGemm::Epilogue;
- using OutputTileIterator =
- typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
- typename DefaultEpilogue::OutputTileIterator>::Iterator;
- using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
- };
- using broadcast_1 = Broadcast1_<typename MatmulQK::BiasLoader::ThreadMap,
- typename MatmulQK::BiasLoader::Shape,
- scalar_t>;
- using broadcast_2 = Broadcast2_<typename MatmulQK::BiasLoader::ThreadMap,
- typename MatmulQK::BiasLoader::Shape,
- scalar_t>;
- // shared storage for keeping Zij matrix. not needed if we aren't using
- // dropout, in which case we use an empty array to save shared memory
- using ZijSharedStorage = typename cutlass::platform::conditional<
- kApplyDropout,
- typename MatmulQK::AccumulatorSharedStorage,
- // dummy shared storage object that takes up no space.
- typename cutlass::gemm::threadblock::AccumulatorSharedStorage<
- #ifdef _WIN32
- // windows builds throw the error:
- // "type containing an unknown-size array is not allowed"
- // if we try to make Zij shared storage zero-sized.
- // To get around this just make it sized 1 on windows.
- typename cutlass::gemm::GemmShape<1, 1, 0>,
- #else
- typename cutlass::gemm::GemmShape<0, 0, 0>,
- #endif
- typename MatmulQK::AccumulatorSharedStorage::Element,
- typename MatmulQK::AccumulatorSharedStorage::Layout,
- typename cutlass::MatrixShape<0, 0>>>::type;
- struct SharedStoragePrologue {
- struct {
- cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
- typename MatmulQK::Mma::SharedStorageA mm_qk_k;
- } persistent;
- union {
- struct {
- // part1 - after Q.K / dV / dO.V
- union {
- // 1. efficient load of bias tile Bij, which is then applied to Pij
- // typename MatmulQK::BiasLoader::SmemTile bias;
- cutlass::AlignedBuffer<float, MatmulQK::BiasLoader::Shape::kCount> bias;
- // 4. store Pij. it is needed:
- // - in dVj += (Pij.T * Zij) @ dOi
- // - in dSij = Pij * (dPij - Di)
- // 6. dVj += (Pij.T * Zij) @ dOi
- // 10. write to fragment
- typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
- };
- // 5. store Zij. it is needed:
- // - to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij
- // are loaded for the computation of dVj.
- // - to compute dPij = (dOi @ Vj.T) * Zij
- // 6. used in dVj += (Pij.T * Zij) @ dOi
- // 9. used in dPij = dPij_dropped * Zij
- ZijSharedStorage zij;
- union {
- // 2. prologue for dVj
- // 6. workspace for dVj += (Pij.T * Zij) @ dOi
- typename MatmulGradV::Mma::SharedStorage mm_gradV;
- // 7. dVj epilogue
- typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
- };
- // 3. prologue for dPij_dropped
- // 8. used in dPij_dropped = dOi @ Vj.T
- typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
- } part1;
- struct {
- // part2 - dQ
- union {
- typename MatmulQK::AccumulatorSharedStorage
- tmpT_shared_storage; // (from part1)
- typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
- };
- typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
- typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload)
- union {
- // store dB = dSij to global memory
- typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
- typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
- };
- } part2;
- struct {
- // part3 - after last iteration on dQ's epilogue / dK
- union {
- typename MatmulQK::AccumulatorSharedStorage
- tmpT_shared_storage; // (from part1)
- typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
- };
- typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
- typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter;
- typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
- } part3;
- struct {
- // part4 - after last iteration on dK's epilogue / preload next K.Q_t
- typename MatmulQK::Mma::SharedStorageB mm_qk_q;
- // If we reach end of current key, dump RF->gmem with "final" epilogues
- typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final;
- typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final;
- } part4;
- };
- // ===========================================
- #define FIELD(INSIDE_STRUCT, FIELDNAME) \
- CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; }
- FIELD(persistent, di)
- FIELD(persistent, mm_qk_k)
- FIELD(part1, bias)
- FIELD(part1, attn_shared_storage)
- FIELD(part1, zij)
- FIELD(part1, mm_gradV)
- FIELD(part1, gradV_epilogue)
- FIELD(part1, mm_doivj)
- FIELD(part2, mm_gradK)
- FIELD(part2, mm_gradQ)
- FIELD(part2, gradB_epilogue)
- FIELD(part2, gradQ_epilogue)
- FIELD(part2, tmp_shared_storage)
- FIELD(part3, tmpT_shared_storage)
- FIELD(part3, gradQ_epilogue_lastIter)
- FIELD(part3, gradK_epilogue)
- FIELD(part4, mm_qk_q)
- FIELD(part4, gradK_epilogue_final)
- FIELD(part4, gradV_epilogue_final)
- };
- struct SharedStorageNoPrologue {
- struct {
- cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
- } persistent;
- union {
- struct {
- // part1 - Q.K matmul
- typename MatmulQK::Mma::SharedStorageA mm_qk_k;
- typename MatmulQK::Mma::SharedStorageB mm_qk_q;
- } part1;
- struct {
- // part2 - compute gradV
- union {
- // 1. efficient load of bias tile Bij, which is then applied to Pij
- cutlass::AlignedBuffer<float, MatmulQK::BiasLoader::Shape::kCount> bias;
- // 2. store Pij to shared memory. it is needed:
- // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi
- // - in next step where it is used in dSij = Pij * (dPij - Di)
- typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
- };
- // 3. store Zij. it is needed:
- // - in this step, where it is used to compute Pij_dropped = Pij * Zij
- // on the
- // fly as fragments of Pij are loaded for the computation of dVj.
- // - later to compute dPij = (dOi @ Vj.T) * Zij
- ZijSharedStorage zij;
- union {
- typename MatmulGradV::Mma::SharedStorage mm_gradV;
- typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
- };
- } part2;
- struct {
- // part3 - DO.V matmul
- union {
- // first compute dPij = (dOi @ Vj.T) * Zij
- // and dSij = Pij * (dPij - Di)
- struct {
- // (from part2) - Pij for computing dSij = Pij * (dPij - Di)
- typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
- // (from part2) - Zij for computing dPij = dPij_dropped * Zij
- ZijSharedStorage zij;
- // matmul to compute dOiVj
- typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
- };
- // then store dB = dSij to global memory
- typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
- };
- } part3;
- struct {
- // part4 - compute gradQ
- typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2)
- typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
- union {
- typename MatmulGradQ::Mma::SharedStorage mm_gradQ;
- typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
- typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter;
- };
- } part4;
- struct {
- // part5 - compute gradK
- typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2)
- typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
- union {
- typename MatmulGradK::Mma::SharedStorage mm_gradK;
- typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
- };
- } part5;
- struct {
- // part6 - store RF accumulated into gmem
- typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final;
- typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final;
- } part6;
- };
- // ===========================================
- #define FIELD(INSIDE_STRUCT, FIELDNAME) \
- CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; }
- FIELD(persistent, di)
- FIELD(part1, mm_qk_k)
- FIELD(part1, mm_qk_q)
- FIELD(part2, bias)
- FIELD(part2, attn_shared_storage)
- FIELD(part2, zij)
- FIELD(part2, mm_gradV)
- FIELD(part2, gradV_epilogue)
- FIELD(part3, mm_doivj)
- FIELD(part3, gradB_epilogue)
- FIELD(part4, tmpT_shared_storage)
- FIELD(part4, tmp_shared_storage)
- FIELD(part4, mm_gradQ)
- FIELD(part4, gradQ_epilogue)
- FIELD(part4, gradQ_epilogue_lastIter)
- FIELD(part5, mm_gradK)
- FIELD(part5, gradK_epilogue)
- FIELD(part6, gradK_epilogue_final)
- FIELD(part6, gradV_epilogue_final)
- };
- using SharedStorage = typename cutlass::platform::
- conditional<kPreload, SharedStoragePrologue, SharedStorageNoPrologue>::type;
- struct OutputFragments {
- typename MatmulGradV::Mma::FragmentC gradV;
- typename MatmulGradK::Mma::FragmentC gradK;
- CUTLASS_DEVICE void clear()
- {
- gradV.clear();
- gradK.clear();
- }
- };
- static bool __host__ check_supported(Params const& p)
- {
- CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment);
- CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment);
- CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
- CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
- CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
- EVOFORMER_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned");
- EVOFORMER_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned");
- EVOFORMER_CHECK(p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0,
- "query is not correctly aligned (strideH)");
- EVOFORMER_CHECK(p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0,
- "key is not correctly aligned (strideH)");
- EVOFORMER_CHECK(p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0,
- "value is not correctly aligned (strideH)");
- EVOFORMER_CHECK(p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0,
- "query is not correctly aligned (strideB)");
- EVOFORMER_CHECK(p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0,
- "key is not correctly aligned (strideB)");
- EVOFORMER_CHECK(p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0,
- "value is not correctly aligned (strideB)");
- EVOFORMER_CHECK(p.q_strideM % kMinimumAlignment == 0,
- "query is not correctly aligned (strideM)");
- EVOFORMER_CHECK(p.k_strideM % kMinimumAlignment == 0,
- "key is not correctly aligned (strideM)");
- EVOFORMER_CHECK(p.v_strideM % kMinimumAlignment == 0,
- "value is not correctly aligned (strideM)");
- EVOFORMER_CHECK(p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f,
- "Invalid value for `dropout_prob`");
- EVOFORMER_CHECK(kApplyDropout || p.dropout_prob == 0.0f,
- "Set `kApplyDropout`=True to support `dropout_prob > 0`");
- EVOFORMER_CHECK(p.head_dim > 0, "Invalid value for `head_dim`");
- EVOFORMER_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`");
- EVOFORMER_CHECK(p.num_queries > 0, "Invalid value for `num_queries`");
- EVOFORMER_CHECK(p.num_keys > 0, "Invalid value for `num_keys`");
- EVOFORMER_CHECK(p.num_heads > 0, "Invalid value for `num_heads`");
- EVOFORMER_CHECK(p.num_batches > 0, "Invalid value for `num_batches`");
- EVOFORMER_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`");
- EVOFORMER_CHECK(p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`");
- return true;
- }
- static CUTLASS_DEVICE void attention_kernel(Params p)
- {
- extern __shared__ char smem_buffer[];
- SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
- uint16_t thread_id = threadIdx.x;
- uint8_t warp_id = warp_uniform(thread_id / 32);
- uint8_t lane_id = thread_id % 32;
- if (kPrologueQK) {
- prologueQkNextIteration<true>(shared_storage, p, 0, 0, warp_id, lane_id);
- }
- // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr`
- if (kKernelComputesDelta) {
- constexpr int kOptimalElements = 128 / cutlass::sizeof_bits<scalar_t>::value;
- if (p.head_dim_value % kOptimalElements == 0) {
- for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) {
- computeDelta<kOptimalElements>(p, query_start, warp_id, lane_id);
- }
- } else {
- for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) {
- computeDelta<1>(p, query_start, warp_id, lane_id);
- }
- }
- __syncthreads();
- }
- OutputFragments output_frags;
- int32_t key_start = 0;
- int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ;
- for (; key_start < key_end; key_start += kBlockSizeJ) {
- output_frags.clear();
- int32_t query_start = getQueryStart(p, key_start);
- int32_t query_end =
- query_start + (p.num_queries - query_start) / kBlockSizeI * kBlockSizeI;
- for (; query_start < query_end; query_start += kBlockSizeI) {
- processBlockIJ<true>(
- shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id);
- }
- // last (partial) query
- if (query_start < p.num_queries) {
- processBlockIJ<false>(
- shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id);
- }
- if (kOutputInRF) {
- writeFragsToGmem<true>(
- shared_storage, output_frags, p, key_start, warp_id, lane_id);
- } else if (getQueryStart(p, key_start) >= p.num_queries) {
- zfillGradKV<true>(p, key_start, warp_id, lane_id);
- }
- __syncthreads();
- }
- // Last (partial) key
- if (key_start != p.num_keys) {
- output_frags.clear();
- int32_t query_start = getQueryStart(p, key_start);
- for (; query_start < p.num_queries; query_start += kBlockSizeI) {
- warp_id = warp_uniform(warp_id);
- processBlockIJ<false>(
- shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id);
- }
- if (kOutputInRF) {
- writeFragsToGmem<false>(
- shared_storage, output_frags, p, key_start, warp_id, lane_id);
- } else if (getQueryStart(p, key_start) >= p.num_queries) {
- zfillGradKV<false>(p, key_start, warp_id, lane_id);
- }
- }
- }
- static CUTLASS_DEVICE void loadDi(cutlass::Array<accum_t, kBlockSizeI>& di,
- Params const& p,
- int32_t query_start)
- {
- int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x;
- if (thread_id < kBlockSizeI) {
- accum_t di_rf = accum_t(0);
- if (query_start + thread_id < p.num_queries) {
- di_rf = p.delta_ptr[query_start + thread_id];
- }
- di[thread_id] = di_rf;
- }
- }
- template <bool skipBoundsChecks>
- static CUTLASS_DEVICE void zfillGradKV(Params const& p,
- int32_t key_start,
- uint8_t warp_id,
- uint8_t lane_id)
- {
- constexpr int kThreadsPerKey = 8;
- constexpr int kParallelKeys = kNumThreads / kThreadsPerKey;
- static_assert(kBlockSizeJ % kParallelKeys == 0, "");
- // This function is not really optimized, but should rarely be used
- // It's only used when some keys are "useless" and don't attend to
- // any query, due to causal masking
- int thread_id = 32 * warp_id + lane_id;
- int k_shift = lane_id % kThreadsPerKey;
- CUTLASS_PRAGMA_UNROLL
- for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) {
- int key = key_start + j + (thread_id / kThreadsPerKey);
- if (!skipBoundsChecks && key >= p.num_keys) { continue; }
- auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM();
- auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM();
- for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) {
- gv_ptr[k] = scalar_t(0);
- }
- for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { gk_ptr[k] = scalar_t(0); }
- }
- }
- template <bool skipBoundsChecks>
- static CUTLASS_DEVICE void processBlockIJ(SharedStorage& shared_storage,
- OutputFragments& output_frags,
- Params& p,
- int32_t query_start,
- int32_t key_start,
- uint8_t warp_id,
- uint8_t lane_id)
- {
- cutlass::MatrixCoord no_offset{0, 0};
- accum_t scale = p.scale;
- int16_t thread_id = 32 * warp_id + lane_id;
- auto rematerializeThreadIds = [&]() {
- // Prevents `nvcc` from keeping values deduced from
- // `thread_id`, `warp_id`, ... in RF - to reduce register pressure
- warp_id = warp_uniform(thread_id / 32);
- lane_id = thread_id % 32;
- thread_id = 32 * warp_id + lane_id;
- };
- bool isFirstQuery = (query_start == getQueryStart(p, key_start));
- int32_t next_query, next_key;
- incrIteration(p, query_start, key_start, next_query, next_key);
- bool isLastQuery = next_key != key_start;
- __syncthreads();
- loadDi(shared_storage.di(), p, query_start);
- int32_t num_queries_in_block =
- skipBoundsChecks ? MatmulQK::Mma::Shape::kN
- : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kN,
- p.num_queries - query_start));
- int32_t num_keys_in_block =
- skipBoundsChecks ? MatmulQK::Mma::Shape::kM
- : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM,
- p.num_keys - key_start));
- auto prologueGradV = [&](int col) {
- typename MatmulGradV::Mma::IteratorB iterator_dO(
- {int32_t(p.gO_strideM)},
- p.grad_output_ptr + query_start * p.gO_strideM + col,
- {num_queries_in_block, p.head_dim_value - col},
- thread_id,
- no_offset);
- MatmulGradV::Mma::prologue(
- shared_storage.mm_gradV(), iterator_dO, thread_id, num_queries_in_block);
- };
- auto prologueGradQ = [&](int col) {
- typename MatmulGradQ::Mma::IteratorB iterator_K(
- {int32_t(p.k_strideM)},
- p.key_ptr + key_start * p.k_strideM + col,
- {num_keys_in_block, p.head_dim - col},
- thread_id,
- no_offset);
- MatmulGradQ::Mma::prologue(
- shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block);
- };
- auto prologueGradK = [&](int col) {
- typename MatmulGradK::Mma::IteratorB iterator_Q(
- {int32_t(p.q_strideM)},
- p.query_ptr + query_start * p.q_strideM + col,
- {num_queries_in_block, p.head_dim - col},
- thread_id,
- no_offset);
- MatmulGradK::Mma::prologue(
- shared_storage.mm_gradK(), iterator_Q, thread_id, num_queries_in_block);
- };
- auto prologueDOV = [&]() {
- typename MatmulDOIVJ::Mma::IteratorA iterator_A(
- {int32_t(p.gO_strideM)},
- p.grad_output_ptr + query_start * p.gO_strideM,
- {num_queries_in_block, p.head_dim_value},
- thread_id,
- no_offset);
- typename MatmulDOIVJ::Mma::IteratorB iterator_B({int32_t(p.v_strideM)},
- p.value_ptr + key_start * p.v_strideM,
- {p.head_dim_value, num_keys_in_block},
- thread_id,
- no_offset);
- MatmulDOIVJ::Mma::prologue(
- shared_storage.mm_doivj(), iterator_A, iterator_B, thread_id, p.head_dim_value);
- };
- /////////////////////////////////////////////////////////////////////////////////////////////////
- // MatmulQK
- /////////////////////////////////////////////////////////////////////////////////////////////////
- {
- using Mma = typename MatmulQK::Mma;
- cutlass::gemm::GemmCoord problem_size(num_keys_in_block,
- num_queries_in_block,
- p.head_dim // k
- );
- // k_j
- typename Mma::IteratorA iterator_A({int32_t(p.k_strideM)},
- p.key_ptr + key_start * p.k_strideM,
- {problem_size.m(), problem_size.k()},
- thread_id,
- no_offset);
- // q_i.transpose(-2, -1)
- typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)},
- p.query_ptr + query_start * p.q_strideM,
- {problem_size.k(), problem_size.n()},
- thread_id,
- no_offset);
- Mma mma(
- shared_storage.mm_qk_k(), shared_storage.mm_qk_q(), thread_id, warp_id, lane_id);
- typename Mma::FragmentC accum;
- accum.clear();
- auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
- // Compute threadblock-scoped matrix multiply-add
- mma.set_prologue_done(kPrologueQK);
- mma.set_zero_outside_bounds(!skipBoundsChecks);
- mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
- // Epilogue: add LSE + exp and store that to our shared memory buffer
- // shmem <- (matmul_result -
- // logsumexp[i_start:i_end].unsqueeze(1)).exp()
- int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
- auto output_tile_coords = cutlass::MatrixCoord{
- warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM};
- if (broadcast_1::kEnable || broadcast_2::kEnable) {
- cutlass::TensorRef<float, cutlass::layout::RowMajor> bias_tensor_ref(
- shared_storage.bias().data(),
- cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM));
- using Shape = cutlass::MatrixShape<MatmulQK::ThreadblockShape::kM,
- MatmulQK::ThreadblockShape::kN>;
- AttentionBiasEpilogue<Shape,
- scalar_t,
- MatmulQK::MmaCore::kThreads,
- Broadcast1_,
- Broadcast2_>
- bias_epilogue;
- bias_epilogue(bias_tensor_ref,
- p.bias1_ptr + key_start,
- p.bias2_ptr + query_start * p.num_keys + key_start,
- thread_id,
- {num_queries_in_block, num_keys_in_block},
- p.num_keys);
- // Pij += Bij, Pij is in register fragment and Bij is in shared memory
- auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
- lane_id, warp_id, output_tile_coords);
- MatmulQK::AccumLambdaIterator::iterateRows(
- lane_offset,
- [&](int accum_n) {},
- [&](int accum_m, int accum_n, int idx) {
- // remember we are transposed
- accum[idx] = accum[idx] * scale + bias_tensor_ref.at({accum_n, accum_m});
- },
- [&](int accum_n) {});
- } else {
- accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
- }
- __syncthreads();
- if (kPrologueGV) { prologueGradV(0); }
- if (kPrologueDOV) { prologueDOV(); }
- MatmulQK::B2bGemm::accumApplyLSEToSmem(shared_storage.attn_shared_storage(),
- accum,
- p.logsumexp_ptr + query_start,
- problem_size.n(),
- thread_id,
- warp_id,
- lane_id,
- output_tile_coords);
- __syncthreads();
- }
- rematerializeThreadIds();
- /////////////////////////////////////////////////////////////////////////////////////////////////
- // GradV matmul
- //
- // grad_v[j_start:j_end] += attn_T @ do_i
- /////////////////////////////////////////////////////////////////////////////////////////////////
- constexpr bool kSingleIterationGradV = kMaxK <= MatmulGradV::ThreadblockShape::kN;
- for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value);
- col += MatmulGradV::ThreadblockShape::kN) {
- using Mma = typename MatmulGradV::Mma;
- using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
- cutlass::gemm::GemmCoord problem_size(
- num_keys_in_block, p.head_dim_value - col, num_queries_in_block);
- auto createEpilogueIter = [&]() {
- return typename MatmulGradV::OutputTileIterator(
- typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
- p.grad_value_ptr + key_start * p.gV_strideM() + col,
- {num_keys_in_block, p.head_dim_value - col},
- thread_id);
- };
- typename Mma::IteratorB iterator_B({int32_t(p.gO_strideM)},
- p.grad_output_ptr + query_start * p.gO_strideM + col,
- {num_queries_in_block, p.head_dim_value - col},
- thread_id,
- no_offset);
- // if dropout: dVj += (Pij.T * Zij) @ dOi
- // otherwise: dVj += Pij.T @ dOi
- Mma mma(shared_storage.mm_gradV(),
- // operand A: Pij
- typename MatmulGradV::WarpIteratorA(
- shared_storage.attn_shared_storage().accum_ref(), lane_id),
- // if we're using dropout, operand A is Pij_dropped = Pij * Zij
- // which is computed on the fly as fragments of Pij are loaded in
- typename Mma::WarpIteratorAScale(shared_storage.zij().accum_ref(), lane_id),
- thread_id,
- warp_id,
- lane_id);
- int storage_id = col / MatmulGradV::ThreadblockShape::kN;
- AccumTileGmem gmem_tile{p.workspace_gv + storage_id * AccumTileGmem::kElementsStored};
- if (!kOutputInRF) {
- if (isFirstQuery || !kNeedsAccumGradV) {
- output_frags.gradV.clear();
- } else {
- gmem_tile.load(output_frags.gradV, thread_id);
- }
- }
- mma.set_prologue_done(kPrologueGV);
- auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
- // Compute threadblock-scoped matrix multiply-add
- __syncthreads();
- mma(gemm_k_iterations, output_frags.gradV, iterator_B, output_frags.gradV);
- __syncthreads();
- if (kPrologueGV && !kSingleIterationGradV &&
- col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) {
- prologueGradV(col + MatmulGradV::ThreadblockShape::kN);
- }
- if (!kOutputInRF) {
- if (kNeedsAccumGradV && !isLastQuery) {
- gmem_tile.store(output_frags.gradV, thread_id);
- } else {
- accumulateInGmem<MatmulGradV>(shared_storage.gradV_epilogue(),
- output_frags.gradV,
- createEpilogueIter(),
- isFirstQuery || kNeedsAccumGradV,
- warp_id,
- lane_id);
- }
- }
- }
- __syncthreads();
- /////////////////////////////////////////////////////////////////////////////////////////////////
- // MatmulDOIVJ
- /////////////////////////////////////////////////////////////////////////////////////////////////
- {
- using Mma = typename MatmulDOIVJ::Mma;
- // do_i
- typename Mma::IteratorA iterator_A({int32_t(p.gO_strideM)},
- p.grad_output_ptr + query_start * p.gO_strideM,
- {num_queries_in_block, p.head_dim_value},
- thread_id,
- no_offset);
- // v_j.transpose(-2, -1)
- typename Mma::IteratorB iterator_B({int32_t(p.v_strideM)},
- p.value_ptr + key_start * p.v_strideM,
- {p.head_dim_value, num_keys_in_block},
- thread_id,
- no_offset);
- Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id);
- mma.set_prologue_done(kPrologueDOV);
- mma.set_zero_outside_bounds(!skipBoundsChecks);
- typename Mma::FragmentC accum;
- accum.clear();
- auto gemm_k_iterations = (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK;
- // Compute threadblock-scoped matrix multiply-add
- mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
- __syncthreads();
- if (kPrologueGQ) { prologueGradQ(0); }
- if (kPrologueGK) { prologueGradK(0); }
- int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
- auto output_tile_coords = cutlass::MatrixCoord{
- warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM};
- // TODO: This must be terribly inefficient. There must be a better way
- // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem]
- // attn_shared_storage [smem] <- tmp.T
- // tmp_shared_storage [smem] <- tmp
- {
- using LambdaIterator =
- typename DefaultMmaAccumLambdaIterator<typename Mma::Operator::IteratorC,
- typename MatmulDOIVJ::ElementAccum,
- kWarpSize>::Iterator;
- auto lane_offset =
- LambdaIterator::get_lane_offset(lane_id, warp_id, output_tile_coords);
- auto attn_T = shared_storage.attn_shared_storage().accum_ref();
- accum_t current_di;
- // dSij = (dPij - Di) * Pij
- LambdaIterator::iterateRows(
- lane_offset,
- [&](int accum_m) { current_di = shared_storage.di()[accum_m]; },
- [&](int accum_m, int accum_n, int idx) {
- if (skipBoundsChecks ||
- (accum_m < num_queries_in_block && accum_n < num_keys_in_block)) {
- accum_t attn = attn_T.at({accum_n, accum_m});
- accum[idx] = (accum[idx] - current_di) * attn;
- } else {
- accum[idx] = 0;
- }
- },
- [&](int accum_m) {
- });
- using DefaultGemm = typename MatmulDOIVJ::DefaultGemm;
- using OutputOp = typename MatmulDOIVJ::BiasGradEpilogueOutputOp;
- if (broadcast_1::kEnable && p.grad_bias1_ptr) {
- using Epilogue =
- typename BiasGradEpilogueAffineRankN<ArchTag,
- 2,
- typename MatmulDOIVJ::ThreadblockShape,
- typename DefaultGemm::Mma::Operator,
- DefaultGemm::kPartitionsK,
- OutputOp,
- OutputOp::kCount>::Epilogue;
- cutlass::layout::AffineRankN<2> layout({0, 1});
- auto dst_ptr = p.grad_bias1_ptr + key_start;
- typename Epilogue::OutputTileIterator output_iter(
- {layout},
- dst_ptr,
- {num_queries_in_block, num_keys_in_block},
- (int)thread_id);
- Epilogue epilogue(shared_storage.gradB_epilogue(),
- (int)thread_id,
- (int)warp_id,
- (int)lane_id);
- epilogue(OutputOp(1), output_iter, accum);
- }
- if (broadcast_2::kEnable && p.grad_bias2_ptr) {
- if (broadcast_1::kEnable) { __syncthreads(); }
- using Epilogue =
- typename BiasGradEpilogue<ArchTag,
- typename MatmulDOIVJ::ThreadblockShape,
- typename DefaultGemm::Mma::Operator,
- DefaultGemm::kPartitionsK,
- OutputOp,
- OutputOp::kCount>::Epilogue;
- typename Epilogue::OutputTileIterator::Params params{p.num_keys};
- auto dst_ptr = p.grad_bias2_ptr + query_start * p.num_keys + key_start;
- typename Epilogue::OutputTileIterator output_iter(
- params, dst_ptr, {num_queries_in_block, num_keys_in_block}, (int)thread_id);
- Epilogue epilogue(shared_storage.gradB_epilogue(),
- (int)thread_id,
- (int)warp_id,
- (int)lane_id);
- epilogue(OutputOp(1), output_iter, accum);
- }
- accum = accum * scale;
- __syncthreads();
- if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) {
- auto tmpT = shared_storage.tmpT_shared_storage().accum_ref();
- // attn <- attn_T.T
- LambdaIterator::iterateRows(
- lane_offset,
- [&](int accum_m) {},
- [&](int accum_m, int accum_n, int idx) {
- tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]);
- },
- [&](int accum_m) {});
- }
- }
- MatmulDOIVJ::B2bGemm::accumToSmem(
- shared_storage.tmp_shared_storage(), accum, lane_id, output_tile_coords);
- __syncthreads();
- }
- p.head_dim = warp_uniform(p.head_dim);
- p.k_strideM = warp_uniform(p.k_strideM);
- rematerializeThreadIds();
- /////////////////////////////////////////////////////////////////////////////////////////////////
- // GradQ matmul
- //
- // grad_q[i_start:i_end] += tmp @ k_j
- /////////////////////////////////////////////////////////////////////////////////////////////////
- // Skip the loop & associated branches if we know at compile time the number
- // of iterations
- constexpr bool kSingleIterationGradQ = kMaxK <= MatmulGradQ::ThreadblockShape::kN;
- for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim);
- col += MatmulGradQ::ThreadblockShape::kN) {
- using Mma = typename MatmulGradQ::Mma;
- using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
- cutlass::gemm::GemmCoord problem_size(
- num_queries_in_block,
- false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col,
- num_keys_in_block);
- // k_j
- typename Mma::IteratorB iterator_B({int32_t(p.k_strideM)},
- p.key_ptr + key_start * p.k_strideM + col,
- {problem_size.k(), problem_size.n()},
- thread_id,
- no_offset);
- auto a = shared_storage.tmp_shared_storage().accum_ref();
- Mma mma(shared_storage.mm_gradQ(),
- shared_storage.tmp_shared_storage(),
- thread_id,
- warp_id,
- lane_id,
- problem_size.k());
- typename Mma::FragmentC accum;
- bool isFirst = key_start == 0;
- int col_id = col / MatmulGradQ::ThreadblockShape::kN;
- int num_cols =
- kSingleIterationGradQ ? 1 : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN);
- int storage_id = (col_id + query_start / kBlockSizeI * num_cols);
- AccumTileGmem gmem_tile{p.workspace_gq + storage_id * AccumTileGmem::kElementsStored};
- if (isFirst || !kNeedsAccumGradQ) {
- accum.clear();
- } else {
- gmem_tile.load(accum, thread_id);
- }
- auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
- // Compute threadblock-scoped matrix multiply-add
- __syncthreads();
- mma.set_prologue_done(kPrologueGQ);
- mma(gemm_k_iterations, accum, iterator_B, accum);
- __syncthreads();
- bool isLastColumn = kSingleIterationGradQ ||
- (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim);
- if (kPrologueGQ && !isLastColumn) {
- prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN);
- }
- // Output results
- int32_t next_query, next_key;
- incrIteration(p, p.num_queries, key_start, next_query, next_key);
- bool isLast = next_query > query_start || next_key >= p.num_keys;
- if (kNeedsAccumGradQ && !isLast) {
- gmem_tile.store(accum, thread_id);
- } else {
- typename MatmulGradQ::OutputTileIterator output_it(
- typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()},
- p.grad_query_ptr + query_start * p.gQ_strideM() + col,
- {problem_size.m(), problem_size.n()},
- thread_id);
- accumulateInGmem<MatmulGradQ>(isLastColumn
- ? shared_storage.gradQ_epilogue_lastIter()
- : shared_storage.gradQ_epilogue(),
- accum,
- output_it,
- isFirst || kNeedsAccumGradQ,
- warp_id,
- lane_id);
- }
- }
- /////////////////////////////////////////////////////////////////////////////////////////////////
- // GradK matmul
- //
- // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i
- /////////////////////////////////////////////////////////////////////////////////////////////////
- rematerializeThreadIds();
- constexpr bool kSingleIterationGradK = kMaxK <= MatmulGradK::ThreadblockShape::kN;
- for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim);
- col += MatmulGradK::ThreadblockShape::kN) {
- using Mma = typename MatmulGradK::Mma;
- using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
- cutlass::gemm::GemmCoord problem_size(
- num_keys_in_block,
- false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col,
- num_queries_in_block);
- auto createEpilogueIter = [&]() {
- return typename MatmulGradK::OutputTileIterator(
- typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
- p.grad_key_ptr + key_start * p.gK_strideM() + col,
- {num_keys_in_block,
- false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col},
- thread_id);
- };
- // q_i
- typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)},
- p.query_ptr + query_start * p.q_strideM + col,
- {problem_size.k(), problem_size.n()},
- thread_id,
- no_offset);
- auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); };
- auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); };
- // this is basically:
- // opA = kIsTransposedA ? getTmp() : getTmpT();
- bool constexpr kIsTransposedA = MatmulGradK::DefaultMmaFromSmem::kIsTransposedA;
- auto& opA =
- *call_conditional<kIsTransposedA, decltype(getTmp), decltype(getTmpT)>::apply(
- getTmp, getTmpT, 0);
- Mma mma(shared_storage.mm_gradK(), opA, thread_id, warp_id, lane_id, problem_size.k());
- int storage_id = col / MatmulGradK::ThreadblockShape::kN;
- AccumTileGmem gmem_tile{p.workspace_gk + storage_id * AccumTileGmem::kElementsStored};
- if (!kOutputInRF) {
- if (isFirstQuery || !kNeedsAccumGradK) {
- output_frags.gradK.clear();
- } else {
- gmem_tile.load(output_frags.gradK, thread_id);
- }
- }
- mma.set_prologue_done(kPrologueGK);
- auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
- // Compute threadblock-scoped matrix multiply-add
- __syncthreads();
- mma(gemm_k_iterations, output_frags.gradK, iterator_B, output_frags.gradK);
- __syncthreads();
- bool isLastColumn = kSingleIterationGradK ||
- col + MatmulGradK::ThreadblockShape::kN >= p.head_dim;
- if (kPrologueGK && !isLastColumn) {
- prologueGradK(col + MatmulGradK::ThreadblockShape::kN);
- }
- if (kPrologueQK && isLastColumn) {
- int32_t next_query, next_key;
- incrIteration(p, query_start, key_start, next_query, next_key);
- DISPATCH_BOOL(next_key != key_start, kForceReloadK, ([&]() {
- prologueQkNextIteration<kForceReloadK>(
- shared_storage, p, next_query, next_key, warp_id, lane_id);
- }));
- }
- // Output results
- if (!kOutputInRF) {
- if (kNeedsAccumGradK && !isLastQuery) {
- gmem_tile.store(output_frags.gradK, thread_id);
- } else {
- accumulateInGmem<MatmulGradK>(isLastColumn
- ? shared_storage.gradK_epilogue_final()
- : shared_storage.gradK_epilogue(),
- output_frags.gradK,
- createEpilogueIter(),
- isFirstQuery || kNeedsAccumGradK,
- warp_id,
- lane_id);
- __syncthreads();
- }
- }
- }
- }
- static CUTLASS_DEVICE int32_t getQueryStart(Params const& p, int32_t key_start) { return 0; };
- static CUTLASS_DEVICE void incrIteration(Params const& p,
- int32_t query_start,
- int32_t key_start,
- int32_t& next_query,
- int32_t& next_key)
- {
- next_query = query_start + kBlockSizeI;
- next_key = key_start;
- if (next_query >= p.num_queries) {
- next_key = key_start + kBlockSizeJ;
- next_query = getQueryStart(p, next_key);
- }
- }
- template <bool kForceReloadK>
- static CUTLASS_DEVICE void prologueQkNextIteration(SharedStorage& shared_storage,
- Params const& p,
- int32_t query_start,
- int32_t key_start,
- uint8_t warp_id,
- uint8_t lane_id)
- {
- if (query_start >= p.num_queries || key_start >= p.num_keys) { return; }
- static constexpr bool kReloadK = kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat;
- int thread_id = 32 * warp_id + lane_id;
- typename MatmulQK::Mma::IteratorA iterator_A({int32_t(p.k_strideM)},
- p.key_ptr + key_start * p.k_strideM,
- {p.num_keys - key_start, p.head_dim},
- thread_id,
- cutlass::MatrixCoord{0, 0});
- typename MatmulQK::Mma::IteratorB iterator_B({int32_t(p.q_strideM)},
- p.query_ptr + query_start * p.q_strideM,
- {p.head_dim, p.num_queries - query_start},
- thread_id,
- cutlass::MatrixCoord{0, 0});
- MatmulQK::Mma::prologue<kReloadK, true>(shared_storage.mm_qk_k(),
- shared_storage.mm_qk_q(),
- iterator_A,
- iterator_B,
- thread_id,
- p.head_dim);
- }
- template <bool skipBoundsChecks>
- static CUTLASS_DEVICE void writeFragsToGmem(SharedStorage& shared_storage,
- OutputFragments& output_frags,
- Params const& p,
- int32_t key_start,
- uint8_t warp_id,
- uint8_t lane_id)
- {
- uint16_t thread_id = 32 * warp_id + lane_id;
- int32_t num_keys_in_block =
- skipBoundsChecks
- ? MatmulQK::Mma::Shape::kM
- : cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start);
- typename MatmulGradV::OutputTileIterator outputV_it(
- typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
- p.grad_value_ptr + key_start * p.gV_strideM(),
- {num_keys_in_block, p.head_dim_value},
- thread_id);
- accumulateInGmem<MatmulGradV>(shared_storage.gradV_epilogue_final(),
- output_frags.gradV,
- outputV_it,
- true,
- warp_id,
- lane_id);
- typename MatmulGradK::OutputTileIterator outputK_it(
- typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
- p.grad_key_ptr + key_start * p.gK_strideM(),
- {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim},
- thread_id);
- accumulateInGmem<MatmulGradK>(shared_storage.gradK_epilogue_final(),
- output_frags.gradK,
- outputK_it,
- true,
- warp_id,
- lane_id);
- }
- template <typename MatmulT>
- static CUTLASS_DEVICE void accumulateInGmem(
- typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem,
- typename MatmulT::Mma::FragmentC const& accum,
- typename MatmulT::OutputTileIterator output_it,
- bool first,
- uint8_t warp_id,
- uint8_t lane_id)
- {
- using DefaultEpilogue = typename MatmulT::DefaultEpilogue;
- using DefaultOutputOp = typename MatmulT::DefaultOutputOp;
- using Mma = typename MatmulT::Mma;
- int thread_id = 32 * warp_id + lane_id;
- DISPATCH_BOOL(
- first, kIsFirst, ([&]() {
- static constexpr auto ScaleType =
- kIsFirst ? cutlass::epilogue::thread::ScaleType::Nothing
- : cutlass::epilogue::thread::ScaleType::NoBetaScaling;
- using EpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination<
- typename DefaultOutputOp::ElementOutput,
- DefaultOutputOp::kCount,
- typename DefaultOutputOp::ElementAccumulator,
- typename DefaultOutputOp::ElementCompute,
- ScaleType>;
- using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined<
- typename DefaultEpilogue::Shape,
- typename Mma::Operator,
- DefaultEpilogue::kPartitionsK,
- typename MatmulT::OutputTileIterator,
- typename DefaultEpilogue::AccumulatorFragmentIterator,
- typename DefaultEpilogue::WarpTileIterator,
- typename DefaultEpilogue::SharedLoadIterator,
- EpilogueOutputOp,
- typename DefaultEpilogue::Padding,
- DefaultEpilogue::kFragmentsPerIteration,
- true // IterationsUnroll
- >;
- EpilogueOutputOp rescale({1, 1});
- Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id);
- epilogue(rescale, output_it, accum, output_it);
- }));
- }
- template <int kElementsPerAccess>
- static CUTLASS_DEVICE void computeDelta(Params const& p,
- int32_t query_start,
- uint8_t warp_id,
- uint8_t lane_id)
- {
- // Each thread computes one value for Delta
- // Depending on warp configuration, we might have multiple
- // threads of the same warp working on the same row
- using AccessType = cutlass::Array<scalar_t, kElementsPerAccess>;
- static_assert(kNumThreads >= kBlockSizeI, "");
- static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI;
- int16_t thread_id = 32 * warp_id + lane_id;
- int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine);
- int16_t laneRow = thread_id / kNumThreadsPerLine;
- bool rowPred = (query_start + laneRow) < p.num_queries;
- bool pred = rowPred;
- // on windows, previous syntax __restrict__ AccessType*
- // resulted in error: "restrict" is not allowed
- const AccessType* __restrict__ grad_output_ptr = reinterpret_cast<const AccessType*>(
- p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + laneFirstCol);
- const AccessType* __restrict__ output_ptr = reinterpret_cast<const AccessType*>(
- p.output_ptr + (query_start + laneRow) * p.o_strideM() + laneFirstCol);
- static constexpr int64_t kMaxIters = kMaxK / (kElementsPerAccess * kNumThreadsPerLine);
- constexpr int kPipelineStages = 2;
- accum_t delta_value = accum_t(0);
- using GlobalLoad = cutlass::arch::global_load<AccessType, sizeof(AccessType)>;
- AccessType frag_grad_output[kPipelineStages];
- AccessType frag_output[kPipelineStages];
- auto loadAndIncrement = [&](int ld_pos, bool is_valid) {
- frag_grad_output[ld_pos].clear();
- frag_output[ld_pos].clear();
- GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid);
- GlobalLoad(frag_output[ld_pos], output_ptr, is_valid);
- grad_output_ptr += kNumThreadsPerLine;
- output_ptr += kNumThreadsPerLine;
- };
- CUTLASS_PRAGMA_UNROLL
- for (int iter = 0; iter < kPipelineStages - 1; ++iter) {
- int ld_pos = iter % kPipelineStages;
- pred = pred && (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) <
- p.head_dim_value;
- loadAndIncrement(ld_pos, pred);
- }
- auto columnIteration = [&](int iter) {
- // Load for next iter
- int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages;
- pred = pred && (laneFirstCol + (iter + kPipelineStages - 1) * kElementsPerAccess *
- kNumThreadsPerLine) < p.head_dim_value;
- loadAndIncrement(ld_pos, pred);
- CUTLASS_PRAGMA_UNROLL
- for (int i = 0; i < AccessType::kElements; ++i) {
- delta_value += accum_t(frag_output[iter % kPipelineStages][i]) *
- accum_t(frag_grad_output[iter % kPipelineStages][i]);
- }
- };
- // If we have a small lower-bound for K, we can unroll the loop
- if (kMaxK <= 256) {
- CUTLASS_PRAGMA_UNROLL
- for (int iter = 0; iter < kMaxIters; ++iter) { columnIteration(iter); }
- } else {
- int num_iters = ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) *
- (kElementsPerAccess * kNumThreadsPerLine);
- for (int iter = 0; iter < num_iters; ++iter) { columnIteration(iter); }
- }
- // Reduce between workers
- static_assert(kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || kNumThreadsPerLine == 4,
- "");
- CUTLASS_PRAGMA_UNROLL
- for (int i = 1; i < kNumThreadsPerLine; i *= 2) {
- delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i);
- }
- // Store in gmem
- if (rowPred) { p.delta_ptr[query_start + laneRow] = delta_value; }
- }
- };
- template <typename AK>
- __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
- attention_kernel_backward_batched_impl(typename AK::Params p)
- {
- if (!p.advance_to_block()) { return; }
- AK::attention_kernel(p);
- }
- template <typename AK>
- __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
- attention_kernel_backward_batched(typename AK::Params params);
|