12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- # -*- coding: utf-8 -*-
- """
- @author:XuMing(xuming624@qq.com)
- @description:
- """
- import argparse
- import sys
- from loguru import logger
- sys.path.append('../..')
- from pycorrector.gpt.gpt_model import GptModel
- from pycorrector.gpt.gpt_corrector import GptCorrector
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument('--train_file', default='../data/grammar/train_sharegpt.jsonl', type=str, help='Train file')
- parser.add_argument('--test_file', default='../data/grammar/test_sharegpt.jsonl', type=str, help='Test file')
- parser.add_argument('--model_type', default='llama', type=str, help='Transformers model type')
- parser.add_argument('--model_name', default='shibing624/chinese-alpaca-plus-7b-hf', type=str,
- help='Transformers model or path')
- parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
- parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
- parser.add_argument('--bf16', action='store_true', help='Whether to use bf16 mixed precision training.')
- parser.add_argument('--output_dir', default='./outputs-llama-demo/', type=str, help='Model output directory')
- parser.add_argument('--prompt_template_name', default='vicuna', type=str, help='Prompt template name')
- parser.add_argument('--max_seq_length', default=128, type=int, help='Input max sequence length')
- parser.add_argument('--max_length', default=128, type=int, help='Output max sequence length')
- parser.add_argument('--num_epochs', default=0.2, type=float, help='Number of training epochs')
- parser.add_argument('--batch_size', default=8, type=int, help='Batch size')
- parser.add_argument('--eval_steps', default=50, type=int, help='Eval every X steps')
- parser.add_argument('--save_steps', default=50, type=int, help='Save checkpoint every X steps')
- parser.add_argument("--local_rank", type=int, help="Used by dist launchers")
- args = parser.parse_args()
- logger.info(args)
- # fine-tune Llama model
- if args.do_train:
- logger.info('Loading data...')
- model_args = {
- "use_peft": True,
- "overwrite_output_dir": True,
- "reprocess_input_data": True,
- "max_seq_length": args.max_seq_length,
- "max_length": args.max_length,
- "per_device_train_batch_size": args.batch_size,
- "eval_batch_size": args.batch_size,
- "num_train_epochs": args.num_epochs,
- "output_dir": args.output_dir,
- "resume_from_checkpoint": args.output_dir,
- "eval_steps": args.eval_steps,
- "save_steps": args.save_steps,
- "bf16": args.bf16,
- "prompt_template_name": args.prompt_template_name,
- }
- model = GptModel(args.model_type, args.model_name, args=model_args)
- model.train_model(args.train_file, eval_data=args.test_file)
- if args.do_predict:
- error_sentences = [
- "美国总统特朗普访日,不仅吸引了美日民众的关注,中国人民也同样密切关注。",
- "这块名表带带相传",
- "少先队员因该为老人让坐",
- ]
- m = GptCorrector(
- model_name_or_path=args.model_name,
- model_type=args.model_type,
- peft_name=args.output_dir,
- args={'use_peft': True, 'eval_batch_size': args.batch_size, "max_length": args.max_length, }
- )
- result = m.correct_batch(error_sentences)
- for res_dict in result:
- print(res_dict)
- if __name__ == '__main__':
- main()
|