ds_attention.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import math
  5. import torch
  6. import torch.nn as nn
  7. from deepspeed import comm as dist
  8. from deepspeed.accelerator import get_accelerator
  9. from .op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp, SoftmaxOp
  10. minus_inf = -10000.0
  11. class DeepSpeedSelfAttention(nn.Module):
  12. num_layers = 0
  13. _qkv_buffers = []
  14. def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1):
  15. super(DeepSpeedSelfAttention, self).__init__()
  16. self.config = config
  17. data_type = self.config.dtype
  18. data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
  19. self.config.layer_id = DeepSpeedSelfAttention.num_layers
  20. DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
  21. device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu'
  22. if self.config.set_empty_params:
  23. self.attn_qw = None
  24. self.attn_qb = None
  25. self.attn_kw = None
  26. self.attn_kb = None
  27. self.attn_vw = None
  28. self.attn_vb = None
  29. self.attn_qkvw = None
  30. self.attn_qkvb = None
  31. self.attn_ow = None
  32. self.attn_ob = None
  33. else:
  34. qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
  35. self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
  36. qkv_size_per_partition,
  37. dtype=data_type,
  38. device=device),
  39. requires_grad=False)
  40. self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
  41. requires_grad=False)
  42. out_size_per_partition = self.config.hidden_size // self.config.mp_size
  43. self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
  44. self.config.hidden_size,
  45. dtype=data_type,
  46. device=device),
  47. requires_grad=False)
  48. self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
  49. requires_grad=False)
  50. self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
  51. self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
  52. self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads
  53. self.mp_group = mp_group
  54. # used for quantization
  55. self.q_scales = q_scales
  56. self.q_groups = q_groups
  57. self.merge_count = int(math.log2(merge_count))
  58. self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads)
  59. if not config.use_mup:
  60. self.norm_factor = math.sqrt(self.norm_factor)
  61. if self.config.scale_attn_by_inverse_layer_idx is True:
  62. self.norm_factor *= math.sqrt(self.config.layer_id + 1)
  63. # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
  64. self.qkv_func = QKVGemmOp(config)
  65. self.score_context_func = SoftmaxContextOp(config)
  66. self.linear_func = LinearOp(config)
  67. self.vector_matmul_func = VectorMatMulOp(config)
  68. if len(DeepSpeedSelfAttention._qkv_buffers) == 0:
  69. DeepSpeedSelfAttention._qkv_buffers = [
  70. torch.empty(self.hidden_size_per_partition * 3,
  71. self.config.hidden_size,
  72. dtype=data_type_fp,
  73. device=device),
  74. torch.empty(self.hidden_size_per_partition * 3, dtype=data_type_fp, device=device)
  75. ]
  76. def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
  77. if isinstance(qkv_out, list) or isinstance(qkv_out, tuple):
  78. qkv_out = qkv_out[0]
  79. no_masking = input_mask is None
  80. if no_masking:
  81. input_mask = torch.empty(1)
  82. attn_key_value = self.score_context_func(
  83. query_key_value=qkv_out,
  84. attn_mask=((1 - input_mask).to(qkv_out.dtype) *
  85. minus_inf) if input_mask.dtype == torch.int64 else input_mask,
  86. heads=self.num_attention_heads_per_partition,
  87. norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0),
  88. no_masking=no_masking,
  89. layer_id=self.config.layer_id,
  90. num_layers=DeepSpeedSelfAttention.num_layers,
  91. alibi=alibi)
  92. context_layer, key_layer, value_layer = attn_key_value
  93. return context_layer, key_layer, value_layer
  94. def _merge_qkv(self):
  95. qvkw = DeepSpeedSelfAttention._qkv_buffers[0]
  96. qvkw[:self.hidden_size_per_partition, :] = self.attn_qw # type: ignore
  97. qvkw[self.hidden_size_per_partition:2 * self.hidden_size_per_partition, :] = self.attn_kw # type: ignore
  98. qvkw[2 * self.hidden_size_per_partition:, :] = self.attn_vw # type: ignore
  99. if self.attn_qb is not None:
  100. qvkb = DeepSpeedSelfAttention._qkv_buffers[1]
  101. qvkb[:self.hidden_size_per_partition] = self.attn_qb
  102. qvkb[self.hidden_size_per_partition:2 * self.hidden_size_per_partition] = self.attn_kb # type: ignore
  103. qvkb[2 * self.hidden_size_per_partition:] = self.attn_vb # type: ignore
  104. return DeepSpeedSelfAttention._qkv_buffers
  105. def forward(self,
  106. input,
  107. input_mask,
  108. head_mask=None,
  109. layer_past=None,
  110. get_present=False,
  111. encoder_hidden_states=None,
  112. encoder_attention_mask=None,
  113. output_attentions=False,
  114. norm_w=None,
  115. norm_b=None,
  116. alibi=None):
  117. if self.attn_qkvw is None:
  118. self._attn_qkvw, self._attn_qkvb = self._merge_qkv()
  119. else:
  120. self._attn_qkvw = self.attn_qkvw
  121. self._attn_qkvb = self.attn_qkvb
  122. if not self.config.pre_layer_norm:
  123. qkv_out = self.linear_func(input=input,
  124. weight=self._attn_qkvw,
  125. bias=self._attn_qkvb,
  126. add_bias=self.attn_qkvb is not None,
  127. do_flash_attn=False,
  128. num_heads=self.num_attention_heads_per_partition,
  129. num_layers=DeepSpeedSelfAttention.num_layers)
  130. else:
  131. qkv_out = self.qkv_func(input=input,
  132. weight=self._attn_qkvw,
  133. bias=self._attn_qkvb,
  134. gamma=norm_w,
  135. beta=norm_b)
  136. context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out,
  137. input_mask=input_mask,
  138. layer_past=layer_past,
  139. alibi=alibi)
  140. output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
  141. inp_norm = qkv_out[-1]
  142. if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
  143. dist.all_reduce(output, group=self.mp_group)
  144. return (output, key_layer, value_layer, context_layer, inp_norm)
  145. class BloomSelfAttention(DeepSpeedSelfAttention):
  146. def __init__(self, *args, **kwargs):
  147. super(BloomSelfAttention, self).__init__(*args, **kwargs)
  148. self.softmax_func = SoftmaxOp(self.config)
  149. ########### This part is taken/modified form the HF modeling_bloom.py ################
  150. # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py
  151. def _transpose_for_context(self, x):
  152. x = x.permute(0, 2, 1, 3).contiguous()
  153. new_x_layer_shape = x.size()[:-2] + \
  154. (self.hidden_size_per_partition,)
  155. return x.view(*new_x_layer_shape).contiguous()
  156. def _split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=True):
  157. """Split a tensor along its last dimension.
  158. Args:
  159. tensor: ([`torch.tensor`], *required*):
  160. input tensor to split
  161. num_partitions ([`int`], *required*):
  162. number of partitions to split the tensor
  163. contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
  164. If True, make each chunk contiguous in memory.
  165. """
  166. # Get the size and dimension.
  167. last_dim = tensor.dim() - 1
  168. numerator, denominator = tensor.size()[last_dim], num_partitions
  169. if not (numerator % denominator == 0):
  170. raise ValueError(f"{numerator} is not divisible by {denominator}")
  171. last_dim_size = numerator // denominator
  172. # Split.
  173. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  174. # Note: torch.split does not create contiguous tensors by default.
  175. if contiguous_split_chunks:
  176. return tuple(chunk.contiguous() for chunk in tensor_list)
  177. return tensor_list
  178. def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
  179. if isinstance(qkv_out, list) or isinstance(qkv_out, tuple):
  180. qkv_out = qkv_out[0]
  181. no_masking = input_mask is None
  182. if no_masking:
  183. input_mask = torch.empty(1)
  184. mixed_x_layer = qkv_out
  185. alibi = alibi.to(get_accelerator().current_device_name())
  186. head_dim = self.hidden_size_per_partition // self.num_attention_heads_per_partition
  187. new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * head_dim)
  188. mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
  189. query_layer, key_layer, value_layer = self._split_tensor_along_last_dim(mixed_x_layer, 3)
  190. # [batch_size, head_dim, q_length, k_length]
  191. output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
  192. # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
  193. query_layer = query_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[2], -1)
  194. # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
  195. key_layer = key_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3],
  196. -1).transpose(-1, -2)
  197. value_layer = value_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3], -1)
  198. if layer_past is not None:
  199. past_key, past_value = layer_past
  200. # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
  201. key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1)
  202. value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)
  203. presents = (key_layer, value_layer)
  204. # Raw attention scores. [batch_size * num_heads, q_length, k_length]
  205. matmul_result = torch.matmul(query_layer, key_layer)
  206. # change view to [batch_size, num_heads, q_length, k_length]
  207. attention_scores = matmul_result.view(output_size[0], output_size[1], output_size[2], -1)
  208. offset = dist.get_rank() * self.num_attention_heads_per_partition if dist.is_initialized() else 0
  209. target_dtype = torch.float16 if self.config.dtype == torch.int8 else self.config.dtype
  210. attention_probs = self.softmax_func(attn_scores=attention_scores,
  211. attn_mask=((1 - input_mask).to(target_dtype) * minus_inf),
  212. alibi=alibi,
  213. triangular=(self.config.triangular_masking
  214. and (attention_scores.shape[-2] > 1)),
  215. recompute=False,
  216. local_attention=False,
  217. window_size=1,
  218. async_op=False,
  219. layer_scale=1 / (self.norm_factor * self.norm_factor),
  220. head_offset=offset)
  221. # change view [batch_size x num_heads, q_length, k_length]
  222. attention_probs_reshaped = attention_probs.view(*matmul_result.shape)
  223. # matmul: [batch_size * num_heads, q_length, head_dim]
  224. context_layer = torch.bmm(attention_probs_reshaped, value_layer)
  225. # change view [batch_size, num_heads, q_length, head_dim]
  226. context_layer = context_layer.view(
  227. context_layer.size(0) // self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
  228. context_layer.size(1), context_layer.shape[-1])
  229. context_layer = self._transpose_for_context(context_layer)
  230. key_layer = presents[0]
  231. value_layer = presents[1]
  232. return context_layer, key_layer, value_layer
  233. ###################### End of HF modeling_bloom addition ########################