train.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description: train seq2seq model
  5. # #### PyTorch代码
  6. # - [seq2seq-tutorial](https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb)
  7. # - [Tutorial from Ben Trevett](https://github.com/bentrevett/pytorch-seq2seq)
  8. # - [IBM seq2seq](https://github.com/IBM/pytorch-seq2seq)
  9. # - [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py)
  10. # - [text-generation](https://github.com/shibing624/text-generation)
  11. """
  12. import argparse
  13. import sys
  14. from loguru import logger
  15. sys.path.append('../..')
  16. from pycorrector.seq2seq.conv_seq2seq_model import ConvSeq2SeqModel
  17. def main():
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("--train_file", default="../data/sighan_2015/train.tsv", type=str, help="Train file")
  20. parser.add_argument("--test_file", default="../data/sighan_2015/test.tsv", type=str, help="Test file")
  21. parser.add_argument("--do_train", action="store_true", help="Whether not to train")
  22. parser.add_argument("--do_predict", action="store_true", help="Whether not to predict")
  23. parser.add_argument("--output_dir", default="outputs-sighan-convseq2seq/", type=str, help="Dir for model save.")
  24. parser.add_argument("--max_length", default=128, type=int)
  25. parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
  26. parser.add_argument("--num_epochs", default=200, type=int, help="Epoch num.")
  27. args = parser.parse_args()
  28. logger.info(args)
  29. # Train model with train data file
  30. if args.do_train:
  31. logger.info('Loading data...')
  32. model = ConvSeq2SeqModel(
  33. num_epochs=args.num_epochs,
  34. batch_size=args.batch_size,
  35. model_dir=args.output_dir,
  36. max_length=args.max_length
  37. )
  38. model.train_model(args.train_file)
  39. model.eval_model(args.test_file)
  40. if args.do_predict:
  41. model = ConvSeq2SeqModel(
  42. model_dir=args.output_dir,
  43. max_length=args.max_length
  44. )
  45. sentences = [
  46. '老是较书。',
  47. '感谢等五分以后,碰到一位很棒的奴生跟我可聊。',
  48. '遇到一位很棒的奴生跟我聊天。',
  49. ]
  50. print("inputs:", sentences)
  51. print("outputs:", model.predict(sentences))
  52. if __name__ == '__main__':
  53. main()