attention.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import math
  5. import torch
  6. import torch.nn as nn
  7. import triton
  8. import triton.language as tl
  9. from deepspeed.accelerator import get_accelerator
  10. from deepspeed import comm as dist
  11. from deepspeed.ops.transformer.inference.op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp
  12. from deepspeed.ops.transformer.inference.triton import (
  13. softmax,
  14. score_4d_matmul,
  15. context_4d_matmul,
  16. )
  17. minus_inf = -10000.0
  18. class TritonSelfAttention(nn.Module):
  19. num_layers = 0
  20. def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, qkv_merging=False):
  21. super(TritonSelfAttention, self).__init__()
  22. self.config = config
  23. data_type = self.config.dtype
  24. data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
  25. assert data_type_fp == torch.half, "triton supports fp16 data_type_fp"
  26. self.config.layer_id = TritonSelfAttention.num_layers
  27. TritonSelfAttention.num_layers = TritonSelfAttention.num_layers + 1
  28. device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu'
  29. assert config.mp_size == 1, "mp_size has to be 1 with triton attention yet"
  30. if self.config.set_empty_params:
  31. self.attn_qw = None
  32. self.attn_qb = None
  33. self.attn_kw = None
  34. self.attn_kb = None
  35. self.attn_vw = None
  36. self.attn_vb = None
  37. self.attn_qkvw = None
  38. self.attn_qkvb = None
  39. self.attn_ow = None
  40. self.attn_ob = None
  41. else:
  42. qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
  43. self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
  44. qkv_size_per_partition,
  45. dtype=data_type,
  46. device=device),
  47. requires_grad=False)
  48. self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
  49. requires_grad=False)
  50. # self-ouput weights
  51. out_size_per_partition = self.config.hidden_size // self.config.mp_size
  52. self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
  53. self.config.hidden_size,
  54. dtype=data_type,
  55. device=device),
  56. requires_grad=False)
  57. self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
  58. requires_grad=False)
  59. self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
  60. self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
  61. self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads
  62. self.mp_group = mp_group
  63. self.use_flash = False
  64. # triton flash attention is enabled when the compute capability >= 8.0
  65. if get_accelerator().is_triton_supported():
  66. self.use_flash = True
  67. # used for quantization
  68. self.q_scales = q_scales
  69. self.q_groups = q_groups
  70. self.merge_count = int(math.log2(merge_count))
  71. self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads)
  72. if not config.use_mup:
  73. self.norm_factor = math.sqrt(self.norm_factor)
  74. if self.config.scale_attn_by_inverse_layer_idx is True:
  75. self.norm_factor *= math.sqrt(self.config.layer_id + 1)
  76. # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
  77. triton_autotune = self.config.triton_autotune and self.config.layer_id == 0
  78. self.qkv_func = QKVGemmOp(config)
  79. self.score_context_func = SoftmaxContextOp(config)
  80. self.linear_func = LinearOp(config)
  81. self.vector_matmul_func = VectorMatMulOp(config)
  82. self.hidden_size = config.hidden_size
  83. self.head_size = config.hidden_size // config.heads
  84. self.scale = (1 / self.norm_factor / self.norm_factor if self.config.scale_attention else 1.0
  85. ) # making it back to 1/sqrt(head_size)
  86. self.triangular_masking = self.config.triangular_masking
  87. # triton autotune table update for score/context matmul
  88. if triton_autotune:
  89. print(f"running triton autotune for regular attention kernel")
  90. __class__._triton_autotune(2, self.config.max_out_tokens, self.head_size, self.config.hidden_size,
  91. self.triangular_masking, self.scale)
  92. @staticmethod
  93. def _triton_autotune(min_seqlen,
  94. max_seqlen,
  95. head_size,
  96. hidden_size,
  97. triangular_masking,
  98. scale,
  99. dtype=torch.float16):
  100. from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, score_4d_matmul, context_4d_matmul
  101. seqlen = [(min_seqlen + i)
  102. for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
  103. Fp16Matmul._read_autotune_table()
  104. for N in seqlen:
  105. qkv = torch.randn((1, N, 3 * hidden_size), dtype=dtype, device='cuda')
  106. output = score_4d_matmul(qkv, head_size, triangular_masking, scale)
  107. context_4d_matmul(output, qkv, head_size)
  108. Fp16Matmul._update_autotune_table()
  109. def ds_compute_attention(self, qkv_out, input_mask, layer_past, alibi):
  110. if isinstance(qkv_out, list):
  111. qkv_out = qkv_out[0]
  112. no_masking = input_mask is None
  113. if no_masking:
  114. input_mask = torch.empty(1)
  115. attn_key_value = self.score_context_func(
  116. query_key_value=qkv_out,
  117. attn_mask=((1 - input_mask).to(qkv_out.dtype) *
  118. minus_inf) if input_mask.dtype == torch.int64 else input_mask,
  119. heads=self.num_attention_heads_per_partition,
  120. norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0),
  121. no_masking=no_masking,
  122. layer_id=self.config.layer_id,
  123. num_layers=TritonSelfAttention.num_layers,
  124. alibi=alibi)
  125. context_layer, key_layer, value_layer = attn_key_value
  126. return context_layer, key_layer, value_layer
  127. def forward(
  128. self,
  129. input,
  130. input_mask,
  131. head_mask=None,
  132. layer_past=None,
  133. get_present=False, # not used
  134. encoder_hidden_states=None, # not used
  135. encoder_attention_mask=None, # not used
  136. triangularutput_attentions=False, # not used
  137. norm_w=None,
  138. norm_b=None,
  139. alibi=None,
  140. use_triton_attention=True):
  141. if not self.config.pre_layer_norm:
  142. qkv_out = self.linear_func(input=input,
  143. weight=self.attn_qkvw,
  144. bias=self.attn_qkvb,
  145. add_bias=self.attn_qkvb is not None,
  146. do_flash_attn=False,
  147. num_heads=self.num_attention_heads_per_partition,
  148. num_layers=TritonSelfAttention.num_layers)
  149. qkv = qkv_out
  150. else:
  151. qkv_out = self.qkv_func(input=input,
  152. weight=self.attn_qkvw,
  153. bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b),
  154. gamma=norm_w,
  155. beta=norm_b)
  156. qkv = qkv_out[0]
  157. if use_triton_attention and (alibi is None):
  158. context_layer = _triton_attention(qkv=qkv,
  159. input_mask=input_mask,
  160. scale=self.scale,
  161. layer_past=layer_past,
  162. alibi=alibi,
  163. head_size=self.head_size,
  164. use_triton_flash=self.use_flash,
  165. use_cuda_flash=False,
  166. triangular=self.triangular_masking)
  167. key_layer, value_layer = qkv[:, :, self.hidden_size:2 * self.hidden_size], qkv[:, :, 2 * self.hidden_size:]
  168. else:
  169. context_layer, key_layer, value_layer = self.ds_compute_attention(qkv_out=qkv_out,
  170. input_mask=input_mask,
  171. layer_past=layer_past,
  172. alibi=alibi)
  173. output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
  174. inp_norm = qkv_out[-1]
  175. if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
  176. dist.all_reduce(output, group=self.mp_group)
  177. return (output, key_layer, value_layer, context_layer, inp_norm)
  178. global inference_module
  179. def _triton_attention(qkv,
  180. input_mask,
  181. layer_past,
  182. alibi,
  183. scale,
  184. head_size,
  185. triangular=False,
  186. use_cuda_flash=False,
  187. use_triton_flash=False,
  188. use_ds_attention=False):
  189. if isinstance(qkv, list):
  190. qkv = qkv[0]
  191. assert alibi is None, "layer_past not supported in alibi yet"
  192. if use_triton_flash:
  193. output = _triton_packed_flash(qkv,
  194. head_size,
  195. input_mask,
  196. scale,
  197. causal=triangular,
  198. add_mask=(not triangular and input_mask is not None))
  199. else:
  200. output = score_4d_matmul(qkv, head_size, triangular, scale)
  201. if triangular:
  202. output = softmax(output)
  203. else:
  204. output = softmax(output, input_mask)
  205. output = context_4d_matmul(output, qkv, head_size)
  206. return output
  207. '''
  208. flash attention 2
  209. modified the triton kernel in
  210. https://github.com/openai/triton/blob/08c16589573621fcb8cd5a9c3b8a0537077f876d/python/tutorials/06-fused-attention.py
  211. '''
  212. @triton.jit
  213. def _flash_packed_kernel(
  214. QKV,
  215. mask,
  216. ADD_MASK: tl.constexpr,
  217. IS_CAUSAL: tl.constexpr,
  218. sm_scale,
  219. Out,
  220. stride_qz,
  221. stride_qn,
  222. stride_qm,
  223. stride_mz,
  224. stride_oz,
  225. stride_on,
  226. Z,
  227. H,
  228. N_CTX,
  229. P_SEQ,
  230. hidden_size,
  231. BLOCK_M: tl.constexpr,
  232. BLOCK_DMODEL: tl.constexpr,
  233. BLOCK_N: tl.constexpr,
  234. ):
  235. start_m = tl.program_id(0)
  236. off_hz = tl.program_id(1)
  237. batch = off_hz // H
  238. head = off_hz % H
  239. q_offset = batch * stride_qz + head * BLOCK_DMODEL
  240. k_offset = q_offset + hidden_size
  241. v_offset = k_offset + hidden_size
  242. # initialize offsets
  243. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  244. offs_n = tl.arange(0, BLOCK_N)
  245. offs_d = tl.arange(0, BLOCK_DMODEL)
  246. q_ptrs = QKV + q_offset + offs_m[:, None] * stride_qn + offs_d[None, :]
  247. k_ptrs = QKV + hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :]
  248. v_ptrs = QKV + 2 * hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :]
  249. # mask
  250. off_mask = batch * stride_mz + offs_n[None, :]
  251. mask_ptrs = mask + off_mask
  252. # initialize pointer to m and l
  253. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  254. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  255. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  256. # scale sm_scale by log_2(e) and use
  257. # 2^x instead of exp in the loop because CSE and LICM
  258. # don't work as expected with `exp` in the loop
  259. qk_scale = sm_scale * 1.44269504
  260. # load q: it will stay in SRAM throughout
  261. q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)
  262. q = (q * qk_scale).to(tl.float16)
  263. # loop over k, v and update accumulator
  264. lo = 0
  265. hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ
  266. for start_n in range(lo, hi, BLOCK_N):
  267. # -- load k, v --
  268. k = tl.load(k_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)
  269. v = tl.load(v_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)
  270. # -- compute qk ---
  271. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
  272. if ADD_MASK:
  273. mask_val = tl.load(mask_ptrs)
  274. mask_ptrs += BLOCK_N
  275. qk = qk + mask_val.to(tl.float32)
  276. if IS_CAUSAL:
  277. qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
  278. qk += tl.dot(q, tl.trans(k), out_dtype=tl.float16)
  279. qk += tl.where((start_n + offs_n)[None, :] < N_CTX, 0, minus_inf)
  280. # -- compute scaling constant ---
  281. m_i_new = tl.maximum(m_i, tl.max(qk, 1))
  282. alpha = tl.math.exp2(m_i - m_i_new)
  283. p = tl.math.exp2(qk - m_i_new[:, None])
  284. # -- scale and update acc --
  285. acc_scale = l_i * 0 + alpha # workaround some compiler bug
  286. acc *= acc_scale[:, None]
  287. acc += tl.dot(p.to(tl.float16), v.to(tl.float16))
  288. # -- update m_i and l_i --
  289. l_i = l_i * alpha + tl.sum(p, 1)
  290. m_i = m_i_new
  291. # write back l and m
  292. acc = acc / l_i[:, None]
  293. o_offset = batch * stride_oz + head * BLOCK_DMODEL
  294. out_ptrs = Out + o_offset + (offs_m[:, None] * stride_on + offs_d[None, :])
  295. tl.store(out_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < N_CTX)
  296. def _triton_packed_flash(qkv, head_size, mask, sm_scale, causal=False, add_mask=True):
  297. heads = qkv.shape[-1] // 3 // head_size
  298. hidden_size = qkv.shape[-1] // 3
  299. BLOCK_M = 128
  300. BLOCK_N = 64 if head_size <= 64 else 32
  301. o = torch.empty((qkv.shape[0], qkv.shape[1], hidden_size), device=qkv.device, dtype=torch.half)
  302. if mask is None:
  303. mask = torch.empty(0)
  304. add_mask = False
  305. grid = (triton.cdiv(qkv.shape[1], BLOCK_M), qkv.shape[0] * heads, 1)
  306. num_stages = 4 if head_size <= 64 else 3
  307. num_warps = 4
  308. P_SEQ = 0
  309. _flash_packed_kernel[grid](qkv,
  310. mask,
  311. add_mask,
  312. causal,
  313. sm_scale,
  314. o,
  315. qkv.stride(0),
  316. qkv.stride(1),
  317. qkv.stride(2),
  318. mask.stride(1) if add_mask else 0,
  319. o.stride(0),
  320. o.stride(1),
  321. qkv.shape[0],
  322. heads,
  323. qkv.shape[1],
  324. P_SEQ,
  325. hidden_size,
  326. BLOCK_M=BLOCK_M,
  327. BLOCK_N=BLOCK_N,
  328. BLOCK_DMODEL=head_size,
  329. num_warps=num_warps,
  330. num_stages=num_stages)
  331. return o