predict.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # -*- coding: UTF-8 -*-
  2. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import sys
  17. from functools import partial
  18. import paddle
  19. from paddlenlp.data import Stack, Tuple, Pad, Vocab
  20. from paddlenlp.transformers import ErnieTokenizer
  21. sys.path.append('../..')
  22. from pycorrector.ernie_csc.utils import convert_example, parse_decode
  23. class Predictor:
  24. def __init__(self, model_file, params_file, device, max_seq_length,
  25. tokenizer, pinyin_vocab):
  26. self.max_seq_length = max_seq_length
  27. config = paddle.inference.Config(model_file, params_file)
  28. if device == "gpu":
  29. # set GPU configs accordingly
  30. config.enable_use_gpu(100, 0)
  31. elif device == "cpu":
  32. # set CPU configs accordingly,
  33. # such as enable_mkldnn, set_cpu_math_library_num_threads
  34. config.disable_gpu()
  35. config.switch_use_feed_fetch_ops(False)
  36. self.predictor = paddle.inference.create_predictor(config)
  37. self.input_handles = [
  38. self.predictor.get_input_handle(name)
  39. for name in self.predictor.get_input_names()
  40. ]
  41. self.det_error_probs_handle = self.predictor.get_output_handle(
  42. self.predictor.get_output_names()[0])
  43. self.corr_logits_handle = self.predictor.get_output_handle(
  44. self.predictor.get_output_names()[1])
  45. self.tokenizer = tokenizer
  46. self.pinyin_vocab = pinyin_vocab
  47. def predict(self, sentences, batch_size=1):
  48. """
  49. Predicts the data labels.
  50. Args:
  51. sentences (obj:`List(Example)`): The processed data and each element is a example (numedtuple) object.
  52. A Example object contains `text`(word_ids) and `seq_len`(sequence length).
  53. batch_size(obj:`int`, defaults to 1): The number of batch.
  54. Returns:
  55. results(obj:`dict`): All the predictions labels.
  56. """
  57. examples = []
  58. texts = []
  59. trans_func = partial(
  60. convert_example,
  61. tokenizer=self.tokenizer,
  62. pinyin_vocab=self.pinyin_vocab,
  63. max_seq_length=self.max_seq_length,
  64. is_test=True)
  65. batchify_fn = lambda samples, fn=Tuple(
  66. Pad(axis=0, pad_val=self.tokenizer.pad_token_id, dtype='int64'), # input
  67. Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id, dtype='int64'), # segment
  68. Pad(axis=0, pad_val=self.pinyin_vocab.token_to_idx[self.pinyin_vocab.pad_token], dtype='int64'), # pinyin
  69. Stack(axis=0, dtype='int64'), # length
  70. ): [data for data in fn(samples)]
  71. for text in sentences:
  72. example = {"source": text.strip()}
  73. input_ids, token_type_ids, pinyin_ids, length = trans_func(example)
  74. examples.append((input_ids, token_type_ids, pinyin_ids, length))
  75. texts.append(example["source"])
  76. batch_examples = [
  77. examples[idx:idx + batch_size]
  78. for idx in range(0, len(examples), batch_size)
  79. ]
  80. batch_texts = [
  81. texts[idx:idx + batch_size]
  82. for idx in range(0, len(examples), batch_size)
  83. ]
  84. results = []
  85. for examples, texts in zip(batch_examples, batch_texts):
  86. token_ids, token_type_ids, pinyin_ids, length = batchify_fn(
  87. examples)
  88. self.input_handles[0].copy_from_cpu(token_ids)
  89. self.input_handles[1].copy_from_cpu(pinyin_ids)
  90. self.predictor.run()
  91. det_error_probs = self.det_error_probs_handle.copy_to_cpu()
  92. corr_logits = self.corr_logits_handle.copy_to_cpu()
  93. det_pred = det_error_probs.argmax(axis=-1)
  94. char_preds = corr_logits.argmax(axis=-1)
  95. for i in range(len(length)):
  96. pred_result = parse_decode(texts[i], char_preds[i], det_pred[i],
  97. length[i], self.tokenizer,
  98. self.max_seq_length)
  99. results.append(''.join(pred_result))
  100. return results
  101. if __name__ == "__main__":
  102. # yapf: disable
  103. parser = argparse.ArgumentParser()
  104. parser.add_argument("--model_file", type=str, required=True, default='./static_graph_params.pdmodel',
  105. help="The path to model info in static graph.")
  106. parser.add_argument("--params_file", type=str, required=True, default='./static_graph_params.pdiparams',
  107. help="The path to parameters in static graph.")
  108. parser.add_argument("--batch_size", type=int, default=4, help="The number of sequences contained in a mini-batch.")
  109. parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.")
  110. parser.add_argument("--device", default="gpu", type=str, choices=["cpu", "gpu"],
  111. help="The device to select to train the model, is must be cpu/gpu.")
  112. parser.add_argument("--pinyin_vocab_file_path", type=str, default="pinyin_vocab.txt", help="pinyin vocab file path")
  113. args = parser.parse_args()
  114. # yapf: enable
  115. tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")
  116. pinyin_vocab = Vocab.load_vocabulary(
  117. args.pinyin_vocab_file_path, unk_token='[UNK]', pad_token='[PAD]')
  118. predictor = Predictor(args.model_file, args.params_file, args.device,
  119. args.max_seq_len, tokenizer, pinyin_vocab)
  120. samples = [
  121. '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
  122. '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。',
  123. ]
  124. results = predictor.predict(samples, batch_size=args.batch_size)
  125. for source, target in zip(samples, results):
  126. print("Source:", source)
  127. print("Target:", target)