custom_mma.h 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. /***************************************************************************************************
  2. * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3. * SPDX-License-Identifier: BSD-3-Clause
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice, this
  9. * list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * 3. Neither the name of the copyright holdvr nor the names of its
  16. * contributors may be used to endorse or promote products derived from
  17. * this software without specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  22. * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  23. * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  24. * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  25. * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  26. * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  27. * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  28. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *
  30. **************************************************************************************************/
  31. // Copyright (c) Microsoft Corporation.
  32. // SPDX-License-Identifier: Apache-2.0
  33. // DeepSpeed Team
  34. #pragma once
  35. #include "custom_mma_multistage.h"
  36. #include "custom_mma_pipelined.h"
  37. #include "cutlass/gemm/threadblock/mma_multistage.h"
  38. #include "cutlass/gemm/threadblock/mma_pipelined.h"
  39. template <typename Mma, int kMaxK>
  40. struct MakeCustomMma;
  41. template <typename Shape,
  42. typename IteratorA,
  43. typename SmemIteratorA,
  44. cutlass::arch::CacheOperation::Kind CacheOpA,
  45. typename IteratorB,
  46. typename SmemIteratorB,
  47. cutlass::arch::CacheOperation::Kind CacheOpB,
  48. typename ElementC,
  49. typename LayoutC,
  50. typename Policy,
  51. int Stages,
  52. cutlass::gemm::SharedMemoryClearOption SharedMemoryClear,
  53. int kMaxK>
  54. struct MakeCustomMma<cutlass::gemm::threadblock::MmaMultistage<Shape,
  55. IteratorA,
  56. SmemIteratorA,
  57. CacheOpA,
  58. IteratorB,
  59. SmemIteratorB,
  60. CacheOpB,
  61. ElementC,
  62. LayoutC,
  63. Policy,
  64. Stages,
  65. SharedMemoryClear>,
  66. kMaxK> {
  67. // Reduce the number of stages if we don't need that many
  68. static int constexpr kStages =
  69. kMaxK == cutlass::platform::numeric_limits<int>::max()
  70. ? Stages
  71. : cutlass::const_min(Stages, (kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
  72. using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<Shape,
  73. IteratorA,
  74. SmemIteratorA,
  75. CacheOpA,
  76. IteratorB,
  77. SmemIteratorB,
  78. CacheOpB,
  79. ElementC,
  80. LayoutC,
  81. Policy,
  82. kStages,
  83. SharedMemoryClear,
  84. kMaxK>;
  85. };
  86. template <typename Shape,
  87. typename IteratorA,
  88. typename SmemIteratorA,
  89. typename IteratorB,
  90. typename SmemIteratorB,
  91. typename ElementC,
  92. typename LayoutC,
  93. typename Policy,
  94. int kMaxK>
  95. struct MakeCustomMma<cutlass::gemm::threadblock::MmaPipelined<Shape,
  96. IteratorA,
  97. SmemIteratorA,
  98. IteratorB,
  99. SmemIteratorB,
  100. ElementC,
  101. LayoutC,
  102. Policy>,
  103. kMaxK> {
  104. using Mma = cutlass::gemm::threadblock::CustomMmaPipelined<Shape,
  105. IteratorA,
  106. SmemIteratorA,
  107. IteratorB,
  108. SmemIteratorB,
  109. ElementC,
  110. LayoutC,
  111. Policy>;
  112. };