evaluate_models.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description:
  5. """
  6. import argparse
  7. import sys
  8. import os
  9. sys.path.append("../..")
  10. from pycorrector import eval_model_batch
  11. pwd_path = os.path.abspath(os.path.dirname(__file__))
  12. def main(args):
  13. if args.model == 'kenlm':
  14. from pycorrector import Corrector
  15. m = Corrector()
  16. if args.data == 'sighan':
  17. eval_model_batch(m.correct_batch)
  18. # Sentence Level: acc:0.5409, precision:0.6532, recall:0.1492, f1:0.2429, cost time:295.07 s, total num: 1100
  19. # Sentence Level: acc:0.5502, precision:0.8022, recall:0.1957, f1:0.3147, cost time:37.28 s, total num: 707
  20. elif args.data == 'ec_law':
  21. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"))
  22. # Sentence Level: acc:0.5790, precision:0.8581, recall:0.2410, f1:0.3763, cost time:64.61 s, total num: 1000
  23. elif args.data == 'mcsc':
  24. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"))
  25. # Sentence Level: acc:0.5850, precision:0.7518, recall:0.2128, f1:0.3317, cost time:30.61 s, total num: 1000
  26. elif args.model == 'macbert':
  27. from pycorrector import MacBertCorrector
  28. model = MacBertCorrector()
  29. if args.data == 'sighan':
  30. eval_model_batch(model.correct_batch)
  31. # macbert: Sentence Level: acc:0.7918, precision:0.8489, recall:0.7035, f1:0.7694, cost time:2.25 s, total num: 1100
  32. # pert-base: Sentence Level: acc:0.7709, precision:0.7893, recall:0.7311, f1:0.7591, cost time:2.52 s, total num: 1100
  33. # pert-large: Sentence Level: acc:0.7709, precision:0.7847, recall:0.7385, f1:0.7609, cost time:7.22 s, total num: 1100
  34. # macbert4csc Sentence Level: acc:0.8388, precision:0.9274, recall:0.7534, f1:0.8314, cost time:4.26 s, total num: 707
  35. elif args.data == 'ec_law':
  36. eval_model_batch(model.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"))
  37. # Sentence Level: acc:0.2390, precision:0.1921, recall:0.1385, f1:0.1610, cost time:7.11 s, total num: 1000
  38. elif args.data == 'mcsc':
  39. eval_model_batch(model.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"))
  40. # Sentence Level: acc:0.5360, precision:0.6000, recall:0.1240, f1:0.2055, cost time:2.65 s, total num: 1000
  41. elif args.model == 'seq2seq':
  42. from pycorrector import ConvSeq2SeqCorrector
  43. model = ConvSeq2SeqCorrector()
  44. eval_model_batch(model.correct_batch)
  45. # Sentence Level: acc:0.3909, precision:0.2803, recall:0.1492, f1:0.1947, cost time:219.50 s, total num: 1100
  46. elif args.model == 't5':
  47. from pycorrector import T5Corrector
  48. m = T5Corrector()
  49. if args.data == 'sighan':
  50. eval_model_batch(m.correct_batch)
  51. # Sentence Level: acc:0.7582, precision:0.8321, recall:0.6390, f1:0.7229, cost time:26.36 s, total num: 1100
  52. # Sentence Level: acc:0.7907, precision:0.8920, recall:0.6863, f1:0.7758, cost time:20.82 s, total num: 707
  53. elif args.data == 'ec_law':
  54. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"))
  55. # Sentence Level: acc:0.5230, precision:0.6471, recall:0.2087, f1:0.3156, cost time:43.61 s, total num: 1000
  56. elif args.data == 'mcsc':
  57. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"))
  58. # Sentence Level: acc:0.4650, precision:0.2743, recall:0.0640, f1:0.1039, cost time:14.99 s, total num: 1000
  59. elif args.model == 'deepcontext':
  60. from pycorrector import DeepContextCorrector
  61. model = DeepContextCorrector()
  62. eval_model_batch(model.correct_batch)
  63. elif args.model == 'ernie_csc':
  64. from pycorrector import ErnieCscCorrector
  65. m = ErnieCscCorrector()
  66. if args.data == 'sighan':
  67. eval_model_batch(m.correct_batch)
  68. # Sentence Level: acc:0.7491, precision:0.7623, recall:0.7145, f1:0.7376, cost time:3.03 s, total num: 1100
  69. # Sentence Level: acc:0.8373, precision:0.8817, recall:0.7989, f1:0.8383, cost time:14.97 s, total num: 707
  70. elif args.data == 'ec_law':
  71. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"))
  72. # Sentence Level: acc:0.5370, precision:0.6882, recall:0.2220, f1:0.3357, cost time:25.15 s, total num: 1000
  73. elif args.data == 'mcsc':
  74. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"))
  75. # Sentence Level: acc:0.4600, precision:0.2971, recall:0.0847, f1:0.1318, cost time:18.69 s, total num: 1000
  76. elif args.model == 'chatglm':
  77. from pycorrector.gpt.gpt_corrector import GptCorrector
  78. model = GptCorrector(model_name_or_path="THUDM/chatglm3-6b",
  79. model_type='chatglm',
  80. peft_name="shibing624/chatglm3-6b-csc-chinese-lora")
  81. eval_model_batch(model.correct_batch)
  82. # chatglm3-6b-csc: Sentence Level: acc:0.5564, precision:0.5574, recall:0.4917, f1:0.5225, cost time:1572.49 s, total num: 1100
  83. elif args.model == 'qwen1.5b':
  84. from pycorrector.gpt.gpt_corrector import GptCorrector
  85. m = GptCorrector(model_name_or_path="shibing624/chinese-text-correction-1.5b")
  86. if args.data == 'sighan':
  87. eval_model_batch(m.correct_batch)
  88. # Sentence Level: acc:0.4540, precision:0.4641, recall:0.2252, f1:0.3032, cost time:243.50 s, total num: 707
  89. elif args.data == 'ec_law':
  90. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"))
  91. # Sentence Level: acc:0.7990, precision:0.9015, recall:0.6945, f1:0.7846, cost time:266.26 s, total num: 1000
  92. elif args.data == 'mcsc':
  93. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"))
  94. # Sentence Level: acc:0.9560, precision:0.9889, recall:0.9194, f1:0.9529, cost time:210.11 s, total num: 1000
  95. elif args.model == 'qwen7b':
  96. from pycorrector.gpt.gpt_corrector import GptCorrector
  97. m = GptCorrector(model_name_or_path="shibing624/chinese-text-correction-7b")
  98. if args.data == 'sighan':
  99. eval_model_batch(m.correct_batch)
  100. # Sentence Level: acc:0.5672, precision:0.6463, recall:0.3968, f1:0.4917, cost time:392.10 s, total num: 707
  101. elif args.data == 'ec_law':
  102. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"))
  103. # Sentence Level: acc:0.9790, precision:0.9941, recall:0.9658, f1:0.9798, cost time:717.37 s, total num: 1000
  104. elif args.data == 'mcsc':
  105. eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"))
  106. # Sentence Level: acc:0.9960, precision:0.9979, recall:0.9938, f1:0.9959, cost time:267.12 s, total num: 1000
  107. else:
  108. raise ValueError('model name error.')
  109. if __name__ == '__main__':
  110. parser = argparse.ArgumentParser()
  111. parser.add_argument('--model', type=str, default='macbert', help='which model to evaluate')
  112. parser.add_argument('--data', type=str, default='sighan', help='test dataset')
  113. args = parser.parse_args()
  114. main(args)