123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- """
- @Time : 2021-01-21 12:00:59
- @File : modeling_soft_masked_bert.py
- @Author : Abtion
- @Email : abtion{at}outlook.com
- """
- from abc import ABC
- from collections import OrderedDict
- import transformers as tfs
- import torch
- from torch import nn
- from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertOnlyMLMHead
- from transformers.modeling_utils import ModuleUtilsMixin
- from pycorrector.macbert.base_model import CscTrainingModel
- class DetectionNetwork(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.gru = nn.GRU(
- self.config.hidden_size,
- self.config.hidden_size // 2,
- num_layers=2,
- batch_first=True,
- dropout=self.config.hidden_dropout_prob,
- bidirectional=True,
- )
- self.sigmoid = nn.Sigmoid()
- self.linear = nn.Linear(self.config.hidden_size, 1)
- def forward(self, hidden_states):
- out, _ = self.gru(hidden_states)
- prob = self.linear(out)
- prob = self.sigmoid(prob)
- return prob
- class CorrectionNetwork(torch.nn.Module, ModuleUtilsMixin):
- def __init__(self, config, tokenizer, device):
- super().__init__()
- self.config = config
- self.tokenizer = tokenizer
- self.embeddings = BertEmbeddings(self.config)
- self.bert = BertEncoder(self.config)
- self.mask_token_id = self.tokenizer.mask_token_id
- self.cls = BertOnlyMLMHead(self.config)
- self._device = device
- def forward(self, texts, prob, embed=None, cor_labels=None, residual_connection=False):
- if cor_labels is not None:
- text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids']
- # torch的cross entropy loss 会忽略-100的label
- text_labels[text_labels == 0] = -100
- text_labels = text_labels.to(self._device)
- else:
- text_labels = None
- encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt')
- encoded_texts.to(self._device)
- if embed is None:
- embed = self.embeddings(input_ids=encoded_texts['input_ids'],
- token_type_ids=encoded_texts['token_type_ids'])
- # 此处较原文有一定改动,做此改动意在完整保留type_ids及position_ids的embedding。
- mask_embed = self.embeddings(torch.ones_like(prob.squeeze(-1)).long() * self.mask_token_id).detach()
- # 此处为原文实现
- # mask_embed = self.embeddings(torch.tensor([[self.mask_token_id]], device=self._device)).detach()
- cor_embed = prob * mask_embed + (1 - prob) * embed
- input_shape = encoded_texts['input_ids'].size()
- device = encoded_texts['input_ids'].device
- extended_attention_mask = self.get_extended_attention_mask(encoded_texts['attention_mask'],
- input_shape, device)
- head_mask = self.get_head_mask(None, self.config.num_hidden_layers)
- encoder_outputs = self.bert(
- cor_embed,
- attention_mask=extended_attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- return_dict=False,
- )
- sequence_output = encoder_outputs[0]
- sequence_output = sequence_output + embed if residual_connection else sequence_output
- prediction_scores = self.cls(sequence_output)
- out = (prediction_scores, sequence_output)
- # Masked language modeling softmax layer
- if text_labels is not None:
- loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
- cor_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), text_labels.view(-1))
- out = (cor_loss,) + out
- return out
- def load_from_transformers_state_dict(self, gen_fp):
- state_dict = OrderedDict()
- gen_state_dict = tfs.AutoModelForMaskedLM.from_pretrained(gen_fp).state_dict()
- for k, v in gen_state_dict.items():
- name = k
- if name.startswith('bert'):
- name = name[5:]
- if name.startswith('encoder'):
- name = f'corrector.{name[8:]}'
- if 'gamma' in name:
- name = name.replace('gamma', 'weight')
- if 'beta' in name:
- name = name.replace('beta', 'bias')
- state_dict[name] = v
- self.load_state_dict(state_dict, strict=False)
- class SoftMaskedBert4Csc(CscTrainingModel, ABC):
- def __init__(self, cfg, tokenizer):
- super().__init__(cfg)
- self.cfg = cfg
- self.config = tfs.AutoConfig.from_pretrained(cfg.MODEL.BERT_CKPT)
- self.detector = DetectionNetwork(self.config)
- self.tokenizer = tokenizer
- self.corrector = CorrectionNetwork(self.config, tokenizer, cfg.MODEL.DEVICE)
- self.corrector.load_from_transformers_state_dict(self.cfg.MODEL.BERT_CKPT)
- self._device = cfg.MODEL.DEVICE
- def forward(self, texts, cor_labels=None, det_labels=None):
- encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt')
- encoded_texts.to(self._device)
- embed = self.corrector.embeddings(input_ids=encoded_texts['input_ids'],
- token_type_ids=encoded_texts['token_type_ids'])
- prob = self.detector(embed)
- cor_out = self.corrector(texts, prob, embed, cor_labels, residual_connection=True)
- if det_labels is not None:
- det_loss_fct = nn.BCELoss()
- # pad部分不计算损失
- active_loss = encoded_texts['attention_mask'].view(-1, prob.shape[1]) == 1
- active_probs = prob.view(-1, prob.shape[1])[active_loss]
- active_labels = det_labels[active_loss]
- det_loss = det_loss_fct(active_probs, active_labels.float())
- outputs = (det_loss, cor_out[0], prob.squeeze(-1)) + cor_out[1:]
- else:
- outputs = (prob.squeeze(-1),) + cor_out
- return outputs
- def load_from_transformers_state_dict(self, gen_fp):
- """
- 从transformers加载预训练权重
- :param gen_fp:
- :return:
- """
- self.corrector.load_from_transformers_state_dict(gen_fp)
|