浏览代码

update infer eval.

shibing624 1 周之前
父节点
当前提交
b949208e71
共有 4 个文件被更改,包括 26 次插入20 次删除
  1. 10 10
      README.md
  2. 6 8
      examples/evaluate_models/evaluate_models.py
  3. 7 1
      examples/gpt/demo.py
  4. 3 1
      pycorrector/gpt/gpt_model.py

+ 10 - 10
README.md

@@ -86,15 +86,15 @@ python examples/macbert/gradio_demo.py
 - CTC(CHinese Text Correction): 文本纠错模型,表示模型支持拼写、语法等长度对齐的错误纠正,还可以处理多字、少字等长度不对齐的错误纠正
 - CTC(CHinese Text Correction): 文本纠错模型,表示模型支持拼写、语法等长度对齐的错误纠正,还可以处理多字、少字等长度不对齐的错误纠正
 - GPU:Tesla V100,显存 32 GB
 - GPU:Tesla V100,显存 32 GB
 
 
-| Model Name       | Model Link                                                                                                          | Base Model                 | Avg        | SIGHAN-2015 | EC-LAW | MCSC   | GPU/CPU | QPS     |
-|:-----------------|:--------------------------------------------------------------------------------------------------------------------|:---------------------------|:-----------|:----------------|:-------|:-------|:--------|:--------|
-| Kenlm-CSC        | [shibing624/chinese-kenlm-klm](https://huggingface.co/shibing624/chinese-kenlm-klm)                                 | kenlm | 0.3409     | 0.3147 | 0.3763 | 0.3317 | CPU     | 9 |
-| Mengzi-T5-CSC    | [shibing624/mengzi-t5-base-chinese-correction](https://huggingface.co/shibing624/mengzi-t5-base-chinese-correction) | mengzi-t5-base | 0.3984     | 0.7758 | 0.3156 | 0.1039 | GPU     | 214 |
-| ERNIE-CSC        | [PaddleNLP/ernie-csc](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/legacy/examples/text_correction/ernie-csc) | PaddlePaddle/ernie-1.0-base-zh | 0.4353     | 0.8383 | 0.3357 | 0.1318 | GPU     | 114 |
-| MacBERT-CSC      | [shibing624/macbert4csc-base-chinese](https://huggingface.co/shibing624/macbert4csc-base-chinese)                   | hfl/chinese-macbert-base   | 0.3993     | 0.8314 | 0.1610 | 0.2055 | GPU     | **224** |
-| ChatGLM3-6B-CSC  | [shibing624/chatglm3-6b-csc-chinese-lora](https://huggingface.co/shibing624/chatglm3-6b-csc-chinese-lora)           | THUDM/chatglm3-6b          | -          | 0.5225          | -      | -      | GPU     | 1 |
-| Qwen2.5-1.5B-CTC | [shibing624/chinese-text-correction-1.5b](https://huggingface.co/shibing624/chinese-text-correction-1.5b)           | Qwen/Qwen2.5-1.5B-Instruct | 0.6802     | 0.3032 | 0.7846 | 0.9529 | GPU     | 3 |
-| Qwen2.5-7B-CTC   | [shibing624/chinese-text-correction-7b](https://huggingface.co/shibing624/chinese-text-correction-7b)               | Qwen/Qwen2.5-7B-Instruct   | **0.8225** | 0.4917 | 0.9798 | 0.9959 | GPU     | 2 |
+| Model Name       | Model Link                                                                                                              | Base Model                 | Avg        | SIGHAN-2015 | EC-LAW | MCSC   | GPU/CPU | QPS     |
+|:-----------------|:------------------------------------------------------------------------------------------------------------------------|:---------------------------|:-----------|:------------|:-------|:-------|:--------|:--------|
+| Kenlm-CSC        | [shibing624/chinese-kenlm-klm](https://huggingface.co/shibing624/chinese-kenlm-klm)                                     | kenlm | 0.3409     | 0.3147      | 0.3763 | 0.3317 | CPU     | 9       |
+| Mengzi-T5-CSC    | [shibing624/mengzi-t5-base-chinese-correction](https://huggingface.co/shibing624/mengzi-t5-base-chinese-correction)     | mengzi-t5-base | 0.3984     | 0.7758      | 0.3156 | 0.1039 | GPU     | 214     |
+| ERNIE-CSC        | [PaddleNLP/ernie-csc](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/legacy/examples/text_correction/ernie-csc) | PaddlePaddle/ernie-1.0-base-zh | 0.4353     | 0.8383      | 0.3357 | 0.1318 | GPU     | 114     |
+| MacBERT-CSC      | [shibing624/macbert4csc-base-chinese](https://huggingface.co/shibing624/macbert4csc-base-chinese)                       | hfl/chinese-macbert-base   | 0.3993     | 0.8314      | 0.1610 | 0.2055 | GPU     | **224** |
+| ChatGLM3-6B-CSC  | [shibing624/chatglm3-6b-csc-chinese-lora](https://huggingface.co/shibing624/chatglm3-6b-csc-chinese-lora)               | THUDM/chatglm3-6b          | 0.4538     | 0.6572      | 0.4369     | 0.2672      | GPU     | 3       |
+| Qwen2.5-1.5B-CTC | [shibing624/chinese-text-correction-1.5b](https://huggingface.co/shibing624/chinese-text-correction-1.5b)               | Qwen/Qwen2.5-1.5B-Instruct | 0.6802     | 0.3032      | 0.7846 | 0.9529 | GPU     | 6       |
+| Qwen2.5-7B-CTC   | [shibing624/chinese-text-correction-7b](https://huggingface.co/shibing624/chinese-text-correction-7b)                   | Qwen/Qwen2.5-7B-Instruct   | **0.8225** | 0.4917      | 0.9798 | 0.9959 | GPU     | 3       |
 
 
 
 
 ## Install
 ## Install
@@ -125,7 +125,7 @@ docker run -it -v ~/.pycorrector:/root/.pycorrector shibing624/pycorrector:0.0.2
 ## Usage
 ## Usage
 本项目的初衷之一是比对、调研各种中文文本纠错方法,抛砖引玉。
 本项目的初衷之一是比对、调研各种中文文本纠错方法,抛砖引玉。
 
 
-项目实现了kenlm、macbert、seq2seq、 ernie_csc、T5、deepcontext、LLaMA等模型应用于文本纠错任务,各模型均可基于已经训练好的纠错模型快速预测,也可使用自有数据训练、预测。
+项目实现了kenlm、macbert、seq2seq、 ernie_csc、T5、deepcontext、GPT(Qwen/ChatGLM)等模型应用于文本纠错任务,各模型均可基于已经训练好的纠错模型快速预测,也可使用自有数据训练、预测。
 
 
 
 
 ### kenlm模型(统计模型)
 ### kenlm模型(统计模型)

+ 6 - 8
examples/evaluate_models/evaluate_models.py

@@ -84,19 +84,17 @@ def main(args):
                          model_type='chatglm',
                          model_type='chatglm',
                          peft_name="shibing624/chatglm3-6b-csc-chinese-lora")
                          peft_name="shibing624/chatglm3-6b-csc-chinese-lora")
         if args.data == 'sighan':
         if args.data == 'sighan':
-            eval_model_batch(m.correct_batch, prefix_prompt="对这个句子语法纠错\n\n", prompt_template_name='vicuna')
+            eval_model_batch(m.correct_batch, prefix_prompt="对下面文本纠错:", prompt_template_name='vicuna')
             # Sentence Level: acc:0.5564, precision:0.5574, recall:0.4917, f1:0.5225, cost time:1572.49 s, total num: 1100
             # Sentence Level: acc:0.5564, precision:0.5574, recall:0.4917, f1:0.5225, cost time:1572.49 s, total num: 1100
-            #
+            # Sentence Level: acc:0.6591, precision:0.7000, recall:0.6193, f1:0.6572, cost time:273.06 s, total num: 707
         elif args.data == 'ec_law':
         elif args.data == 'ec_law':
             eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"),
             eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"),
-                             prefix_prompt="对这个句子语法纠错\n\n",
-                             prompt_template_name='vicuna')
-            #
+                             prefix_prompt="对下面文本纠错:", prompt_template_name='vicuna')
+            # Sentence Level: acc:0.4870, precision:0.5182, recall:0.3776, f1:0.4369, cost time:372.46 s, total num: 1000
         elif args.data == 'mcsc':
         elif args.data == 'mcsc':
             eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"),
             eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"),
-                             prefix_prompt="对这个句子语法纠错\n\n",
-                             prompt_template_name='vicuna')
-            #
+                             prefix_prompt="对下面文本纠错:", prompt_template_name='vicuna')
+            # Sentence Level: acc:0.4790, precision:0.4185, recall:0.1963, f1:0.2672, cost time:383.76 s, total num: 1000
     elif args.model == 'qwen1.5b':
     elif args.model == 'qwen1.5b':
         from pycorrector.gpt.gpt_corrector import GptCorrector
         from pycorrector.gpt.gpt_corrector import GptCorrector
         m = GptCorrector(model_name_or_path="shibing624/chinese-text-correction-1.5b")
         m = GptCorrector(model_name_or_path="shibing624/chinese-text-correction-1.5b")

+ 7 - 1
examples/gpt/demo.py

@@ -21,7 +21,13 @@ if __name__ == '__main__':
     ]
     ]
     m = GptCorrector("shibing624/chinese-text-correction-1.5b")
     m = GptCorrector("shibing624/chinese-text-correction-1.5b")
 
 
-    batch_res = m.correct_batch(error_sentences, system_prompt="你是一个中文文本纠错助手。请根据用户提供的原始文本,生成纠正后的文本。")
+    batch_res = m.correct_batch(error_sentences,
+                                system_prompt="你是一个中文文本纠错助手。请根据用户提供的原始文本,生成纠正后的文本。")
     for i in batch_res:
     for i in batch_res:
         print(i)
         print(i)
         print()
         print()
+
+    # batch_res = m.correct_batch(error_sentences, prefix_prompt='文本纠错:\n\n', prompt_template_name='qwen')
+    # for i in batch_res:
+    #     print(i)
+    #     print()

+ 3 - 1
pycorrector/gpt/gpt_model.py

@@ -579,7 +579,8 @@ class GptModel:
                 prompt_len = len(input_ids[0])
                 prompt_len = len(input_ids[0])
                 generated_sequence = generated_sequence[prompt_len:]
                 generated_sequence = generated_sequence[prompt_len:]
                 gen_text = self.tokenizer.decode(generated_sequence, skip_special_tokens=True)
                 gen_text = self.tokenizer.decode(generated_sequence, skip_special_tokens=True)
-                # logger.error(f"input_text: {input_text}, gen_text: {gen_text}")
+                gen_text = gen_text.strip()
+                # logger.debug(f"input_text: {input_text}, gen_text: {gen_text}")
                 all_outputs.append(gen_text)
                 all_outputs.append(gen_text)
 
 
         return all_outputs
         return all_outputs
@@ -641,6 +642,7 @@ class GptModel:
             )
             )
             output_tensor = outputs[0][len(input_ids[0]):] if skip_prompt else outputs[0]
             output_tensor = outputs[0][len(input_ids[0]):] if skip_prompt else outputs[0]
             response = self.tokenizer.decode(output_tensor, skip_special_tokens=True)
             response = self.tokenizer.decode(output_tensor, skip_special_tokens=True)
+            response = response.strip()
             history[-1][1] = response
             history[-1][1] = response
             return response, history
             return response, history