# -*- coding: utf-8 -*- """ @author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) @description: """ from abc import ABC import torch.nn as nn from transformers import BertForMaskedLM from pycorrector.macbert.base_model import CscTrainingModel, FocalLoss class MacBert4Csc(CscTrainingModel, ABC): def __init__(self, cfg, tokenizer): super().__init__(cfg) self.cfg = cfg self.bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT) self.detection = nn.Linear(self.bert.config.hidden_size, 1) self.sigmoid = nn.Sigmoid() self.tokenizer = tokenizer def forward(self, texts, cor_labels=None, det_labels=None): if cor_labels: text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids'] text_labels[text_labels == 0] = -100 # -100计算损失时会忽略 text_labels = text_labels.to(self.device) else: text_labels = None encoded_text = self.tokenizer(texts, padding=True, return_tensors='pt') encoded_text.to(self.device) bert_outputs = self.bert(**encoded_text, labels=text_labels, return_dict=True, output_hidden_states=True) # 检错概率 prob = self.detection(bert_outputs.hidden_states[-1]) if text_labels is None: # 检错输出,纠错输出 outputs = (prob, bert_outputs.logits) else: det_loss_fct = FocalLoss(num_labels=None, activation_type='sigmoid') # pad部分不计算损失 active_loss = encoded_text['attention_mask'].view(-1, prob.shape[1]) == 1 active_probs = prob.view(-1, prob.shape[1])[active_loss] active_labels = det_labels[active_loss] det_loss = det_loss_fct(active_probs, active_labels.float()) # 检错loss,纠错loss,检错输出,纠错输出 outputs = (det_loss, bert_outputs.loss, self.sigmoid(prob).squeeze(-1), bert_outputs.logits) return outputs