123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- # -*- coding: utf-8 -*-
- """
- @author:XuMing(xuming624@qq.com)
- @description: train seq2seq model
- # #### PyTorch代码
- # - [seq2seq-tutorial](https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb)
- # - [Tutorial from Ben Trevett](https://github.com/bentrevett/pytorch-seq2seq)
- # - [IBM seq2seq](https://github.com/IBM/pytorch-seq2seq)
- # - [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py)
- # - [text-generation](https://github.com/shibing624/text-generation)
- """
- import argparse
- import sys
- from loguru import logger
- sys.path.append('../..')
- from pycorrector.seq2seq.conv_seq2seq_model import ConvSeq2SeqModel
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--train_file", default="../data/sighan_2015/train.tsv", type=str, help="Train file")
- parser.add_argument("--test_file", default="../data/sighan_2015/test.tsv", type=str, help="Test file")
- parser.add_argument("--do_train", action="store_true", help="Whether not to train")
- parser.add_argument("--do_predict", action="store_true", help="Whether not to predict")
- parser.add_argument("--output_dir", default="outputs-sighan-convseq2seq/", type=str, help="Dir for model save.")
- parser.add_argument("--max_length", default=128, type=int)
- parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
- parser.add_argument("--num_epochs", default=200, type=int, help="Epoch num.")
- args = parser.parse_args()
- logger.info(args)
- # Train model with train data file
- if args.do_train:
- logger.info('Loading data...')
- model = ConvSeq2SeqModel(
- num_epochs=args.num_epochs,
- batch_size=args.batch_size,
- model_dir=args.output_dir,
- max_length=args.max_length
- )
- model.train_model(args.train_file)
- model.eval_model(args.test_file)
- if args.do_predict:
- model = ConvSeq2SeqModel(
- model_dir=args.output_dir,
- max_length=args.max_length
- )
- sentences = [
- '老是较书。',
- '感谢等五分以后,碰到一位很棒的奴生跟我可聊。',
- '遇到一位很棒的奴生跟我聊天。',
- ]
- print("inputs:", sentences)
- print("outputs:", model.predict(sentences))
- if __name__ == '__main__':
- main()
|