sparse_self_attention.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch.nn as nn
  5. import torch
  6. from torch import distributed as dist
  7. from deepspeed.ops.sparse_attention import SparsityConfig
  8. class SparseSelfAttention(nn.Module):
  9. """Implements an efficient Sparse Self Attention of Transformer layer based on `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
  10. For more information please see, TODO DeepSpeed Sparse Transformer.
  11. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
  12. """
  13. def __init__(
  14. self,
  15. # SparsityConfig parameters needs to be set accordingly
  16. sparsity_config=SparsityConfig(num_heads=4),
  17. key_padding_mask_mode='add',
  18. attn_mask_mode='mul',
  19. max_seq_length=2048):
  20. """Initialize the sparse self attention layer.
  21. Arguments:
  22. sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on SparsityConfig class.
  23. key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`.
  24. attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`.
  25. max_seq_length: optional: the maximum sequence length this sparse attention module will be applied to; it controls the size of the master_layout.
  26. """
  27. super().__init__()
  28. # sparsity information
  29. self.sparsity_config = sparsity_config
  30. # initialize sparse layout and register as buffer
  31. master_layout = self.sparsity_config.make_layout(max_seq_length)
  32. self.register_buffer("master_layout", master_layout)
  33. self._need_layout_synchronization = True
  34. # mask modes
  35. self.key_padding_mask_mode = key_padding_mask_mode
  36. self.attn_mask_mode = attn_mask_mode
  37. ops = dict()
  38. def get_layout(self, L):
  39. # if layout is never synchronized across GPUs, broadcast the layout from global rank 0
  40. if self._need_layout_synchronization and dist.is_initialized():
  41. dist.broadcast(self.master_layout, src=0)
  42. self._need_layout_synchronization = False
  43. if (L % self.sparsity_config.block != 0):
  44. raise ValueError(
  45. f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!')
  46. num_blocks = L // self.sparsity_config.block
  47. return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor
  48. # add to cache
  49. def get_ops(self, H, L):
  50. from deepspeed.ops.sparse_attention.matmul import MatMul
  51. from deepspeed.ops.sparse_attention.softmax import Softmax
  52. if L not in SparseSelfAttention.ops:
  53. sparsity_layout = self.get_layout(L)
  54. sparse_dot_sdd_nt = MatMul(sparsity_layout, self.sparsity_config.block, 'sdd', trans_a=False, trans_b=True)
  55. sparse_dot_dsd_nn = MatMul(sparsity_layout,
  56. self.sparsity_config.block,
  57. 'dsd',
  58. trans_a=False,
  59. trans_b=False)
  60. sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)
  61. SparseSelfAttention.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax)
  62. return SparseSelfAttention.ops[L]
  63. def transpose_key_for_scores(self, x, L):
  64. bsz, num_heads, seq_len, head_dim = x.size()
  65. if seq_len != L:
  66. return x.permute(0, 1, 3, 2)
  67. return x
  68. def transpose_mask_for_sparse(self, qtype, x, is_key_padding_mask=False):
  69. x = x.type(qtype)
  70. if is_key_padding_mask:
  71. xdim = x.dim()
  72. for d in range(xdim - 1, 0, -1):
  73. x = x.squeeze(dim=d)
  74. return x
  75. return x.squeeze()
  76. # forward pass
  77. def forward(self, query, key, value, rpe=None, key_padding_mask=None, attn_mask=None):
  78. """Applies forward phase of sparse self attention
  79. Arguments:
  80. query: required: query tensor
  81. key: required: key tensor
  82. value: required: value tensor
  83. rpe: optional: a tensor same dimension as x that is used as relative position embedding
  84. key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)
  85. attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
  86. key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied
  87. attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied
  88. Return:
  89. attn_output: a dense tensor containing attention context
  90. """
  91. assert query.dtype == torch.half, "sparse attention only supports training in fp16 currently, please file a github issue if you need fp32 support"
  92. bsz, num_heads, tgt_len, head_dim = query.size()
  93. # transpose back key if it is already transposed
  94. key = self.transpose_key_for_scores(key, tgt_len)
  95. # check that operation is supported
  96. if query.shape != key.shape or key.shape != value.shape:
  97. raise NotImplementedError('only self-attention is supported for now')
  98. # squeeze key_padding_mask if it is given
  99. if key_padding_mask is not None:
  100. key_padding_mask = self.transpose_mask_for_sparse(query.dtype, key_padding_mask, is_key_padding_mask=True)
  101. # squeeze attn_mask if it is given
  102. if attn_mask is not None:
  103. attn_mask = self.transpose_mask_for_sparse(query.dtype, attn_mask)
  104. # cache look-up table computations etc
  105. sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops(num_heads, tgt_len)
  106. scaling = float(head_dim)**-0.5
  107. # attention scores
  108. attn_output_weights = sparse_dot_sdd_nt(query, key)
  109. attn_output_weights = sparse_softmax(attn_output_weights,
  110. scale=scaling,
  111. rpe=rpe,
  112. key_padding_mask=key_padding_mask,
  113. attn_mask=attn_mask,
  114. key_padding_mask_mode=self.key_padding_mask_mode,
  115. attn_mask_mode=self.attn_mask_mode)
  116. # outputs
  117. attn_output = sparse_dot_dsd_nn(attn_output_weights, value)
  118. return attn_output