123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- # -*- coding: utf-8 -*-
- """
- @author:XuMing(xuming624@qq.com)
- @description:
- """
- import argparse
- import sys
- sys.path.append("../..")
- from pycorrector.deepcontext.deepcontext_model import DeepContextModel
- from pycorrector.deepcontext.deepcontext_corrector import DeepContextCorrector
- def main():
- parser = argparse.ArgumentParser()
- # Required parameters
- parser.add_argument("--train_path", default="../data/wiki_zh_200.txt", type=str,
- help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
- parser.add_argument("--do_train", action="store_true", help="Whether not to train")
- parser.add_argument("--do_predict", action="store_true", help="Whether not to predict")
- parser.add_argument("--output_dir", default="outputs-deepcontext-lm/", type=str, help="Dir for model save.")
- # Other parameters
- parser.add_argument("--max_length", default=1024, type=int, help="Max length of input sentence.")
- parser.add_argument("--batch_size", default=512, type=int, help="Batch size.")
- parser.add_argument("--min_freq", default=1, type=int, help="Mini word frequency.")
- parser.add_argument("--dropout", default=0.5, type=float, help="Dropout rate.")
- parser.add_argument("--num_epochs", default=80, type=int, help="Epoch num.")
- args = parser.parse_args()
- print(args)
- # Train
- if args.do_train:
- m = DeepContextModel(args.output_dir, max_length=args.max_length)
- # Train model with train data file
- m.train_model(
- args.train_path,
- batch_size=args.batch_size,
- num_epochs=args.num_epochs,
- min_freq=args.min_freq,
- dropout=args.dropout
- )
- sent = '老是较书。'
- pred_words_res = m.predict_mask_token(list(sent), mask_index=2)
- print(sent, pred_words_res)
- # Predict
- if args.do_predict:
- m = DeepContextCorrector(args.output_dir, max_length=args.max_length)
- inputs = [
- '老是较书。',
- '感谢等五分以后,碰到一位很棒的奴生跟我可聊。',
- '遇到一位很棒的奴生跟我聊天。',
- '遇到一位很美的女生跟我疗天。',
- '他们只能有两个选择:接受降新或自动离职。',
- '王天华开心得一直说话。'
- ]
- for i in inputs:
- output = m.correct(i)
- print(output)
- print()
- if __name__ == "__main__":
- main()
|