evaluate_utils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description:
  5. """
  6. import os
  7. import time
  8. from codecs import open
  9. pwd_path = os.path.abspath(os.path.dirname(__file__))
  10. sighan_2015_path = os.path.join(pwd_path, '../data/sighan2015_test.tsv')
  11. def eval_model_single(correct_fn, input_tsv_file=sighan_2015_path, verbose=True):
  12. """
  13. SIGHAN句级评估结果,设定需要纠错为正样本,无需纠错为负样本
  14. Args:
  15. correct_fn:
  16. input_tsv_file:
  17. verbose:
  18. Returns:
  19. Acc, Recall, F1
  20. """
  21. TP = 0.0
  22. FP = 0.0
  23. FN = 0.0
  24. TN = 0.0
  25. total_num = 0
  26. start_time = time.time()
  27. with open(input_tsv_file, 'r', encoding='utf-8') as f:
  28. for line in f:
  29. line = line.strip()
  30. if line.startswith('#'):
  31. continue
  32. parts = line.split('\t')
  33. if len(parts) != 2:
  34. continue
  35. src = parts[0]
  36. tgt = parts[1]
  37. r = correct_fn(src)
  38. tgt_pred, pred_detail = r['target'], r['errors']
  39. if verbose:
  40. print()
  41. print('input :', src)
  42. print('truth :', tgt)
  43. print('predict:', tgt_pred, pred_detail)
  44. # 负样本
  45. if src == tgt:
  46. # 预测也为负
  47. if tgt == tgt_pred:
  48. TN += 1
  49. print('right')
  50. # 预测为正
  51. else:
  52. FP += 1
  53. print('wrong')
  54. # 正样本
  55. else:
  56. # 预测也为正
  57. if tgt == tgt_pred:
  58. TP += 1
  59. print('right')
  60. # 预测为负
  61. else:
  62. FN += 1
  63. print('wrong')
  64. total_num += 1
  65. spend_time = time.time() - start_time
  66. acc = (TP + TN) / total_num
  67. precision = TP / (TP + FP) if TP > 0 else 0.0
  68. recall = TP / (TP + FN) if TP > 0 else 0.0
  69. f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
  70. print(
  71. f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, '
  72. f'cost time:{spend_time:.2f} s, total num: {total_num}')
  73. return acc, precision, recall, f1
  74. def eval_model_batch(correct_fn, input_tsv_file=sighan_2015_path, verbose=True):
  75. """
  76. SIGHAN句级评估结果,设定需要纠错为正样本,无需纠错为负样本
  77. Args:
  78. correct_fn:
  79. input_tsv_file:
  80. verbose:
  81. Returns:
  82. Acc, Recall, F1
  83. """
  84. TP = 0.0
  85. FP = 0.0
  86. FN = 0.0
  87. TN = 0.0
  88. total_num = 0
  89. start_time = time.time()
  90. srcs = []
  91. tgts = []
  92. with open(input_tsv_file, 'r', encoding='utf-8') as f:
  93. for line in f:
  94. line = line.strip()
  95. if line.startswith('#'):
  96. continue
  97. parts = line.split('\t')
  98. if len(parts) != 2:
  99. continue
  100. src = parts[0]
  101. tgt = parts[1]
  102. srcs.append(src)
  103. tgts.append(tgt)
  104. res = correct_fn(srcs)
  105. for each_res, src, tgt in zip(res, srcs, tgts):
  106. pred_detail = ''
  107. if isinstance(each_res, str):
  108. tgt_pred = each_res
  109. elif isinstance(each_res, dict):
  110. tgt_pred = each_res['target']
  111. pred_detail = each_res['errors']
  112. else:
  113. raise ValueError('correct_fn return type error.')
  114. if verbose:
  115. print()
  116. print('input :', src)
  117. print('truth :', tgt)
  118. print('predict:', tgt_pred, pred_detail)
  119. # 负样本
  120. if src == tgt:
  121. # 预测也为负
  122. if tgt == tgt_pred:
  123. TN += 1
  124. print('right')
  125. # 预测为正
  126. else:
  127. FP += 1
  128. print('wrong')
  129. # 正样本
  130. else:
  131. # 预测也为正
  132. if tgt == tgt_pred:
  133. TP += 1
  134. print('right')
  135. # 预测为负
  136. else:
  137. FN += 1
  138. print('wrong')
  139. total_num += 1
  140. spend_time = time.time() - start_time
  141. acc = (TP + TN) / total_num
  142. precision = TP / (TP + FP) if TP > 0 else 0.0
  143. recall = TP / (TP + FN) if TP > 0 else 0.0
  144. f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
  145. print(
  146. f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, '
  147. f'cost time:{spend_time:.2f} s, total num: {total_num}')
  148. return acc, precision, recall, f1
  149. if __name__ == "__main__":
  150. # 评估macbert模型的纠错准召率
  151. from pycorrector.macbert.macbert_corrector import MacBertCorrector
  152. model = MacBertCorrector()
  153. eval_model_batch(model.correct)