sparse_attention_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from torch.nn import functional as F
  6. from deepspeed.ops.sparse_attention import BertSparseSelfAttention, SparsityConfig
  7. '''
  8. This file contains few utility functions to handle adapting pretrained model with sparse self-attention module.
  9. '''
  10. class SparseAttentionUtils:
  11. """This class provides some utility functions that are use integrating sparse attention into transformer models.
  12. Such utilities include extending position embeddings, replacing current self-attention layer with sparse attention, padding sequences to multiple of block size, etc.
  13. """
  14. @staticmethod
  15. def extend_position_embedding(model, max_position):
  16. """This function extends the position embedding weights of a model loaded from a checkpoint.
  17. It assumes the new max position is bigger than the original max length.
  18. Arguments:
  19. model: required: a transformer model
  20. max_position: required: an integer determining new position embedding size
  21. Return:
  22. model: updated model; in which position embedding weights have been extended based on new size
  23. """
  24. if hasattr(model, 'bert'):
  25. original_max_position = model.bert.embeddings.position_embeddings.weight.size(0)
  26. assert max_position > original_max_position
  27. extend_multiples = max(1, max_position // original_max_position)
  28. model.bert.embeddings.position_embeddings.weight.data = model.bert.embeddings.position_embeddings.weight.repeat(
  29. extend_multiples, 1)
  30. elif hasattr(model, 'roberta'):
  31. # RoBERTa has positions 0 & 1 reserved, so embedding size is max position + 2
  32. original_max_position, embed_size = model.roberta.embeddings.position_embeddings.weight.shape
  33. original_max_position -= 2
  34. extend_multiples = max(1, max_position // original_max_position)
  35. assert max_position > original_max_position
  36. max_position += 2
  37. extended_position_embedding = model.roberta.embeddings.position_embeddings.weight.new_empty(
  38. max_position, embed_size)
  39. k = 2
  40. for i in range(extend_multiples):
  41. extended_position_embedding[k:(
  42. k + original_max_position)] = model.roberta.embeddings.position_embeddings.weight[2:]
  43. k += original_max_position
  44. model.roberta.embeddings.position_embeddings.weight.data = extended_position_embedding
  45. else:
  46. raise ValueError(
  47. 'Please extend \"extend_position_embedding\" function to support your model type. It currently only supports \"bert\" & \"roberta\"!'
  48. )
  49. model.config.max_position_embeddings = max_position
  50. print(f'Extended position embeddings to {original_max_position * extend_multiples}')
  51. return model
  52. @staticmethod
  53. def update_tokenizer_model_max_length(tokenizer, max_position):
  54. """This function updates the position embedding length of a tokenizer to a new max position.
  55. Arguments:
  56. tokenizer: required: a transformer tokenizer
  57. max_position: required: an integer determining new position embedding size
  58. Return:
  59. tokenizer: updated tokenizer; in which model maximum length has been extended based on new size
  60. """
  61. tokenizer.model_max_length = max_position
  62. tokenizer.init_kwargs['model_max_length'] = max_position
  63. print(f'updated tokenizer model max imum length to {max_position}')
  64. return tokenizer
  65. @staticmethod
  66. def replace_model_self_attention_with_sparse_self_attention(
  67. model,
  68. max_position,
  69. # SparsityConfig parameters needs to be set accordingly
  70. sparsity_config=SparsityConfig(num_heads=4)):
  71. """This function replaces the self attention layers in model encoder with sparse self attention.
  72. It currently supports bert and roberta model and can be easily extended to any other models following similar steps here.
  73. For sparsityConfig, refer to the config class.
  74. Arguments:
  75. model: required: a transformer model
  76. max_position: required: an integer determining new position embedding size
  77. sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on SparsityConfig class
  78. Return:
  79. model: updated model; in which self attention layer has been replaced with DeepSpeed Sparse Self Attention layer.
  80. """
  81. if hasattr(model, 'bert'):
  82. model.config.max_position_embeddings = max_position
  83. model.replace_self_attention_layer_with_sparse_self_attention_layer(model.config, model.bert.encoder.layer,
  84. sparsity_config)
  85. elif hasattr(model, 'roberta'):
  86. model.config.max_position_embeddings = max_position + 2
  87. model.replace_self_attention_layer_with_sparse_self_attention_layer(model.config,
  88. model.roberta.encoder.layer,
  89. sparsity_config)
  90. else:
  91. raise ValueError(
  92. 'Please extend \"update_model_self_attention_to_sparse_self_attention\" function to support \
  93. your model type. It currently only supports \"bert\" & \"roberta\"!')
  94. return model
  95. @staticmethod
  96. def replace_self_attention_layer_with_sparse_self_attention_layer(
  97. config,
  98. layers,
  99. # SparsityConfig parameters needs to be set accordingly
  100. sparsity_config=SparsityConfig(num_heads=4)):
  101. """This function replaces the self attention layers in attention layer with sparse self attention.
  102. For sparsityConfig, refer to the config class.
  103. Arguments:
  104. config: required: transformer model config
  105. layers: required: transformer model attention layers
  106. sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on SparsityConfig class
  107. Return:
  108. layers: updated attention layers; in which self attention layers have been replaced with DeepSpeed Sparse Self Attention layer.
  109. """
  110. for layer in layers:
  111. deepspeed_sparse_self_attn = BertSparseSelfAttention(config, sparsity_config)
  112. deepspeed_sparse_self_attn.query = layer.attention.self.query
  113. deepspeed_sparse_self_attn.key = layer.attention.self.key
  114. deepspeed_sparse_self_attn.value = layer.attention.self.value
  115. layer.attention.self = deepspeed_sparse_self_attn
  116. return layers
  117. @staticmethod
  118. def pad_to_block_size(block_size, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds,
  119. pad_token_id, model_embeddings):
  120. """This function pads input tokens and attention mask on sequence length dimension to be multiple of block size.
  121. This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size.
  122. It needs to be called in your model, such as BertModel, right before you calculate the embedding outputs.
  123. Note)
  124. 1- instead of passing your embedding layer to this function, you can simply add this function to your model. It can be more simplified if given attention_mask and/or token_type_ids are none.
  125. 2- you need to call unpad function before returning your model output to unpad the encoder sequence output.
  126. Arguments:
  127. block_size: required: an integer determining the block size of sparsity config.
  128. pad_token_id: required: an integer determining the pad token from the model config; such as bert.config.pad_token_id.
  129. input_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary
  130. attention_mask: a torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
  131. token_type_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
  132. position_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the indices of positions of each input sequence tokens in the position embeddings.
  133. inputs_embeds: an optional torch.FloatTensor of shape [batch_size, sequence_length, hidden_size] that contains embedded representation and can be passed instead of input_ids directly.
  134. model_embeddings: an optional object. If inputs_embeds are not none, this will be your model embeddings such as BertEmbeddings from your model such as BertModel. You can move this function inside your model and use self.embeddings instead of passing this parameter.
  135. Return:
  136. pad_len: an integer determining how much inputs have been padded to transfer sequence length dimension to multiple of block size.
  137. input_ids: if input_ids are not none padded input_ids otherwise none.
  138. attention_mask: if attention_mask is not none padded attention_mask otherwise none.
  139. token_type_ids: if token_type_ids are not none padded token_type_ids otherwise none.
  140. position_ids: if position_ids are not none padded position_ids otherwise none.
  141. inputs_embeds: if inputs_embeds are not none padded inputs_embeds otherwise none.
  142. """
  143. batch_size, seq_len = input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
  144. pad_len = (block_size - seq_len % block_size) % block_size
  145. if pad_len > 0:
  146. if inputs_embeds is not None:
  147. pad_input_ids = inputs_embeds.new_full((batch_size, pad_len), pad_token_id, dtype=torch.long)
  148. pad_inputs_embeds = model_embeddings(pad_input_ids)
  149. inputs_embeds = torch.cat([inputs_embeds, pad_inputs_embeds], dim=-2)
  150. # may not be needed as input_ids are not used if inputs_embeds are given
  151. if input_ids is not None:
  152. input_ids = F.pad(input_ids, (0, pad_len), value=pad_token_id)
  153. if position_ids is not None:
  154. # pad position_id with pad_token_id
  155. position_ids = F.pad(position_ids, (0, pad_len), value=pad_token_id)
  156. # pad attention mask without attention on the padding tokens
  157. attention_mask = F.pad(attention_mask, (0, pad_len), value=False)
  158. # pad token_type_ids with token_type_id = 0
  159. token_type_ids = F.pad(token_type_ids, (0, pad_len), value=0)
  160. return pad_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
  161. @staticmethod
  162. def unpad_sequence_output(pad_len, sequence_output):
  163. """This function unpads sequence output if inputs of the model were padded.
  164. This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size.
  165. It needs to be called in your model, such as BertModel, right before you return the model outputs.
  166. Arguments:
  167. pad_len: required: an integer determining how much model inputs have been padded to transfer sequence length dimension to multiple of block size.
  168. sequence_output: required: sequence output of the encoder layer.
  169. Return:
  170. sequence_output: unpaded sequence output of the encoder layer.
  171. """
  172. if (pad_len > 0):
  173. sequence_output = sequence_output[:, :-pad_len]
  174. return sequence_output