export_model.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # -*- coding: UTF-8 -*-
  2. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import sys
  17. import paddle
  18. from paddle.static import InputSpec
  19. from paddlenlp.data import Vocab
  20. from paddlenlp.transformers import ErnieModel
  21. sys.path.append('../..')
  22. from pycorrector.ernie_csc.model import ErnieForCSC
  23. # yapf: disable
  24. parser = argparse.ArgumentParser(__doc__)
  25. parser.add_argument("--params_path", type=str, default='./checkpoints/final.pdparams',
  26. help="The path of model parameter to be loaded.")
  27. parser.add_argument("--output_path", type=str, default='./infer_model/static_graph_params',
  28. help="The path of model parameter in static graph to be saved.")
  29. parser.add_argument("--model_name_or_path", type=str, default="ernie-1.0", choices=["ernie-1.0"],
  30. help="Pretraining model name or path")
  31. parser.add_argument("--pinyin_vocab_file_path", type=str, default="pinyin_vocab.txt", help="pinyin vocab file path")
  32. args = parser.parse_args()
  33. # yapf: enable
  34. def main():
  35. pinyin_vocab = Vocab.load_vocabulary(
  36. args.pinyin_vocab_file_path, unk_token='[UNK]', pad_token='[PAD]')
  37. ernie = ErnieModel.from_pretrained(args.model_name_or_path)
  38. model = ErnieForCSC(
  39. ernie,
  40. pinyin_vocab_size=len(pinyin_vocab),
  41. pad_pinyin_id=pinyin_vocab[pinyin_vocab.pad_token])
  42. model_dict = paddle.load(args.params_path)
  43. model.set_dict(model_dict)
  44. model.eval()
  45. model = paddle.jit.to_static(
  46. model,
  47. input_spec=[
  48. InputSpec(
  49. shape=[None, None], dtype="int64", name='input_ids'), InputSpec(
  50. shape=[None, None], dtype="int64", name='pinyin_ids')
  51. ])
  52. paddle.jit.save(model, args.output_path)
  53. if __name__ == "__main__":
  54. main()