find_default_mma.h 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /***************************************************************************************************
  2. * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. * SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice, this
  9. * list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holdvr nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  22. * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  23. * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  24. * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  25. * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  26. * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  27. * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  28. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *
  30. **************************************************************************************************/
  31. // Copyright (c) Microsoft Corporation.
  32. // SPDX-License-Identifier: Apache-2.0
  33. // DeepSpeed Team
  34. /*! \file
  35. \brief Cutlass provides helper template functions to figure out the right
  36. datastructures to instantiate to run a GEMM with various parameters (see
  37. `cutlass/gemm/threadblock/default_mma.h`). However, due to template
  38. instantiation priority rules, it will only create an MmaMultiStage with
  39. kStages=3 (otherwise creates an MmePipelined - which is not compatible with
  40. FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
  41. so we just copy-pasted some code from `default_mma.h` and
  42. `default_mma_core.h` files and wrapped this template to allow our usecase.
  43. This is really only for the FastF32 case - aka using TensorCores with fp32.
  44. */
  45. #pragma once
  46. #include "cutlass/gemm/threadblock/default_mma.h"
  47. #include "cutlass/gemm/threadblock/default_mma_core_simt.h"
  48. #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
  49. #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
  50. #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
  51. namespace cutlass {
  52. namespace gemm {
  53. namespace threadblock {
  54. template <
  55. /// Element type for A matrix operand
  56. typename ElementA,
  57. /// Layout type for A matrix operand
  58. typename LayoutA,
  59. /// Access granularity of A matrix in units of elements
  60. int kAlignmentA,
  61. /// Element type for B matrix operand
  62. typename ElementB,
  63. /// Layout type for B matrix operand
  64. typename LayoutB,
  65. /// Access granularity of B matrix in units of elements
  66. int kAlignmentB,
  67. /// Element type for internal accumulation
  68. typename ElementAccumulator,
  69. /// Layout type for C and D matrix operand
  70. typename LayoutC,
  71. /// Operator class tag
  72. typename OperatorClass,
  73. /// Tag indicating architecture to tune for
  74. typename ArchTag,
  75. /// Threadblock-level tile size (concept: GemmShape)
  76. typename ThreadblockShape,
  77. /// Warp-level tile size (concept: GemmShape)
  78. typename WarpShape,
  79. /// Instruction-level tile size (concept: GemmShape)
  80. typename InstructionShape,
  81. /// Number of stages used in the pipelined mainloop
  82. int Stages,
  83. /// Operation performed by GEMM
  84. typename Operator,
  85. typename Enable_ = void>
  86. struct FindDefaultMma {
  87. static constexpr bool AccumulatorsInRowMajor = false;
  88. static constexpr SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone;
  89. using DefaultMma = cutlass::gemm::threadblock::DefaultMma<ElementA,
  90. LayoutA,
  91. kAlignmentA,
  92. ElementB,
  93. LayoutB,
  94. kAlignmentB,
  95. ElementAccumulator,
  96. LayoutC,
  97. OperatorClass,
  98. ArchTag,
  99. ThreadblockShape,
  100. WarpShape,
  101. InstructionShape,
  102. Stages,
  103. Operator,
  104. AccumulatorsInRowMajor,
  105. SharedMemoryClear>;
  106. };
  107. /// Specialization for sm80 / FastF32 / multistage with kStages=2
  108. template <typename ElementA_,
  109. /// Layout type for A matrix operand
  110. typename LayoutA_,
  111. /// Access granularity of A matrix in units of elements
  112. int kAlignmentA,
  113. typename ElementB_,
  114. /// Layout type for B matrix operand
  115. typename LayoutB_,
  116. /// Access granularity of B matrix in units of elements
  117. int kAlignmentB,
  118. typename ElementAccumulator,
  119. /// Threadblock-level tile size (concept: GemmShape)
  120. typename ThreadblockShape,
  121. /// Warp-level tile size (concept: GemmShape)
  122. typename WarpShape,
  123. /// Instruction-level tile size (concept: GemmShape)
  124. typename InstructionShape,
  125. int kStages,
  126. typename Operator>
  127. struct FindDefaultMma<ElementA_,
  128. LayoutA_,
  129. kAlignmentA,
  130. ElementB_,
  131. LayoutB_,
  132. kAlignmentB,
  133. ElementAccumulator,
  134. layout::RowMajor,
  135. arch::OpClassTensorOp,
  136. arch::Sm80,
  137. ThreadblockShape,
  138. WarpShape,
  139. InstructionShape,
  140. kStages,
  141. Operator,
  142. typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> {
  143. using LayoutC = layout::RowMajor;
  144. using OperatorClass = arch::OpClassTensorOp;
  145. using ArchTag = arch::Sm80;
  146. using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma<ElementA_,
  147. LayoutA_,
  148. kAlignmentA,
  149. ElementB_,
  150. LayoutB_,
  151. kAlignmentB,
  152. ElementAccumulator,
  153. LayoutC,
  154. OperatorClass,
  155. ArchTag,
  156. ThreadblockShape,
  157. WarpShape,
  158. InstructionShape,
  159. 3,
  160. Operator>;
  161. struct DefaultMma : DefaultMma_ {
  162. using MmaCore_ = typename DefaultMma_::MmaCore;
  163. // Define the threadblock-scoped multistage matrix multiply
  164. using ThreadblockMma =
  165. cutlass::gemm::threadblock::MmaMultistage<typename MmaCore_::Shape,
  166. typename DefaultMma_::IteratorA,
  167. typename MmaCore_::SmemIteratorA,
  168. MmaCore_::kCacheOpA,
  169. typename DefaultMma_::IteratorB,
  170. typename MmaCore_::SmemIteratorB,
  171. MmaCore_::kCacheOpB,
  172. ElementAccumulator,
  173. LayoutC,
  174. typename MmaCore_::MmaPolicy,
  175. kStages>;
  176. };
  177. };
  178. } // namespace threadblock
  179. } // namespace gemm
  180. } // namespace cutlass