bias_broadcast.h 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. // This does nothing.
  6. template <typename ThreadMap, typename Shape, typename scalar_t>
  7. struct BroadcastNoLoad {
  8. using Fragment =
  9. cutlass::Array<scalar_t, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
  10. static const bool kEnable = false;
  11. CUTLASS_DEVICE static void load(Fragment& frag,
  12. scalar_t* ptr,
  13. int thread_id,
  14. const cutlass::MatrixCoord& extent,
  15. int stride)
  16. {
  17. }
  18. CUTLASS_DEVICE static scalar_t*
  19. advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH)
  20. {
  21. return ptr;
  22. }
  23. };
  24. // This is to load the bias matrix from the global memory with on-the-fly
  25. // broadcast. The shape in global memory is [B, N, 1, 1, L]. Each time we load
  26. // the last dimension as a L row vector, and we further broadcast the L vector
  27. // to a tile of size [L, L] by repeating the L vector L times
  28. template <typename ThreadMap, typename Shape, typename scalar_t>
  29. struct BroadcastA : public BroadcastNoLoad<ThreadMap, Shape, scalar_t> {
  30. using Base = BroadcastNoLoad<ThreadMap, Shape, scalar_t>;
  31. static const bool kEnable = true;
  32. using layout = cutlass::layout::AffineRank2RowMajor;
  33. using GmemTileIterator = cutlass::transform::threadblock::
  34. PredicatedTileIterator<Shape, scalar_t, layout, 0, ThreadMap>;
  35. using Fragment = typename GmemTileIterator::Fragment;
  36. CUTLASS_DEVICE static void load(Fragment& frag,
  37. scalar_t* ptr,
  38. int thread_id,
  39. const cutlass::MatrixCoord& extent,
  40. int stride)
  41. {
  42. GmemTileIterator iter({layout(0, 1)}, ptr, extent, thread_id);
  43. iter.load(frag);
  44. }
  45. CUTLASS_DEVICE static scalar_t*
  46. advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH)
  47. {
  48. return ptr + B_id * strideB + N_id * strideN;
  49. }
  50. };
  51. // This is to load the bias matrix from the global memory with on-the-fly
  52. // broadcast. The shape in global memory is [B, 1, H, L, L]. Each time we load
  53. // a [L, L] matrix. Different N use the same bias matrix when B and H are the
  54. // same.
  55. template <typename ThreadMap, typename Shape, typename scalar_t>
  56. struct BroadcastB : public BroadcastNoLoad<ThreadMap, Shape, scalar_t> {
  57. using Base = BroadcastNoLoad<ThreadMap, Shape, scalar_t>;
  58. static const bool kEnable = true;
  59. using layout = cutlass::layout::RowMajor;
  60. using GmemTileIterator = cutlass::transform::threadblock::
  61. PredicatedTileIterator<Shape, scalar_t, layout, 0, ThreadMap>;
  62. using Fragment = typename GmemTileIterator::Fragment;
  63. CUTLASS_DEVICE static void load(Fragment& frag,
  64. scalar_t* ptr,
  65. int thread_id,
  66. const cutlass::MatrixCoord& extent,
  67. int stride)
  68. {
  69. GmemTileIterator iter({layout(stride)}, ptr, extent, thread_id);
  70. iter.load(frag);
  71. }
  72. CUTLASS_DEVICE static scalar_t*
  73. advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH)
  74. {
  75. return ptr + B_id * strideB + H_id * strideH;
  76. }
  77. };
  78. template <typename Shape,
  79. typename scalar_t,
  80. int kThreads,
  81. template <typename, typename, typename>
  82. class Broadcast1_,
  83. template <typename, typename, typename>
  84. class Broadcast2_>
  85. struct AttentionBiasEpilogue {
  86. using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
  87. cutlass::layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
  88. kThreads,
  89. 1>;
  90. using Broadcast1 = Broadcast1_<ThreadMap, Shape, scalar_t>;
  91. using Broadcast2 = Broadcast2_<ThreadMap, Shape, scalar_t>;
  92. Broadcast1 broadcast1;
  93. Broadcast2 broadcast2;
  94. using Ref = cutlass::TensorRef<float, cutlass::layout::RowMajor>;
  95. using SmemTileIterator = cutlass::transform::threadblock::
  96. RegularTileIterator<Shape, float, cutlass::layout::RowMajor, 0, ThreadMap>;
  97. CUTLASS_DEVICE void operator()(const Ref& ref,
  98. scalar_t* ptr1,
  99. scalar_t* ptr2,
  100. int thread_id,
  101. const cutlass::MatrixCoord& extent,
  102. int stride)
  103. {
  104. static_assert(Broadcast1::Fragment::kElements == Broadcast2::Fragment::kElements,
  105. "The two broadcast fragments must have the same number of "
  106. "elements");
  107. typename SmemTileIterator::Fragment frag;
  108. frag.clear();
  109. float* frag_ptr = reinterpret_cast<float*>(&frag);
  110. if (Broadcast1::kEnable) {
  111. typename Broadcast1::Fragment frag1;
  112. frag1.clear();
  113. broadcast1.load(frag1, ptr1, thread_id, extent, stride);
  114. scalar_t* frag1_ptr = reinterpret_cast<scalar_t*>(&frag1);
  115. for (int i = 0; i < Broadcast1::Fragment::kElements; ++i) {
  116. frag_ptr[i] += static_cast<float>(frag1_ptr[i]);
  117. }
  118. }
  119. if (Broadcast2::kEnable) {
  120. typename Broadcast2::Fragment frag2;
  121. frag2.clear();
  122. broadcast2.load(frag2, ptr2, thread_id, extent, stride);
  123. scalar_t* frag2_ptr = reinterpret_cast<scalar_t*>(&frag2);
  124. for (int i = 0; i < Broadcast2::Fragment::kElements; ++i) {
  125. frag_ptr[i] += static_cast<float>(frag2_ptr[i]);
  126. }
  127. }
  128. SmemTileIterator iter(ref, thread_id);
  129. iter.store(frag);
  130. __syncthreads();
  131. }
  132. };