kernel_forward.h 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986
  1. /***************************************************************************************************
  2. * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. * SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice, this
  9. * list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holdvr nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  22. * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  23. * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  24. * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  25. * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  26. * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  27. * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  28. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *
  30. **************************************************************************************************/
  31. // Copyright (c) Microsoft Corporation.
  32. // SPDX-License-Identifier: Apache-2.0
  33. // DeepSpeed Team
  34. #pragma once
  35. #include <curand_kernel.h>
  36. #include <cmath>
  37. #include <vector>
  38. #include "cutlass/bfloat16.h"
  39. #include "cutlass/fast_math.h"
  40. #include "cutlass/gemm/gemm.h"
  41. #include "cutlass/layout/matrix.h"
  42. #include "cutlass/layout/vector.h"
  43. #include "cutlass/matrix.h"
  44. #include "cutlass/numeric_types.h"
  45. #include "cutlass/tensor_ref.h"
  46. #include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
  47. #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
  48. #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
  49. #include "cutlass/gemm/device/default_gemm_configuration.h"
  50. #include "cutlass/gemm/kernel/default_gemm.h"
  51. #include "cutlass/gemm/threadblock/default_mma.h"
  52. #include "cutlass/gemm/threadblock/default_mma_core_simt.h"
  53. #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
  54. #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
  55. #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
  56. #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
  57. #include "cutlass/matrix_shape.h"
  58. #include "cutlass/platform/platform.h"
  59. #include "cutlass/transform/threadblock/predicated_tile_iterator.h"
  60. #include "epilogue/epilogue_pipelined.h"
  61. #include "epilogue/epilogue_rescale_output.h"
  62. #include "gemm/find_default_mma.h"
  63. #include "gemm/mma_from_smem.h"
  64. #include "gemm_kernel_utils.h"
  65. #include "transform/bias_broadcast.h"
  66. #include "transform/tile_smem_loader.h"
  67. #include <inttypes.h>
  68. using namespace gemm_kernel_utils;
  69. namespace {
  70. template <typename scalar_t, typename Arch>
  71. constexpr int getWarpsPerSm()
  72. {
  73. return (Arch::kMinComputeCapability >= 80 && !cutlass::platform::is_same<scalar_t, float>::value
  74. ? 16
  75. : 12);
  76. }
  77. static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value)
  78. {
  79. // source: https://stackoverflow.com/a/51549250
  80. return (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
  81. : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
  82. }
  83. } // namespace
  84. template <
  85. // The datatype of Q/K/V
  86. typename scalar_t_,
  87. // Architecture we are targeting (eg `cutlass::arch::Sm80`)
  88. typename ArchTag,
  89. // If Q/K/V are correctly aligned in memory and we can run a fast kernel
  90. bool isAligned_,
  91. int kQueriesPerBlock,
  92. int kKeysPerBlock_,
  93. bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
  94. // This is quite slower on V100 for some reason
  95. // Set to false if you know at compile-time you will never need dropout
  96. bool kSupportsBias_ = false,
  97. template <typename, typename, typename> class Broadcast1_ = BroadcastNoLoad,
  98. template <typename, typename, typename> class Broadcast2_ = BroadcastNoLoad>
  99. struct AttentionKernel {
  100. using scalar_t = scalar_t_;
  101. using accum_t = float;
  102. using lse_scalar_t = float;
  103. using output_t = scalar_t;
  104. // Accumulator between 2 iterations
  105. // Using `accum_t` improves perf on f16 at the cost of
  106. // numerical errors
  107. using output_accum_t = accum_t;
  108. static constexpr bool kSupportsBias = kSupportsBias_;
  109. static constexpr int kKeysPerBlock = kKeysPerBlock_;
  110. static constexpr bool kIsAligned = isAligned_;
  111. static constexpr bool kSingleValueIteration = kSingleValueIteration_;
  112. static constexpr int32_t kAlignLSE = 32; // block size of backward
  113. static constexpr bool kPreloadV =
  114. ArchTag::kMinComputeCapability >= 80 && cutlass::sizeof_bits<scalar_t>::value == 16;
  115. static constexpr bool kKeepOutputInRF = kSingleValueIteration;
  116. static constexpr bool kNeedsOutputAccumulatorBuffer =
  117. !kKeepOutputInRF && !cutlass::platform::is_same<output_accum_t, output_t>::value;
  118. static_assert(kQueriesPerBlock % 32 == 0, "");
  119. static_assert(kKeysPerBlock % 32 == 0, "");
  120. static constexpr int kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (32 * 32);
  121. static constexpr int kWarpSize = 32;
  122. // Launch bounds
  123. static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
  124. static constexpr int kMinBlocksPerSm = getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
  125. struct Params {
  126. // Input tensors
  127. scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
  128. scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
  129. scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
  130. // Output tensors
  131. output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
  132. output_accum_t* output_accum_ptr; // [num_queries, num_heads, head_dim_value]
  133. lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
  134. // Scale
  135. accum_t scale;
  136. // Dimensions/strides
  137. int32_t head_dim;
  138. int32_t head_dim_value;
  139. int32_t num_queries;
  140. int32_t num_keys;
  141. int32_t q_strideM;
  142. int32_t k_strideM;
  143. int32_t v_strideM;
  144. // int32_t bias_strideM = 0;
  145. int32_t o_strideM = 0;
  146. // Everything below is only used in `advance_to_block`
  147. // and shouldn't use registers
  148. int32_t q_strideH;
  149. int32_t k_strideH;
  150. int32_t v_strideH;
  151. // int32_t bias_strideH = 0;
  152. int64_t q_strideB;
  153. int64_t k_strideB;
  154. int64_t v_strideB;
  155. // int32_t bias_strideB = 0;
  156. int32_t num_batches;
  157. int32_t num_heads;
  158. // Parameters for biases
  159. scalar_t* bias1_ptr = nullptr;
  160. scalar_t* bias2_ptr = nullptr;
  161. int32_t B = 0;
  162. int32_t N = 0;
  163. // Moves pointers to what we should process
  164. // Returns "false" if there is no work to do
  165. CUTLASS_DEVICE bool advance_to_block()
  166. {
  167. auto batch_id = blockIdx.z;
  168. auto head_id = blockIdx.y;
  169. auto query_start = blockIdx.x * kQueriesPerBlock;
  170. auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
  171. query_ptr += batch_id * q_strideB;
  172. key_ptr += batch_id * k_strideB;
  173. value_ptr += batch_id * v_strideB;
  174. output_ptr += int64_t(batch_id * num_queries) * o_strideM;
  175. if (output_accum_ptr != nullptr) {
  176. output_accum_ptr += int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
  177. }
  178. int64_t q_start = 0, k_start = 0;
  179. // Advance to the current batch / head / query_start
  180. query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
  181. key_ptr += k_start * k_strideM + head_id * k_strideH;
  182. value_ptr += k_start * v_strideM + head_id * v_strideH;
  183. output_ptr += int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
  184. if (output_accum_ptr != nullptr) {
  185. output_accum_ptr += int64_t(q_start + query_start) * (head_dim_value * num_heads) +
  186. head_id * head_dim_value;
  187. } else {
  188. // Accumulate directly in the destination buffer (eg for f32)
  189. output_accum_ptr = (accum_t*)output_ptr;
  190. }
  191. if (logsumexp_ptr != nullptr) {
  192. // lse[batch_id, head_id, query_start]
  193. logsumexp_ptr += batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
  194. }
  195. using broadcast_1 = Broadcast1_<typename MM0::BiasLoader::ThreadMap,
  196. typename MM0::BiasLoader::Shape,
  197. scalar_t>;
  198. if (kSupportsBias && broadcast_1::kEnable && bias1_ptr) {
  199. bias1_ptr = broadcast_1::advance(bias1_ptr,
  200. batch_id / N,
  201. batch_id % N,
  202. head_id,
  203. num_queries * N,
  204. num_queries,
  205. 0);
  206. }
  207. using broadcast_2 = Broadcast2_<typename MM0::BiasLoader::ThreadMap,
  208. typename MM0::BiasLoader::Shape,
  209. scalar_t>;
  210. if (kSupportsBias && broadcast_2::kEnable && bias2_ptr) {
  211. auto strideB = num_heads * num_queries * num_keys;
  212. auto strideH = num_queries * num_keys;
  213. bias2_ptr = broadcast_2::advance(
  214. bias2_ptr, batch_id / N, batch_id % N, head_id, strideB, 0, strideH);
  215. }
  216. num_queries -= query_start;
  217. num_batches = 0; // no longer used after
  218. // If num_queries == 1, and there is only one key head we're wasting
  219. // 15/16th of tensor core compute In that case :
  220. // - we only launch kernels for head_id % kQueriesPerBlock == 0
  221. // - we iterate over heads instead of queries (strideM = strideH)
  222. if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) {
  223. if (head_id % kQueriesPerBlock != 0) return false;
  224. q_strideM = q_strideH;
  225. num_queries = num_heads;
  226. num_heads = 1; // unused but here for intent
  227. o_strideM = head_dim_value;
  228. }
  229. // Make sure the compiler knows these variables are the same on all
  230. // the threads of the warp.
  231. query_ptr = warp_uniform(query_ptr);
  232. key_ptr = warp_uniform(key_ptr);
  233. value_ptr = warp_uniform(value_ptr);
  234. output_ptr = warp_uniform(output_ptr);
  235. output_accum_ptr = warp_uniform(output_accum_ptr);
  236. logsumexp_ptr = warp_uniform(logsumexp_ptr);
  237. num_queries = warp_uniform(num_queries);
  238. num_keys = warp_uniform(num_keys);
  239. num_heads = warp_uniform(num_heads);
  240. head_dim = warp_uniform(head_dim);
  241. head_dim_value = warp_uniform(head_dim_value);
  242. o_strideM = warp_uniform(o_strideM);
  243. if (kSupportsBias && broadcast_1::kEnable) { bias1_ptr = warp_uniform(bias1_ptr); }
  244. if (kSupportsBias && broadcast_2::kEnable) { bias2_ptr = warp_uniform(bias2_ptr); }
  245. return true;
  246. }
  247. __host__ dim3 getBlocksGrid() const
  248. {
  249. return dim3(ceil_div(num_queries, (int32_t)kQueriesPerBlock), num_heads, num_batches);
  250. }
  251. __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); }
  252. };
  253. struct MM0 {
  254. /*
  255. In this first matmul, we compute a block of `Q @ K.T`.
  256. While the calculation result is still hot in registers, we update
  257. `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
  258. into a shared-memory ("AccumulatorSharedStorage") that is used later as
  259. operand A for the second matmul (see MM1)
  260. */
  261. using GemmType = DefaultGemmType<ArchTag, scalar_t>;
  262. using OpClass = typename GemmType::OpClass;
  263. using DefaultConfig =
  264. typename cutlass::gemm::device::DefaultGemmConfiguration<OpClass,
  265. ArchTag,
  266. scalar_t,
  267. scalar_t,
  268. scalar_t, // ElementC
  269. accum_t // ElementAccumulator
  270. >;
  271. static constexpr int kAlignmentA = kIsAligned ? DefaultConfig::kAlignmentA
  272. : GemmType::kMinimumAlignment;
  273. static constexpr int kAlignmentB = kIsAligned ? DefaultConfig::kAlignmentB
  274. : GemmType::kMinimumAlignment;
  275. using ThreadblockShape =
  276. cutlass::gemm::GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
  277. using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
  278. using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
  279. scalar_t, // ElementA,
  280. cutlass::layout::RowMajor, // LayoutA,
  281. kAlignmentA,
  282. scalar_t, // ElementB,
  283. cutlass::layout::ColumnMajor, // LayoutB,
  284. kAlignmentB,
  285. accum_t,
  286. cutlass::layout::RowMajor, // LayoutC,
  287. OpClass,
  288. ArchTag, // ArchTag
  289. ThreadblockShape, // ThreadblockShape
  290. WarpShape, // WarpShape
  291. typename GemmType::InstructionShape, // InstructionShape
  292. DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
  293. // uses too much smem
  294. typename GemmType::Operator // Operator
  295. >::DefaultMma;
  296. using MmaCore = typename DefaultMma::MmaCore;
  297. using IteratorA = typename DefaultMma::IteratorA;
  298. using IteratorB = typename DefaultMma::IteratorB;
  299. using Mma = typename DefaultMma::ThreadblockMma;
  300. using AccumLambdaIterator =
  301. typename DefaultMmaAccumLambdaIterator<typename Mma::Operator::IteratorC,
  302. accum_t,
  303. kWarpSize>::Iterator;
  304. static_assert(MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * MmaCore::WarpCount::kK ==
  305. kNumWarpsPerBlock,
  306. "");
  307. // used for efficient load of bias tile Bij from global to shared memory
  308. using BiasLoader =
  309. TileSmemLoader<scalar_t,
  310. cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
  311. MmaCore::kThreads,
  312. // input restriction: kv_len has to be a multiple of this value
  313. 128 / cutlass::sizeof_bits<scalar_t>::value>;
  314. // Epilogue to store to shared-memory in a format that we can use later for
  315. // the second matmul
  316. using B2bGemm =
  317. typename cutlass::gemm::threadblock::B2bGemm<typename Mma::Operator::IteratorC,
  318. typename Mma::Operator,
  319. scalar_t,
  320. WarpShape,
  321. ThreadblockShape>;
  322. using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
  323. };
  324. struct MM1 {
  325. /**
  326. Second matmul: perform `attn @ V` where `attn` is the attention (not
  327. normalized) and stored in shared memory
  328. */
  329. using GemmType = DefaultGemmType<ArchTag, scalar_t>;
  330. using OpClass = typename GemmType::OpClass;
  331. using DefaultConfig =
  332. typename cutlass::gemm::device::DefaultGemmConfiguration<OpClass,
  333. ArchTag,
  334. scalar_t,
  335. scalar_t,
  336. output_accum_t, // ElementC
  337. accum_t // ElementAccumulator
  338. >;
  339. static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
  340. static constexpr int kAlignmentB = kIsAligned ? DefaultConfig::kAlignmentB
  341. : GemmType::kMinimumAlignment;
  342. using ThreadblockShape =
  343. cutlass::gemm::GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
  344. using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
  345. using InstructionShape = typename GemmType::InstructionShape;
  346. using LayoutB = cutlass::layout::RowMajor;
  347. using DefaultGemm =
  348. cutlass::gemm::kernel::DefaultGemm<scalar_t, // ElementA,
  349. cutlass::layout::RowMajor, // LayoutA,
  350. kAlignmentA,
  351. scalar_t, // ElementB,
  352. LayoutB, // LayoutB,
  353. kAlignmentB,
  354. output_accum_t,
  355. cutlass::layout::RowMajor, // LayoutC,
  356. accum_t,
  357. OpClass,
  358. ArchTag,
  359. ThreadblockShape,
  360. WarpShape,
  361. typename GemmType::InstructionShape,
  362. typename DefaultConfig::EpilogueOutputOp,
  363. void, // ThreadblockSwizzle - not used
  364. DefaultConfig::kStages,
  365. false, // SplitKSerial
  366. typename GemmType::Operator>;
  367. using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
  368. typename DefaultGemm::Mma,
  369. typename MM0::AccumulatorSharedStorage,
  370. false>; // kScaleOperandA
  371. using Mma = typename DefaultMmaFromSmem::Mma;
  372. using IteratorB = typename Mma::IteratorB;
  373. using WarpCount = typename Mma::WarpCount;
  374. static_assert(WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, "");
  375. using DefaultEpilogue = typename DefaultGemm::Epilogue;
  376. using OutputTileIterator = typename cutlass::epilogue::threadblock::PredicatedTileIterator<
  377. typename DefaultEpilogue::OutputTileIterator::ThreadMap,
  378. output_t>;
  379. using OutputTileIteratorAccum =
  380. typename cutlass::epilogue::threadblock::PredicatedTileIterator<
  381. typename DefaultEpilogue::OutputTileIterator::ThreadMap,
  382. output_accum_t>;
  383. struct SharedStorageMM1 {
  384. typename Mma::SharedStorage mm;
  385. };
  386. };
  387. static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
  388. static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
  389. static constexpr int64_t kAlignmentV = 1;
  390. // Shared storage - depends on kernel params
  391. struct ScalingCoefs {
  392. cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
  393. cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
  394. cutlass::Array<accum_t, kQueriesPerBlock> mi;
  395. };
  396. struct SharedStorageEpilogueAtEnd : ScalingCoefs {
  397. struct SharedStorageAfterMM0 {
  398. // Everything here might be overwritten during MM0
  399. union {
  400. // typename MM0::BiasLoader::SmemTile bias;
  401. cutlass::AlignedBuffer<float, MM0::BiasLoader::Shape::kCount> bias;
  402. typename MM0::AccumulatorSharedStorage si;
  403. };
  404. typename MM1::SharedStorageMM1 mm1;
  405. };
  406. union {
  407. typename MM0::Mma::SharedStorage mm0;
  408. SharedStorageAfterMM0 after_mm0;
  409. typename MM1::DefaultEpilogue::SharedStorage epilogue;
  410. };
  411. CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage()
  412. {
  413. return epilogue;
  414. }
  415. };
  416. struct SharedStorageEpilogueInLoop : ScalingCoefs {
  417. struct SharedStorageAfterMM0 {
  418. // Everything here might be overwritten during MM0
  419. union {
  420. // typename MM0::BiasLoader::SmemTile bias;
  421. cutlass::AlignedBuffer<float, MM0::BiasLoader::Shape::kCount> bias;
  422. typename MM0::AccumulatorSharedStorage si;
  423. };
  424. typename MM1::SharedStorageMM1 mm1;
  425. typename MM1::DefaultEpilogue::SharedStorage epilogue;
  426. };
  427. union {
  428. typename MM0::Mma::SharedStorage mm0;
  429. SharedStorageAfterMM0 after_mm0;
  430. };
  431. CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage()
  432. {
  433. return after_mm0.epilogue;
  434. }
  435. };
  436. using SharedStorage =
  437. typename cutlass::platform::conditional<kSingleValueIteration || kKeepOutputInRF,
  438. SharedStorageEpilogueAtEnd,
  439. SharedStorageEpilogueInLoop>::type;
  440. static bool __host__ check_supported(Params const& p)
  441. {
  442. CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
  443. CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
  444. CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
  445. EVOFORMER_CHECK(p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned (strideM)");
  446. EVOFORMER_CHECK(p.k_strideM % kAlignmentK == 0, "key is not correctly aligned (strideM)");
  447. EVOFORMER_CHECK(p.v_strideM % kAlignmentV == 0, "value is not correctly aligned (strideM)");
  448. EVOFORMER_CHECK(p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
  449. "query is not correctly aligned (strideH)");
  450. EVOFORMER_CHECK(p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
  451. "key is not correctly aligned (strideH)");
  452. EVOFORMER_CHECK(p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
  453. "value is not correctly aligned (strideH)");
  454. return true;
  455. }
  456. static void CUTLASS_DEVICE attention_kernel(Params& p)
  457. {
  458. // In this block, we will only ever:
  459. // - read query[query_start:query_end, :]
  460. // - write to output[query_start:query_end, :]
  461. extern __shared__ char smem_buffer[];
  462. SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
  463. auto& m_prime = shared_storage.m_prime;
  464. auto& s_prime = shared_storage.s_prime;
  465. auto& mi = shared_storage.mi;
  466. const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
  467. static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
  468. if (thread_id() < kQueriesPerBlock) {
  469. s_prime[thread_id()] = accum_t(0);
  470. m_prime[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
  471. mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
  472. }
  473. typename MM1::Mma::FragmentC accum_o;
  474. accum_o.clear();
  475. auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
  476. using OutputTileIterator = typename MM1::OutputTileIterator;
  477. return OutputTileIterator(
  478. typename OutputTileIterator::Params{(int32_t)p.o_strideM},
  479. p.output_ptr,
  480. typename OutputTileIterator::TensorCoord{p.num_queries, p.head_dim_value},
  481. thread_id(),
  482. {0, col});
  483. };
  484. auto createOutputAccumIter = [&](int col) -> typename MM1::OutputTileIteratorAccum {
  485. using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
  486. return OutputTileIteratorAccum(
  487. typename OutputTileIteratorAccum::Params{(int32_t)(p.head_dim_value * p.num_heads)},
  488. p.output_accum_ptr,
  489. typename OutputTileIteratorAccum::TensorCoord{p.num_queries, p.head_dim_value},
  490. thread_id(),
  491. {0, col});
  492. };
  493. // Iterate through keys
  494. for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
  495. iter_key_start += kKeysPerBlock) {
  496. int32_t problem_size_0_m = cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
  497. int32_t problem_size_0_n =
  498. cutlass::fast_min(int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
  499. int32_t const& problem_size_0_k = p.head_dim;
  500. int32_t const& problem_size_1_n = p.head_dim_value;
  501. int32_t const& problem_size_1_k = problem_size_0_n;
  502. auto prologueV = [&](int blockN) {
  503. typename MM1::Mma::IteratorB iterator_V(
  504. typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
  505. p.value_ptr + iter_key_start * p.v_strideM,
  506. {problem_size_1_k, problem_size_1_n},
  507. thread_id(),
  508. cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
  509. MM1::Mma::prologue(
  510. shared_storage.after_mm0.mm1.mm, iterator_V, thread_id(), problem_size_1_k);
  511. };
  512. __syncthreads(); // Need to have shared memory initialized, and `m_prime`
  513. // updated from end of prev iter
  514. //
  515. // MATMUL: Q.K_t
  516. //
  517. // Computes the block-matrix product of:
  518. // (a) query[query_start:query_end, :]
  519. // with
  520. // (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
  521. // and stores that into `shared_storage.si`
  522. //
  523. // Compute threadblock location
  524. cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
  525. cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * MM0::Mma::Shape::kM,
  526. tb_tile_offset.k()};
  527. cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(),
  528. tb_tile_offset.n() * MM0::Mma::Shape::kN};
  529. // Construct iterators to A and B operands
  530. typename MM0::IteratorA iterator_A(
  531. typename MM0::IteratorA::Params(typename MM0::MmaCore::LayoutA(p.q_strideM)),
  532. p.query_ptr,
  533. {problem_size_0_m, problem_size_0_k},
  534. thread_id(),
  535. tb_offset_A);
  536. typename MM0::IteratorB iterator_B(
  537. typename MM0::IteratorB::Params(typename MM0::MmaCore::LayoutB(p.k_strideM)),
  538. p.key_ptr + iter_key_start * p.k_strideM,
  539. {problem_size_0_k, problem_size_0_n},
  540. thread_id(),
  541. tb_offset_B);
  542. auto my_warp_id = warp_id();
  543. auto my_lane_id = lane_id();
  544. // Construct thread-scoped matrix multiply
  545. typename MM0::Mma mma(shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
  546. typename MM0::Mma::FragmentC accum;
  547. accum.clear();
  548. auto gemm_k_iterations =
  549. (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
  550. // Compute threadblock-scoped matrix multiply-add
  551. mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
  552. __syncthreads();
  553. if (kPreloadV) {
  554. prologueV(0);
  555. } else {
  556. MM1::Mma::drain_cp_asyncs();
  557. }
  558. typename MM0::Mma::Operator::IteratorC::TensorCoord iteratorC_tile_offset = {
  559. (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
  560. (my_warp_id % MM0::Mma::WarpCount::kM),
  561. (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
  562. (my_warp_id / MM0::Mma::WarpCount::kM)};
  563. // multiply by scaling factor
  564. // if (kSupportsBias) {
  565. // accum =
  566. // cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale,
  567. // accum);
  568. // }
  569. if (kSupportsBias) {
  570. cutlass::TensorRef<float, cutlass::layout::RowMajor> bias_tensor_ref(
  571. shared_storage.after_mm0.bias.data(),
  572. cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
  573. using Shape =
  574. cutlass::MatrixShape<MM0::ThreadblockShape::kM, MM0::ThreadblockShape::kN>;
  575. AttentionBiasEpilogue<Shape,
  576. scalar_t,
  577. MM0::MmaCore::kThreads,
  578. Broadcast1_,
  579. Broadcast2_>
  580. bias_epilogue;
  581. bias_epilogue(bias_tensor_ref,
  582. p.bias1_ptr + iter_key_start,
  583. p.bias2_ptr + query_start * p.num_keys + iter_key_start,
  584. thread_id(),
  585. {problem_size_0_m, problem_size_0_n},
  586. p.num_keys);
  587. // Pij += Bij, Pij is in register fragment and Bij is in shared memory
  588. auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
  589. lane_id(), warp_id(), iteratorC_tile_offset);
  590. MM0::AccumLambdaIterator::iterateRows(
  591. lane_offset,
  592. [&](int accum_m) {},
  593. [&](int accum_m, int accum_n, int idx) {
  594. if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
  595. accum[idx] =
  596. accum[idx] * p.scale + bias_tensor_ref.at({accum_m, accum_n});
  597. }
  598. },
  599. [&](int accum_m) {});
  600. }
  601. DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
  602. DISPATCH_BOOL(
  603. p.num_keys - iter_key_start >= kKeysPerBlock, kFullColumns, ([&] {
  604. // Update `mi` from accum stored in registers
  605. // Also does accum[i] <- exp(accum[i] - mi)
  606. iterative_softmax<typename MM0::Mma::Operator::IteratorC,
  607. kFullColumns,
  608. kIsFirst>(accum_o,
  609. accum,
  610. mi,
  611. m_prime,
  612. s_prime,
  613. lane_id(),
  614. thread_id(),
  615. warp_id(),
  616. p.num_keys - iter_key_start,
  617. iteratorC_tile_offset,
  618. kSupportsBias ? 1.0f : p.scale);
  619. }));
  620. }));
  621. // Output results to shared-memory
  622. int warp_idx_mn_0 =
  623. my_warp_id % (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
  624. auto output_tile_coords =
  625. cutlass::MatrixCoord{warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
  626. warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
  627. MM0::B2bGemm::accumToSmem(
  628. shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
  629. __syncthreads();
  630. //
  631. // MATMUL: Attn . V
  632. // Run the matmul `attn @ V` for a block of attn and V.
  633. // `attn` is read from shared memory (in `shared_storage_si`)
  634. // `V` is read from global memory (with iterator_B)
  635. //
  636. const int64_t nBlockN =
  637. kSingleValueIteration
  638. ? 1
  639. : ceil_div((int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
  640. for (int blockN = 0; blockN < nBlockN; ++blockN) {
  641. int gemm_k_iterations =
  642. (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
  643. // Compute threadblock-scoped matrix multiply-add and store it in accum
  644. // (in registers)
  645. if (!kPreloadV) {
  646. __syncthreads(); // we share shmem between mma and epilogue
  647. }
  648. typename MM1::Mma::IteratorB iterator_V(
  649. typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
  650. p.value_ptr + iter_key_start * p.v_strideM,
  651. {problem_size_1_k, problem_size_1_n},
  652. thread_id(),
  653. cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
  654. typename MM1::Mma mma_pv(shared_storage.after_mm0.mm1.mm,
  655. shared_storage.after_mm0.si,
  656. (int)thread_id(),
  657. (int)warp_id(),
  658. (int)lane_id(),
  659. (int)problem_size_1_k);
  660. mma_pv.set_prologue_done(kPreloadV);
  661. if (!kKeepOutputInRF) { accum_o.clear(); }
  662. mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
  663. __syncthreads();
  664. if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
  665. prologueV(blockN + 1);
  666. }
  667. if (!kKeepOutputInRF) {
  668. MM1::Mma::drain_cp_asyncs();
  669. DISPATCH_BOOL(
  670. iter_key_start == 0, kIsFirst, ([&] {
  671. DISPATCH_BOOL(
  672. (iter_key_start + kKeysPerBlock) >= p.num_keys, kIsLast, ([&] {
  673. using DefaultEpilogue = typename MM1::DefaultEpilogue;
  674. using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
  675. using ElementCompute = typename DefaultOp::ElementCompute;
  676. using EpilogueOutputOp = typename cutlass::epilogue::thread::
  677. MemoryEfficientAttentionNormalize<
  678. typename cutlass::platform::
  679. conditional<kIsLast, output_t, output_accum_t>::
  680. type,
  681. output_accum_t,
  682. DefaultOp::kCount,
  683. typename DefaultOp::ElementAccumulator,
  684. ElementCompute,
  685. kIsFirst,
  686. kIsLast,
  687. cutlass::Array<ElementCompute, kQueriesPerBlock>>;
  688. using Epilogue =
  689. typename cutlass::epilogue::threadblock::EpiloguePipelined<
  690. typename DefaultEpilogue::Shape,
  691. typename MM1::Mma::Operator,
  692. DefaultEpilogue::kPartitionsK,
  693. typename cutlass::platform::conditional<
  694. kIsLast,
  695. typename MM1::OutputTileIterator,
  696. typename MM1::OutputTileIteratorAccum>::type,
  697. typename DefaultEpilogue::AccumulatorFragmentIterator,
  698. typename DefaultEpilogue::WarpTileIterator,
  699. typename DefaultEpilogue::SharedLoadIterator,
  700. EpilogueOutputOp,
  701. typename DefaultEpilogue::Padding,
  702. DefaultEpilogue::kFragmentsPerIteration,
  703. true, // IterationsUnroll
  704. typename MM1::OutputTileIteratorAccum // Read
  705. // iterator
  706. >;
  707. int col = blockN * MM1::Mma::Shape::kN;
  708. auto source_iter = createOutputAccumIter(col);
  709. auto dest_iter =
  710. call_conditional<kIsLast,
  711. decltype(createOutputIter),
  712. decltype(createOutputAccumIter)>::
  713. apply(createOutputIter, createOutputAccumIter, col);
  714. EpilogueOutputOp rescale(s_prime, m_prime);
  715. Epilogue epilogue(shared_storage.epilogue_shared_storage(),
  716. thread_id(),
  717. warp_id(),
  718. lane_id());
  719. epilogue(rescale, dest_iter, accum_o, source_iter);
  720. }));
  721. }));
  722. if (!kSingleValueIteration) { __syncthreads(); }
  723. }
  724. }
  725. __syncthreads(); // we modify `m_prime` after
  726. }
  727. if (kKeepOutputInRF) {
  728. constexpr bool kIsFirst = true;
  729. constexpr bool kIsLast = true;
  730. using DefaultEpilogue = typename MM1::DefaultEpilogue;
  731. using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
  732. using ElementCompute = typename DefaultOp::ElementCompute;
  733. using EpilogueOutputOp =
  734. typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
  735. output_t, // output
  736. output_accum_t, // source
  737. DefaultOp::kCount,
  738. typename DefaultOp::ElementAccumulator, // accum
  739. output_accum_t, // compute
  740. kIsFirst,
  741. kIsLast,
  742. cutlass::Array<ElementCompute, kQueriesPerBlock>>;
  743. using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined<
  744. typename DefaultEpilogue::Shape,
  745. typename MM1::Mma::Operator,
  746. DefaultEpilogue::kPartitionsK,
  747. typename MM1::OutputTileIterator, // destination
  748. typename DefaultEpilogue::AccumulatorFragmentIterator,
  749. typename DefaultEpilogue::WarpTileIterator,
  750. typename DefaultEpilogue::SharedLoadIterator,
  751. EpilogueOutputOp,
  752. typename DefaultEpilogue::Padding,
  753. DefaultEpilogue::kFragmentsPerIteration,
  754. true, // IterationsUnroll
  755. typename MM1::OutputTileIteratorAccum // source tile
  756. >;
  757. auto dest_iter = createOutputIter(0);
  758. EpilogueOutputOp rescale(s_prime, m_prime);
  759. Epilogue epilogue(
  760. shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id());
  761. MM1::Mma::drain_cp_asyncs();
  762. epilogue(rescale, dest_iter, accum_o);
  763. }
  764. // 7. Calculate logsumexp
  765. // To make the backward easier, we pad logsumexp with `inf`
  766. // this avoids a few bound checks, and is not more expensive during fwd
  767. static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
  768. if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
  769. auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
  770. if (thread_id() < p.num_queries) {
  771. p.logsumexp_ptr[thread_id()] =
  772. accum_t(mi[thread_id()]) + cutlass::fast_log(accum_t(s_prime[thread_id()]));
  773. } else if (thread_id() < lse_dim) {
  774. p.logsumexp_ptr[thread_id()] =
  775. cutlass::platform::numeric_limits<accum_t>::infinity();
  776. }
  777. }
  778. }
  779. template <typename WarpIteratorC,
  780. bool kFullColumns,
  781. bool kIsFirst>
  782. CUTLASS_DEVICE static void iterative_softmax(
  783. typename WarpIteratorC::Fragment& frag_o, // output so far
  784. typename WarpIteratorC::Fragment& frag,
  785. cutlass::Array<accum_t, kQueriesPerBlock>& mi,
  786. cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
  787. cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
  788. int8_t lane_id,
  789. int8_t thread_id,
  790. int8_t warp_id,
  791. int16_t max_col,
  792. typename WarpIteratorC::TensorCoord const& tile_offset,
  793. float scaling)
  794. {
  795. /* Iterates on the accumulator and corresponding position on result matrix
  796. (1) Update `mi[r]` to the max value of the row `r`
  797. (2) In a second iteration do the following:
  798. (a) accum <- exp(accum - mi)
  799. (b) m_prime <- exp(m_prime - mi)
  800. (c) s_prime <- s_prime * m_prime + sum(accum)
  801. All of this is done on registers, before we store all of this
  802. on shared memory for the next matmul with Value.
  803. */
  804. using Fragment = typename WarpIteratorC::Fragment;
  805. using LambdaIterator =
  806. typename DefaultMmaAccumLambdaIterator<WarpIteratorC, accum_t, kWarpSize>::Iterator;
  807. // Convert to `accum_t` (rather than double)
  808. constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
  809. if (!kIsFirst) {
  810. if (thread_id < kQueriesPerBlock) { m_prime[thread_id] = mi[thread_id]; }
  811. __syncthreads();
  812. }
  813. auto lane_offset = LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
  814. // First update `mi` to the max per-row
  815. {
  816. accum_t max;
  817. LambdaIterator::iterateRows(
  818. lane_offset,
  819. [&](int accum_m) { max = -cutlass::platform::numeric_limits<accum_t>::infinity(); },
  820. [&](int accum_m, int accum_n, int idx) {
  821. if (kFullColumns || accum_n < max_col) {
  822. max = cutlass::fast_max(max, frag[idx]);
  823. }
  824. },
  825. [&](int accum_m) {
  826. // Having 4x atomicMax seems faster than reduce within warp
  827. // first...
  828. atomicMaxFloat(&mi[accum_m], max * scaling);
  829. });
  830. }
  831. frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
  832. // Make sure we all share the update values for `mi`
  833. __syncthreads();
  834. if (thread_id < kQueriesPerBlock) {
  835. auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
  836. m_prime[thread_id] = m_prime_exp;
  837. s_prime[thread_id] *= m_prime_exp;
  838. }
  839. __syncthreads(); // Update output fragments
  840. if (kKeepOutputInRF && !kIsFirst) {
  841. accum_t mp;
  842. LambdaIterator::iterateRows(
  843. lane_offset,
  844. [&](int accum_m) { mp = m_prime[accum_m]; },
  845. [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
  846. [&](int accum_m) {});
  847. __syncthreads();
  848. }
  849. // Update accum_m, accum_n, ...
  850. {
  851. accum_t mi_row, total_row;
  852. LambdaIterator::iterateRows(
  853. lane_offset,
  854. [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
  855. [&](int accum_m, int accum_n, int idx) {
  856. frag[idx] = (kFullColumns || accum_n < max_col) ? exp2f(frag[idx] - mi_row)
  857. : accum_t(0.0);
  858. },
  859. [&](int accum_m) {});
  860. LambdaIterator::iterateRows(
  861. lane_offset,
  862. [&](int accum_m) { total_row = 0.0; },
  863. [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
  864. [&](int accum_m) {
  865. if (LambdaIterator::reduceSameRow(
  866. lane_id, total_row, [](accum_t a, accum_t b) { return a + b; })) {
  867. atomicAdd(&s_prime[accum_m], total_row);
  868. }
  869. });
  870. }
  871. }
  872. static CUTLASS_DEVICE int8_t lane_id() { return threadIdx.x; }
  873. static CUTLASS_DEVICE int8_t warp_id() { return threadIdx.y; }
  874. static CUTLASS_DEVICE int16_t thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; }
  875. };
  876. template <typename AK>
  877. __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
  878. attention_kernel_batched_impl(typename AK::Params p)
  879. {
  880. if (!p.advance_to_block()) { return; }
  881. AK::attention_kernel(p);
  882. }
  883. template <typename AK>
  884. __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
  885. attention_kernel_batched(typename AK::Params params);