# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import torch.nn as nn import torch from torch import distributed as dist from deepspeed.ops.sparse_attention import SparsityConfig class SparseSelfAttention(nn.Module): """Implements an efficient Sparse Self Attention of Transformer layer based on `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509 For more information please see, TODO DeepSpeed Sparse Transformer. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial. """ def __init__( self, # SparsityConfig parameters needs to be set accordingly sparsity_config=SparsityConfig(num_heads=4), key_padding_mask_mode='add', attn_mask_mode='mul', max_seq_length=2048): """Initialize the sparse self attention layer. Arguments: sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on SparsityConfig class. key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`. attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`. max_seq_length: optional: the maximum sequence length this sparse attention module will be applied to; it controls the size of the master_layout. """ super().__init__() # sparsity information self.sparsity_config = sparsity_config # initialize sparse layout and register as buffer master_layout = self.sparsity_config.make_layout(max_seq_length) self.register_buffer("master_layout", master_layout) self._need_layout_synchronization = True # mask modes self.key_padding_mask_mode = key_padding_mask_mode self.attn_mask_mode = attn_mask_mode ops = dict() def get_layout(self, L): # if layout is never synchronized across GPUs, broadcast the layout from global rank 0 if self._need_layout_synchronization and dist.is_initialized(): dist.broadcast(self.master_layout, src=0) self._need_layout_synchronization = False if (L % self.sparsity_config.block != 0): raise ValueError( f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!') num_blocks = L // self.sparsity_config.block return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor # add to cache def get_ops(self, H, L): from deepspeed.ops.sparse_attention.matmul import MatMul from deepspeed.ops.sparse_attention.softmax import Softmax if L not in SparseSelfAttention.ops: sparsity_layout = self.get_layout(L) sparse_dot_sdd_nt = MatMul(sparsity_layout, self.sparsity_config.block, 'sdd', trans_a=False, trans_b=True) sparse_dot_dsd_nn = MatMul(sparsity_layout, self.sparsity_config.block, 'dsd', trans_a=False, trans_b=False) sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block) SparseSelfAttention.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax) return SparseSelfAttention.ops[L] def transpose_key_for_scores(self, x, L): bsz, num_heads, seq_len, head_dim = x.size() if seq_len != L: return x.permute(0, 1, 3, 2) return x def transpose_mask_for_sparse(self, qtype, x, is_key_padding_mask=False): x = x.type(qtype) if is_key_padding_mask: xdim = x.dim() for d in range(xdim - 1, 0, -1): x = x.squeeze(dim=d) return x return x.squeeze() # forward pass def forward(self, query, key, value, rpe=None, key_padding_mask=None, attn_mask=None): """Applies forward phase of sparse self attention Arguments: query: required: query tensor key: required: key tensor value: required: value tensor rpe: optional: a tensor same dimension as x that is used as relative position embedding key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength) attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied Return: attn_output: a dense tensor containing attention context """ assert query.dtype == torch.half, "sparse attention only supports training in fp16 currently, please file a github issue if you need fp32 support" bsz, num_heads, tgt_len, head_dim = query.size() # transpose back key if it is already transposed key = self.transpose_key_for_scores(key, tgt_len) # check that operation is supported if query.shape != key.shape or key.shape != value.shape: raise NotImplementedError('only self-attention is supported for now') # squeeze key_padding_mask if it is given if key_padding_mask is not None: key_padding_mask = self.transpose_mask_for_sparse(query.dtype, key_padding_mask, is_key_padding_mask=True) # squeeze attn_mask if it is given if attn_mask is not None: attn_mask = self.transpose_mask_for_sparse(query.dtype, attn_mask) # cache look-up table computations etc sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops(num_heads, tgt_len) scaling = float(head_dim)**-0.5 # attention scores attn_output_weights = sparse_dot_sdd_nt(query, key) attn_output_weights = sparse_softmax(attn_output_weights, scale=scaling, rpe=rpe, key_padding_mask=key_padding_mask, attn_mask=attn_mask, key_padding_mask_mode=self.key_padding_mask_mode, attn_mask_mode=self.attn_mask_mode) # outputs attn_output = sparse_dot_dsd_nn(attn_output_weights, value) return attn_output