train.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description:
  5. """
  6. import argparse
  7. import sys
  8. sys.path.append("../..")
  9. from pycorrector.deepcontext.deepcontext_model import DeepContextModel
  10. from pycorrector.deepcontext.deepcontext_corrector import DeepContextCorrector
  11. def main():
  12. parser = argparse.ArgumentParser()
  13. # Required parameters
  14. parser.add_argument("--train_path", default="../data/wiki_zh_200.txt", type=str,
  15. help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
  16. parser.add_argument("--do_train", action="store_true", help="Whether not to train")
  17. parser.add_argument("--do_predict", action="store_true", help="Whether not to predict")
  18. parser.add_argument("--output_dir", default="outputs-deepcontext-lm/", type=str, help="Dir for model save.")
  19. # Other parameters
  20. parser.add_argument("--max_length", default=1024, type=int, help="Max length of input sentence.")
  21. parser.add_argument("--batch_size", default=512, type=int, help="Batch size.")
  22. parser.add_argument("--min_freq", default=1, type=int, help="Mini word frequency.")
  23. parser.add_argument("--dropout", default=0.5, type=float, help="Dropout rate.")
  24. parser.add_argument("--num_epochs", default=80, type=int, help="Epoch num.")
  25. args = parser.parse_args()
  26. print(args)
  27. # Train
  28. if args.do_train:
  29. m = DeepContextModel(args.output_dir, max_length=args.max_length)
  30. # Train model with train data file
  31. m.train_model(
  32. args.train_path,
  33. batch_size=args.batch_size,
  34. num_epochs=args.num_epochs,
  35. min_freq=args.min_freq,
  36. dropout=args.dropout
  37. )
  38. sent = '老是较书。'
  39. pred_words_res = m.predict_mask_token(list(sent), mask_index=2)
  40. print(sent, pred_words_res)
  41. # Predict
  42. if args.do_predict:
  43. m = DeepContextCorrector(args.output_dir, max_length=args.max_length)
  44. inputs = [
  45. '老是较书。',
  46. '感谢等五分以后,碰到一位很棒的奴生跟我可聊。',
  47. '遇到一位很棒的奴生跟我聊天。',
  48. '遇到一位很美的女生跟我疗天。',
  49. '他们只能有两个选择:接受降新或自动离职。',
  50. '王天华开心得一直说话。'
  51. ]
  52. for i in inputs:
  53. output = m.correct(i)
  54. print(output)
  55. print()
  56. if __name__ == "__main__":
  57. main()