softmaskedbert4csc.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """
  2. @Time : 2021-01-21 12:00:59
  3. @File : modeling_soft_masked_bert.py
  4. @Author : Abtion
  5. @Email : abtion{at}outlook.com
  6. """
  7. from abc import ABC
  8. from collections import OrderedDict
  9. import transformers as tfs
  10. import torch
  11. from torch import nn
  12. from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertOnlyMLMHead
  13. from transformers.modeling_utils import ModuleUtilsMixin
  14. from pycorrector.macbert.base_model import CscTrainingModel
  15. class DetectionNetwork(nn.Module):
  16. def __init__(self, config):
  17. super().__init__()
  18. self.config = config
  19. self.gru = nn.GRU(
  20. self.config.hidden_size,
  21. self.config.hidden_size // 2,
  22. num_layers=2,
  23. batch_first=True,
  24. dropout=self.config.hidden_dropout_prob,
  25. bidirectional=True,
  26. )
  27. self.sigmoid = nn.Sigmoid()
  28. self.linear = nn.Linear(self.config.hidden_size, 1)
  29. def forward(self, hidden_states):
  30. out, _ = self.gru(hidden_states)
  31. prob = self.linear(out)
  32. prob = self.sigmoid(prob)
  33. return prob
  34. class CorrectionNetwork(torch.nn.Module, ModuleUtilsMixin):
  35. def __init__(self, config, tokenizer, device):
  36. super().__init__()
  37. self.config = config
  38. self.tokenizer = tokenizer
  39. self.embeddings = BertEmbeddings(self.config)
  40. self.bert = BertEncoder(self.config)
  41. self.mask_token_id = self.tokenizer.mask_token_id
  42. self.cls = BertOnlyMLMHead(self.config)
  43. self._device = device
  44. def forward(self, texts, prob, embed=None, cor_labels=None, residual_connection=False):
  45. if cor_labels is not None:
  46. text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids']
  47. # torch的cross entropy loss 会忽略-100的label
  48. text_labels[text_labels == 0] = -100
  49. text_labels = text_labels.to(self._device)
  50. else:
  51. text_labels = None
  52. encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt')
  53. encoded_texts.to(self._device)
  54. if embed is None:
  55. embed = self.embeddings(input_ids=encoded_texts['input_ids'],
  56. token_type_ids=encoded_texts['token_type_ids'])
  57. # 此处较原文有一定改动,做此改动意在完整保留type_ids及position_ids的embedding。
  58. mask_embed = self.embeddings(torch.ones_like(prob.squeeze(-1)).long() * self.mask_token_id).detach()
  59. # 此处为原文实现
  60. # mask_embed = self.embeddings(torch.tensor([[self.mask_token_id]], device=self._device)).detach()
  61. cor_embed = prob * mask_embed + (1 - prob) * embed
  62. input_shape = encoded_texts['input_ids'].size()
  63. device = encoded_texts['input_ids'].device
  64. extended_attention_mask = self.get_extended_attention_mask(encoded_texts['attention_mask'],
  65. input_shape, device)
  66. head_mask = self.get_head_mask(None, self.config.num_hidden_layers)
  67. encoder_outputs = self.bert(
  68. cor_embed,
  69. attention_mask=extended_attention_mask,
  70. head_mask=head_mask,
  71. encoder_hidden_states=None,
  72. encoder_attention_mask=None,
  73. return_dict=False,
  74. )
  75. sequence_output = encoder_outputs[0]
  76. sequence_output = sequence_output + embed if residual_connection else sequence_output
  77. prediction_scores = self.cls(sequence_output)
  78. out = (prediction_scores, sequence_output)
  79. # Masked language modeling softmax layer
  80. if text_labels is not None:
  81. loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
  82. cor_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), text_labels.view(-1))
  83. out = (cor_loss,) + out
  84. return out
  85. def load_from_transformers_state_dict(self, gen_fp):
  86. state_dict = OrderedDict()
  87. gen_state_dict = tfs.AutoModelForMaskedLM.from_pretrained(gen_fp).state_dict()
  88. for k, v in gen_state_dict.items():
  89. name = k
  90. if name.startswith('bert'):
  91. name = name[5:]
  92. if name.startswith('encoder'):
  93. name = f'corrector.{name[8:]}'
  94. if 'gamma' in name:
  95. name = name.replace('gamma', 'weight')
  96. if 'beta' in name:
  97. name = name.replace('beta', 'bias')
  98. state_dict[name] = v
  99. self.load_state_dict(state_dict, strict=False)
  100. class SoftMaskedBert4Csc(CscTrainingModel, ABC):
  101. def __init__(self, cfg, tokenizer):
  102. super().__init__(cfg)
  103. self.cfg = cfg
  104. self.config = tfs.AutoConfig.from_pretrained(cfg.MODEL.BERT_CKPT)
  105. self.detector = DetectionNetwork(self.config)
  106. self.tokenizer = tokenizer
  107. self.corrector = CorrectionNetwork(self.config, tokenizer, cfg.MODEL.DEVICE)
  108. self.corrector.load_from_transformers_state_dict(self.cfg.MODEL.BERT_CKPT)
  109. self._device = cfg.MODEL.DEVICE
  110. def forward(self, texts, cor_labels=None, det_labels=None):
  111. encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt')
  112. encoded_texts.to(self._device)
  113. embed = self.corrector.embeddings(input_ids=encoded_texts['input_ids'],
  114. token_type_ids=encoded_texts['token_type_ids'])
  115. prob = self.detector(embed)
  116. cor_out = self.corrector(texts, prob, embed, cor_labels, residual_connection=True)
  117. if det_labels is not None:
  118. det_loss_fct = nn.BCELoss()
  119. # pad部分不计算损失
  120. active_loss = encoded_texts['attention_mask'].view(-1, prob.shape[1]) == 1
  121. active_probs = prob.view(-1, prob.shape[1])[active_loss]
  122. active_labels = det_labels[active_loss]
  123. det_loss = det_loss_fct(active_probs, active_labels.float())
  124. outputs = (det_loss, cor_out[0], prob.squeeze(-1)) + cor_out[1:]
  125. else:
  126. outputs = (prob.squeeze(-1),) + cor_out
  127. return outputs
  128. def load_from_transformers_state_dict(self, gen_fp):
  129. """
  130. 从transformers加载预训练权重
  131. :param gen_fp:
  132. :return:
  133. """
  134. self.corrector.load_from_transformers_state_dict(gen_fp)