123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- # -*- coding: utf-8 -*-
- """
- @author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
- @description:
- """
- import operator
- from abc import ABC
- from loguru import logger
- import torch
- import torch.nn as nn
- import numpy as np
- import pytorch_lightning as pl
- from pycorrector.macbert import lr_scheduler
- from pycorrector.macbert.evaluate_util import compute_corrector_prf, compute_sentence_level_prf
- class FocalLoss(nn.Module):
- """
- Softmax and sigmoid focal loss.
- copy from https://github.com/lonePatient/TorchBlocks
- """
- def __init__(self, num_labels, activation_type='softmax', gamma=2.0, alpha=0.25, epsilon=1.e-9):
- super(FocalLoss, self).__init__()
- self.num_labels = num_labels
- self.gamma = gamma
- self.alpha = alpha
- self.epsilon = epsilon
- self.activation_type = activation_type
- def forward(self, input, target):
- """
- Args:
- logits: model's output, shape of [batch_size, num_cls]
- target: ground truth labels, shape of [batch_size]
- Returns:
- shape of [batch_size]
- """
- if self.activation_type == 'softmax':
- idx = target.view(-1, 1).long()
- one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device)
- one_hot_key = one_hot_key.scatter_(1, idx, 1)
- logits = torch.softmax(input, dim=-1)
- loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
- loss = loss.sum(1)
- elif self.activation_type == 'sigmoid':
- multi_hot_key = target
- logits = torch.sigmoid(input)
- zero_hot_key = 1 - multi_hot_key
- loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
- loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log()
- return loss.mean()
- def make_optimizer(cfg, model):
- params = []
- for key, value in model.named_parameters():
- if not value.requires_grad:
- continue
- lr = cfg.SOLVER.BASE_LR
- weight_decay = cfg.SOLVER.WEIGHT_DECAY
- if "bias" in key:
- lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
- weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
- params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
- if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
- optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
- else:
- optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
- return optimizer
- def build_lr_scheduler(cfg, optimizer):
- scheduler_args = {
- "optimizer": optimizer,
- # warmup options
- "warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
- "warmup_epochs": cfg.SOLVER.WARMUP_EPOCHS,
- "warmup_method": cfg.SOLVER.WARMUP_METHOD,
- # multi-step lr scheduler options
- "milestones": cfg.SOLVER.STEPS,
- "gamma": cfg.SOLVER.GAMMA,
- # cosine annealing lr scheduler options
- "max_iters": cfg.SOLVER.MAX_ITER,
- "delay_iters": cfg.SOLVER.DELAY_ITERS,
- "eta_min_lr": cfg.SOLVER.ETA_MIN_LR,
- }
- scheduler = getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args)
- return {'scheduler': scheduler, 'interval': cfg.SOLVER.INTERVAL}
- class BaseTrainingEngine(pl.LightningModule):
- def __init__(self, cfg, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.cfg = cfg
- def configure_optimizers(self):
- optimizer = make_optimizer(self.cfg, self)
- scheduler = build_lr_scheduler(self.cfg, optimizer)
- return [optimizer], [scheduler]
- def on_validation_epoch_start(self) -> None:
- logger.info('Valid.')
- def on_test_epoch_start(self) -> None:
- logger.info('Testing...')
- class CscTrainingModel(BaseTrainingEngine, ABC):
- """
- 用于CSC的BaseModel, 定义了训练及预测步骤
- """
- def __init__(self, cfg, *args, **kwargs):
- super().__init__(cfg, *args, **kwargs)
- # loss weight
- self.w = cfg.MODEL.HYPER_PARAMS[0]
- def training_step(self, batch, batch_idx):
- ori_text, cor_text, det_labels = batch
- outputs = self.forward(ori_text, cor_text, det_labels)
- loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
- self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(ori_text))
- return loss
- def validation_step(self, batch, batch_idx):
- ori_text, cor_text, det_labels = batch
- outputs = self.forward(ori_text, cor_text, det_labels)
- loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
- det_y_hat = (outputs[2] > 0.5).long()
- cor_y_hat = torch.argmax((outputs[3]), dim=-1)
- encoded_x = self.tokenizer(cor_text, padding=True, return_tensors='pt')
- encoded_x.to(self._device)
- cor_y = encoded_x['input_ids']
- cor_y_hat *= encoded_x['attention_mask']
- results = []
- det_acc_labels = []
- cor_acc_labels = []
- for src, tgt, predict, det_predict, det_label in zip(ori_text, cor_y, cor_y_hat, det_y_hat, det_labels):
- _src = self.tokenizer(src, add_special_tokens=False)['input_ids']
- _tgt = tgt[1:len(_src) + 1].cpu().numpy().tolist()
- _predict = predict[1:len(_src) + 1].cpu().numpy().tolist()
- cor_acc_labels.append(1 if operator.eq(_tgt, _predict) else 0)
- det_acc_labels.append(det_predict[1:len(_src) + 1].equal(det_label[1:len(_src) + 1]))
- results.append((_src, _tgt, _predict,))
- return loss.cpu().item(), det_acc_labels, cor_acc_labels, results
- def validation_epoch_end(self, outputs) -> None:
- det_acc_labels = []
- cor_acc_labels = []
- results = []
- for out in outputs:
- det_acc_labels += out[1]
- cor_acc_labels += out[2]
- results += out[3]
- loss = np.mean([out[0] for out in outputs])
- self.log('val_loss', loss)
- logger.info(f'loss: {loss}')
- logger.info(f'Detection: acc: {np.mean(det_acc_labels):.4f}')
- logger.info(f'Correction: acc: {np.mean(cor_acc_labels):.4f}')
- compute_corrector_prf(results, logger)
- compute_sentence_level_prf(results, logger)
- def test_step(self, batch, batch_idx):
- return self.validation_step(batch, batch_idx)
- def test_epoch_end(self, outputs) -> None:
- logger.info('Test.')
- self.validation_epoch_end(outputs)
- def predict(self, texts):
- inputs = self.tokenizer(texts, padding=True, return_tensors='pt')
- inputs.to(self.cfg.MODEL.DEVICE)
- with torch.no_grad():
- outputs = self.forward(texts)
- y_hat = torch.argmax(outputs[1], dim=-1)
- expand_text_lens = torch.sum(inputs['attention_mask'], dim=-1) - 1
- rst = []
- for t_len, _y_hat in zip(expand_text_lens, y_hat):
- rst.append(self.tokenizer.decode(_y_hat[1:t_len]).replace(' ', ''))
- return rst
|