training_llama_demo.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description:
  5. """
  6. import argparse
  7. import sys
  8. from loguru import logger
  9. sys.path.append('../..')
  10. from pycorrector.gpt.gpt_model import GptModel
  11. from pycorrector.gpt.gpt_corrector import GptCorrector
  12. def main():
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('--train_file', default='../data/grammar/train_sharegpt.jsonl', type=str, help='Train file')
  15. parser.add_argument('--test_file', default='../data/grammar/test_sharegpt.jsonl', type=str, help='Test file')
  16. parser.add_argument('--model_type', default='llama', type=str, help='Transformers model type')
  17. parser.add_argument('--model_name', default='shibing624/chinese-alpaca-plus-7b-hf', type=str,
  18. help='Transformers model or path')
  19. parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
  20. parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
  21. parser.add_argument('--bf16', action='store_true', help='Whether to use bf16 mixed precision training.')
  22. parser.add_argument('--output_dir', default='./outputs-llama-demo/', type=str, help='Model output directory')
  23. parser.add_argument('--prompt_template_name', default='vicuna', type=str, help='Prompt template name')
  24. parser.add_argument('--max_seq_length', default=128, type=int, help='Input max sequence length')
  25. parser.add_argument('--max_length', default=128, type=int, help='Output max sequence length')
  26. parser.add_argument('--num_epochs', default=0.2, type=float, help='Number of training epochs')
  27. parser.add_argument('--batch_size', default=8, type=int, help='Batch size')
  28. parser.add_argument('--eval_steps', default=50, type=int, help='Eval every X steps')
  29. parser.add_argument('--save_steps', default=50, type=int, help='Save checkpoint every X steps')
  30. parser.add_argument("--local_rank", type=int, help="Used by dist launchers")
  31. args = parser.parse_args()
  32. logger.info(args)
  33. # fine-tune Llama model
  34. if args.do_train:
  35. logger.info('Loading data...')
  36. model_args = {
  37. "use_peft": True,
  38. "overwrite_output_dir": True,
  39. "reprocess_input_data": True,
  40. "max_seq_length": args.max_seq_length,
  41. "max_length": args.max_length,
  42. "per_device_train_batch_size": args.batch_size,
  43. "eval_batch_size": args.batch_size,
  44. "num_train_epochs": args.num_epochs,
  45. "output_dir": args.output_dir,
  46. "resume_from_checkpoint": args.output_dir,
  47. "eval_steps": args.eval_steps,
  48. "save_steps": args.save_steps,
  49. "bf16": args.bf16,
  50. "prompt_template_name": args.prompt_template_name,
  51. }
  52. model = GptModel(args.model_type, args.model_name, args=model_args)
  53. model.train_model(args.train_file, eval_data=args.test_file)
  54. if args.do_predict:
  55. error_sentences = [
  56. "美国总统特朗普访日,不仅吸引了美日民众的关注,中国人民也同样密切关注。",
  57. "这块名表带带相传",
  58. "少先队员因该为老人让坐",
  59. ]
  60. m = GptCorrector(
  61. model_name_or_path=args.model_name,
  62. model_type=args.model_type,
  63. peft_name=args.output_dir,
  64. args={'use_peft': True, 'eval_batch_size': args.batch_size, "max_length": args.max_length, }
  65. )
  66. result = m.correct_batch(error_sentences)
  67. for res_dict in result:
  68. print(res_dict)
  69. if __name__ == '__main__':
  70. main()