triton_ops.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Inspired by original Triton implementation:
  6. https://github.com/openai/triton/blob/b244db06da24a87453a40ad35b085ee37dac3705/python/tutorials/06-fused-attention.py
  7. """
  8. import torch
  9. import triton
  10. import triton.language as tl
  11. @triton.jit
  12. def _fwd_kernel(
  13. Q,
  14. K,
  15. V,
  16. sm_scale,
  17. TMP,
  18. Out,
  19. stride_qz,
  20. stride_qh,
  21. stride_qm,
  22. stride_qk,
  23. stride_kz,
  24. stride_kh,
  25. stride_kn,
  26. stride_kk,
  27. stride_vz,
  28. stride_vh,
  29. stride_vk,
  30. stride_vn,
  31. stride_oz,
  32. stride_oh,
  33. stride_om,
  34. stride_on,
  35. Z,
  36. H,
  37. N_CTX,
  38. BLOCK_M: tl.constexpr,
  39. BLOCK_DMODEL: tl.constexpr,
  40. BLOCK_N: tl.constexpr,
  41. ):
  42. start_m = tl.program_id(0)
  43. off_hz = tl.program_id(1)
  44. # initialize offsets
  45. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  46. offs_n = tl.arange(0, BLOCK_N)
  47. offs_d = tl.arange(0, BLOCK_DMODEL)
  48. off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
  49. off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
  50. off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
  51. # Initialize pointers to Q, K, V
  52. q_ptrs = Q + off_q
  53. k_ptrs = K + off_k
  54. v_ptrs = V + off_v
  55. # initialize pointer to m and l
  56. t_ptrs = TMP + off_hz * N_CTX + offs_m
  57. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  58. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  59. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  60. # load q: it will stay in SRAM throughout
  61. q = tl.load(q_ptrs)
  62. # loop over k, v and update accumulator
  63. for start_n in range(0, N_CTX, BLOCK_N):
  64. start_n = tl.multiple_of(start_n, BLOCK_N)
  65. # -- compute qk ----
  66. k = tl.load(k_ptrs + start_n * stride_kn)
  67. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  68. qk += tl.dot(q, k, trans_b=True)
  69. qk *= sm_scale
  70. # -- compute m_ij, p, l_ij
  71. m_ij = tl.max(qk, 1)
  72. p = tl.exp(qk - m_ij[:, None])
  73. l_ij = tl.sum(p, 1)
  74. # -- update m_i and l_i
  75. m_i_new = tl.maximum(m_i, m_ij)
  76. alpha = tl.exp(m_i - m_i_new)
  77. beta = tl.exp(m_ij - m_i_new)
  78. l_i_new = alpha * l_i + beta * l_ij
  79. # -- update output accumulator --
  80. # scale p
  81. p_scale = beta / l_i_new
  82. p = p * p_scale[:, None]
  83. # scale acc
  84. acc_scale = l_i / l_i_new * alpha
  85. tl.store(t_ptrs, acc_scale)
  86. acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
  87. acc = acc * acc_scale[:, None]
  88. # update acc
  89. v = tl.load(v_ptrs + start_n * stride_vk)
  90. p = p.to(tl.float16)
  91. acc += tl.dot(p, v)
  92. # update m_i and l_i
  93. l_i = l_i_new
  94. m_i = m_i_new
  95. # initialize pointers to output
  96. offs_n = tl.arange(0, BLOCK_DMODEL)
  97. off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
  98. out_ptrs = Out + off_o
  99. tl.store(out_ptrs, acc)
  100. class triton_flash_attn(torch.nn.Module):
  101. def __init__(self, ):
  102. super(triton_flash_attn, self).__init__()
  103. def forward(self, q, k, v, sm_scale, block_128=True):
  104. BLOCK = 128 if block_128 else 64
  105. # shape constraints
  106. Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
  107. o = torch.empty_like(q)
  108. grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
  109. tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
  110. num_warps = 4 if Lk <= 64 else 8
  111. _fwd_kernel[grid](
  112. q,
  113. k,
  114. v,
  115. sm_scale,
  116. tmp,
  117. o,
  118. q.stride(0),
  119. q.stride(1),
  120. q.stride(2),
  121. q.stride(3),
  122. k.stride(0),
  123. k.stride(1),
  124. k.stride(2),
  125. k.stride(3),
  126. v.stride(0),
  127. v.stride(1),
  128. v.stride(2),
  129. v.stride(3),
  130. o.stride(0),
  131. o.stride(1),
  132. o.stride(2),
  133. o.stride(3),
  134. k.shape[0],
  135. k.shape[1],
  136. k.shape[2],
  137. BLOCK_M=BLOCK,
  138. BLOCK_N=BLOCK,
  139. BLOCK_DMODEL=Lk,
  140. num_warps=num_warps,
  141. num_stages=1,
  142. )
  143. return o