123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388 |
- /***************************************************************************************************
- * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
- *reserved. SPDX-License-Identifier: BSD-3-Clause
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- *
- * 1. Redistributions of source code must retain the above copyright notice,
- *this list of conditions and the following disclaimer.
- *
- * 2. Redistributions in binary form must reproduce the above copyright notice,
- * this list of conditions and the following disclaimer in the documentation
- * and/or other materials provided with the distribution.
- *
- * 3. Neither the name of the copyright holder nor the names of its
- * contributors may be used to endorse or promote products derived from
- * this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
- *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
- *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
- *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
- *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
- *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
- *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
- *POSSIBILITY OF SUCH DAMAGE.
- *
- **************************************************************************************************/
- // Copyright (c) Microsoft Corporation.
- // SPDX-License-Identifier: Apache-2.0
- // DeepSpeed Team
- /*! \file
- \brief Template for a double-buffered threadblock-scoped GEMM kernel.
- */
- #pragma once
- #include "cutlass/aligned_buffer.h"
- #include "cutlass/array.h"
- #include "cutlass/cutlass.h"
- #include "cutlass/numeric_conversion.h"
- #include "cutlass/matrix_shape.h"
- #include "cutlass/numeric_types.h"
- #include "custom_mma_base.h"
- #include "cutlass/gemm/gemm.h"
- /////////////////////////////////////////////////////////////////////////////////////////////////
- namespace cutlass {
- namespace gemm {
- namespace threadblock {
- /////////////////////////////////////////////////////////////////////////////////////////////////
- /// Structure to compute the matrix product targeting CUDA cores and SIMT math
- /// instructions.
- template <
- /// Size of the Gemm problem - concept: gemm::GemmShape<>
- typename Shape_,
- /// Iterates over tiles of A operand in global memory
- // (concept: ReadableTileIterator | ForwardTileIterator |
- // MaskedTileIterator)
- typename IteratorA_,
- /// Iterates over tiles of A operand in shared memory
- /// (concept: WriteableTileIterator | RandomAccessTileIterator)
- typename SmemIteratorA_,
- /// Iterates over tiles of B operand in global memory
- // (concept: ReadableTileIterator | ForwardTileIterator |
- // MaskedTileIterator)
- typename IteratorB_,
- /// Iterates over tiles of B operand in shared memory
- /// (concept: WriteableTileIterator | RandomAccessTileIterator)
- typename SmemIteratorB_,
- /// Data type of accumulator matrix
- typename ElementC_,
- /// Data type of accumulator matrix
- typename LayoutC_,
- /// Policy describing tuning details (concept: MmaPolicy)
- typename Policy_,
- /// Transformation applied to A operand
- typename TransformA_ = NumericArrayConverter<typename SmemIteratorA_::Element,
- typename IteratorA_::Element,
- IteratorA_::Fragment::kElements>,
- ///
- /// Transformation applied to B operand
- typename TransformB_ = NumericArrayConverter<typename SmemIteratorB_::Element,
- typename IteratorB_::Element,
- IteratorB_::Fragment::kElements>,
- /// Used for partial specialization
- typename Enable = bool>
- class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
- public:
- ///< Base class
- using Base = CustomMmaBase<Shape_, Policy_, 2>;
- using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
- using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
- using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
- using ElementC = ElementC_; ///< Data type of accumulator matrix
- using LayoutC = LayoutC_; ///< Layout of accumulator matrix
- using Policy = Policy_; ///< Policy describing tuning details
- using SmemIteratorA = SmemIteratorA_;
- using SmemIteratorB = SmemIteratorB_;
- using TransformA = TransformA_;
- using TransformB = TransformB_;
- //
- // Dependent types
- //
- /// Fragment of operand A loaded from global memory
- using FragmentA = typename IteratorA::Fragment;
- /// Fragment of operand B loaded from global memory
- using FragmentB = typename IteratorB::Fragment;
- /// Fragment of accumulator tile
- using FragmentC = typename Policy::Operator::FragmentC;
- /// Warp-level Mma
- using Operator = typename Policy::Operator;
- /// Obtain the arch tag from the warp-level operator
- using ArchTag = typename Policy::Operator::ArchTag;
- /// Complex transform on A operand
- static ComplexTransform const kTransformA = Operator::kTransformA;
- /// Complex transform on B operand
- static ComplexTransform const kTransformB = Operator::kTransformB;
- // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
- static_assert((Base::kStages == 2), "MmaPipelined requires kStages set to value 2");
- static bool const kSmemContainsEntireMat = false;
- private:
- using WarpFragmentA = typename Operator::FragmentA;
- using WarpFragmentB = typename Operator::FragmentB;
- protected:
- /// Iterator to write threadblock-scoped tile of A operand to shared memory
- SmemIteratorA smem_iterator_A_;
- /// Iterator to write threadblock-scoped tile of B operand to shared memory
- SmemIteratorB smem_iterator_B_;
- public:
- /// Construct from tensor references
- CUTLASS_DEVICE
- CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA,
- typename Base::SharedStorageB& shared_storageB,
- int thread_idx, ///< ID within the threadblock
- int warp_idx, ///< ID of warp
- int lane_idx ///< ID of each thread within a warp
- )
- : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
- smem_iterator_A_(shared_storageA.ref(), thread_idx),
- smem_iterator_B_(shared_storageB.ref(), thread_idx)
- {
- // Compute warp location within threadblock tile by mapping the warp_id to
- // three coordinates:
- // _m: the warp's position within the threadblock along the M dimension
- // _n: the warp's position within the threadblock along the N dimension
- // _k: the warp's position within the threadblock along the K dimension
- int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
- int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
- int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
- int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
- // Add per-warp offsets in units of warp-level tiles
- this->warp_tile_iterator_A_.add_tile_offset(
- {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
- this->warp_tile_iterator_B_.add_tile_offset(
- {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
- }
- CUTLASS_DEVICE
- CustomMmaPipelined(
- ///< Shared storage needed for internal use by threadblock-scoped GEMM
- typename Base::SharedStorage& st,
- ///< ID within the threadblock
- int thread_idx,
- ///< ID of warp
- int warp_idx,
- ///< ID of each thread within a warp
- int lane_idx)
- : CustomMmaPipelined(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx)
- {
- }
- CUTLASS_DEVICE
- bool set_prologue_done(bool value)
- {
- // NOT IMPLEMENTED FOR PIPELINED
- }
- CUTLASS_DEVICE
- bool set_zero_outside_bounds(bool value)
- {
- // NOT NEEDED FOR PIPELINED
- // shared memory will always be zero-filled
- }
- template <bool kLoadA = true, bool kLoadB = true>
- CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
- ///< iterator over A operand in global memory
- IteratorA iterator_A,
- ///< iterator over B operand in global memory
- IteratorB iterator_B,
- int thread_idx,
- int problem_size_k)
- {
- prologue<kLoadA, kLoadB>(shared_storage.operand_A,
- shared_storage.operand_B,
- iterator_A,
- iterator_B,
- thread_idx,
- problem_size_k);
- }
- template <bool kLoadA = true, bool kLoadB = true>
- CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA,
- typename Base::SharedStorageB& shared_storageB,
- ///< iterator over A operand in global memory
- IteratorA iterator_A,
- ///< iterator over B operand in global memory
- IteratorB iterator_B,
- int thread_idx,
- int problem_size_k)
- {
- // NOT IMPLEMENTED FOR PIPELINED
- }
- /// Perform a threadblock-scoped matrix multiply-accumulate
- CUTLASS_DEVICE
- void operator()(
- int gemm_k_iterations, ///< number of iterations of the mainloop
- FragmentC& accum, ///< destination accumulator tile
- IteratorA iterator_A, ///< iterator over A operand in global memory
- IteratorB iterator_B, ///< iterator over B operand in global memory
- FragmentC const& src_accum, ///< source accumulator tile
- TransformA transform_A = TransformA(), ///< transformation applied to A fragment
- TransformB transform_B = TransformB())
- { ///< transformation applied to B fragment
- //
- // Prologue
- //
- // Perform accumulation in the 'd' output operand
- accum = src_accum;
- FragmentA tb_frag_A;
- FragmentB tb_frag_B;
- tb_frag_A.clear();
- tb_frag_B.clear();
- // The last kblock is loaded in the prolog
- iterator_A.load(tb_frag_A);
- iterator_B.load(tb_frag_B);
- ++iterator_A;
- ++iterator_B;
- this->smem_iterator_A_.store(transform_A(tb_frag_A));
- this->smem_iterator_B_.store(transform_B(tb_frag_B));
- ++this->smem_iterator_A_;
- ++this->smem_iterator_B_;
- __syncthreads();
- // Pair of fragments used to overlap shared memory loads and math
- // instructions
- WarpFragmentA warp_frag_A[2];
- WarpFragmentB warp_frag_B[2];
- this->warp_tile_iterator_A_.set_kgroup_index(0);
- this->warp_tile_iterator_B_.set_kgroup_index(0);
- this->warp_tile_iterator_A_.load(warp_frag_A[0]);
- this->warp_tile_iterator_B_.load(warp_frag_B[0]);
- ++this->warp_tile_iterator_A_;
- ++this->warp_tile_iterator_B_;
- Operator warp_mma;
- int smem_write_stage_idx = 1;
- // Avoid reading out of bounds
- iterator_A.clear_mask(gemm_k_iterations <= 1);
- iterator_B.clear_mask(gemm_k_iterations <= 1);
- // Issue loads during the first warp-level matrix multiply-add *AFTER*
- // issuing shared memory loads (which have the tightest latency requirement).
- //
- // Mainloop
- //
- // Note: The main loop does not support Base::kWarpGemmIterations == 2.
- CUTLASS_GEMM_LOOP
- for (; gemm_k_iterations > 0; --gemm_k_iterations) {
- //
- // Loop over GEMM K dimension
- //
- CUTLASS_PRAGMA_UNROLL
- for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
- // Load warp-level tiles from shared memory, wrapping to k offset if
- // this is the last group as the case may be.
- if (warp_mma_k == Base::kWarpGemmIterations - 1) {
- // Write fragments to shared memory
- this->smem_iterator_A_.store(transform_A(tb_frag_A));
- this->smem_iterator_B_.store(transform_B(tb_frag_B));
- __syncthreads();
- ++this->smem_iterator_A_;
- ++this->smem_iterator_B_;
- // Add negative offsets to return iterators to the 'start' of the
- // circular buffer in shared memory
- if (smem_write_stage_idx == 1) {
- this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
- this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
- } else {
- this->warp_tile_iterator_A_.add_tile_offset(
- {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
- this->warp_tile_iterator_B_.add_tile_offset(
- {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
- }
- smem_write_stage_idx ^= 1;
- }
- this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
- Base::kWarpGemmIterations);
- this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
- Base::kWarpGemmIterations);
- this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
- this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
- ++this->warp_tile_iterator_A_;
- ++this->warp_tile_iterator_B_;
- if (warp_mma_k == 0) {
- iterator_A.load(tb_frag_A);
- iterator_B.load(tb_frag_B);
- ++iterator_A;
- ++iterator_B;
- // Avoid reading out of bounds if this was the last loop iteration
- iterator_A.clear_mask(gemm_k_iterations <= 2);
- iterator_B.clear_mask(gemm_k_iterations <= 2);
- }
- warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
- }
- }
- }
- };
- /////////////////////////////////////////////////////////////////////////////////////////////////
- } // namespace threadblock
- } // namespace gemm
- } // namespace cutlass
- /////////////////////////////////////////////////////////////////////////////////////////////////
|