123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- """
- Inspired by original Triton implementation:
- https://github.com/openai/triton/blob/b244db06da24a87453a40ad35b085ee37dac3705/python/tutorials/06-fused-attention.py
- """
- import torch
- import triton
- import triton.language as tl
- @triton.jit
- def _fwd_kernel(
- Q,
- K,
- V,
- sm_scale,
- TMP,
- Out,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vk,
- stride_vn,
- stride_oz,
- stride_oh,
- stride_om,
- stride_on,
- Z,
- H,
- N_CTX,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- ):
- start_m = tl.program_id(0)
- off_hz = tl.program_id(1)
- # initialize offsets
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = tl.arange(0, BLOCK_N)
- offs_d = tl.arange(0, BLOCK_DMODEL)
- off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
- off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
- off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
- # Initialize pointers to Q, K, V
- q_ptrs = Q + off_q
- k_ptrs = K + off_k
- v_ptrs = V + off_v
- # initialize pointer to m and l
- t_ptrs = TMP + off_hz * N_CTX + offs_m
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
- l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- # load q: it will stay in SRAM throughout
- q = tl.load(q_ptrs)
- # loop over k, v and update accumulator
- for start_n in range(0, N_CTX, BLOCK_N):
- start_n = tl.multiple_of(start_n, BLOCK_N)
- # -- compute qk ----
- k = tl.load(k_ptrs + start_n * stride_kn)
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
- qk += tl.dot(q, k, trans_b=True)
- qk *= sm_scale
- # -- compute m_ij, p, l_ij
- m_ij = tl.max(qk, 1)
- p = tl.exp(qk - m_ij[:, None])
- l_ij = tl.sum(p, 1)
- # -- update m_i and l_i
- m_i_new = tl.maximum(m_i, m_ij)
- alpha = tl.exp(m_i - m_i_new)
- beta = tl.exp(m_ij - m_i_new)
- l_i_new = alpha * l_i + beta * l_ij
- # -- update output accumulator --
- # scale p
- p_scale = beta / l_i_new
- p = p * p_scale[:, None]
- # scale acc
- acc_scale = l_i / l_i_new * alpha
- tl.store(t_ptrs, acc_scale)
- acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
- acc = acc * acc_scale[:, None]
- # update acc
- v = tl.load(v_ptrs + start_n * stride_vk)
- p = p.to(tl.float16)
- acc += tl.dot(p, v)
- # update m_i and l_i
- l_i = l_i_new
- m_i = m_i_new
- # initialize pointers to output
- offs_n = tl.arange(0, BLOCK_DMODEL)
- off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
- out_ptrs = Out + off_o
- tl.store(out_ptrs, acc)
- class triton_flash_attn(torch.nn.Module):
- def __init__(self, ):
- super(triton_flash_attn, self).__init__()
- def forward(self, q, k, v, sm_scale, block_128=True):
- BLOCK = 128 if block_128 else 64
- # shape constraints
- Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
- o = torch.empty_like(q)
- grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
- tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- num_warps = 4 if Lk <= 64 else 8
- _fwd_kernel[grid](
- q,
- k,
- v,
- sm_scale,
- tmp,
- o,
- q.stride(0),
- q.stride(1),
- q.stride(2),
- q.stride(3),
- k.stride(0),
- k.stride(1),
- k.stride(2),
- k.stride(3),
- v.stride(0),
- v.stride(1),
- v.stride(2),
- v.stride(3),
- o.stride(0),
- o.stride(1),
- o.stride(2),
- o.stride(3),
- k.shape[0],
- k.shape[1],
- k.shape[2],
- BLOCK_M=BLOCK,
- BLOCK_N=BLOCK,
- BLOCK_DMODEL=Lk,
- num_warps=num_warps,
- num_stages=1,
- )
- return o
|