123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- /***************************************************************************************************
- * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
- * SPDX-License-Identifier: BSD-3-Clause
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- *
- * 1. Redistributions of source code must retain the above copyright notice, this
- * list of conditions and the following disclaimer.
- *
- * 2. Redistributions in binary form must reproduce the above copyright notice,
- * this list of conditions and the following disclaimer in the documentation
- * and/or other materials provided with the distribution.
- *
- * 3. Neither the name of the copyright holdvr nor the names of its
- * contributors may be used to endorse or promote products derived from
- * this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- **************************************************************************************************/
- // Copyright (c) Microsoft Corporation.
- // SPDX-License-Identifier: Apache-2.0
- // DeepSpeed Team
- #pragma once
- #include <type_traits>
- #include "cutlass/arch/mma.h"
- template <typename arch, typename scalar_t>
- struct CheckArch {
- static constexpr bool isPreVolta = arch::kMinComputeCapability < 70;
- static constexpr bool isPreAmpere =
- arch::kMinComputeCapability < 80 && arch::kMinComputeCapability >= 70;
- static constexpr bool isAmpere = arch::kMinComputeCapability >= 80;
- #if defined(__CUDA_ARCH__)
- static constexpr bool compiler_cc = arch::kMinComputeCapability * 10 <= __CUDA_ARCH__;
- #else
- static constexpr bool compiler_cc = true;
- #endif
- static constexpr bool value = (isPreVolta && std::is_same_v<scalar_t, float>) ||
- (isPreAmpere && !std::is_same_v<scalar_t, cutlass::bfloat16_t>) ||
- isAmpere && compiler_cc;
- };
- #define DISPATCH_ARCHTAG(CC, func) \
- { \
- if constexpr (GPU_ARCH >= 80) { \
- if (CC >= 80) { \
- using ArchTag = cutlass::arch::Sm80; \
- func; \
- } else { \
- EVOFORMER_CHECK(false, "Compile flag error. Unexpected GPU"); \
- } \
- } else if constexpr (GPU_ARCH >= 75) { \
- if (CC >= 75) { \
- using ArchTag = cutlass::arch::Sm75; \
- func; \
- } else { \
- EVOFORMER_CHECK(false, "Compile flag error. Unexpected GPU"); \
- } \
- } else if constexpr (GPU_ARCH >= 70) { \
- if (CC >= 70) { \
- using ArchTag = cutlass::arch::Sm70; \
- func; \
- } else { \
- EVOFORMER_CHECK(false, "Compile flag error. Unexpected GPU"); \
- } \
- } else { \
- EVOFORMER_CHECK(false, "Only GPUs with Tensor Core are supported for now"); \
- } \
- }
- #define DISPATCH_TYPES(tensor, func) \
- { \
- if (tensor.scalar_type() == at::ScalarType::Half) { \
- using scalar_t = cutlass::half_t; \
- using torch_scalar_t = at::Half; \
- func; \
- } else if (tensor.scalar_type() == at::ScalarType::BFloat16) { \
- using scalar_t = cutlass::bfloat16_t; \
- using torch_scalar_t = at::BFloat16; \
- func; \
- } else { \
- EVOFORMER_CHECK(false, "Only fp16 and bf16 supported at the moment"); \
- } \
- }
- #define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
- { \
- if (BOOL_V) { \
- constexpr bool BOOL_NAME = true; \
- F(); \
- } else { \
- constexpr bool BOOL_NAME = false; \
- F(); \
- } \
- }
- #ifdef TORCH_CHECK
- #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
- EVOFORMER_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
- #define EVOFORMER_CHECK TORCH_CHECK
- #elif defined(__CUDACC_RTC__)
- #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
- if (!(uint64_t(PTR) % ALIGNMENT == 0)) { return false; }
- #define EVOFORMER_CHECK(COND, ERR) \
- if (!(COND)) { return false; }
- #else
- #include <iostream>
- #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
- if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
- std::cerr << #PTR " is not correctly aligned\n"; \
- return false; \
- }
- #define EVOFORMER_CHECK(COND, ERR) \
- if (!(COND)) { \
- std::cerr << "[Evoformer Attention]" \
- << "'" #COND "' failed: " << ERR << "\n"; \
- return false; \
- }
- #endif
- namespace gemm_kernel_utils {
- template <typename integer>
- constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m)
- {
- return (n + m - 1) / m;
- }
- template <typename integer>
- constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m)
- {
- return ((n + m - 1) / m) * m;
- }
- ////////////////////////////////////////////////////////////////////////////////
- // Determine the type of GEMM we do (TensorCores or not, Shapes ...)
- // TODO: Maybe we could rely on Cutlass's DefaultGemm templates
- ////////////////////////////////////////////////////////////////////////////////
- // Fallback to Simt (FMA on cuda cores) if not in a special case below
- template <typename ArchTag, typename scalar_t_, typename Enable = void>
- struct DefaultGemmType {
- static constexpr int ThreadK = 8;
- static constexpr int WarpK = 8;
- static constexpr int kMinimumAlignment = 1;
- using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
- using OpClass = cutlass::arch::OpClassSimt;
- using Operator = cutlass::arch::OpMultiplyAdd;
- };
- // Specialization for tensorcores with f32
- template <typename ArchTag>
- struct DefaultGemmType<
- ArchTag,
- float,
- typename cutlass::platform::enable_if<ArchTag::kMinComputeCapability >= 80>::type> {
- static constexpr int ThreadK = 32;
- static constexpr int WarpK = 32;
- static constexpr int kMinimumAlignment = 4;
- using OpClass = cutlass::arch::OpClassTensorOp;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
- using Operator = cutlass::arch::OpMultiplyAddFastF32;
- };
- // Specialization for tensorcores with f16/bf16 - Sm75+
- template <typename ArchTag, typename scalar_t>
- struct DefaultGemmType<
- ArchTag,
- scalar_t,
- typename cutlass::platform::enable_if<ArchTag::kMinComputeCapability >= 75 &&
- cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
- static constexpr int ThreadK = 32;
- static constexpr int WarpK = 32;
- static constexpr int kMinimumAlignment = 4;
- using OpClass = cutlass::arch::OpClassTensorOp;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
- using Operator = cutlass::arch::OpMultiplyAdd;
- };
- // Specialization for tensorcores with f16 - Volta
- template <>
- struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
- static constexpr int ThreadK = 32;
- static constexpr int WarpK = 32;
- static constexpr int kMinimumAlignment = 2;
- using OpClass = cutlass::arch::OpClassTensorOp;
- using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
- using Operator = cutlass::arch::OpMultiplyAdd;
- };
- // Enables to do
- // `auto x = kCondition ? fa(arg) : fb(arg)`
- // when `fa` and `fb` have different types
- template <bool kVal, typename TA, typename TB>
- struct call_conditional;
- template <typename TA, typename TB>
- struct call_conditional<true, TA, TB> {
- template <typename Arg>
- static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(ta(arg))
- {
- return ta(arg);
- }
- };
- template <typename TA, typename TB>
- struct call_conditional<false, TA, TB> {
- template <typename Arg>
- static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(tb(arg))
- {
- return tb(arg);
- }
- };
- ////////////////////////////////////////////////////////////////////////////////
- // Mark a variable as warp-uniform - enables some compiler optimizations
- // The cheapest way to do it is just to broadcast it from lane 0
- ////////////////////////////////////////////////////////////////////////////////
- CUTLASS_DEVICE int32_t warp_uniform(int32_t value)
- {
- return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
- }
- template <typename T>
- CUTLASS_DEVICE T* warp_uniform(T* ptr)
- {
- struct {
- union {
- T* ptr;
- uint32_t asInt[2];
- };
- } p;
- p.ptr = ptr;
- p.asInt[0] = warp_uniform(p.asInt[0]);
- p.asInt[1] = warp_uniform(p.asInt[1]);
- return p.ptr;
- }
- } // namespace gemm_kernel_utils
|