base_model.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
  4. @description:
  5. """
  6. import operator
  7. from abc import ABC
  8. from loguru import logger
  9. import torch
  10. import torch.nn as nn
  11. import numpy as np
  12. import pytorch_lightning as pl
  13. from pycorrector.macbert import lr_scheduler
  14. from pycorrector.macbert.evaluate_util import compute_corrector_prf, compute_sentence_level_prf
  15. class FocalLoss(nn.Module):
  16. """
  17. Softmax and sigmoid focal loss.
  18. copy from https://github.com/lonePatient/TorchBlocks
  19. """
  20. def __init__(self, num_labels, activation_type='softmax', gamma=2.0, alpha=0.25, epsilon=1.e-9):
  21. super(FocalLoss, self).__init__()
  22. self.num_labels = num_labels
  23. self.gamma = gamma
  24. self.alpha = alpha
  25. self.epsilon = epsilon
  26. self.activation_type = activation_type
  27. def forward(self, input, target):
  28. """
  29. Args:
  30. logits: model's output, shape of [batch_size, num_cls]
  31. target: ground truth labels, shape of [batch_size]
  32. Returns:
  33. shape of [batch_size]
  34. """
  35. if self.activation_type == 'softmax':
  36. idx = target.view(-1, 1).long()
  37. one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device)
  38. one_hot_key = one_hot_key.scatter_(1, idx, 1)
  39. logits = torch.softmax(input, dim=-1)
  40. loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
  41. loss = loss.sum(1)
  42. elif self.activation_type == 'sigmoid':
  43. multi_hot_key = target
  44. logits = torch.sigmoid(input)
  45. zero_hot_key = 1 - multi_hot_key
  46. loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
  47. loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log()
  48. return loss.mean()
  49. def make_optimizer(cfg, model):
  50. params = []
  51. for key, value in model.named_parameters():
  52. if not value.requires_grad:
  53. continue
  54. lr = cfg.SOLVER.BASE_LR
  55. weight_decay = cfg.SOLVER.WEIGHT_DECAY
  56. if "bias" in key:
  57. lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
  58. weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
  59. params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
  60. if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
  61. optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
  62. else:
  63. optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
  64. return optimizer
  65. def build_lr_scheduler(cfg, optimizer):
  66. scheduler_args = {
  67. "optimizer": optimizer,
  68. # warmup options
  69. "warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
  70. "warmup_epochs": cfg.SOLVER.WARMUP_EPOCHS,
  71. "warmup_method": cfg.SOLVER.WARMUP_METHOD,
  72. # multi-step lr scheduler options
  73. "milestones": cfg.SOLVER.STEPS,
  74. "gamma": cfg.SOLVER.GAMMA,
  75. # cosine annealing lr scheduler options
  76. "max_iters": cfg.SOLVER.MAX_ITER,
  77. "delay_iters": cfg.SOLVER.DELAY_ITERS,
  78. "eta_min_lr": cfg.SOLVER.ETA_MIN_LR,
  79. }
  80. scheduler = getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args)
  81. return {'scheduler': scheduler, 'interval': cfg.SOLVER.INTERVAL}
  82. class BaseTrainingEngine(pl.LightningModule):
  83. def __init__(self, cfg, *args, **kwargs):
  84. super().__init__(*args, **kwargs)
  85. self.cfg = cfg
  86. def configure_optimizers(self):
  87. optimizer = make_optimizer(self.cfg, self)
  88. scheduler = build_lr_scheduler(self.cfg, optimizer)
  89. return [optimizer], [scheduler]
  90. def on_validation_epoch_start(self) -> None:
  91. logger.info('Valid.')
  92. def on_test_epoch_start(self) -> None:
  93. logger.info('Testing...')
  94. class CscTrainingModel(BaseTrainingEngine, ABC):
  95. """
  96. 用于CSC的BaseModel, 定义了训练及预测步骤
  97. """
  98. def __init__(self, cfg, *args, **kwargs):
  99. super().__init__(cfg, *args, **kwargs)
  100. # loss weight
  101. self.w = cfg.MODEL.HYPER_PARAMS[0]
  102. def training_step(self, batch, batch_idx):
  103. ori_text, cor_text, det_labels = batch
  104. outputs = self.forward(ori_text, cor_text, det_labels)
  105. loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
  106. self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(ori_text))
  107. return loss
  108. def validation_step(self, batch, batch_idx):
  109. ori_text, cor_text, det_labels = batch
  110. outputs = self.forward(ori_text, cor_text, det_labels)
  111. loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
  112. det_y_hat = (outputs[2] > 0.5).long()
  113. cor_y_hat = torch.argmax((outputs[3]), dim=-1)
  114. encoded_x = self.tokenizer(cor_text, padding=True, return_tensors='pt')
  115. encoded_x.to(self._device)
  116. cor_y = encoded_x['input_ids']
  117. cor_y_hat *= encoded_x['attention_mask']
  118. results = []
  119. det_acc_labels = []
  120. cor_acc_labels = []
  121. for src, tgt, predict, det_predict, det_label in zip(ori_text, cor_y, cor_y_hat, det_y_hat, det_labels):
  122. _src = self.tokenizer(src, add_special_tokens=False)['input_ids']
  123. _tgt = tgt[1:len(_src) + 1].cpu().numpy().tolist()
  124. _predict = predict[1:len(_src) + 1].cpu().numpy().tolist()
  125. cor_acc_labels.append(1 if operator.eq(_tgt, _predict) else 0)
  126. det_acc_labels.append(det_predict[1:len(_src) + 1].equal(det_label[1:len(_src) + 1]))
  127. results.append((_src, _tgt, _predict,))
  128. return loss.cpu().item(), det_acc_labels, cor_acc_labels, results
  129. def validation_epoch_end(self, outputs) -> None:
  130. det_acc_labels = []
  131. cor_acc_labels = []
  132. results = []
  133. for out in outputs:
  134. det_acc_labels += out[1]
  135. cor_acc_labels += out[2]
  136. results += out[3]
  137. loss = np.mean([out[0] for out in outputs])
  138. self.log('val_loss', loss)
  139. logger.info(f'loss: {loss}')
  140. logger.info(f'Detection: acc: {np.mean(det_acc_labels):.4f}')
  141. logger.info(f'Correction: acc: {np.mean(cor_acc_labels):.4f}')
  142. compute_corrector_prf(results, logger)
  143. compute_sentence_level_prf(results, logger)
  144. def test_step(self, batch, batch_idx):
  145. return self.validation_step(batch, batch_idx)
  146. def test_epoch_end(self, outputs) -> None:
  147. logger.info('Test.')
  148. self.validation_epoch_end(outputs)
  149. def predict(self, texts):
  150. inputs = self.tokenizer(texts, padding=True, return_tensors='pt')
  151. inputs.to(self.cfg.MODEL.DEVICE)
  152. with torch.no_grad():
  153. outputs = self.forward(texts)
  154. y_hat = torch.argmax(outputs[1], dim=-1)
  155. expand_text_lens = torch.sum(inputs['attention_mask'], dim=-1) - 1
  156. rst = []
  157. for t_len, _y_hat in zip(expand_text_lens, y_hat):
  158. rst.append(self.tokenizer.decode(_y_hat[1:t_len]).replace(' ', ''))
  159. return rst