attention_back.cu 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <torch/extension.h>
  6. #include <type_traits>
  7. #include "gemm_kernel_utils.h"
  8. #include "kernel_backward.h"
  9. #include "transform/bias_broadcast.h"
  10. constexpr auto kBlockSizeI = 64;
  11. constexpr auto kBlockSizeJ = 64;
  12. template <typename arch,
  13. typename scalar_t,
  14. typename torch_scalar_t,
  15. template <typename, typename, typename>
  16. class Broadcast1_,
  17. template <typename, typename, typename>
  18. class Broadcast2_>
  19. typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
  20. torch::Tensor& go,
  21. torch::Tensor& q,
  22. torch::Tensor& k,
  23. torch::Tensor& v,
  24. torch::Tensor& o,
  25. torch::Tensor& lse,
  26. torch::Tensor& delta,
  27. torch::Tensor& bias1,
  28. torch::Tensor& bias2,
  29. torch::Tensor& gq,
  30. torch::Tensor& gk,
  31. torch::Tensor& gv,
  32. torch::Tensor& gb1,
  33. torch::Tensor& gb2)
  34. {
  35. EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
  36. }
  37. template <typename arch,
  38. typename scalar_t,
  39. typename torch_scalar_t,
  40. template <typename, typename, typename>
  41. class Broadcast1_,
  42. template <typename, typename, typename>
  43. class Broadcast2_>
  44. typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
  45. torch::Tensor& go,
  46. torch::Tensor& q,
  47. torch::Tensor& k,
  48. torch::Tensor& v,
  49. torch::Tensor& o,
  50. torch::Tensor& lse,
  51. torch::Tensor& delta,
  52. torch::Tensor& bias1,
  53. torch::Tensor& bias2,
  54. torch::Tensor& gq,
  55. torch::Tensor& gk,
  56. torch::Tensor& gv,
  57. torch::Tensor& gb1,
  58. torch::Tensor& gb2)
  59. {
  60. constexpr bool kPreload_ = arch::kMinComputeCapability >= 80;
  61. using Kernel = AttentionBackwardKernel<arch,
  62. scalar_t, // scalar_t
  63. true, // kIsAligned_
  64. false, // kApplyDropout_
  65. kPreload_, // kPreload_
  66. kBlockSizeI, // kBlockSizeI_,
  67. kBlockSizeJ, // kBlockSizeJ_,
  68. 64, // kMaxK
  69. Broadcast1_,
  70. Broadcast2_>;
  71. int head_size = q.size(-1);
  72. int head_number = q.size(-2);
  73. int seq_length = q.size(-3);
  74. auto q_view = q.view({-1, seq_length, head_number, head_size});
  75. auto k_view = k.view({-1, seq_length, head_number, head_size});
  76. auto v_view = v.view({-1, seq_length, head_number, head_size});
  77. auto o_view = o.view({-1, seq_length, head_number, head_size});
  78. auto do_view = go.view({-1, seq_length, head_number, head_size});
  79. auto dk_view = gk.view({-1, seq_length, head_number, head_size});
  80. auto dv_view = gv.view({-1, seq_length, head_number, head_size});
  81. auto dq_view = gq.view({-1, seq_length, head_number, head_size});
  82. auto q_ptr = reinterpret_cast<scalar_t*>(q.data_ptr<torch_scalar_t>());
  83. auto k_ptr = reinterpret_cast<scalar_t*>(k.data_ptr<torch_scalar_t>());
  84. auto v_ptr = reinterpret_cast<scalar_t*>(v.data_ptr<torch_scalar_t>());
  85. auto o_ptr = reinterpret_cast<scalar_t*>(o.data_ptr<torch_scalar_t>());
  86. auto do_ptr = reinterpret_cast<scalar_t*>(go.data_ptr<torch_scalar_t>());
  87. auto dk_ptr = reinterpret_cast<scalar_t*>(gk.data_ptr<torch_scalar_t>());
  88. auto dv_ptr = reinterpret_cast<scalar_t*>(gv.data_ptr<torch_scalar_t>());
  89. auto dq_ptr = reinterpret_cast<scalar_t*>(gq.data_ptr<torch_scalar_t>());
  90. auto db1_ptr = gb1.size(0) > 0 ? reinterpret_cast<float*>(gb1.data_ptr<float>()) : nullptr;
  91. auto db2_ptr = gb2.size(0) > 0 ? reinterpret_cast<float*>(gb2.data_ptr<float>()) : nullptr;
  92. auto lse_ptr = reinterpret_cast<float*>(lse.data_ptr<float>());
  93. auto delta_ptr = reinterpret_cast<float*>(delta.data_ptr<float>());
  94. auto bias1_ptr = reinterpret_cast<scalar_t*>(bias1.data_ptr<torch_scalar_t>());
  95. auto bias2_ptr = reinterpret_cast<scalar_t*>(bias2.data_ptr<torch_scalar_t>());
  96. static_assert(Kernel::kKernelComputesDelta, "Kernel must compute delta");
  97. typename Kernel::Params p;
  98. p.query_ptr = q_ptr;
  99. p.key_ptr = k_ptr;
  100. p.value_ptr = v_ptr;
  101. p.logsumexp_ptr = lse_ptr;
  102. p.output_ptr = o_ptr;
  103. p.grad_output_ptr = do_ptr;
  104. p.delta_ptr = delta_ptr;
  105. p.grad_query_ptr = dq_ptr;
  106. p.grad_key_ptr = dk_ptr;
  107. p.grad_value_ptr = dv_ptr;
  108. p.grad_bias1_ptr = db1_ptr;
  109. p.grad_bias2_ptr = db2_ptr;
  110. p.B = q.size(0);
  111. p.N = q.size(1);
  112. p.bias1_ptr = bias1.size(0) ? bias1_ptr : nullptr;
  113. p.bias2_ptr = bias2.size(0) ? bias2_ptr : nullptr;
  114. p.scale = 1.0f / sqrtf(head_size);
  115. p.head_dim = head_size;
  116. p.head_dim_value = head_size;
  117. p.num_queries = seq_length;
  118. p.num_keys = seq_length;
  119. p.num_heads = head_number;
  120. p.q_strideM = q_view.stride(-3);
  121. p.k_strideM = k_view.stride(-3);
  122. p.v_strideM = v_view.stride(-3);
  123. p.gO_strideM = do_view.stride(-3);
  124. p.o_strideH = o_view.stride(-2);
  125. p.q_strideH = q_view.stride(-2);
  126. p.k_strideH = k_view.stride(-2);
  127. p.v_strideH = v_view.stride(-2);
  128. p.o_strideB = o_view.stride(-4);
  129. p.q_strideB = q_view.stride(-4);
  130. p.k_strideB = k_view.stride(-4);
  131. p.v_strideB = v_view.stride(-4);
  132. p.lse_strideB = lse.stride(-3);
  133. p.lse_strideH = lse.stride(-2);
  134. p.delta_strideB = delta.stride(-3);
  135. p.delta_strideH = delta.stride(-2);
  136. p.num_batches = q_view.size(-4);
  137. p.gO_strideB = do_view.stride(-4);
  138. p.gQ_strideB = dq_view.stride(-4);
  139. p.gK_strideB = dk_view.stride(-4);
  140. p.gV_strideB = dv_view.stride(-4);
  141. p.gO_strideH = do_view.stride(-2);
  142. p.gQ_strideH = dq_view.stride(-2);
  143. p.gK_strideH = dk_view.stride(-2);
  144. p.gV_strideH = dv_view.stride(-2);
  145. torch::Tensor workspace = torch::empty(p.workspace_size() / 4, lse.options());
  146. p.workspace = workspace.data_ptr<float>();
  147. auto kernel_fn = attention_kernel_backward_batched_impl<Kernel>;
  148. size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
  149. cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes));
  150. if (!Kernel::check_supported(p)) { throw std::runtime_error("Unsupported parameters"); }
  151. kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
  152. }
  153. #define CODE(scalar_t, torch_scalar_t) \
  154. do { \
  155. if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
  156. attention_back_impl_template<ArchTag, \
  157. scalar_t, \
  158. torch_scalar_t, \
  159. BroadcastNoLoad, \
  160. BroadcastNoLoad>( \
  161. go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
  162. } else if (bias1.size(0) > 0 && bias2.size(0) > 0) { \
  163. attention_back_impl_template<ArchTag, \
  164. scalar_t, \
  165. torch_scalar_t, \
  166. BroadcastA, \
  167. BroadcastB>( \
  168. go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
  169. } else if (bias1.size(0) > 0) { \
  170. attention_back_impl_template<ArchTag, \
  171. scalar_t, \
  172. torch_scalar_t, \
  173. BroadcastA, \
  174. BroadcastNoLoad>( \
  175. go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
  176. } else { \
  177. attention_back_impl_template<ArchTag, \
  178. scalar_t, \
  179. torch_scalar_t, \
  180. BroadcastNoLoad, \
  181. BroadcastB>( \
  182. go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
  183. } \
  184. } while (0)
  185. void attention_back_impl(torch::Tensor& go,
  186. torch::Tensor& q,
  187. torch::Tensor& k,
  188. torch::Tensor& v,
  189. torch::Tensor& o,
  190. torch::Tensor& lse,
  191. torch::Tensor& delta,
  192. torch::Tensor& bias1,
  193. torch::Tensor& bias2,
  194. torch::Tensor& gq,
  195. torch::Tensor& gk,
  196. torch::Tensor& gv,
  197. torch::Tensor& gb1,
  198. torch::Tensor& gb2)
  199. {
  200. cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
  201. DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
  202. DISPATCH_TYPES(q, { CODE(scalar_t, torch_scalar_t); }));
  203. }