sparse_attention_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. import torch
  5. from torch.nn import functional as F
  6. from deepspeed.ops.sparse_attention import BertSparseSelfAttention, SparsityConfig
  7. '''
  8. This file contains few utility functions to handle adapting pretrained model with sparse self-attention module.
  9. '''
  10. class SparseAttentionUtils:
  11. """This class provides some utility functions that are use integrating sparse attention into transformer models.
  12. Such utilities include extending position embeddings, replacing current self-attention layer with sparse attention, padding sequences to multiple of block size, etc.
  13. """
  14. @staticmethod
  15. def extend_position_embedding(model, max_position):
  16. """This function extends the position embedding weights of a model loaded from a checkpoint.
  17. It assumes the new max position is bigger than the original max length.
  18. Arguments:
  19. model: required: a transformer model
  20. max_position: required: an integer determining new position embedding size
  21. Return:
  22. model: updated model; in which position embedding weights have been extended based on new size
  23. """
  24. if hasattr(model, 'bert'):
  25. original_max_position = model.bert.embeddings.position_embeddings.weight.size(
  26. 0)
  27. assert max_position > original_max_position
  28. extend_multiples = max(1, max_position // original_max_position)
  29. model.bert.embeddings.position_embeddings.weight.data = model.bert.embeddings.position_embeddings.weight.repeat(
  30. extend_multiples,
  31. 1)
  32. elif hasattr(model, 'roberta'):
  33. # RoBERTa has positions 0 & 1 reserved, so embedding size is max position + 2
  34. original_max_position, embed_size = model.roberta.embeddings.position_embeddings.weight.shape
  35. original_max_position -= 2
  36. extend_multiples = max(1, max_position // original_max_position)
  37. assert max_position > original_max_position
  38. max_position += 2
  39. extended_position_embedding = model.roberta.embeddings.position_embeddings.weight.new_empty(
  40. max_position,
  41. embed_size)
  42. k = 2
  43. for i in range(extend_multiples):
  44. extended_position_embedding[k:(
  45. k + original_max_position
  46. )] = model.roberta.embeddings.position_embeddings.weight[2:]
  47. k += original_max_position
  48. model.roberta.embeddings.position_embeddings.weight.data = extended_position_embedding
  49. else:
  50. raise ValueError(
  51. 'Please extend \"extend_position_embedding\" function to support your model type. It currently only supports \"bert\" & \"roberta\"!'
  52. )
  53. model.config.max_position_embeddings = max_position
  54. print(
  55. f'Extended position embeddings to {original_max_position * extend_multiples}'
  56. )
  57. return model
  58. @staticmethod
  59. def update_tokenizer_model_max_length(tokenizer, max_position):
  60. """This function updates the position embedding length of a tokenizer to a new max position.
  61. Arguments:
  62. tokenizer: required: a transformer tokenizer
  63. max_position: required: an integer determining new position embedding size
  64. Return:
  65. tokenizer: updated tokenizer; in which model maximum length has been extended based on new size
  66. """
  67. tokenizer.model_max_length = max_position
  68. tokenizer.init_kwargs['model_max_length'] = max_position
  69. print(f'updated tokenizer model max imum length to {max_position}')
  70. return tokenizer
  71. @staticmethod
  72. def replace_model_self_attention_with_sparse_self_attention(
  73. model,
  74. max_position,
  75. # SparsityConfig parameters needs to be set accordingly
  76. sparsity_config=SparsityConfig(num_heads=4)):
  77. """This function replaces the self attention layers in model encoder with sparse self attention.
  78. It currently supports bert and roberta model and can be easily extended to any other models following similar steps here.
  79. For sparsityConfig, refer to the config class.
  80. Arguments:
  81. model: required: a transformer model
  82. max_position: required: an integer determining new position embedding size
  83. sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on SparsityConfig class
  84. Return:
  85. model: updated model; in which self attention layer has been replaced with DeepSpeed Sparse Self Attention layer.
  86. """
  87. if hasattr(model, 'bert'):
  88. model.config.max_position_embeddings = max_position
  89. model.replace_self_attention_layer_with_sparse_self_attention_layer(
  90. model.config,
  91. model.bert.encoder.layer,
  92. sparsity_config)
  93. elif hasattr(model, 'roberta'):
  94. model.config.max_position_embeddings = max_position + 2
  95. model.replace_self_attention_layer_with_sparse_self_attention_layer(
  96. model.config,
  97. model.roberta.encoder.layer,
  98. sparsity_config)
  99. else:
  100. raise ValueError(
  101. 'Please extend \"update_model_self_attention_to_sparse_self_attention\" function to support \
  102. your model type. It currently only supports \"bert\" & \"roberta\"!'
  103. )
  104. return model
  105. @staticmethod
  106. def replace_self_attention_layer_with_sparse_self_attention_layer(
  107. config,
  108. layers,
  109. # SparsityConfig parameters needs to be set accordingly
  110. sparsity_config=SparsityConfig(num_heads=4)):
  111. """This function replaces the self attention layers in attention layer with sparse self attention.
  112. For sparsityConfig, refer to the config class.
  113. Arguments:
  114. config: required: transformer model config
  115. layers: required: transformer model attention layers
  116. sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on SparsityConfig class
  117. Return:
  118. layers: updated attention layers; in which self attention layers have been replaced with DeepSpeed Sparse Self Attention layer.
  119. """
  120. for layer in layers:
  121. deepspeed_sparse_self_attn = BertSparseSelfAttention(config, sparsity_config)
  122. deepspeed_sparse_self_attn.query = layer.attention.self.query
  123. deepspeed_sparse_self_attn.key = layer.attention.self.key
  124. deepspeed_sparse_self_attn.value = layer.attention.self.value
  125. layer.attention.self = deepspeed_sparse_self_attn
  126. return layers
  127. @staticmethod
  128. def pad_to_block_size(block_size,
  129. input_ids,
  130. attention_mask,
  131. token_type_ids,
  132. position_ids,
  133. inputs_embeds,
  134. pad_token_id,
  135. model_embeddings):
  136. """This function pads input tokens and attention mask on sequence length dimension to be multiple of block size.
  137. This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size.
  138. It needs to be called in your model, such as BertModel, right before you calculate the embedding outputs.
  139. Note)
  140. 1- instead of passing your embedding layer to this function, you can simply add this function to your model. It can be more simplified if given attention_mask and/or token_type_ids are none.
  141. 2- you need to call unpad function before returning your model output to unpad the encoder sequence output.
  142. Arguments:
  143. block_size: required: an integer determining the block size of sparsity config.
  144. pad_token_id: required: an integer determining the pad token from the model config; such as bert.config.pad_token_id.
  145. input_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary
  146. attention_mask: a torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
  147. token_type_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
  148. position_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the indices of positions of each input sequence tokens in the position embeddings.
  149. inputs_embeds: an optional torch.FloatTensor of shape [batch_size, sequence_length, hidden_size] that contains embedded representation and can be passed instead of input_ids directly.
  150. model_embeddings: an optional object. If inputs_embeds are not none, this will be your model embeddings such as BertEmbeddings from your model such as BertModel. You can move this function inside your model and use self.embeddings instead of passing this parameter.
  151. Return:
  152. pad_len: an integer determining how much inputs have been padded to transfer sequence length dimension to multiple of block size.
  153. input_ids: if input_ids are not none padded input_ids otherwise none.
  154. attention_mask: if attention_mask is not none padded attention_mask otherwise none.
  155. token_type_ids: if token_type_ids are not none padded token_type_ids otherwise none.
  156. position_ids: if position_ids are not none padded position_ids otherwise none.
  157. inputs_embeds: if inputs_embeds are not none padded inputs_embeds otherwise none.
  158. """
  159. batch_size, seq_len = input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
  160. pad_len = (block_size - seq_len % block_size) % block_size
  161. if pad_len > 0:
  162. if inputs_embeds is not None:
  163. pad_input_ids = inputs_embeds.new_full((batch_size,
  164. pad_len),
  165. pad_token_id,
  166. dtype=torch.long)
  167. pad_inputs_embeds = model_embeddings(pad_input_ids)
  168. inputs_embeds = torch.cat([inputs_embeds, pad_inputs_embeds], dim=-2)
  169. # may not be needed as input_ids are not used if inputs_embeds are given
  170. if input_ids is not None:
  171. input_ids = F.pad(input_ids, (0, pad_len), value=pad_token_id)
  172. if position_ids is not None:
  173. # pad position_id with pad_token_id
  174. position_ids = F.pad(position_ids, (0, pad_len), value=pad_token_id)
  175. # pad attention mask without attention on the padding tokens
  176. attention_mask = F.pad(attention_mask, (0, pad_len), value=False)
  177. # pad token_type_ids with token_type_id = 0
  178. token_type_ids = F.pad(token_type_ids, (0, pad_len), value=0)
  179. return pad_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
  180. @staticmethod
  181. def unpad_sequence_output(pad_len, sequence_output):
  182. """This function unpads sequence output if inputs of the model were padded.
  183. This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size.
  184. It needs to be called in your model, such as BertModel, right before you return the model outputs.
  185. Arguments:
  186. pad_len: required: an integer determining how much model inputs have been padded to transfer sequence length dimension to multiple of block size.
  187. sequence_output: required: sequence output of the encoder layer.
  188. Return:
  189. sequence_output: unpaded sequence output of the encoder layer.
  190. """
  191. if (pad_len > 0):
  192. sequence_output = sequence_output[:, :-pad_len]
  193. return sequence_output