123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from pypinyin import lazy_pinyin, Style
- import paddle
- from paddlenlp.transformers import is_chinese_char
- def read_train_ds(data_path):
- with open(data_path, 'r', encoding='utf-8') as f:
- for line in f:
- source, target = line.strip('\n').split('\t')[0:2]
- yield {'source': source, 'target': target}
- def read_test_ds(data_path):
- with open(data_path, 'r', encoding='utf-8') as f:
- for line in f:
- ids, words = line.strip('\n').split('\t')[0:2]
- yield {'source': words}
- def convert_example(example,
- tokenizer,
- pinyin_vocab,
- max_seq_length=128,
- ignore_label=-1,
- is_test=False):
- source = example["source"]
- words = list(source)
- if len(words) > max_seq_length - 2:
- words = words[:max_seq_length - 2]
- length = len(words)
- words = ['[CLS]'] + words + ['[SEP]']
- input_ids = tokenizer.convert_tokens_to_ids(words)
- token_type_ids = [0] * len(input_ids)
- # Use pad token in pinyin emb to map word emb [CLS], [SEP]
- pinyins = lazy_pinyin(
- source, style=Style.TONE3, neutral_tone_with_five=True)
- pinyin_ids = [0]
- # Align pinyin and chinese char
- pinyin_offset = 0
- for i, word in enumerate(words[1:-1]):
- pinyin = '[UNK]' if word != '[PAD]' else '[PAD]'
- if len(word) == 1 and is_chinese_char(ord(word)):
- while pinyin_offset < len(pinyins):
- current_pinyin = pinyins[pinyin_offset][:-1]
- pinyin_offset += 1
- if current_pinyin in pinyin_vocab:
- pinyin = current_pinyin
- break
- pinyin_ids.append(pinyin_vocab[pinyin])
- pinyin_ids.append(0)
- assert len(input_ids) == len(
- pinyin_ids), "length of input_ids must be equal to length of pinyin_ids"
- if not is_test:
- target = example["target"]
- correction_labels = list(target)
- if len(correction_labels) > max_seq_length - 2:
- correction_labels = correction_labels[:max_seq_length - 2]
- correction_labels = tokenizer.convert_tokens_to_ids(correction_labels)
- correction_labels = [ignore_label] + correction_labels + [ignore_label]
- detection_labels = []
- for input_id, label in zip(input_ids[1:-1], correction_labels[1:-1]):
- detection_label = 0 if input_id == label else 1
- detection_labels += [detection_label]
- detection_labels = [ignore_label] + detection_labels + [ignore_label]
- return input_ids, token_type_ids, pinyin_ids, detection_labels, correction_labels, length
- else:
- return input_ids, token_type_ids, pinyin_ids, length
- def create_dataloader(dataset,
- mode='train',
- batch_size=1,
- batchify_fn=None,
- trans_fn=None):
- if trans_fn:
- dataset = dataset.map(trans_fn)
- shuffle = True if mode == 'train' else False
- if mode == 'train':
- batch_sampler = paddle.io.DistributedBatchSampler(
- dataset, batch_size=batch_size, shuffle=shuffle)
- else:
- batch_sampler = paddle.io.BatchSampler(
- dataset, batch_size=batch_size, shuffle=shuffle)
- return paddle.io.DataLoader(
- dataset=dataset,
- batch_sampler=batch_sampler,
- collate_fn=batchify_fn,
- return_list=True)
- def parse_decode(words, corr_preds, det_preds, lengths, tokenizer,
- max_seq_length):
- UNK = tokenizer.unk_token
- UNK_id = tokenizer.convert_tokens_to_ids(UNK)
- corr_pred = corr_preds[1:1 + lengths].tolist()
- det_pred = det_preds[1:1 + lengths].tolist()
- words = list(words)
- rest_words = []
- if len(words) > max_seq_length - 2:
- rest_words = words[max_seq_length - 2:]
- words = words[:max_seq_length - 2]
- pred_result = ""
- for j, word in enumerate(words):
- candidates = tokenizer.convert_ids_to_tokens(corr_pred[j] if corr_pred[
- j] < tokenizer.vocab_size else UNK_id)
- word_icc = is_chinese_char(ord(word))
- cand_icc = is_chinese_char(ord(candidates)) if len(
- candidates) == 1 else False
- if not word_icc or det_pred[j] == 0\
- or candidates in [UNK, '[PAD]']\
- or (word_icc and not cand_icc):
- pred_result += word
- else:
- pred_result += candidates.lstrip("##")
- pred_result += ''.join(rest_words)
- return pred_result
|