gemm_kernel_utils.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /***************************************************************************************************
  2. * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. * SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice, this
  9. * list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holdvr nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  22. * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  23. * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  24. * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  25. * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  26. * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  27. * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  28. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *
  30. **************************************************************************************************/
  31. // Copyright (c) Microsoft Corporation.
  32. // SPDX-License-Identifier: Apache-2.0
  33. // DeepSpeed Team
  34. #pragma once
  35. #include <type_traits>
  36. #include "cutlass/arch/mma.h"
  37. template <typename arch, typename scalar_t>
  38. struct CheckArch {
  39. static constexpr bool isPreVolta = arch::kMinComputeCapability < 70;
  40. static constexpr bool isPreAmpere =
  41. arch::kMinComputeCapability < 80 && arch::kMinComputeCapability >= 70;
  42. static constexpr bool isAmpere = arch::kMinComputeCapability >= 80;
  43. #if defined(__CUDA_ARCH__)
  44. static constexpr bool compiler_cc = arch::kMinComputeCapability * 10 <= __CUDA_ARCH__;
  45. #else
  46. static constexpr bool compiler_cc = true;
  47. #endif
  48. static constexpr bool value = (isPreVolta && std::is_same_v<scalar_t, float>) ||
  49. (isPreAmpere && !std::is_same_v<scalar_t, cutlass::bfloat16_t>) ||
  50. isAmpere && compiler_cc;
  51. };
  52. #define DISPATCH_ARCHTAG(CC, func) \
  53. { \
  54. if constexpr (GPU_ARCH >= 80) { \
  55. if (CC >= 80) { \
  56. using ArchTag = cutlass::arch::Sm80; \
  57. func; \
  58. } else { \
  59. EVOFORMER_CHECK(false, "Compile flag error. Unexpected GPU"); \
  60. } \
  61. } else if constexpr (GPU_ARCH >= 75) { \
  62. if (CC >= 75) { \
  63. using ArchTag = cutlass::arch::Sm75; \
  64. func; \
  65. } else { \
  66. EVOFORMER_CHECK(false, "Compile flag error. Unexpected GPU"); \
  67. } \
  68. } else if constexpr (GPU_ARCH >= 70) { \
  69. if (CC >= 70) { \
  70. using ArchTag = cutlass::arch::Sm70; \
  71. func; \
  72. } else { \
  73. EVOFORMER_CHECK(false, "Compile flag error. Unexpected GPU"); \
  74. } \
  75. } else { \
  76. EVOFORMER_CHECK(false, "Only GPUs with Tensor Core are supported for now"); \
  77. } \
  78. }
  79. #define DISPATCH_TYPES(tensor, func) \
  80. { \
  81. if (tensor.scalar_type() == at::ScalarType::Half) { \
  82. using scalar_t = cutlass::half_t; \
  83. using torch_scalar_t = at::Half; \
  84. func; \
  85. } else if (tensor.scalar_type() == at::ScalarType::BFloat16) { \
  86. using scalar_t = cutlass::bfloat16_t; \
  87. using torch_scalar_t = at::BFloat16; \
  88. func; \
  89. } else { \
  90. EVOFORMER_CHECK(false, "Only fp16 and bf16 supported at the moment"); \
  91. } \
  92. }
  93. #define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
  94. { \
  95. if (BOOL_V) { \
  96. constexpr bool BOOL_NAME = true; \
  97. F(); \
  98. } else { \
  99. constexpr bool BOOL_NAME = false; \
  100. F(); \
  101. } \
  102. }
  103. #ifdef TORCH_CHECK
  104. #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
  105. EVOFORMER_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
  106. #define EVOFORMER_CHECK TORCH_CHECK
  107. #elif defined(__CUDACC_RTC__)
  108. #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
  109. if (!(uint64_t(PTR) % ALIGNMENT == 0)) { return false; }
  110. #define EVOFORMER_CHECK(COND, ERR) \
  111. if (!(COND)) { return false; }
  112. #else
  113. #include <iostream>
  114. #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
  115. if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
  116. std::cerr << #PTR " is not correctly aligned\n"; \
  117. return false; \
  118. }
  119. #define EVOFORMER_CHECK(COND, ERR) \
  120. if (!(COND)) { \
  121. std::cerr << "[Evoformer Attention]" \
  122. << "'" #COND "' failed: " << ERR << "\n"; \
  123. return false; \
  124. }
  125. #endif
  126. namespace gemm_kernel_utils {
  127. template <typename integer>
  128. constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m)
  129. {
  130. return (n + m - 1) / m;
  131. }
  132. template <typename integer>
  133. constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m)
  134. {
  135. return ((n + m - 1) / m) * m;
  136. }
  137. ////////////////////////////////////////////////////////////////////////////////
  138. // Determine the type of GEMM we do (TensorCores or not, Shapes ...)
  139. // TODO: Maybe we could rely on Cutlass's DefaultGemm templates
  140. ////////////////////////////////////////////////////////////////////////////////
  141. // Fallback to Simt (FMA on cuda cores) if not in a special case below
  142. template <typename ArchTag, typename scalar_t_, typename Enable = void>
  143. struct DefaultGemmType {
  144. static constexpr int ThreadK = 8;
  145. static constexpr int WarpK = 8;
  146. static constexpr int kMinimumAlignment = 1;
  147. using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
  148. using OpClass = cutlass::arch::OpClassSimt;
  149. using Operator = cutlass::arch::OpMultiplyAdd;
  150. };
  151. // Specialization for tensorcores with f32
  152. template <typename ArchTag>
  153. struct DefaultGemmType<
  154. ArchTag,
  155. float,
  156. typename cutlass::platform::enable_if<ArchTag::kMinComputeCapability >= 80>::type> {
  157. static constexpr int ThreadK = 32;
  158. static constexpr int WarpK = 32;
  159. static constexpr int kMinimumAlignment = 4;
  160. using OpClass = cutlass::arch::OpClassTensorOp;
  161. using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
  162. using Operator = cutlass::arch::OpMultiplyAddFastF32;
  163. };
  164. // Specialization for tensorcores with f16/bf16 - Sm75+
  165. template <typename ArchTag, typename scalar_t>
  166. struct DefaultGemmType<
  167. ArchTag,
  168. scalar_t,
  169. typename cutlass::platform::enable_if<ArchTag::kMinComputeCapability >= 75 &&
  170. cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
  171. static constexpr int ThreadK = 32;
  172. static constexpr int WarpK = 32;
  173. static constexpr int kMinimumAlignment = 4;
  174. using OpClass = cutlass::arch::OpClassTensorOp;
  175. using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
  176. using Operator = cutlass::arch::OpMultiplyAdd;
  177. };
  178. // Specialization for tensorcores with f16 - Volta
  179. template <>
  180. struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
  181. static constexpr int ThreadK = 32;
  182. static constexpr int WarpK = 32;
  183. static constexpr int kMinimumAlignment = 2;
  184. using OpClass = cutlass::arch::OpClassTensorOp;
  185. using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
  186. using Operator = cutlass::arch::OpMultiplyAdd;
  187. };
  188. // Enables to do
  189. // `auto x = kCondition ? fa(arg) : fb(arg)`
  190. // when `fa` and `fb` have different types
  191. template <bool kVal, typename TA, typename TB>
  192. struct call_conditional;
  193. template <typename TA, typename TB>
  194. struct call_conditional<true, TA, TB> {
  195. template <typename Arg>
  196. static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(ta(arg))
  197. {
  198. return ta(arg);
  199. }
  200. };
  201. template <typename TA, typename TB>
  202. struct call_conditional<false, TA, TB> {
  203. template <typename Arg>
  204. static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(tb(arg))
  205. {
  206. return tb(arg);
  207. }
  208. };
  209. ////////////////////////////////////////////////////////////////////////////////
  210. // Mark a variable as warp-uniform - enables some compiler optimizations
  211. // The cheapest way to do it is just to broadcast it from lane 0
  212. ////////////////////////////////////////////////////////////////////////////////
  213. CUTLASS_DEVICE int32_t warp_uniform(int32_t value)
  214. {
  215. return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
  216. }
  217. template <typename T>
  218. CUTLASS_DEVICE T* warp_uniform(T* ptr)
  219. {
  220. struct {
  221. union {
  222. T* ptr;
  223. uint32_t asInt[2];
  224. };
  225. } p;
  226. p.ptr = ptr;
  227. p.asInt[0] = warp_uniform(p.asInt[0]);
  228. p.asInt[1] = warp_uniform(p.asInt[1]);
  229. return p.ptr;
  230. }
  231. } // namespace gemm_kernel_utils