evaluate_utils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description:
  5. """
  6. import os
  7. import time
  8. from codecs import open
  9. pwd_path = os.path.abspath(os.path.dirname(__file__))
  10. sighan_2015_path = os.path.join(pwd_path, '../data/sighan2015_test.tsv')
  11. def eval_sighan2015_by_model(correct_fn, sighan_path=sighan_2015_path, verbose=True):
  12. """
  13. SIGHAN句级评估结果,设定需要纠错为正样本,无需纠错为负样本
  14. Args:
  15. correct_fn:
  16. input_eval_path:
  17. output_eval_path:
  18. verbose:
  19. Returns:
  20. Acc, Recall, F1
  21. """
  22. TP = 0.0
  23. FP = 0.0
  24. FN = 0.0
  25. TN = 0.0
  26. total_num = 0
  27. start_time = time.time()
  28. with open(sighan_path, 'r', encoding='utf-8') as f:
  29. for line in f:
  30. line = line.strip()
  31. if line.startswith('#'):
  32. continue
  33. parts = line.split('\t')
  34. if len(parts) != 2:
  35. continue
  36. src = parts[0]
  37. tgt = parts[1]
  38. r = correct_fn(src)
  39. tgt_pred, pred_detail = r['target'], r['errors']
  40. if verbose:
  41. print()
  42. print('input :', src)
  43. print('truth :', tgt)
  44. print('predict:', tgt_pred, pred_detail)
  45. # 负样本
  46. if src == tgt:
  47. # 预测也为负
  48. if tgt == tgt_pred:
  49. TN += 1
  50. print('right')
  51. # 预测为正
  52. else:
  53. FP += 1
  54. print('wrong')
  55. # 正样本
  56. else:
  57. # 预测也为正
  58. if tgt == tgt_pred:
  59. TP += 1
  60. print('right')
  61. # 预测为负
  62. else:
  63. FN += 1
  64. print('wrong')
  65. total_num += 1
  66. spend_time = time.time() - start_time
  67. acc = (TP + TN) / total_num
  68. precision = TP / (TP + FP) if TP > 0 else 0.0
  69. recall = TP / (TP + FN) if TP > 0 else 0.0
  70. f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
  71. print(
  72. f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, '
  73. f'cost time:{spend_time:.2f} s, total num: {total_num}')
  74. return acc, precision, recall, f1
  75. def eval_sighan2015_by_model_batch(correct_fn, sighan_path=sighan_2015_path, verbose=True):
  76. """
  77. SIGHAN句级评估结果,设定需要纠错为正样本,无需纠错为负样本
  78. Args:
  79. correct_fn:
  80. sighan_path:
  81. verbose:
  82. Returns:
  83. Acc, Recall, F1
  84. """
  85. TP = 0.0
  86. FP = 0.0
  87. FN = 0.0
  88. TN = 0.0
  89. total_num = 0
  90. start_time = time.time()
  91. srcs = []
  92. tgts = []
  93. with open(sighan_path, 'r', encoding='utf-8') as f:
  94. for line in f:
  95. line = line.strip()
  96. if line.startswith('#'):
  97. continue
  98. parts = line.split('\t')
  99. if len(parts) != 2:
  100. continue
  101. src = parts[0]
  102. tgt = parts[1]
  103. srcs.append(src)
  104. tgts.append(tgt)
  105. res = correct_fn(srcs)
  106. for each_res, src, tgt in zip(res, srcs, tgts):
  107. if isinstance(each_res, str):
  108. tgt_pred = each_res
  109. elif isinstance(each_res, dict):
  110. tgt_pred = each_res['target']
  111. else:
  112. raise ValueError('correct_fn return type error.')
  113. if verbose:
  114. print()
  115. print('input :', src)
  116. print('truth :', tgt)
  117. print('predict:', each_res)
  118. # 负样本
  119. if src == tgt:
  120. # 预测也为负
  121. if tgt == tgt_pred:
  122. TN += 1
  123. print('right')
  124. # 预测为正
  125. else:
  126. FP += 1
  127. print('wrong')
  128. # 正样本
  129. else:
  130. # 预测也为正
  131. if tgt == tgt_pred:
  132. TP += 1
  133. print('right')
  134. # 预测为负
  135. else:
  136. FN += 1
  137. print('wrong')
  138. total_num += 1
  139. spend_time = time.time() - start_time
  140. acc = (TP + TN) / total_num
  141. precision = TP / (TP + FP) if TP > 0 else 0.0
  142. recall = TP / (TP + FN) if TP > 0 else 0.0
  143. f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
  144. print(
  145. f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, '
  146. f'cost time:{spend_time:.2f} s, total num: {total_num}')
  147. return acc, precision, recall, f1
  148. if __name__ == "__main__":
  149. # 评估macbert模型的纠错准召率
  150. from pycorrector.macbert.macbert_corrector import MacBertCorrector
  151. model = MacBertCorrector()
  152. eval_sighan2015_by_model(model.correct)