custom_mma_pipelined.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. /***************************************************************************************************
  2. * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
  3. *reserved. SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice,
  9. *this list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holder nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  22. *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  23. *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  24. *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  25. *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  26. *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  27. *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  28. *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  29. *POSSIBILITY OF SUCH DAMAGE.
  30. *
  31. **************************************************************************************************/
  32. // Copyright (c) Microsoft Corporation.
  33. // SPDX-License-Identifier: Apache-2.0
  34. // DeepSpeed Team
  35. /*! \file
  36. \brief Template for a double-buffered threadblock-scoped GEMM kernel.
  37. */
  38. #pragma once
  39. #include "cutlass/aligned_buffer.h"
  40. #include "cutlass/array.h"
  41. #include "cutlass/cutlass.h"
  42. #include "cutlass/numeric_conversion.h"
  43. #include "cutlass/matrix_shape.h"
  44. #include "cutlass/numeric_types.h"
  45. #include "custom_mma_base.h"
  46. #include "cutlass/gemm/gemm.h"
  47. /////////////////////////////////////////////////////////////////////////////////////////////////
  48. namespace cutlass {
  49. namespace gemm {
  50. namespace threadblock {
  51. /////////////////////////////////////////////////////////////////////////////////////////////////
  52. /// Structure to compute the matrix product targeting CUDA cores and SIMT math
  53. /// instructions.
  54. template <
  55. /// Size of the Gemm problem - concept: gemm::GemmShape<>
  56. typename Shape_,
  57. /// Iterates over tiles of A operand in global memory
  58. // (concept: ReadableTileIterator | ForwardTileIterator |
  59. // MaskedTileIterator)
  60. typename IteratorA_,
  61. /// Iterates over tiles of A operand in shared memory
  62. /// (concept: WriteableTileIterator | RandomAccessTileIterator)
  63. typename SmemIteratorA_,
  64. /// Iterates over tiles of B operand in global memory
  65. // (concept: ReadableTileIterator | ForwardTileIterator |
  66. // MaskedTileIterator)
  67. typename IteratorB_,
  68. /// Iterates over tiles of B operand in shared memory
  69. /// (concept: WriteableTileIterator | RandomAccessTileIterator)
  70. typename SmemIteratorB_,
  71. /// Data type of accumulator matrix
  72. typename ElementC_,
  73. /// Data type of accumulator matrix
  74. typename LayoutC_,
  75. /// Policy describing tuning details (concept: MmaPolicy)
  76. typename Policy_,
  77. /// Transformation applied to A operand
  78. typename TransformA_ = NumericArrayConverter<typename SmemIteratorA_::Element,
  79. typename IteratorA_::Element,
  80. IteratorA_::Fragment::kElements>,
  81. ///
  82. /// Transformation applied to B operand
  83. typename TransformB_ = NumericArrayConverter<typename SmemIteratorB_::Element,
  84. typename IteratorB_::Element,
  85. IteratorB_::Fragment::kElements>,
  86. /// Used for partial specialization
  87. typename Enable = bool>
  88. class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
  89. public:
  90. ///< Base class
  91. using Base = CustomMmaBase<Shape_, Policy_, 2>;
  92. using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
  93. using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
  94. using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
  95. using ElementC = ElementC_; ///< Data type of accumulator matrix
  96. using LayoutC = LayoutC_; ///< Layout of accumulator matrix
  97. using Policy = Policy_; ///< Policy describing tuning details
  98. using SmemIteratorA = SmemIteratorA_;
  99. using SmemIteratorB = SmemIteratorB_;
  100. using TransformA = TransformA_;
  101. using TransformB = TransformB_;
  102. //
  103. // Dependent types
  104. //
  105. /// Fragment of operand A loaded from global memory
  106. using FragmentA = typename IteratorA::Fragment;
  107. /// Fragment of operand B loaded from global memory
  108. using FragmentB = typename IteratorB::Fragment;
  109. /// Fragment of accumulator tile
  110. using FragmentC = typename Policy::Operator::FragmentC;
  111. /// Warp-level Mma
  112. using Operator = typename Policy::Operator;
  113. /// Obtain the arch tag from the warp-level operator
  114. using ArchTag = typename Policy::Operator::ArchTag;
  115. /// Complex transform on A operand
  116. static ComplexTransform const kTransformA = Operator::kTransformA;
  117. /// Complex transform on B operand
  118. static ComplexTransform const kTransformB = Operator::kTransformB;
  119. // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
  120. static_assert((Base::kStages == 2), "MmaPipelined requires kStages set to value 2");
  121. static bool const kSmemContainsEntireMat = false;
  122. private:
  123. using WarpFragmentA = typename Operator::FragmentA;
  124. using WarpFragmentB = typename Operator::FragmentB;
  125. protected:
  126. /// Iterator to write threadblock-scoped tile of A operand to shared memory
  127. SmemIteratorA smem_iterator_A_;
  128. /// Iterator to write threadblock-scoped tile of B operand to shared memory
  129. SmemIteratorB smem_iterator_B_;
  130. public:
  131. /// Construct from tensor references
  132. CUTLASS_DEVICE
  133. CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA,
  134. typename Base::SharedStorageB& shared_storageB,
  135. int thread_idx, ///< ID within the threadblock
  136. int warp_idx, ///< ID of warp
  137. int lane_idx ///< ID of each thread within a warp
  138. )
  139. : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
  140. smem_iterator_A_(shared_storageA.ref(), thread_idx),
  141. smem_iterator_B_(shared_storageB.ref(), thread_idx)
  142. {
  143. // Compute warp location within threadblock tile by mapping the warp_id to
  144. // three coordinates:
  145. // _m: the warp's position within the threadblock along the M dimension
  146. // _n: the warp's position within the threadblock along the N dimension
  147. // _k: the warp's position within the threadblock along the K dimension
  148. int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
  149. int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
  150. int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
  151. int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
  152. // Add per-warp offsets in units of warp-level tiles
  153. this->warp_tile_iterator_A_.add_tile_offset(
  154. {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
  155. this->warp_tile_iterator_B_.add_tile_offset(
  156. {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
  157. }
  158. CUTLASS_DEVICE
  159. CustomMmaPipelined(
  160. ///< Shared storage needed for internal use by threadblock-scoped GEMM
  161. typename Base::SharedStorage& st,
  162. ///< ID within the threadblock
  163. int thread_idx,
  164. ///< ID of warp
  165. int warp_idx,
  166. ///< ID of each thread within a warp
  167. int lane_idx)
  168. : CustomMmaPipelined(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx)
  169. {
  170. }
  171. CUTLASS_DEVICE
  172. bool set_prologue_done(bool value)
  173. {
  174. // NOT IMPLEMENTED FOR PIPELINED
  175. }
  176. CUTLASS_DEVICE
  177. bool set_zero_outside_bounds(bool value)
  178. {
  179. // NOT NEEDED FOR PIPELINED
  180. // shared memory will always be zero-filled
  181. }
  182. template <bool kLoadA = true, bool kLoadB = true>
  183. CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
  184. ///< iterator over A operand in global memory
  185. IteratorA iterator_A,
  186. ///< iterator over B operand in global memory
  187. IteratorB iterator_B,
  188. int thread_idx,
  189. int problem_size_k)
  190. {
  191. prologue<kLoadA, kLoadB>(shared_storage.operand_A,
  192. shared_storage.operand_B,
  193. iterator_A,
  194. iterator_B,
  195. thread_idx,
  196. problem_size_k);
  197. }
  198. template <bool kLoadA = true, bool kLoadB = true>
  199. CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA,
  200. typename Base::SharedStorageB& shared_storageB,
  201. ///< iterator over A operand in global memory
  202. IteratorA iterator_A,
  203. ///< iterator over B operand in global memory
  204. IteratorB iterator_B,
  205. int thread_idx,
  206. int problem_size_k)
  207. {
  208. // NOT IMPLEMENTED FOR PIPELINED
  209. }
  210. /// Perform a threadblock-scoped matrix multiply-accumulate
  211. CUTLASS_DEVICE
  212. void operator()(
  213. int gemm_k_iterations, ///< number of iterations of the mainloop
  214. FragmentC& accum, ///< destination accumulator tile
  215. IteratorA iterator_A, ///< iterator over A operand in global memory
  216. IteratorB iterator_B, ///< iterator over B operand in global memory
  217. FragmentC const& src_accum, ///< source accumulator tile
  218. TransformA transform_A = TransformA(), ///< transformation applied to A fragment
  219. TransformB transform_B = TransformB())
  220. { ///< transformation applied to B fragment
  221. //
  222. // Prologue
  223. //
  224. // Perform accumulation in the 'd' output operand
  225. accum = src_accum;
  226. FragmentA tb_frag_A;
  227. FragmentB tb_frag_B;
  228. tb_frag_A.clear();
  229. tb_frag_B.clear();
  230. // The last kblock is loaded in the prolog
  231. iterator_A.load(tb_frag_A);
  232. iterator_B.load(tb_frag_B);
  233. ++iterator_A;
  234. ++iterator_B;
  235. this->smem_iterator_A_.store(transform_A(tb_frag_A));
  236. this->smem_iterator_B_.store(transform_B(tb_frag_B));
  237. ++this->smem_iterator_A_;
  238. ++this->smem_iterator_B_;
  239. __syncthreads();
  240. // Pair of fragments used to overlap shared memory loads and math
  241. // instructions
  242. WarpFragmentA warp_frag_A[2];
  243. WarpFragmentB warp_frag_B[2];
  244. this->warp_tile_iterator_A_.set_kgroup_index(0);
  245. this->warp_tile_iterator_B_.set_kgroup_index(0);
  246. this->warp_tile_iterator_A_.load(warp_frag_A[0]);
  247. this->warp_tile_iterator_B_.load(warp_frag_B[0]);
  248. ++this->warp_tile_iterator_A_;
  249. ++this->warp_tile_iterator_B_;
  250. Operator warp_mma;
  251. int smem_write_stage_idx = 1;
  252. // Avoid reading out of bounds
  253. iterator_A.clear_mask(gemm_k_iterations <= 1);
  254. iterator_B.clear_mask(gemm_k_iterations <= 1);
  255. // Issue loads during the first warp-level matrix multiply-add *AFTER*
  256. // issuing shared memory loads (which have the tightest latency requirement).
  257. //
  258. // Mainloop
  259. //
  260. // Note: The main loop does not support Base::kWarpGemmIterations == 2.
  261. CUTLASS_GEMM_LOOP
  262. for (; gemm_k_iterations > 0; --gemm_k_iterations) {
  263. //
  264. // Loop over GEMM K dimension
  265. //
  266. CUTLASS_PRAGMA_UNROLL
  267. for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
  268. // Load warp-level tiles from shared memory, wrapping to k offset if
  269. // this is the last group as the case may be.
  270. if (warp_mma_k == Base::kWarpGemmIterations - 1) {
  271. // Write fragments to shared memory
  272. this->smem_iterator_A_.store(transform_A(tb_frag_A));
  273. this->smem_iterator_B_.store(transform_B(tb_frag_B));
  274. __syncthreads();
  275. ++this->smem_iterator_A_;
  276. ++this->smem_iterator_B_;
  277. // Add negative offsets to return iterators to the 'start' of the
  278. // circular buffer in shared memory
  279. if (smem_write_stage_idx == 1) {
  280. this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
  281. this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
  282. } else {
  283. this->warp_tile_iterator_A_.add_tile_offset(
  284. {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
  285. this->warp_tile_iterator_B_.add_tile_offset(
  286. {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
  287. }
  288. smem_write_stage_idx ^= 1;
  289. }
  290. this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
  291. Base::kWarpGemmIterations);
  292. this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
  293. Base::kWarpGemmIterations);
  294. this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
  295. this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
  296. ++this->warp_tile_iterator_A_;
  297. ++this->warp_tile_iterator_B_;
  298. if (warp_mma_k == 0) {
  299. iterator_A.load(tb_frag_A);
  300. iterator_B.load(tb_frag_B);
  301. ++iterator_A;
  302. ++iterator_B;
  303. // Avoid reading out of bounds if this was the last loop iteration
  304. iterator_A.clear_mask(gemm_k_iterations <= 2);
  305. iterator_B.clear_mask(gemm_k_iterations <= 2);
  306. }
  307. warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
  308. }
  309. }
  310. }
  311. };
  312. /////////////////////////////////////////////////////////////////////////////////////////////////
  313. } // namespace threadblock
  314. } // namespace gemm
  315. } // namespace cutlass
  316. /////////////////////////////////////////////////////////////////////////////////////////////////