@@ -67,7 +67,7 @@ def main():
peft_name=args.output_dir,
args={'use_peft': True, 'eval_batch_size': args.batch_size, "max_length": args.max_length, }
)
- result = m.correct_batch(error_sentences)
+ result = m.correct_batch(error_sentences, prefix_prompt="对这个句子语法纠错\n\n")
for res_dict in result:
print(res_dict)