model.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. class ErnieForCSC(nn.Layer):
  17. r"""
  18. ErnieForCSC is a model specified for Chinese Spelling Correction task.
  19. It integrates phonetic features into language model by leveraging the powerful
  20. pre-training and fine-tuning method.
  21. See more details on https://aclanthology.org/2021.findings-acl.198.pdf.
  22. Args:
  23. ernie (ErnieModel):
  24. An instance of `paddlenlp.transformers.ErnieModel`.
  25. pinyin_vocab_size (int):
  26. The vocab size of pinyin vocab.
  27. pad_pinyin_id (int, optional):
  28. The pad token id of pinyin vocab. Defaults to 0.
  29. """
  30. def __init__(self, ernie, pinyin_vocab_size, pad_pinyin_id=0):
  31. super(ErnieForCSC, self).__init__()
  32. self.ernie = ernie
  33. emb_size = self.ernie.config["hidden_size"]
  34. hidden_size = self.ernie.config["hidden_size"]
  35. vocab_size = self.ernie.config["vocab_size"]
  36. self.pad_token_id = self.ernie.config["pad_token_id"]
  37. self.pinyin_vocab_size = pinyin_vocab_size
  38. self.pad_pinyin_id = pad_pinyin_id
  39. self.pinyin_embeddings = nn.Embedding(
  40. self.pinyin_vocab_size, emb_size, padding_idx=pad_pinyin_id)
  41. self.detection_layer = nn.Linear(hidden_size, 2)
  42. self.correction_layer = nn.Linear(hidden_size, vocab_size)
  43. self.softmax = nn.Softmax()
  44. def forward(self,
  45. input_ids,
  46. pinyin_ids,
  47. token_type_ids=None,
  48. position_ids=None,
  49. attention_mask=None):
  50. r"""
  51. Args:
  52. input_ids (Tensor):
  53. Indices of input sequence tokens in the vocabulary. They are
  54. numerical representations of tokens that build the input sequence.
  55. It's data type should be `int64` and has a shape of [batch_size, sequence_length].
  56. pinyin_ids (Tensor):
  57. Indices of pinyin tokens of input sequence in the pinyin vocabulary. They are
  58. numerical representations of tokens that build the pinyin input sequence.
  59. It's data type should be `int64` and has a shape of [batch_size, sequence_length].
  60. token_type_ids (Tensor, optional):
  61. Segment token indices to indicate first and second portions of the inputs.
  62. Indices can be either 0 or 1:
  63. - 0 corresponds to a **sentence A** token,
  64. - 1 corresponds to a **sentence B** token.
  65. It's data type should be `int64` and has a shape of [batch_size, sequence_length].
  66. Defaults to None, which means no segment embeddings is added to token embeddings.
  67. position_ids (Tensor, optional):
  68. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
  69. config.max_position_embeddings - 1]``.
  70. Defaults to `None`. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`.
  71. attention_mask (Tensor, optional):
  72. Mask to indicate whether to perform attention on each input token or not.
  73. The values should be either 0 or 1. The attention scores will be set
  74. to **-infinity** for any positions in the mask that are **0**, and will be
  75. **unchanged** for positions that are **1**.
  76. - **1** for tokens that are **not masked**,
  77. - **0** for tokens that are **masked**.
  78. It's data type should be `float32` and has a shape of [batch_size, sequence_length].
  79. Defaults to `None`.
  80. Returns:
  81. detection_error_probs (Tensor):
  82. A Tensor of the detection probablity of each tokens.
  83. Shape as `(batch_size, sequence_length, 2)` and dtype as `int`.
  84. correction_logits (Tensor):
  85. A Tensor of the correction logits of each tokens.
  86. Shape as `(batch_size, sequence_length, vocab_size)` and dtype as `int`.
  87. """
  88. if attention_mask is None:
  89. attention_mask = paddle.unsqueeze(
  90. (input_ids == self.pad_token_id
  91. ).astype(self.detection_layer.weight.dtype) * -1e9,
  92. axis=[1, 2])
  93. embedding_output = self.ernie.embeddings(
  94. input_ids=input_ids,
  95. position_ids=position_ids,
  96. token_type_ids=token_type_ids)
  97. pinyin_embedding_output = self.pinyin_embeddings(pinyin_ids)
  98. # Detection module aims to detect whether each Chinese charater has spelling error.
  99. detection_outputs = self.ernie.encoder(embedding_output, attention_mask)
  100. # detection_error_probs shape: [B, T, 2]. It indicates the erroneous probablity of each
  101. # word in the sequence from 0 to 1.
  102. detection_error_probs = self.softmax(
  103. self.detection_layer(detection_outputs))
  104. # Correction module aims to correct each potential wrong charater to right charater.
  105. word_pinyin_embedding_output = detection_error_probs[:, :, 0:1] * embedding_output \
  106. + detection_error_probs[:,:, 1:2] * pinyin_embedding_output
  107. correction_outputs = self.ernie.encoder(word_pinyin_embedding_output,
  108. attention_mask)
  109. # correction_logits shape: [B, T, V]. It indicates the correct score of each token in vocab
  110. # according to each word in the sequence.
  111. correction_logits = self.correction_layer(correction_outputs)
  112. return detection_error_probs, correction_logits