attention.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import math
  5. import torch
  6. import torch.nn as nn
  7. from deepspeed.accelerator import get_accelerator
  8. from deepspeed import comm as dist
  9. from deepspeed.ops.transformer.inference.op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp
  10. from deepspeed.ops.transformer.inference.triton import (
  11. softmax,
  12. score_4d_matmul,
  13. context_4d_matmul,
  14. )
  15. minus_inf = -10000.0
  16. class TritonSelfAttention(nn.Module):
  17. num_layers = 0
  18. def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, qkv_merging=False):
  19. super(TritonSelfAttention, self).__init__()
  20. self.config = config
  21. data_type = self.config.dtype
  22. data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
  23. assert data_type_fp == torch.half, "triton supports fp16 data_type_fp"
  24. self.config.layer_id = TritonSelfAttention.num_layers
  25. TritonSelfAttention.num_layers = TritonSelfAttention.num_layers + 1
  26. device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu'
  27. assert config.mp_size == 1, "mp_size has to be 1 with triton attention yet"
  28. if self.config.set_empty_params:
  29. self.attn_qw = None
  30. self.attn_qb = None
  31. self.attn_kw = None
  32. self.attn_kb = None
  33. self.attn_vw = None
  34. self.attn_vb = None
  35. self.attn_qkvw = None
  36. self.attn_qkvb = None
  37. self.attn_ow = None
  38. self.attn_ob = None
  39. else:
  40. qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
  41. self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
  42. qkv_size_per_partition,
  43. dtype=data_type,
  44. device=device),
  45. requires_grad=False)
  46. self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
  47. requires_grad=False)
  48. # self-ouput weights
  49. out_size_per_partition = self.config.hidden_size // self.config.mp_size
  50. self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
  51. self.config.hidden_size,
  52. dtype=data_type,
  53. device=device),
  54. requires_grad=False)
  55. self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
  56. requires_grad=False)
  57. self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
  58. self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
  59. self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads
  60. self.mp_group = mp_group
  61. self.use_flash = False
  62. # used for quantization
  63. self.q_scales = q_scales
  64. self.q_groups = q_groups
  65. self.merge_count = int(math.log2(merge_count))
  66. self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads)
  67. if not config.use_mup:
  68. self.norm_factor = math.sqrt(self.norm_factor)
  69. if self.config.scale_attn_by_inverse_layer_idx is True:
  70. self.norm_factor *= math.sqrt(self.config.layer_id + 1)
  71. # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
  72. triton_autotune = self.config.triton_autotune and self.config.layer_id == 0
  73. self.qkv_func = QKVGemmOp(config)
  74. self.score_context_func = SoftmaxContextOp(config)
  75. self.linear_func = LinearOp(config)
  76. self.vector_matmul_func = VectorMatMulOp(config)
  77. self.hidden_size = config.hidden_size
  78. self.head_size = config.hidden_size // config.heads
  79. self.scale = (1 / self.norm_factor / self.norm_factor if self.config.scale_attention else 1.0
  80. ) # making it back to 1/sqrt(head_size)
  81. self.triangular_masking = self.config.triangular_masking
  82. # triton autotune table update for score/context matmul
  83. if triton_autotune:
  84. print(f"running triton autotune for attention")
  85. __class__._triton_autotune(2, self.config.max_out_tokens, self.head_size, self.config.hidden_size,
  86. self.triangular_masking, self.scale)
  87. @staticmethod
  88. def _triton_autotune(min_seqlen,
  89. max_seqlen,
  90. head_size,
  91. hidden_size,
  92. triangular_masking,
  93. scale,
  94. dtype=torch.float16):
  95. from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, score_4d_matmul, context_4d_matmul
  96. seqlen = [(min_seqlen + i)
  97. for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
  98. Fp16Matmul._read_autotune_table()
  99. for N in seqlen:
  100. qkv = torch.randn((1, N, 3 * hidden_size), dtype=dtype, device='cuda')
  101. output = score_4d_matmul(qkv, head_size, triangular_masking, scale)
  102. context_4d_matmul(output, qkv, head_size)
  103. Fp16Matmul._update_autotune_table()
  104. def ds_compute_attention(self, qkv_out, input_mask, layer_past, alibi):
  105. if isinstance(qkv_out, list):
  106. qkv_out = qkv_out[0]
  107. no_masking = input_mask is None
  108. if no_masking:
  109. input_mask = torch.empty(1)
  110. attn_key_value = self.score_context_func(
  111. query_key_value=qkv_out,
  112. attn_mask=((1 - input_mask).to(qkv_out.dtype) *
  113. minus_inf) if input_mask.dtype == torch.int64 else input_mask,
  114. heads=self.num_attention_heads_per_partition,
  115. norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0),
  116. no_masking=no_masking,
  117. layer_id=self.config.layer_id,
  118. num_layers=TritonSelfAttention.num_layers,
  119. alibi=alibi)
  120. context_layer, key_layer, value_layer = attn_key_value
  121. return context_layer, key_layer, value_layer
  122. def forward(
  123. self,
  124. input,
  125. input_mask,
  126. head_mask=None,
  127. layer_past=None,
  128. get_present=False, # not used
  129. encoder_hidden_states=None, # not used
  130. encoder_attention_mask=None, # not used
  131. triangularutput_attentions=False, # not used
  132. norm_w=None,
  133. norm_b=None,
  134. alibi=None,
  135. use_triton_attention=True):
  136. if not self.config.pre_layer_norm:
  137. qkv_out = self.linear_func(input=input,
  138. weight=self.attn_qkvw,
  139. bias=self.attn_qkvb,
  140. add_bias=self.attn_qkvb is not None,
  141. do_flash_attn=False,
  142. num_heads=self.num_attention_heads_per_partition,
  143. num_layers=TritonSelfAttention.num_layers)
  144. qkv = qkv_out
  145. else:
  146. qkv_out = self.qkv_func(input=input,
  147. weight=self.attn_qkvw,
  148. bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b),
  149. gamma=norm_w,
  150. beta=norm_b)
  151. qkv = qkv_out[0]
  152. if use_triton_attention and (alibi is None):
  153. context_layer = compute_attention(qkv=qkv,
  154. input_mask=input_mask,
  155. scale=self.scale,
  156. layer_past=layer_past,
  157. alibi=alibi,
  158. head_size=self.head_size,
  159. use_triton_flash=self.use_flash,
  160. use_cuda_flash=False,
  161. triangular=self.triangular_masking)
  162. key_layer, value_layer = qkv[:, :, self.hidden_size:2 * self.hidden_size], qkv[:, :, 2 * self.hidden_size:]
  163. else:
  164. context_layer, key_layer, value_layer = self.ds_compute_attention(qkv_out=qkv_out,
  165. input_mask=input_mask,
  166. layer_past=layer_past,
  167. alibi=alibi)
  168. output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
  169. inp_norm = qkv_out[-1]
  170. if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
  171. dist.all_reduce(output, group=self.mp_group)
  172. return (output, key_layer, value_layer, context_layer, inp_norm)
  173. global inference_module
  174. def compute_attention(qkv,
  175. input_mask,
  176. layer_past,
  177. alibi,
  178. scale,
  179. head_size,
  180. triangular=False,
  181. use_cuda_flash=False,
  182. use_triton_flash=False,
  183. use_ds_attention=False):
  184. if isinstance(qkv, list):
  185. qkv = qkv[0]
  186. #assert layer_past is None, "layer_past not supported in triton yet"
  187. assert alibi is None, "layer_past not supported in alibi yet"
  188. output = score_4d_matmul(qkv, head_size, triangular, scale)
  189. if triangular:
  190. output = softmax(output)
  191. else:
  192. output = softmax(output, input_mask)
  193. output = context_4d_matmul(output, qkv, head_size)
  194. return output