attention.cpp 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <torch/extension.h>
  5. void attention_impl(torch::Tensor& q,
  6. torch::Tensor& k,
  7. torch::Tensor& v,
  8. torch::Tensor& bias1,
  9. torch::Tensor& bias2,
  10. torch::Tensor& o,
  11. torch::Tensor& lse);
  12. void attention(torch::Tensor& q,
  13. torch::Tensor& k,
  14. torch::Tensor& v,
  15. torch::Tensor& bias1,
  16. torch::Tensor& bias2,
  17. torch::Tensor& o,
  18. torch::Tensor& lse)
  19. {
  20. attention_impl(q, k, v, bias1, bias2, o, lse);
  21. }
  22. void attention_back_impl(torch::Tensor& go,
  23. torch::Tensor& q,
  24. torch::Tensor& k,
  25. torch::Tensor& v,
  26. torch::Tensor& o,
  27. torch::Tensor& lse,
  28. torch::Tensor& delta,
  29. torch::Tensor& bias1,
  30. torch::Tensor& bias2,
  31. torch::Tensor& gq,
  32. torch::Tensor& gk,
  33. torch::Tensor& gv,
  34. torch::Tensor& gb1,
  35. torch::Tensor& gb2);
  36. void attention_bwd(torch::Tensor& go,
  37. torch::Tensor& q,
  38. torch::Tensor& k,
  39. torch::Tensor& v,
  40. torch::Tensor& o,
  41. torch::Tensor& lse,
  42. torch::Tensor& delta,
  43. torch::Tensor& bias1,
  44. torch::Tensor& bias2,
  45. torch::Tensor& gq,
  46. torch::Tensor& gk,
  47. torch::Tensor& gv,
  48. torch::Tensor& gb1,
  49. torch::Tensor& gb2)
  50. {
  51. attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2);
  52. }
  53. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  54. {
  55. m.def("attention", &attention, "");
  56. m.def("attention_bwd", &attention_bwd, "");
  57. }