train.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import random
  17. import sys
  18. import time
  19. from functools import partial
  20. import numpy as np
  21. import paddle
  22. from paddlenlp.data import Stack, Tuple, Pad, Vocab
  23. from paddlenlp.datasets import load_dataset, MapDataset
  24. from paddlenlp.metrics import DetectionF1, CorrectionF1
  25. from paddlenlp.transformers import ErnieModel, ErnieTokenizer
  26. from paddlenlp.transformers import LinearDecayWithWarmup
  27. from paddlenlp.utils.log import logger
  28. sys.path.append('../..')
  29. from pycorrector.ernie_csc.model import ErnieForCSC
  30. from pycorrector.ernie_csc.utils import convert_example, create_dataloader, read_train_ds
  31. # yapf: disable
  32. parser = argparse.ArgumentParser()
  33. parser.add_argument("--model_name_or_path", type=str, default="ernie-1.0", choices=["ernie-1.0"],
  34. help="Pretraining model name or path")
  35. parser.add_argument("--max_seq_length", type=int, default=128,
  36. help="The maximum total input sequence length after SentencePiece tokenization.")
  37. parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train.")
  38. parser.add_argument("--batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
  39. parser.add_argument("--save_steps", type=int, default=1000, help="Save checkpoint every X updates steps.")
  40. parser.add_argument("--logging_steps", type=int, default=1, help="Log every X updates steps.")
  41. parser.add_argument("--output_dir", type=str, default='checkpoints/', help="Directory to save model checkpoint")
  42. parser.add_argument("--epochs", type=int, default=3, help="Number of epoches for training.")
  43. parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu"],
  44. help="Select cpu, gpu devices to train model.")
  45. parser.add_argument("--seed", type=int, default=1, help="Random seed for initialization.")
  46. parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.")
  47. parser.add_argument("--warmup_proportion", default=0.1, type=float,
  48. help="Linear warmup proption over the training process.")
  49. parser.add_argument("--max_steps", default=-1, type=int,
  50. help="If > 0: set total number of training steps to perform. Override num_train_epochs.", )
  51. parser.add_argument("--pinyin_vocab_file_path", type=str, default="pinyin_vocab.txt", help="pinyin vocab file path")
  52. parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
  53. parser.add_argument("--ignore_label", default=-1, type=int, help="Ignore label for CrossEntropyLoss")
  54. parser.add_argument("--extra_train_ds_dir", default=None, type=str, help="The directory of extra train dataset.")
  55. # yapf: enable
  56. args = parser.parse_args()
  57. def set_seed(args):
  58. random.seed(args.seed)
  59. np.random.seed(args.seed)
  60. paddle.seed(args.seed)
  61. @paddle.no_grad()
  62. def evaluate(model, eval_data_loader):
  63. model.eval()
  64. det_metric = DetectionF1()
  65. corr_metric = CorrectionF1()
  66. for step, batch in enumerate(eval_data_loader, start=1):
  67. input_ids, token_type_ids, pinyin_ids, det_labels, corr_labels, length = batch
  68. # det_error_probs shape: [B, T, 2]
  69. # corr_logits shape: [B, T, V]
  70. det_error_probs, corr_logits = model(input_ids, pinyin_ids,
  71. token_type_ids)
  72. det_metric.update(det_error_probs, det_labels, length)
  73. corr_metric.update(det_error_probs, det_labels, corr_logits,
  74. corr_labels, length)
  75. det_f1, det_precision, det_recall = det_metric.accumulate()
  76. corr_f1, corr_precision, corr_recall = corr_metric.accumulate()
  77. logger.info("Sentence-Level Performance:")
  78. logger.info("Detection metric: F1={:.4f}, Recall={:.4f}, Precision={:.4f}".
  79. format(det_f1, det_recall, det_precision))
  80. logger.info("Correction metric: F1={:.4f}, Recall={:.4f}, Precision={:.4f}".
  81. format(corr_f1, corr_recall, corr_precision))
  82. model.train()
  83. return det_f1, corr_f1
  84. def do_train(args):
  85. set_seed(args)
  86. paddle.set_device(args.device)
  87. if paddle.distributed.get_world_size() > 1:
  88. paddle.distributed.init_parallel_env()
  89. pinyin_vocab = Vocab.load_vocabulary(
  90. args.pinyin_vocab_file_path, unk_token='[UNK]', pad_token='[PAD]')
  91. tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
  92. ernie = ErnieModel.from_pretrained(args.model_name_or_path)
  93. model = ErnieForCSC(
  94. ernie,
  95. pinyin_vocab_size=len(pinyin_vocab),
  96. pad_pinyin_id=pinyin_vocab[pinyin_vocab.pad_token])
  97. train_ds, eval_ds = load_dataset('sighan-cn', splits=['train', 'dev'])
  98. # Extend current training dataset by providing extra training
  99. # datasets directory. The suffix of dataset file name in extra
  100. # dataset directory has to be ".txt". The data format of
  101. # dataset need to be a couple of senteces at every line, such as:
  102. # "城府宫员表示,这是过去三十六小时内第三期强烈的余震。\t政府官员表示,这是过去三十六小时内第三起强烈的余震。\n"
  103. if args.extra_train_ds_dir is not None and os.path.exists(
  104. args.extra_train_ds_dir):
  105. data = train_ds.data
  106. data_files = [
  107. os.path.join(args.extra_train_ds_dir, data_file)
  108. for data_file in os.listdir(args.extra_train_ds_dir)
  109. if data_file.endswith(".txt")
  110. ]
  111. for data_file in data_files:
  112. ds = load_dataset(
  113. read_train_ds,
  114. data_path=data_file,
  115. splits=["train"],
  116. lazy=False)
  117. data += ds.data
  118. train_ds = MapDataset(data)
  119. det_loss_act = paddle.nn.CrossEntropyLoss(
  120. ignore_index=args.ignore_label, use_softmax=False)
  121. corr_loss_act = paddle.nn.CrossEntropyLoss(
  122. ignore_index=args.ignore_label, reduction='none')
  123. trans_func = partial(
  124. convert_example,
  125. tokenizer=tokenizer,
  126. pinyin_vocab=pinyin_vocab,
  127. max_seq_length=args.max_seq_length)
  128. batchify_fn = lambda samples, fn=Tuple(
  129. Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
  130. Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
  131. Pad(axis=0, pad_val=pinyin_vocab.token_to_idx[pinyin_vocab.pad_token]), # pinyin
  132. Pad(axis=0, dtype="int64"), # detection label
  133. Pad(axis=0, dtype="int64"), # correction label
  134. Stack(axis=0, dtype="int64") # length
  135. ): [data for data in fn(samples)]
  136. train_data_loader = create_dataloader(
  137. train_ds,
  138. mode='train',
  139. batch_size=args.batch_size,
  140. batchify_fn=batchify_fn,
  141. trans_fn=trans_func)
  142. eval_data_loader = create_dataloader(
  143. eval_ds,
  144. mode='eval',
  145. batch_size=args.batch_size,
  146. batchify_fn=batchify_fn,
  147. trans_fn=trans_func)
  148. num_training_steps = args.max_steps if args.max_steps > 0 else len(
  149. train_data_loader) * args.num_epochs
  150. lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
  151. args.warmup_proportion)
  152. logger.info("Total training step: {}".format(num_training_steps))
  153. # Generate parameter names needed to perform weight decay.
  154. # All bias and LayerNorm parameters are excluded.
  155. decay_params = [
  156. p.name for n, p in model.named_parameters()
  157. if not any(nd in n for nd in ["bias", "norm"])
  158. ]
  159. optimizer = paddle.optimizer.AdamW(
  160. learning_rate=lr_scheduler,
  161. epsilon=args.adam_epsilon,
  162. parameters=model.parameters(),
  163. weight_decay=args.weight_decay,
  164. apply_decay_param_fun=lambda x: x in decay_params)
  165. global_steps = 1
  166. best_f1 = -1
  167. tic_train = time.time()
  168. for epoch in range(args.num_epochs):
  169. for step, batch in enumerate(train_data_loader, start=1):
  170. input_ids, token_type_ids, pinyin_ids, det_labels, corr_labels, length = batch
  171. det_error_probs, corr_logits = model(input_ids, pinyin_ids,
  172. token_type_ids)
  173. # Chinese Spelling Correction has 2 tasks: detection task and correction task.
  174. # Detection task aims to detect whether each Chinese charater has spelling error.
  175. # Correction task aims to correct each potential wrong charater to right charater.
  176. # So we need to minimize detection loss and correction loss simultaneously.
  177. # See more loss design details on https://aclanthology.org/2021.findings-acl.198.pdf
  178. det_loss = det_loss_act(det_error_probs, det_labels)
  179. corr_loss = corr_loss_act(
  180. corr_logits, corr_labels) * det_error_probs.max(axis=-1)
  181. loss = (det_loss + corr_loss).mean()
  182. loss.backward()
  183. optimizer.step()
  184. lr_scheduler.step()
  185. optimizer.clear_grad()
  186. if global_steps % args.logging_steps == 0:
  187. logger.info(
  188. "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
  189. % (global_steps, epoch, step, loss,
  190. args.logging_steps / (time.time() - tic_train)))
  191. tic_train = time.time()
  192. if global_steps % args.save_steps == 0:
  193. if paddle.distributed.get_rank() == 0:
  194. logger.info("Eval:")
  195. det_f1, corr_f1 = evaluate(model, eval_data_loader)
  196. f1 = (det_f1 + corr_f1) / 2
  197. model_file = "model_%d" % global_steps
  198. if f1 > best_f1:
  199. # save best model
  200. paddle.save(model.state_dict(),
  201. os.path.join(args.output_dir,
  202. "best_model.pdparams"))
  203. logger.info("Save best model at {} step.".format(
  204. global_steps))
  205. best_f1 = f1
  206. model_file = model_file + "_best"
  207. model_file = model_file + ".pdparams"
  208. paddle.save(model.state_dict(),
  209. os.path.join(args.output_dir, model_file))
  210. logger.info("Save model at {} step.".format(global_steps))
  211. if args.max_steps > 0 and global_steps >= args.max_steps:
  212. return
  213. global_steps += 1
  214. if __name__ == "__main__":
  215. do_train(args)