/*************************************************************************************************** * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holdvr nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #pragma once #include #include #include #include #include "cutlass/cutlass.h" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/scale_type.h" #include "cutlass/fast_math.h" #include "cutlass/functional.h" #include "cutlass/gemm/gemm.h" #include "cutlass/layout/matrix.h" #include "cutlass/layout/vector.h" #include "cutlass/numeric_conversion.h" #include "cutlass/numeric_types.h" #include "cutlass/tensor_ref.h" #include "gemm_kernel_utils.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" #include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" #include "cutlass/gemm/device/default_gemm_configuration.h" #include "cutlass/gemm/kernel/default_gemm.h" #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass/gemm/threadblock/default_mma_core_simt.h" #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm80.h" #include "cutlass/matrix_shape.h" #include "cutlass/platform/platform.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "cutlass/transform/threadblock/vector_iterator.h" #include "epilogue/epilogue_pipelined.h" #include "iterators/epilogue_predicated_tile_iterator.h" #include "epilogue/epilogue_grad_bias.h" #include "gemm/custom_mma.h" #include "gemm/find_default_mma.h" #include "gemm/mma_accum_lambda_iterator.h" #include "gemm/mma_from_smem.h" #include "transform/bias_broadcast.h" #include "transform/tile_smem_loader.h" #include using namespace gemm_kernel_utils; namespace { template struct GmemTile { /* Helper functions to efficient store/load RF to gmem GEMM accumulators have a particular format on A100, and it takes some compute/shared-memory to rearrange them to a RowMajor or ColumnMajor format in global memory through an Epilogue. The same complexity goes for loading into RF. This class loads/stores RF as they are, and can be used for efficient accumulation across gemms for instance: ``` GmemTile tile; for (int i = 0; i < N; ++i) { // ... Fragment accum; if (i == 0) { accum.clear(); } else { tile.load(accum); } mma(accum, ...); if (i < N-1) { // Store for next GEMM tile.store(accum); } else { // Store in tensor (eg RowMajor) epilogue(accum); } // ... } ``` */ // 128bits per thread using AccessType = cutlass::Array; static constexpr int32_t kBytes = sizeof(AccessType); static constexpr int32_t kStride = kNumThreads * AccessType::kElements; static constexpr int32_t kNumIters = FragmentType::kElements / AccessType::kElements; static constexpr int32_t kElementsStored = kNumThreads * FragmentType::kElements; static_assert(FragmentType::kElements % AccessType::kElements == 0, "fragment not aligned on 128 bits"); float* ptr; CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kNumIters; ++i) { AccessType* __restrict__ gmem_ptr = reinterpret_cast( ptr + thread_id * AccessType::kElements + i * kStride); AccessType sub_fragment; cutlass::arch::global_load(sub_fragment, gmem_ptr, true); CUTLASS_PRAGMA_UNROLL for (int j = 0; j < AccessType::kElements; ++j) { fragment[i * AccessType::kElements + j] = sub_fragment[j]; } } } CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kNumIters; ++i) { AccessType* __restrict__ gmem_ptr = reinterpret_cast( ptr + thread_id * AccessType::kElements + i * kStride); AccessType sub_fragment; CUTLASS_PRAGMA_UNROLL for (int j = 0; j < AccessType::kElements; ++j) { sub_fragment[j] = fragment[i * AccessType::kElements + j]; } cutlass::arch::global_store(sub_fragment, gmem_ptr, true); } } }; template constexpr int getWarpsPerSm() { constexpr bool is_half = !cutlass::platform::is_same::value; if (Arch::kMinComputeCapability >= 80) { return is_half ? 12 : 8; } return 8; } } // namespace template < // which arch we target (eg `cutlass::arch::Sm80`) typename ArchTag_, // input/output type typename scalar_t_, // run optimized kernel because memory accesses will be aligned bool kIsAligned_, // use dropout if enabled bool kApplyDropout_, // when doing a GEMM, preload the next one (uses more shmem) bool kPreload_, // block dimensions int kBlockSizeI_, int kBlockSizeJ_, // upperbound on `max(value.shape[-1], query.shape[-1])` int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), template class Broadcast1_ = BroadcastNoLoad, template class Broadcast2_ = BroadcastNoLoad> struct AttentionBackwardKernel { using scalar_t = scalar_t_; using output_t = scalar_t; using output_accum_t = float; using lse_scalar_t = float; using accum_t = float; using ArchTag = ArchTag_; static constexpr bool kIsAligned = kIsAligned_; static constexpr bool kApplyDropout = kApplyDropout_; static constexpr bool kPreload = kPreload_; static constexpr int kBlockSizeI = kBlockSizeI_; static constexpr int kBlockSizeJ = kBlockSizeJ_; static constexpr int kMaxK = kMaxK_; struct Params { // Input tensors scalar_t* query_ptr; // [Mq, nH, K] scalar_t* key_ptr; // [Mk, nH, K] scalar_t* value_ptr; // [Mk, nH, Kv] lse_scalar_t* logsumexp_ptr; // [nH, Mq] scalar_t* output_ptr; // [Mq, nH, Kv] scalar_t* grad_output_ptr; // [Mq, nH, Kv] accum_t* delta_ptr; // [nH, Mq] int32_t* cu_seqlens_q_ptr = nullptr; int32_t* cu_seqlens_k_ptr = nullptr; // Output tensors output_t* grad_query_ptr; // [Mq, nH, K] output_t* grad_key_ptr; // [Mk, nH, K] output_t* grad_value_ptr; // [Mk, nH, Kv] accum_t* grad_bias1_ptr = nullptr; accum_t* grad_bias2_ptr = nullptr; int32_t B = 0; int32_t N = 0; scalar_t* bias1_ptr = nullptr; scalar_t* bias2_ptr = nullptr; // Accumulators union { output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] output_accum_t* workspace_gk; }; output_accum_t* workspace_gv; // (will be calculated by the kernel) output_accum_t* workspace_gq; // (will be calculated by the kernel) // Scale accum_t scale; // Dimensions/strides int32_t head_dim = -1; int32_t head_dim_value = -1; int32_t num_queries = -1; int32_t num_keys = -1; int32_t num_heads = -1; int32_t q_strideM; int32_t k_strideM; int32_t v_strideM; int32_t gO_strideM; int32_t gB_strideM; int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise // RNG sequence offset based on batch_id and head_id unsigned long long dropout_batch_head_rng_offset; float dropout_prob = 0.0f; CUTLASS_HOST_DEVICE int32_t o_strideM() const { return head_dim_value * num_heads; } CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { return gQKV_strideM_multiplier * num_heads * head_dim; } CUTLASS_HOST_DEVICE int32_t gK_strideM() const { return gQKV_strideM_multiplier * num_heads * head_dim; } CUTLASS_HOST_DEVICE int32_t gV_strideM() const { return gQKV_strideM_multiplier * num_heads * head_dim_value; } // Everything below is only used in `advance_to_block` // and shouldn't use registers int64_t o_strideH; int32_t q_strideH; int32_t k_strideH; int32_t v_strideH; int64_t o_strideB; int64_t q_strideB; int64_t k_strideB; int64_t v_strideB; int64_t lse_strideB; int64_t lse_strideH; int64_t delta_strideB; int64_t delta_strideH; int32_t num_batches; int64_t gO_strideB = 0; int64_t gQ_strideB = 0; int64_t gK_strideB = 0; int64_t gV_strideB = 0; int64_t gB_strideB = 0; int64_t gO_strideH = 0; int64_t gQ_strideH = 0; int64_t gK_strideH = 0; int64_t gV_strideH = 0; int64_t gB_strideH = 0; CUTLASS_DEVICE bool advance_to_block() { int64_t batch_id = blockIdx.z; int32_t head_id = blockIdx.y; if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { assert(workspace_size() == 0 || workspace != nullptr); workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); workspace = warp_uniform(workspace); workspace_gv = workspace + workspace_elements_gk(); workspace_gq = workspace_gv + workspace_elements_gv(); } else { workspace = nullptr; } // Advance pointers that depend on the total concatenated // number of queries, as `num_queries` is modified in the block // below dropout_batch_head_rng_offset = batch_id * (num_heads * num_queries * num_keys) + head_id * (num_queries * num_keys); logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; query_ptr += batch_id * q_strideB + head_id * q_strideH; key_ptr += batch_id * k_strideB + head_id * k_strideH; value_ptr += batch_id * v_strideB + head_id * v_strideH; output_ptr += batch_id * o_strideB + head_id * o_strideH; grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; using broadcast_1 = Broadcast1_; using broadcast_2 = Broadcast2_; if (broadcast_1::kEnable && grad_bias1_ptr) { grad_bias1_ptr += batch_id * num_queries; } if (broadcast_2::kEnable && grad_bias2_ptr) { auto strideB = num_heads * num_queries * num_keys; auto strideH = num_queries * num_keys; grad_bias2_ptr += (batch_id / N) * strideB + head_id * strideH; } if (broadcast_1::kEnable && bias1_ptr) { bias1_ptr = broadcast_1::advance(bias1_ptr, batch_id / N, batch_id % N, head_id, num_queries * N, num_queries, 0); } if (broadcast_2::kEnable && bias2_ptr) { auto strideB = num_heads * num_queries * num_keys; auto strideH = num_queries * num_keys; bias2_ptr = broadcast_2::advance( bias2_ptr, batch_id / N, batch_id % N, head_id, strideB, 0, strideH); } num_queries = warp_uniform(num_queries); num_keys = warp_uniform(num_keys); query_ptr = warp_uniform(query_ptr); key_ptr = warp_uniform(key_ptr); value_ptr = warp_uniform(value_ptr); logsumexp_ptr = warp_uniform(logsumexp_ptr); output_ptr = warp_uniform(output_ptr); grad_output_ptr = warp_uniform(grad_output_ptr); delta_ptr = warp_uniform(delta_ptr); grad_query_ptr = warp_uniform(grad_query_ptr); grad_key_ptr = warp_uniform(grad_key_ptr); grad_value_ptr = warp_uniform(grad_value_ptr); if (broadcast_1::kEnable) { grad_bias1_ptr = warp_uniform(grad_bias1_ptr); bias1_ptr = warp_uniform(bias1_ptr); } if (broadcast_2::kEnable) { grad_bias2_ptr = warp_uniform(grad_bias2_ptr); bias2_ptr = warp_uniform(bias2_ptr); } return true; } __host__ dim3 getBlocksGrid() const { return dim3(1, num_heads, num_batches); } __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); } CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { if (!kNeedsAccumGradK) { return 0; } return align_up(num_keys, (int32_t)kBlockSizeJ) * align_up(head_dim, (int32_t)kBlockSizeI); } CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { if (!kNeedsAccumGradV) { return 0; } return align_up(num_keys, (int32_t)kBlockSizeJ) * align_up(head_dim_value, (int32_t)kBlockSizeI); } CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { if (!kNeedsAccumGradQ) { return 0; } if (num_keys <= kBlockSizeJ) { return 0; } return align_up(num_queries, (int32_t)kBlockSizeI) * align_up(head_dim, (int32_t)kBlockSizeJ); } CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { // Aligned on 128bits return align_up( workspace_elements_gk() + workspace_elements_gv() + workspace_elements_gq(), int64_t(4)); } CUTLASS_HOST_DEVICE int64_t workspace_size() const { // Returns size of buffer we need to run this kernel return num_batches * num_heads * workspace_strideBH() * sizeof(float); } }; static constexpr int64_t kWarpSize = 32; // If this is true, we store and accumulate dK/dV in RF // rather than going back to gmem every time static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; static_assert(!kPreload || (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF), "preload MMA not supported"); static constexpr bool kPrologueQK = kPreload; static constexpr bool kPrologueGV = kPreload; static constexpr bool kPrologueDOV = kPreload; static constexpr bool kPrologueGQ = kPreload; static constexpr bool kPrologueGK = kPreload; static constexpr int64_t kNumWarpsPerBlock = (kBlockSizeI * kBlockSizeJ) / (32 * 32); // Compute delta for the f16 kernels // TODO: Figure out why it's slower on the f32 kernels // (something due to RF pressure?) // TODO: Remove condition on `kOutputInRF` - this is needed to work // around a compiler bug on V100, not exactly sure why but I spent // too much time on this already. Reproducible with // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance static constexpr bool kKernelComputesDelta = kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); static constexpr bool kNeedsAccumGradQ = !cutlass::platform::is_same::value; static constexpr bool kNeedsAccumGradK = !kOutputInRF && !cutlass::platform::is_same::value; static constexpr bool kNeedsAccumGradV = !kOutputInRF && !cutlass::platform::is_same::value; // Launch bounds static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; static constexpr int64_t kMinBlocksPerSm = getWarpsPerSm() / kNumWarpsPerBlock; using GemmType = DefaultGemmType; using DefaultConfig = typename cutlass::gemm::device::DefaultGemmConfiguration; static constexpr auto kOptimalAlignement = cutlass::platform::max(DefaultConfig::kAlignmentA, DefaultConfig::kAlignmentB); static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; struct MatmulQK { /* attn_T = k_j @ q_i.transpose(-2, -1) # matmul attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, -1)).exp() # epilogue with attn_T.shape = (kBlockSizeJ, kBlockSizeI) */ using ThreadblockShape = cutlass::gemm::GemmShape; using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< scalar_t, // ElementA cutlass::layout::RowMajor, // LayoutA kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, scalar_t, // ElementB cutlass::layout::ColumnMajor, // LayoutB kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, accum_t, // ElementC cutlass::layout::RowMajor, // LayoutC typename GemmType::OpClass, ArchTag, ThreadblockShape, WarpShape, typename GemmType::InstructionShape, DefaultConfig::kStages, typename GemmType::Operator, false, // AccumulatorsInRowMajor = false, cutlass::gemm::SharedMemoryClearOption::kNone>; using MmaCore = typename DefaultMma::MmaCore; using Mma = typename MakeCustomMma::Mma; // used for efficient load of bias tile (Bij) from global memory to shared // memory using BiasLoader = TileSmemLoader, MmaCore::kThreads, // input restriction: kv_len has to be a multiple of this value 128 / cutlass::sizeof_bits::value>; // Epilogue to store to shared-memory in a format that we can use later for // the second matmul using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm; using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator::Iterator; using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; }; struct MatmulGradV { /* grad_v[j_start:j_end] += attn_T @ do_i # matmul Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) (we might need to iterate multiple times on K) */ using ThreadblockShape = cutlass::gemm::GemmShape; using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; using InstructionShape = typename GemmType::InstructionShape; using DefaultGemm = cutlass::gemm::kernel::DefaultGemm; // if dropout: // for computing dVj += (Pij.T * Zij) @ dOi // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of // Pij.T are loaded in. The reason we do it this way is because Pij.T and // Zij are reused in later steps, while Pij_dropped.T is only needed in // this step. computing Pij_dropped.T on the fly allows us to avoid // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the // same time. // if no dropout: // for computing dVj += Pij.T @ dOi using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, typename MatmulQK::AccumulatorSharedStorage, kApplyDropout>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; using WarpIteratorA = typename DefaultMmaFromSmem::WarpIteratorA; using IteratorB = typename Mma::IteratorB; using WarpCount = typename Mma::WarpCount; // Epilogue using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; using DefaultEpilogue = typename DefaultGemm::Epilogue; using OutputTileIterator = typename cutlass::epilogue::threadblock::MakePrefetchableIterator< typename DefaultEpilogue::OutputTileIterator>::Iterator; using AccumTileGmem = GmemTile; }; struct MatmulDOIVJ { /* doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue? */ using ThreadblockShape = cutlass::gemm::GemmShape; using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; using ElementC = accum_t; // CSY: Change it for better accuracy using ElementAccum = accum_t; // no-op output op - epilogue just stores result to global memory using BiasGradEpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination< ElementC, DefaultConfig::EpilogueOutputOp::kCount, typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, typename DefaultConfig::EpilogueOutputOp::ElementCompute, cutlass::epilogue::thread::ScaleType::Nothing>; using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< scalar_t, // ElementA cutlass::layout::RowMajor, // LayoutA kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, scalar_t, // ElementB cutlass::layout::ColumnMajor, // LayoutB kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, ElementC, // ElementC cutlass::layout::RowMajor, // LayoutC ElementAccum, // ElementAccumulator typename GemmType::OpClass, ArchTag, ThreadblockShape, WarpShape, typename GemmType::InstructionShape, BiasGradEpilogueOutputOp, // EpilogueOutputOp void, // ThreadblockSwizzle (not used) // multiple preloads, dropout Zij tile, and 3 stages push us over shared // memory capacity on A100. set a ceiling on number of stages to save // shared memory if dropout is in use. kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) ? cutlass::const_min(2, DefaultConfig::kStages) : DefaultConfig::kStages, // Stages false, // SplitKSerial typename GemmType::Operator, cutlass::gemm::SharedMemoryClearOption::kNone>; using Mma = typename MakeCustomMma::Mma; // epilogue used to write bias gradient, which is just the output of this // matmul with some operations applied to the fragment using BiasGradEpilogue = typename DefaultGemm::Epilogue; // Epilogue to store to shared-memory in a format that we can use later for // the second matmul using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm; using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; }; struct MatmulGradQ { // grad_q <- tmp @ k_j using ThreadblockShape = cutlass::gemm::GemmShape; using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; using InstructionShape = typename GemmType::InstructionShape; using DefaultGemm = cutlass::gemm::kernel::DefaultGemm; using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, typename MatmulDOIVJ::AccumulatorSharedStorage, false>; // kScaleOperandA using Mma = typename DefaultMmaFromSmem::Mma; using IteratorB = typename Mma::IteratorB; using WarpCount = typename Mma::WarpCount; // Epilogue using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; using DefaultEpilogue = typename DefaultGemm::Epilogue; using OutputTileIterator = typename cutlass::epilogue::threadblock::MakePrefetchableIterator< typename DefaultEpilogue::OutputTileIterator>::Iterator; using AccumTileGmem = GmemTile; }; struct MatmulGradK { // grad_k <- tmp.transpose(-2, -1) @ q_i using ThreadblockShape = cutlass::gemm::GemmShape; using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; using InstructionShape = typename GemmType::InstructionShape; using DefaultGemm = cutlass::gemm::kernel::DefaultGemm; using DefaultMmaFromSmemN = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, typename MatmulQK::AccumulatorSharedStorage, false>; // kScaleOperandA using DefaultMmaFromSmemT = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename DefaultGemm::Mma, typename MatmulDOIVJ::AccumulatorSharedStorage, false, // kScaleOperandA kPreload>; // kTransposeA using DefaultMmaFromSmem = typename cutlass::platform::conditional::type; using Mma = typename DefaultMmaFromSmem::Mma; using IteratorB = typename Mma::IteratorB; using WarpCount = typename Mma::WarpCount; // Epilogue using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; using DefaultEpilogue = typename DefaultGemm::Epilogue; using OutputTileIterator = typename cutlass::epilogue::threadblock::MakePrefetchableIterator< typename DefaultEpilogue::OutputTileIterator>::Iterator; using AccumTileGmem = GmemTile; }; using broadcast_1 = Broadcast1_; using broadcast_2 = Broadcast2_; // shared storage for keeping Zij matrix. not needed if we aren't using // dropout, in which case we use an empty array to save shared memory using ZijSharedStorage = typename cutlass::platform::conditional< kApplyDropout, typename MatmulQK::AccumulatorSharedStorage, // dummy shared storage object that takes up no space. typename cutlass::gemm::threadblock::AccumulatorSharedStorage< #ifdef _WIN32 // windows builds throw the error: // "type containing an unknown-size array is not allowed" // if we try to make Zij shared storage zero-sized. // To get around this just make it sized 1 on windows. typename cutlass::gemm::GemmShape<1, 1, 0>, #else typename cutlass::gemm::GemmShape<0, 0, 0>, #endif typename MatmulQK::AccumulatorSharedStorage::Element, typename MatmulQK::AccumulatorSharedStorage::Layout, typename cutlass::MatrixShape<0, 0>>>::type; struct SharedStoragePrologue { struct { cutlass::Array di; // (do_i * o_i).sum(-1) typename MatmulQK::Mma::SharedStorageA mm_qk_k; } persistent; union { struct { // part1 - after Q.K / dV / dO.V union { // 1. efficient load of bias tile Bij, which is then applied to Pij // typename MatmulQK::BiasLoader::SmemTile bias; cutlass::AlignedBuffer bias; // 4. store Pij. it is needed: // - in dVj += (Pij.T * Zij) @ dOi // - in dSij = Pij * (dPij - Di) // 6. dVj += (Pij.T * Zij) @ dOi // 10. write to fragment typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; }; // 5. store Zij. it is needed: // - to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij // are loaded for the computation of dVj. // - to compute dPij = (dOi @ Vj.T) * Zij // 6. used in dVj += (Pij.T * Zij) @ dOi // 9. used in dPij = dPij_dropped * Zij ZijSharedStorage zij; union { // 2. prologue for dVj // 6. workspace for dVj += (Pij.T * Zij) @ dOi typename MatmulGradV::Mma::SharedStorage mm_gradV; // 7. dVj epilogue typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; }; // 3. prologue for dPij_dropped // 8. used in dPij_dropped = dOi @ Vj.T typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; } part1; struct { // part2 - dQ union { typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part1) typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; }; typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload) union { // store dB = dSij to global memory typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; }; } part2; struct { // part3 - after last iteration on dQ's epilogue / dK union { typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part1) typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; }; typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter; typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; } part3; struct { // part4 - after last iteration on dK's epilogue / preload next K.Q_t typename MatmulQK::Mma::SharedStorageB mm_qk_q; // If we reach end of current key, dump RF->gmem with "final" epilogues typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final; typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final; } part4; }; // =========================================== #define FIELD(INSIDE_STRUCT, FIELDNAME) \ CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; } FIELD(persistent, di) FIELD(persistent, mm_qk_k) FIELD(part1, bias) FIELD(part1, attn_shared_storage) FIELD(part1, zij) FIELD(part1, mm_gradV) FIELD(part1, gradV_epilogue) FIELD(part1, mm_doivj) FIELD(part2, mm_gradK) FIELD(part2, mm_gradQ) FIELD(part2, gradB_epilogue) FIELD(part2, gradQ_epilogue) FIELD(part2, tmp_shared_storage) FIELD(part3, tmpT_shared_storage) FIELD(part3, gradQ_epilogue_lastIter) FIELD(part3, gradK_epilogue) FIELD(part4, mm_qk_q) FIELD(part4, gradK_epilogue_final) FIELD(part4, gradV_epilogue_final) }; struct SharedStorageNoPrologue { struct { cutlass::Array di; // (do_i * o_i).sum(-1) } persistent; union { struct { // part1 - Q.K matmul typename MatmulQK::Mma::SharedStorageA mm_qk_k; typename MatmulQK::Mma::SharedStorageB mm_qk_q; } part1; struct { // part2 - compute gradV union { // 1. efficient load of bias tile Bij, which is then applied to Pij cutlass::AlignedBuffer bias; // 2. store Pij to shared memory. it is needed: // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi // - in next step where it is used in dSij = Pij * (dPij - Di) typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; }; // 3. store Zij. it is needed: // - in this step, where it is used to compute Pij_dropped = Pij * Zij // on the // fly as fragments of Pij are loaded for the computation of dVj. // - later to compute dPij = (dOi @ Vj.T) * Zij ZijSharedStorage zij; union { typename MatmulGradV::Mma::SharedStorage mm_gradV; typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; }; } part2; struct { // part3 - DO.V matmul union { // first compute dPij = (dOi @ Vj.T) * Zij // and dSij = Pij * (dPij - Di) struct { // (from part2) - Pij for computing dSij = Pij * (dPij - Di) typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; // (from part2) - Zij for computing dPij = dPij_dropped * Zij ZijSharedStorage zij; // matmul to compute dOiVj typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; }; // then store dB = dSij to global memory typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; }; } part3; struct { // part4 - compute gradQ typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2) typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; union { typename MatmulGradQ::Mma::SharedStorage mm_gradQ; typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter; }; } part4; struct { // part5 - compute gradK typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2) typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; union { typename MatmulGradK::Mma::SharedStorage mm_gradK; typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; }; } part5; struct { // part6 - store RF accumulated into gmem typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final; typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final; } part6; }; // =========================================== #define FIELD(INSIDE_STRUCT, FIELDNAME) \ CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; } FIELD(persistent, di) FIELD(part1, mm_qk_k) FIELD(part1, mm_qk_q) FIELD(part2, bias) FIELD(part2, attn_shared_storage) FIELD(part2, zij) FIELD(part2, mm_gradV) FIELD(part2, gradV_epilogue) FIELD(part3, mm_doivj) FIELD(part3, gradB_epilogue) FIELD(part4, tmpT_shared_storage) FIELD(part4, tmp_shared_storage) FIELD(part4, mm_gradQ) FIELD(part4, gradQ_epilogue) FIELD(part4, gradQ_epilogue_lastIter) FIELD(part5, mm_gradK) FIELD(part5, gradK_epilogue) FIELD(part6, gradK_epilogue_final) FIELD(part6, gradV_epilogue_final) }; using SharedStorage = typename cutlass::platform:: conditional::type; struct OutputFragments { typename MatmulGradV::Mma::FragmentC gradV; typename MatmulGradK::Mma::FragmentC gradK; CUTLASS_DEVICE void clear() { gradV.clear(); gradK.clear(); } }; static bool __host__ check_supported(Params const& p) { CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); EVOFORMER_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); EVOFORMER_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); EVOFORMER_CHECK(p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned (strideH)"); EVOFORMER_CHECK(p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned (strideH)"); EVOFORMER_CHECK(p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned (strideH)"); EVOFORMER_CHECK(p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, "query is not correctly aligned (strideB)"); EVOFORMER_CHECK(p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, "key is not correctly aligned (strideB)"); EVOFORMER_CHECK(p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0, "value is not correctly aligned (strideB)"); EVOFORMER_CHECK(p.q_strideM % kMinimumAlignment == 0, "query is not correctly aligned (strideM)"); EVOFORMER_CHECK(p.k_strideM % kMinimumAlignment == 0, "key is not correctly aligned (strideM)"); EVOFORMER_CHECK(p.v_strideM % kMinimumAlignment == 0, "value is not correctly aligned (strideM)"); EVOFORMER_CHECK(p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f, "Invalid value for `dropout_prob`"); EVOFORMER_CHECK(kApplyDropout || p.dropout_prob == 0.0f, "Set `kApplyDropout`=True to support `dropout_prob > 0`"); EVOFORMER_CHECK(p.head_dim > 0, "Invalid value for `head_dim`"); EVOFORMER_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`"); EVOFORMER_CHECK(p.num_queries > 0, "Invalid value for `num_queries`"); EVOFORMER_CHECK(p.num_keys > 0, "Invalid value for `num_keys`"); EVOFORMER_CHECK(p.num_heads > 0, "Invalid value for `num_heads`"); EVOFORMER_CHECK(p.num_batches > 0, "Invalid value for `num_batches`"); EVOFORMER_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); EVOFORMER_CHECK(p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); return true; } static CUTLASS_DEVICE void attention_kernel(Params p) { extern __shared__ char smem_buffer[]; SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); uint16_t thread_id = threadIdx.x; uint8_t warp_id = warp_uniform(thread_id / 32); uint8_t lane_id = thread_id % 32; if (kPrologueQK) { prologueQkNextIteration(shared_storage, p, 0, 0, warp_id, lane_id); } // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr` if (kKernelComputesDelta) { constexpr int kOptimalElements = 128 / cutlass::sizeof_bits::value; if (p.head_dim_value % kOptimalElements == 0) { for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) { computeDelta(p, query_start, warp_id, lane_id); } } else { for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) { computeDelta<1>(p, query_start, warp_id, lane_id); } } __syncthreads(); } OutputFragments output_frags; int32_t key_start = 0; int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; for (; key_start < key_end; key_start += kBlockSizeJ) { output_frags.clear(); int32_t query_start = getQueryStart(p, key_start); int32_t query_end = query_start + (p.num_queries - query_start) / kBlockSizeI * kBlockSizeI; for (; query_start < query_end; query_start += kBlockSizeI) { processBlockIJ( shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id); } // last (partial) query if (query_start < p.num_queries) { processBlockIJ( shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id); } if (kOutputInRF) { writeFragsToGmem( shared_storage, output_frags, p, key_start, warp_id, lane_id); } else if (getQueryStart(p, key_start) >= p.num_queries) { zfillGradKV(p, key_start, warp_id, lane_id); } __syncthreads(); } // Last (partial) key if (key_start != p.num_keys) { output_frags.clear(); int32_t query_start = getQueryStart(p, key_start); for (; query_start < p.num_queries; query_start += kBlockSizeI) { warp_id = warp_uniform(warp_id); processBlockIJ( shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id); } if (kOutputInRF) { writeFragsToGmem( shared_storage, output_frags, p, key_start, warp_id, lane_id); } else if (getQueryStart(p, key_start) >= p.num_queries) { zfillGradKV(p, key_start, warp_id, lane_id); } } } static CUTLASS_DEVICE void loadDi(cutlass::Array& di, Params const& p, int32_t query_start) { int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; if (thread_id < kBlockSizeI) { accum_t di_rf = accum_t(0); if (query_start + thread_id < p.num_queries) { di_rf = p.delta_ptr[query_start + thread_id]; } di[thread_id] = di_rf; } } template static CUTLASS_DEVICE void zfillGradKV(Params const& p, int32_t key_start, uint8_t warp_id, uint8_t lane_id) { constexpr int kThreadsPerKey = 8; constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; static_assert(kBlockSizeJ % kParallelKeys == 0, ""); // This function is not really optimized, but should rarely be used // It's only used when some keys are "useless" and don't attend to // any query, due to causal masking int thread_id = 32 * warp_id + lane_id; int k_shift = lane_id % kThreadsPerKey; CUTLASS_PRAGMA_UNROLL for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) { int key = key_start + j + (thread_id / kThreadsPerKey); if (!skipBoundsChecks && key >= p.num_keys) { continue; } auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM(); auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM(); for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) { gv_ptr[k] = scalar_t(0); } for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { gk_ptr[k] = scalar_t(0); } } } template static CUTLASS_DEVICE void processBlockIJ(SharedStorage& shared_storage, OutputFragments& output_frags, Params& p, int32_t query_start, int32_t key_start, uint8_t warp_id, uint8_t lane_id) { cutlass::MatrixCoord no_offset{0, 0}; accum_t scale = p.scale; int16_t thread_id = 32 * warp_id + lane_id; auto rematerializeThreadIds = [&]() { // Prevents `nvcc` from keeping values deduced from // `thread_id`, `warp_id`, ... in RF - to reduce register pressure warp_id = warp_uniform(thread_id / 32); lane_id = thread_id % 32; thread_id = 32 * warp_id + lane_id; }; bool isFirstQuery = (query_start == getQueryStart(p, key_start)); int32_t next_query, next_key; incrIteration(p, query_start, key_start, next_query, next_key); bool isLastQuery = next_key != key_start; __syncthreads(); loadDi(shared_storage.di(), p, query_start); int32_t num_queries_in_block = skipBoundsChecks ? MatmulQK::Mma::Shape::kN : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start)); int32_t num_keys_in_block = skipBoundsChecks ? MatmulQK::Mma::Shape::kM : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start)); auto prologueGradV = [&](int col) { typename MatmulGradV::Mma::IteratorB iterator_dO( {int32_t(p.gO_strideM)}, p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); MatmulGradV::Mma::prologue( shared_storage.mm_gradV(), iterator_dO, thread_id, num_queries_in_block); }; auto prologueGradQ = [&](int col) { typename MatmulGradQ::Mma::IteratorB iterator_K( {int32_t(p.k_strideM)}, p.key_ptr + key_start * p.k_strideM + col, {num_keys_in_block, p.head_dim - col}, thread_id, no_offset); MatmulGradQ::Mma::prologue( shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block); }; auto prologueGradK = [&](int col) { typename MatmulGradK::Mma::IteratorB iterator_Q( {int32_t(p.q_strideM)}, p.query_ptr + query_start * p.q_strideM + col, {num_queries_in_block, p.head_dim - col}, thread_id, no_offset); MatmulGradK::Mma::prologue( shared_storage.mm_gradK(), iterator_Q, thread_id, num_queries_in_block); }; auto prologueDOV = [&]() { typename MatmulDOIVJ::Mma::IteratorA iterator_A( {int32_t(p.gO_strideM)}, p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); typename MatmulDOIVJ::Mma::IteratorB iterator_B({int32_t(p.v_strideM)}, p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); MatmulDOIVJ::Mma::prologue( shared_storage.mm_doivj(), iterator_A, iterator_B, thread_id, p.head_dim_value); }; ///////////////////////////////////////////////////////////////////////////////////////////////// // MatmulQK ///////////////////////////////////////////////////////////////////////////////////////////////// { using Mma = typename MatmulQK::Mma; cutlass::gemm::GemmCoord problem_size(num_keys_in_block, num_queries_in_block, p.head_dim // k ); // k_j typename Mma::IteratorA iterator_A({int32_t(p.k_strideM)}, p.key_ptr + key_start * p.k_strideM, {problem_size.m(), problem_size.k()}, thread_id, no_offset); // q_i.transpose(-2, -1) typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)}, p.query_ptr + query_start * p.q_strideM, {problem_size.k(), problem_size.n()}, thread_id, no_offset); Mma mma( shared_storage.mm_qk_k(), shared_storage.mm_qk_q(), thread_id, warp_id, lane_id); typename Mma::FragmentC accum; accum.clear(); auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add mma.set_prologue_done(kPrologueQK); mma.set_zero_outside_bounds(!skipBoundsChecks); mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); // Epilogue: add LSE + exp and store that to our shared memory buffer // shmem <- (matmul_result - // logsumexp[i_start:i_end].unsqueeze(1)).exp() int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); auto output_tile_coords = cutlass::MatrixCoord{ warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM}; if (broadcast_1::kEnable || broadcast_2::kEnable) { cutlass::TensorRef bias_tensor_ref( shared_storage.bias().data(), cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); using Shape = cutlass::MatrixShape; AttentionBiasEpilogue bias_epilogue; bias_epilogue(bias_tensor_ref, p.bias1_ptr + key_start, p.bias2_ptr + query_start * p.num_keys + key_start, thread_id, {num_queries_in_block, num_keys_in_block}, p.num_keys); // Pij += Bij, Pij is in register fragment and Bij is in shared memory auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( lane_id, warp_id, output_tile_coords); MatmulQK::AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_n) {}, [&](int accum_m, int accum_n, int idx) { // remember we are transposed accum[idx] = accum[idx] * scale + bias_tensor_ref.at({accum_n, accum_m}); }, [&](int accum_n) {}); } else { accum = cutlass::multiplies()(scale, accum); } __syncthreads(); if (kPrologueGV) { prologueGradV(0); } if (kPrologueDOV) { prologueDOV(); } MatmulQK::B2bGemm::accumApplyLSEToSmem(shared_storage.attn_shared_storage(), accum, p.logsumexp_ptr + query_start, problem_size.n(), thread_id, warp_id, lane_id, output_tile_coords); __syncthreads(); } rematerializeThreadIds(); ///////////////////////////////////////////////////////////////////////////////////////////////// // GradV matmul // // grad_v[j_start:j_end] += attn_T @ do_i ///////////////////////////////////////////////////////////////////////////////////////////////// constexpr bool kSingleIterationGradV = kMaxK <= MatmulGradV::ThreadblockShape::kN; for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); col += MatmulGradV::ThreadblockShape::kN) { using Mma = typename MatmulGradV::Mma; using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; cutlass::gemm::GemmCoord problem_size( num_keys_in_block, p.head_dim_value - col, num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradV::OutputTileIterator( typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, p.grad_value_ptr + key_start * p.gV_strideM() + col, {num_keys_in_block, p.head_dim_value - col}, thread_id); }; typename Mma::IteratorB iterator_B({int32_t(p.gO_strideM)}, p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); // if dropout: dVj += (Pij.T * Zij) @ dOi // otherwise: dVj += Pij.T @ dOi Mma mma(shared_storage.mm_gradV(), // operand A: Pij typename MatmulGradV::WarpIteratorA( shared_storage.attn_shared_storage().accum_ref(), lane_id), // if we're using dropout, operand A is Pij_dropped = Pij * Zij // which is computed on the fly as fragments of Pij are loaded in typename Mma::WarpIteratorAScale(shared_storage.zij().accum_ref(), lane_id), thread_id, warp_id, lane_id); int storage_id = col / MatmulGradV::ThreadblockShape::kN; AccumTileGmem gmem_tile{p.workspace_gv + storage_id * AccumTileGmem::kElementsStored}; if (!kOutputInRF) { if (isFirstQuery || !kNeedsAccumGradV) { output_frags.gradV.clear(); } else { gmem_tile.load(output_frags.gradV, thread_id); } } mma.set_prologue_done(kPrologueGV); auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add __syncthreads(); mma(gemm_k_iterations, output_frags.gradV, iterator_B, output_frags.gradV); __syncthreads(); if (kPrologueGV && !kSingleIterationGradV && col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { prologueGradV(col + MatmulGradV::ThreadblockShape::kN); } if (!kOutputInRF) { if (kNeedsAccumGradV && !isLastQuery) { gmem_tile.store(output_frags.gradV, thread_id); } else { accumulateInGmem(shared_storage.gradV_epilogue(), output_frags.gradV, createEpilogueIter(), isFirstQuery || kNeedsAccumGradV, warp_id, lane_id); } } } __syncthreads(); ///////////////////////////////////////////////////////////////////////////////////////////////// // MatmulDOIVJ ///////////////////////////////////////////////////////////////////////////////////////////////// { using Mma = typename MatmulDOIVJ::Mma; // do_i typename Mma::IteratorA iterator_A({int32_t(p.gO_strideM)}, p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); // v_j.transpose(-2, -1) typename Mma::IteratorB iterator_B({int32_t(p.v_strideM)}, p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id); mma.set_prologue_done(kPrologueDOV); mma.set_zero_outside_bounds(!skipBoundsChecks); typename Mma::FragmentC accum; accum.clear(); auto gemm_k_iterations = (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); __syncthreads(); if (kPrologueGQ) { prologueGradQ(0); } if (kPrologueGK) { prologueGradK(0); } int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); auto output_tile_coords = cutlass::MatrixCoord{ warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM}; // TODO: This must be terribly inefficient. There must be a better way // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem] // attn_shared_storage [smem] <- tmp.T // tmp_shared_storage [smem] <- tmp { using LambdaIterator = typename DefaultMmaAccumLambdaIterator::Iterator; auto lane_offset = LambdaIterator::get_lane_offset(lane_id, warp_id, output_tile_coords); auto attn_T = shared_storage.attn_shared_storage().accum_ref(); accum_t current_di; // dSij = (dPij - Di) * Pij LambdaIterator::iterateRows( lane_offset, [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, [&](int accum_m, int accum_n, int idx) { if (skipBoundsChecks || (accum_m < num_queries_in_block && accum_n < num_keys_in_block)) { accum_t attn = attn_T.at({accum_n, accum_m}); accum[idx] = (accum[idx] - current_di) * attn; } else { accum[idx] = 0; } }, [&](int accum_m) { }); using DefaultGemm = typename MatmulDOIVJ::DefaultGemm; using OutputOp = typename MatmulDOIVJ::BiasGradEpilogueOutputOp; if (broadcast_1::kEnable && p.grad_bias1_ptr) { using Epilogue = typename BiasGradEpilogueAffineRankN::Epilogue; cutlass::layout::AffineRankN<2> layout({0, 1}); auto dst_ptr = p.grad_bias1_ptr + key_start; typename Epilogue::OutputTileIterator output_iter( {layout}, dst_ptr, {num_queries_in_block, num_keys_in_block}, (int)thread_id); Epilogue epilogue(shared_storage.gradB_epilogue(), (int)thread_id, (int)warp_id, (int)lane_id); epilogue(OutputOp(1), output_iter, accum); } if (broadcast_2::kEnable && p.grad_bias2_ptr) { if (broadcast_1::kEnable) { __syncthreads(); } using Epilogue = typename BiasGradEpilogue::Epilogue; typename Epilogue::OutputTileIterator::Params params{p.num_keys}; auto dst_ptr = p.grad_bias2_ptr + query_start * p.num_keys + key_start; typename Epilogue::OutputTileIterator output_iter( params, dst_ptr, {num_queries_in_block, num_keys_in_block}, (int)thread_id); Epilogue epilogue(shared_storage.gradB_epilogue(), (int)thread_id, (int)warp_id, (int)lane_id); epilogue(OutputOp(1), output_iter, accum); } accum = accum * scale; __syncthreads(); if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); // attn <- attn_T.T LambdaIterator::iterateRows( lane_offset, [&](int accum_m) {}, [&](int accum_m, int accum_n, int idx) { tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]); }, [&](int accum_m) {}); } } MatmulDOIVJ::B2bGemm::accumToSmem( shared_storage.tmp_shared_storage(), accum, lane_id, output_tile_coords); __syncthreads(); } p.head_dim = warp_uniform(p.head_dim); p.k_strideM = warp_uniform(p.k_strideM); rematerializeThreadIds(); ///////////////////////////////////////////////////////////////////////////////////////////////// // GradQ matmul // // grad_q[i_start:i_end] += tmp @ k_j ///////////////////////////////////////////////////////////////////////////////////////////////// // Skip the loop & associated branches if we know at compile time the number // of iterations constexpr bool kSingleIterationGradQ = kMaxK <= MatmulGradQ::ThreadblockShape::kN; for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); col += MatmulGradQ::ThreadblockShape::kN) { using Mma = typename MatmulGradQ::Mma; using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; cutlass::gemm::GemmCoord problem_size( num_queries_in_block, false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, num_keys_in_block); // k_j typename Mma::IteratorB iterator_B({int32_t(p.k_strideM)}, p.key_ptr + key_start * p.k_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); auto a = shared_storage.tmp_shared_storage().accum_ref(); Mma mma(shared_storage.mm_gradQ(), shared_storage.tmp_shared_storage(), thread_id, warp_id, lane_id, problem_size.k()); typename Mma::FragmentC accum; bool isFirst = key_start == 0; int col_id = col / MatmulGradQ::ThreadblockShape::kN; int num_cols = kSingleIterationGradQ ? 1 : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); int storage_id = (col_id + query_start / kBlockSizeI * num_cols); AccumTileGmem gmem_tile{p.workspace_gq + storage_id * AccumTileGmem::kElementsStored}; if (isFirst || !kNeedsAccumGradQ) { accum.clear(); } else { gmem_tile.load(accum, thread_id); } auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add __syncthreads(); mma.set_prologue_done(kPrologueGQ); mma(gemm_k_iterations, accum, iterator_B, accum); __syncthreads(); bool isLastColumn = kSingleIterationGradQ || (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); if (kPrologueGQ && !isLastColumn) { prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); } // Output results int32_t next_query, next_key; incrIteration(p, p.num_queries, key_start, next_query, next_key); bool isLast = next_query > query_start || next_key >= p.num_keys; if (kNeedsAccumGradQ && !isLast) { gmem_tile.store(accum, thread_id); } else { typename MatmulGradQ::OutputTileIterator output_it( typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, p.grad_query_ptr + query_start * p.gQ_strideM() + col, {problem_size.m(), problem_size.n()}, thread_id); accumulateInGmem(isLastColumn ? shared_storage.gradQ_epilogue_lastIter() : shared_storage.gradQ_epilogue(), accum, output_it, isFirst || kNeedsAccumGradQ, warp_id, lane_id); } } ///////////////////////////////////////////////////////////////////////////////////////////////// // GradK matmul // // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i ///////////////////////////////////////////////////////////////////////////////////////////////// rematerializeThreadIds(); constexpr bool kSingleIterationGradK = kMaxK <= MatmulGradK::ThreadblockShape::kN; for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); col += MatmulGradK::ThreadblockShape::kN) { using Mma = typename MatmulGradK::Mma; using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; cutlass::gemm::GemmCoord problem_size( num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col, num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradK::OutputTileIterator( typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, p.grad_key_ptr + key_start * p.gK_strideM() + col, {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, thread_id); }; // q_i typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)}, p.query_ptr + query_start * p.q_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); }; auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); }; // this is basically: // opA = kIsTransposedA ? getTmp() : getTmpT(); bool constexpr kIsTransposedA = MatmulGradK::DefaultMmaFromSmem::kIsTransposedA; auto& opA = *call_conditional::apply( getTmp, getTmpT, 0); Mma mma(shared_storage.mm_gradK(), opA, thread_id, warp_id, lane_id, problem_size.k()); int storage_id = col / MatmulGradK::ThreadblockShape::kN; AccumTileGmem gmem_tile{p.workspace_gk + storage_id * AccumTileGmem::kElementsStored}; if (!kOutputInRF) { if (isFirstQuery || !kNeedsAccumGradK) { output_frags.gradK.clear(); } else { gmem_tile.load(output_frags.gradK, thread_id); } } mma.set_prologue_done(kPrologueGK); auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add __syncthreads(); mma(gemm_k_iterations, output_frags.gradK, iterator_B, output_frags.gradK); __syncthreads(); bool isLastColumn = kSingleIterationGradK || col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; if (kPrologueGK && !isLastColumn) { prologueGradK(col + MatmulGradK::ThreadblockShape::kN); } if (kPrologueQK && isLastColumn) { int32_t next_query, next_key; incrIteration(p, query_start, key_start, next_query, next_key); DISPATCH_BOOL(next_key != key_start, kForceReloadK, ([&]() { prologueQkNextIteration( shared_storage, p, next_query, next_key, warp_id, lane_id); })); } // Output results if (!kOutputInRF) { if (kNeedsAccumGradK && !isLastQuery) { gmem_tile.store(output_frags.gradK, thread_id); } else { accumulateInGmem(isLastColumn ? shared_storage.gradK_epilogue_final() : shared_storage.gradK_epilogue(), output_frags.gradK, createEpilogueIter(), isFirstQuery || kNeedsAccumGradK, warp_id, lane_id); __syncthreads(); } } } } static CUTLASS_DEVICE int32_t getQueryStart(Params const& p, int32_t key_start) { return 0; }; static CUTLASS_DEVICE void incrIteration(Params const& p, int32_t query_start, int32_t key_start, int32_t& next_query, int32_t& next_key) { next_query = query_start + kBlockSizeI; next_key = key_start; if (next_query >= p.num_queries) { next_key = key_start + kBlockSizeJ; next_query = getQueryStart(p, next_key); } } template static CUTLASS_DEVICE void prologueQkNextIteration(SharedStorage& shared_storage, Params const& p, int32_t query_start, int32_t key_start, uint8_t warp_id, uint8_t lane_id) { if (query_start >= p.num_queries || key_start >= p.num_keys) { return; } static constexpr bool kReloadK = kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; int thread_id = 32 * warp_id + lane_id; typename MatmulQK::Mma::IteratorA iterator_A({int32_t(p.k_strideM)}, p.key_ptr + key_start * p.k_strideM, {p.num_keys - key_start, p.head_dim}, thread_id, cutlass::MatrixCoord{0, 0}); typename MatmulQK::Mma::IteratorB iterator_B({int32_t(p.q_strideM)}, p.query_ptr + query_start * p.q_strideM, {p.head_dim, p.num_queries - query_start}, thread_id, cutlass::MatrixCoord{0, 0}); MatmulQK::Mma::prologue(shared_storage.mm_qk_k(), shared_storage.mm_qk_q(), iterator_A, iterator_B, thread_id, p.head_dim); } template static CUTLASS_DEVICE void writeFragsToGmem(SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, int32_t key_start, uint8_t warp_id, uint8_t lane_id) { uint16_t thread_id = 32 * warp_id + lane_id; int32_t num_keys_in_block = skipBoundsChecks ? MatmulQK::Mma::Shape::kM : cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); typename MatmulGradV::OutputTileIterator outputV_it( typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, p.grad_value_ptr + key_start * p.gV_strideM(), {num_keys_in_block, p.head_dim_value}, thread_id); accumulateInGmem(shared_storage.gradV_epilogue_final(), output_frags.gradV, outputV_it, true, warp_id, lane_id); typename MatmulGradK::OutputTileIterator outputK_it( typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, p.grad_key_ptr + key_start * p.gK_strideM(), {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, thread_id); accumulateInGmem(shared_storage.gradK_epilogue_final(), output_frags.gradK, outputK_it, true, warp_id, lane_id); } template static CUTLASS_DEVICE void accumulateInGmem( typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, typename MatmulT::Mma::FragmentC const& accum, typename MatmulT::OutputTileIterator output_it, bool first, uint8_t warp_id, uint8_t lane_id) { using DefaultEpilogue = typename MatmulT::DefaultEpilogue; using DefaultOutputOp = typename MatmulT::DefaultOutputOp; using Mma = typename MatmulT::Mma; int thread_id = 32 * warp_id + lane_id; DISPATCH_BOOL( first, kIsFirst, ([&]() { static constexpr auto ScaleType = kIsFirst ? cutlass::epilogue::thread::ScaleType::Nothing : cutlass::epilogue::thread::ScaleType::NoBetaScaling; using EpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination< typename DefaultOutputOp::ElementOutput, DefaultOutputOp::kCount, typename DefaultOutputOp::ElementAccumulator, typename DefaultOutputOp::ElementCompute, ScaleType>; using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined< typename DefaultEpilogue::Shape, typename Mma::Operator, DefaultEpilogue::kPartitionsK, typename MatmulT::OutputTileIterator, typename DefaultEpilogue::AccumulatorFragmentIterator, typename DefaultEpilogue::WarpTileIterator, typename DefaultEpilogue::SharedLoadIterator, EpilogueOutputOp, typename DefaultEpilogue::Padding, DefaultEpilogue::kFragmentsPerIteration, true // IterationsUnroll >; EpilogueOutputOp rescale({1, 1}); Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); epilogue(rescale, output_it, accum, output_it); })); } template static CUTLASS_DEVICE void computeDelta(Params const& p, int32_t query_start, uint8_t warp_id, uint8_t lane_id) { // Each thread computes one value for Delta // Depending on warp configuration, we might have multiple // threads of the same warp working on the same row using AccessType = cutlass::Array; static_assert(kNumThreads >= kBlockSizeI, ""); static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; int16_t thread_id = 32 * warp_id + lane_id; int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); int16_t laneRow = thread_id / kNumThreadsPerLine; bool rowPred = (query_start + laneRow) < p.num_queries; bool pred = rowPred; // on windows, previous syntax __restrict__ AccessType* // resulted in error: "restrict" is not allowed const AccessType* __restrict__ grad_output_ptr = reinterpret_cast( p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + laneFirstCol); const AccessType* __restrict__ output_ptr = reinterpret_cast( p.output_ptr + (query_start + laneRow) * p.o_strideM() + laneFirstCol); static constexpr int64_t kMaxIters = kMaxK / (kElementsPerAccess * kNumThreadsPerLine); constexpr int kPipelineStages = 2; accum_t delta_value = accum_t(0); using GlobalLoad = cutlass::arch::global_load; AccessType frag_grad_output[kPipelineStages]; AccessType frag_output[kPipelineStages]; auto loadAndIncrement = [&](int ld_pos, bool is_valid) { frag_grad_output[ld_pos].clear(); frag_output[ld_pos].clear(); GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid); GlobalLoad(frag_output[ld_pos], output_ptr, is_valid); grad_output_ptr += kNumThreadsPerLine; output_ptr += kNumThreadsPerLine; }; CUTLASS_PRAGMA_UNROLL for (int iter = 0; iter < kPipelineStages - 1; ++iter) { int ld_pos = iter % kPipelineStages; pred = pred && (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) < p.head_dim_value; loadAndIncrement(ld_pos, pred); } auto columnIteration = [&](int iter) { // Load for next iter int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages; pred = pred && (laneFirstCol + (iter + kPipelineStages - 1) * kElementsPerAccess * kNumThreadsPerLine) < p.head_dim_value; loadAndIncrement(ld_pos, pred); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < AccessType::kElements; ++i) { delta_value += accum_t(frag_output[iter % kPipelineStages][i]) * accum_t(frag_grad_output[iter % kPipelineStages][i]); } }; // If we have a small lower-bound for K, we can unroll the loop if (kMaxK <= 256) { CUTLASS_PRAGMA_UNROLL for (int iter = 0; iter < kMaxIters; ++iter) { columnIteration(iter); } } else { int num_iters = ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) * (kElementsPerAccess * kNumThreadsPerLine); for (int iter = 0; iter < num_iters; ++iter) { columnIteration(iter); } } // Reduce between workers static_assert(kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || kNumThreadsPerLine == 4, ""); CUTLASS_PRAGMA_UNROLL for (int i = 1; i < kNumThreadsPerLine; i *= 2) { delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i); } // Store in gmem if (rowPred) { p.delta_ptr[query_start + laneRow] = delta_value; } } }; template __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) attention_kernel_backward_batched_impl(typename AK::Params p) { if (!p.advance_to_block()) { return; } AK::attention_kernel(p); } template __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) attention_kernel_backward_batched(typename AK::Params params);