attention.cu 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <torch/extension.h>
  6. #include "gemm_kernel_utils.h"
  7. #include "kernel_forward.h"
  8. #include "transform/bias_broadcast.h"
  9. template <typename arch,
  10. typename scalar_t,
  11. typename torch_scalar_t,
  12. template <typename, typename, typename>
  13. class Broadcast1_,
  14. template <typename, typename, typename>
  15. class Broadcast2_>
  16. typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_impl_template(
  17. torch::Tensor& q,
  18. torch::Tensor& k,
  19. torch::Tensor& v,
  20. torch::Tensor& bias1,
  21. torch::Tensor& bias2,
  22. torch::Tensor& o,
  23. float* lse_ptr)
  24. {
  25. EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
  26. }
  27. template <typename arch,
  28. typename scalar_t,
  29. typename torch_scalar_t,
  30. template <typename, typename, typename>
  31. class Broadcast1_,
  32. template <typename, typename, typename>
  33. class Broadcast2_>
  34. typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_impl_template(
  35. torch::Tensor& q,
  36. torch::Tensor& k,
  37. torch::Tensor& v,
  38. torch::Tensor& bias1,
  39. torch::Tensor& bias2,
  40. torch::Tensor& o,
  41. float* lse_ptr)
  42. {
  43. // Attention definition goes here, replaced with BroadcastType1 and
  44. // BroadcastType2
  45. using Attention = AttentionKernel<scalar_t, /* scalar_t */
  46. arch, /* ArchTag */
  47. true, /* Memory is aligned */
  48. 64,
  49. 64,
  50. true,
  51. true, /* Supports bias */
  52. Broadcast1_,
  53. Broadcast2_>;
  54. static_assert(!Attention::kNeedsOutputAccumulatorBuffer,
  55. "This test does not support output accumulator buffer");
  56. int head_size = q.size(-1);
  57. int head_number = q.size(-2);
  58. int seq_length = q.size(-3);
  59. auto q_view = q.view({-1, seq_length, head_number, head_size});
  60. auto k_view = k.view({-1, seq_length, head_number, head_size});
  61. auto v_view = v.view({-1, seq_length, head_number, head_size});
  62. auto o_view = o.view({-1, seq_length, head_number, head_size});
  63. int batch_size = q_view.size(0);
  64. auto q_ptr = reinterpret_cast<scalar_t*>(q.data_ptr<torch_scalar_t>());
  65. auto k_ptr = reinterpret_cast<scalar_t*>(k.data_ptr<torch_scalar_t>());
  66. auto v_ptr = reinterpret_cast<scalar_t*>(v.data_ptr<torch_scalar_t>());
  67. auto o_ptr = reinterpret_cast<scalar_t*>(o.data_ptr<torch_scalar_t>());
  68. auto bias1_ptr = reinterpret_cast<scalar_t*>(bias1.data_ptr<torch_scalar_t>());
  69. auto bias2_ptr = reinterpret_cast<scalar_t*>(bias2.data_ptr<torch_scalar_t>());
  70. typename Attention::Params p;
  71. { // set parameters
  72. p.query_ptr = q_ptr;
  73. p.key_ptr = k_ptr;
  74. p.value_ptr = v_ptr;
  75. p.logsumexp_ptr = lse_ptr; // Only needed for bw
  76. p.output_accum_ptr = nullptr;
  77. p.output_ptr = o_ptr;
  78. p.scale = 1.0f / sqrt(float(head_size));
  79. p.bias1_ptr = bias1_ptr;
  80. p.bias2_ptr = bias2_ptr;
  81. p.B = q.size(0);
  82. p.N = q.size(1);
  83. p.num_heads = head_number;
  84. p.num_batches = batch_size;
  85. p.head_dim = head_size;
  86. p.head_dim_value = head_size;
  87. p.num_queries = seq_length;
  88. p.num_keys = seq_length;
  89. // All tensors are in BMHK shapes
  90. p.q_strideH = q_view.stride(-2);
  91. p.k_strideH = k_view.stride(-2);
  92. p.v_strideH = v_view.stride(-2);
  93. p.q_strideM = q_view.stride(-3);
  94. p.k_strideM = k_view.stride(-3);
  95. p.v_strideM = v_view.stride(-3);
  96. p.o_strideM = o_view.stride(-3);
  97. p.q_strideB = q_view.stride(-4);
  98. p.k_strideB = k_view.stride(-4);
  99. p.v_strideB = v_view.stride(-4);
  100. }
  101. constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
  102. int smem_bytes = sizeof(typename Attention::SharedStorage);
  103. if (smem_bytes > 0xc000) {
  104. cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
  105. }
  106. if (!Attention::check_supported(p)) { throw std::runtime_error("Parameters not supported"); }
  107. kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
  108. }
  109. #define CODE(scalar_t, torch_scalar_t) \
  110. do { \
  111. if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
  112. attention_impl_template<ArchTag, \
  113. scalar_t, \
  114. torch_scalar_t, \
  115. BroadcastNoLoad, \
  116. BroadcastNoLoad>(q, k, v, bias1, bias2, o, lse_ptr); \
  117. } else if (bias1.size(0) == 0) { \
  118. attention_impl_template<ArchTag, \
  119. scalar_t, \
  120. torch_scalar_t, \
  121. BroadcastNoLoad, \
  122. BroadcastB>(q, k, v, bias1, bias2, o, lse_ptr); \
  123. } else if (bias2.size(0) == 0) { \
  124. attention_impl_template<ArchTag, \
  125. scalar_t, \
  126. torch_scalar_t, \
  127. BroadcastA, \
  128. BroadcastNoLoad>(q, k, v, bias1, bias2, o, lse_ptr); \
  129. } else { \
  130. attention_impl_template<ArchTag, scalar_t, torch_scalar_t, BroadcastA, BroadcastB>( \
  131. q, k, v, bias1, bias2, o, lse_ptr); \
  132. } \
  133. } while (0)
  134. // Function to select and call the correct template based on biases sizes
  135. void attention_impl(torch::Tensor& q,
  136. torch::Tensor& k,
  137. torch::Tensor& v,
  138. torch::Tensor& bias1,
  139. torch::Tensor& bias2,
  140. torch::Tensor& o,
  141. torch::Tensor& lse)
  142. {
  143. auto lse_ptr = lse.size(0) == 0 ? nullptr : reinterpret_cast<float*>(lse.data_ptr<float>());
  144. cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
  145. DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
  146. DISPATCH_TYPES(q, { CODE(scalar_t, torch_scalar_t); }));
  147. }