macbert4csc.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
  4. @description:
  5. """
  6. from abc import ABC
  7. import torch.nn as nn
  8. from transformers import BertForMaskedLM
  9. from pycorrector.macbert.base_model import CscTrainingModel, FocalLoss
  10. class MacBert4Csc(CscTrainingModel, ABC):
  11. def __init__(self, cfg, tokenizer):
  12. super().__init__(cfg)
  13. self.cfg = cfg
  14. self.bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT)
  15. self.detection = nn.Linear(self.bert.config.hidden_size, 1)
  16. self.sigmoid = nn.Sigmoid()
  17. self.tokenizer = tokenizer
  18. def forward(self, texts, cor_labels=None, det_labels=None):
  19. if cor_labels:
  20. text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids']
  21. text_labels[text_labels == 0] = -100 # -100计算损失时会忽略
  22. text_labels = text_labels.to(self.device)
  23. else:
  24. text_labels = None
  25. encoded_text = self.tokenizer(texts, padding=True, return_tensors='pt')
  26. encoded_text.to(self.device)
  27. bert_outputs = self.bert(**encoded_text, labels=text_labels, return_dict=True, output_hidden_states=True)
  28. # 检错概率
  29. prob = self.detection(bert_outputs.hidden_states[-1])
  30. if text_labels is None:
  31. # 检错输出,纠错输出
  32. outputs = (prob, bert_outputs.logits)
  33. else:
  34. det_loss_fct = FocalLoss(num_labels=None, activation_type='sigmoid')
  35. # pad部分不计算损失
  36. active_loss = encoded_text['attention_mask'].view(-1, prob.shape[1]) == 1
  37. active_probs = prob.view(-1, prob.shape[1])[active_loss]
  38. active_labels = det_labels[active_loss]
  39. det_loss = det_loss_fct(active_probs, active_labels.float())
  40. # 检错loss,纠错loss,检错输出,纠错输出
  41. outputs = (det_loss,
  42. bert_outputs.loss,
  43. self.sigmoid(prob).squeeze(-1),
  44. bert_outputs.logits)
  45. return outputs