evaluate_util.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
  4. @description:
  5. """
  6. def compute_corrector_prf(results, logger):
  7. """
  8. copy from https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check/blob/master/utils/evaluation_metrics.py
  9. """
  10. TP = 0
  11. FP = 0
  12. FN = 0
  13. all_predict_true_index = []
  14. all_gold_index = []
  15. for item in results:
  16. src, tgt, predict = item
  17. gold_index = []
  18. each_true_index = []
  19. for i in range(len(list(src))):
  20. if src[i] == tgt[i]:
  21. continue
  22. else:
  23. gold_index.append(i)
  24. all_gold_index.append(gold_index)
  25. predict_index = []
  26. for i in range(len(list(src))):
  27. if src[i] == predict[i]:
  28. continue
  29. else:
  30. predict_index.append(i)
  31. for i in predict_index:
  32. if i in gold_index:
  33. TP += 1
  34. each_true_index.append(i)
  35. else:
  36. FP += 1
  37. for i in gold_index:
  38. if i in predict_index:
  39. continue
  40. else:
  41. FN += 1
  42. all_predict_true_index.append(each_true_index)
  43. # For the detection Precision, Recall and F1
  44. detection_precision = TP / (TP + FP) if (TP + FP) > 0 else 0
  45. detection_recall = TP / (TP + FN) if (TP + FN) > 0 else 0
  46. if detection_precision + detection_recall == 0:
  47. detection_f1 = 0
  48. else:
  49. detection_f1 = 2 * (detection_precision * detection_recall) / (detection_precision + detection_recall)
  50. logger.info(
  51. "The detection result is precision={}, recall={} and F1={}".format(detection_precision, detection_recall,
  52. detection_f1))
  53. TP = 0
  54. FP = 0
  55. FN = 0
  56. for i in range(len(all_predict_true_index)):
  57. # we only detect those correctly detected location, which is a different from the common metrics since
  58. # we want to see the precision improve by using the confusionset
  59. if len(all_predict_true_index[i]) > 0:
  60. predict_words = []
  61. for j in all_predict_true_index[i]:
  62. predict_words.append(results[i][2][j])
  63. if results[i][1][j] == results[i][2][j]:
  64. TP += 1
  65. else:
  66. FP += 1
  67. for j in all_gold_index[i]:
  68. if results[i][1][j] in predict_words:
  69. continue
  70. else:
  71. FN += 1
  72. # For the correction Precision, Recall and F1
  73. correction_precision = TP / (TP + FP) if (TP + FP) > 0 else 0
  74. correction_recall = TP / (TP + FN) if (TP + FN) > 0 else 0
  75. if correction_precision + correction_recall == 0:
  76. correction_f1 = 0
  77. else:
  78. correction_f1 = 2 * (correction_precision * correction_recall) / (correction_precision + correction_recall)
  79. logger.info("The correction result is precision={}, recall={} and F1={}".format(correction_precision,
  80. correction_recall,
  81. correction_f1))
  82. return detection_f1, correction_f1
  83. def compute_sentence_level_prf(results, logger):
  84. """
  85. 自定义的句级prf,设定需要纠错为正样本,无需纠错为负样本
  86. :param results:
  87. :return:
  88. """
  89. TP = 0.0
  90. FP = 0.0
  91. FN = 0.0
  92. TN = 0.0
  93. total_num = len(results)
  94. for item in results:
  95. src, tgt, predict = item
  96. # 负样本
  97. if src == tgt:
  98. # 预测也为负
  99. if tgt == predict:
  100. TN += 1
  101. # 预测为正
  102. else:
  103. FP += 1
  104. # 正样本
  105. else:
  106. # 预测也为正
  107. if tgt == predict:
  108. TP += 1
  109. # 预测为负
  110. else:
  111. FN += 1
  112. acc = (TP + TN) / total_num
  113. precision = TP / (TP + FP) if TP > 0 else 0.0
  114. recall = TP / (TP + FN) if TP > 0 else 0.0
  115. f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
  116. logger.info(f'Sentence Level: acc:{acc:.6f}, precision:{precision:.6f}, recall:{recall:.6f}, f1:{f1:.6f}')
  117. return acc, precision, recall, f1
  118. def report_prf(tp, fp, fn, phase, logger=None, return_dict=False):
  119. # For the detection Precision, Recall and F1
  120. precision = tp / (tp + fp) if (tp + fp) > 0 else 0
  121. recall = tp / (tp + fn) if (tp + fn) > 0 else 0
  122. if precision + recall == 0:
  123. f1_score = 0
  124. else:
  125. f1_score = 2 * (precision * recall) / (precision + recall)
  126. if phase and logger:
  127. logger.info(f"The {phase} result is: "
  128. f"{precision:.4f}/{recall:.4f}/{f1_score:.4f} -->\n"
  129. # f"precision={precision:.6f}, recall={recall:.6f} and F1={f1_score:.6f}\n"
  130. f"support: TP={tp}, FP={fp}, FN={fn}")
  131. if return_dict:
  132. ret_dict = {
  133. f'{phase}_p': precision,
  134. f'{phase}_r': recall,
  135. f'{phase}_f1': f1_score}
  136. return ret_dict
  137. return precision, recall, f1_score
  138. def compute_corrector_prf_faspell(results, logger=None, strict=True):
  139. """
  140. All-in-one measure function.
  141. based on FASpell's measure script.
  142. :param results: a list of (wrong, correct, predict, ...)
  143. both token_ids or characters are fine for the script.
  144. :param logger: take which logger to print logs.
  145. :param strict: a more strict evaluation mode (all-char-detected/corrected)
  146. References:
  147. sentence-level PRF: https://github.com/iqiyi/
  148. FASPell/blob/master/faspell.py
  149. """
  150. corrected_char, wrong_char = 0, 0
  151. corrected_sent, wrong_sent = 0, 0
  152. true_corrected_char = 0
  153. true_corrected_sent = 0
  154. true_detected_char = 0
  155. true_detected_sent = 0
  156. accurate_detected_sent = 0
  157. accurate_corrected_sent = 0
  158. all_sent = 0
  159. for item in results:
  160. # wrong, correct, predict, d_tgt, d_predict = item
  161. wrong, correct, predict = item[:3]
  162. all_sent += 1
  163. wrong_num = 0
  164. corrected_num = 0
  165. original_wrong_num = 0
  166. true_detected_char_in_sentence = 0
  167. for c, w, p in zip(correct, wrong, predict):
  168. if c != p:
  169. wrong_num += 1
  170. if w != p:
  171. corrected_num += 1
  172. if c == p:
  173. true_corrected_char += 1
  174. if w != c:
  175. true_detected_char += 1
  176. true_detected_char_in_sentence += 1
  177. if c != w:
  178. original_wrong_num += 1
  179. corrected_char += corrected_num
  180. wrong_char += original_wrong_num
  181. if original_wrong_num != 0:
  182. wrong_sent += 1
  183. if corrected_num != 0 and wrong_num == 0:
  184. true_corrected_sent += 1
  185. if corrected_num != 0:
  186. corrected_sent += 1
  187. if strict: # find out all faulty wordings' potisions
  188. true_detected_flag = (true_detected_char_in_sentence == original_wrong_num \
  189. and original_wrong_num != 0 \
  190. and corrected_num == true_detected_char_in_sentence)
  191. else: # think it has faulty wordings
  192. true_detected_flag = (corrected_num != 0 and original_wrong_num != 0)
  193. # if corrected_num != 0 and original_wrong_num != 0:
  194. if true_detected_flag:
  195. true_detected_sent += 1
  196. if correct == predict:
  197. accurate_corrected_sent += 1
  198. if correct == predict or true_detected_flag:
  199. accurate_detected_sent += 1
  200. counts = { # TP, FP, TN for each level
  201. 'det_char_counts': [true_detected_char,
  202. corrected_char - true_detected_char,
  203. wrong_char - true_detected_char],
  204. 'cor_char_counts': [true_corrected_char,
  205. corrected_char - true_corrected_char,
  206. wrong_char - true_corrected_char],
  207. 'det_sent_counts': [true_detected_sent,
  208. corrected_sent - true_detected_sent,
  209. wrong_sent - true_detected_sent],
  210. 'cor_sent_counts': [true_corrected_sent,
  211. corrected_sent - true_corrected_sent,
  212. wrong_sent - true_corrected_sent],
  213. 'det_sent_acc': accurate_detected_sent / all_sent,
  214. 'cor_sent_acc': accurate_corrected_sent / all_sent,
  215. 'all_sent_count': all_sent,
  216. }
  217. details = {}
  218. for phase in ['det_char', 'cor_char', 'det_sent', 'cor_sent']:
  219. dic = report_prf(
  220. *counts[f'{phase}_counts'],
  221. phase=phase, logger=logger,
  222. return_dict=True)
  223. details.update(dic)
  224. details.update(counts)
  225. return details