utils.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pypinyin import lazy_pinyin, Style
  15. import paddle
  16. from paddlenlp.transformers import is_chinese_char
  17. def read_train_ds(data_path):
  18. with open(data_path, 'r', encoding='utf-8') as f:
  19. for line in f:
  20. source, target = line.strip('\n').split('\t')[0:2]
  21. yield {'source': source, 'target': target}
  22. def read_test_ds(data_path):
  23. with open(data_path, 'r', encoding='utf-8') as f:
  24. for line in f:
  25. ids, words = line.strip('\n').split('\t')[0:2]
  26. yield {'source': words}
  27. def convert_example(example,
  28. tokenizer,
  29. pinyin_vocab,
  30. max_seq_length=128,
  31. ignore_label=-1,
  32. is_test=False):
  33. source = example["source"]
  34. words = list(source)
  35. if len(words) > max_seq_length - 2:
  36. words = words[:max_seq_length - 2]
  37. length = len(words)
  38. words = ['[CLS]'] + words + ['[SEP]']
  39. input_ids = tokenizer.convert_tokens_to_ids(words)
  40. token_type_ids = [0] * len(input_ids)
  41. # Use pad token in pinyin emb to map word emb [CLS], [SEP]
  42. pinyins = lazy_pinyin(
  43. source, style=Style.TONE3, neutral_tone_with_five=True)
  44. pinyin_ids = [0]
  45. # Align pinyin and chinese char
  46. pinyin_offset = 0
  47. for i, word in enumerate(words[1:-1]):
  48. pinyin = '[UNK]' if word != '[PAD]' else '[PAD]'
  49. if len(word) == 1 and is_chinese_char(ord(word)):
  50. while pinyin_offset < len(pinyins):
  51. current_pinyin = pinyins[pinyin_offset][:-1]
  52. pinyin_offset += 1
  53. if current_pinyin in pinyin_vocab:
  54. pinyin = current_pinyin
  55. break
  56. pinyin_ids.append(pinyin_vocab[pinyin])
  57. pinyin_ids.append(0)
  58. assert len(input_ids) == len(
  59. pinyin_ids), "length of input_ids must be equal to length of pinyin_ids"
  60. if not is_test:
  61. target = example["target"]
  62. correction_labels = list(target)
  63. if len(correction_labels) > max_seq_length - 2:
  64. correction_labels = correction_labels[:max_seq_length - 2]
  65. correction_labels = tokenizer.convert_tokens_to_ids(correction_labels)
  66. correction_labels = [ignore_label] + correction_labels + [ignore_label]
  67. detection_labels = []
  68. for input_id, label in zip(input_ids[1:-1], correction_labels[1:-1]):
  69. detection_label = 0 if input_id == label else 1
  70. detection_labels += [detection_label]
  71. detection_labels = [ignore_label] + detection_labels + [ignore_label]
  72. return input_ids, token_type_ids, pinyin_ids, detection_labels, correction_labels, length
  73. else:
  74. return input_ids, token_type_ids, pinyin_ids, length
  75. def create_dataloader(dataset,
  76. mode='train',
  77. batch_size=1,
  78. batchify_fn=None,
  79. trans_fn=None):
  80. if trans_fn:
  81. dataset = dataset.map(trans_fn)
  82. shuffle = True if mode == 'train' else False
  83. if mode == 'train':
  84. batch_sampler = paddle.io.DistributedBatchSampler(
  85. dataset, batch_size=batch_size, shuffle=shuffle)
  86. else:
  87. batch_sampler = paddle.io.BatchSampler(
  88. dataset, batch_size=batch_size, shuffle=shuffle)
  89. return paddle.io.DataLoader(
  90. dataset=dataset,
  91. batch_sampler=batch_sampler,
  92. collate_fn=batchify_fn,
  93. return_list=True)
  94. def parse_decode(words, corr_preds, det_preds, lengths, tokenizer,
  95. max_seq_length):
  96. UNK = tokenizer.unk_token
  97. UNK_id = tokenizer.convert_tokens_to_ids(UNK)
  98. corr_pred = corr_preds[1:1 + lengths].tolist()
  99. det_pred = det_preds[1:1 + lengths].tolist()
  100. words = list(words)
  101. rest_words = []
  102. if len(words) > max_seq_length - 2:
  103. rest_words = words[max_seq_length - 2:]
  104. words = words[:max_seq_length - 2]
  105. pred_result = ""
  106. for j, word in enumerate(words):
  107. candidates = tokenizer.convert_ids_to_tokens(corr_pred[j] if corr_pred[
  108. j] < tokenizer.vocab_size else UNK_id)
  109. word_icc = is_chinese_char(ord(word))
  110. cand_icc = is_chinese_char(ord(candidates)) if len(
  111. candidates) == 1 else False
  112. if not word_icc or det_pred[j] == 0\
  113. or candidates in [UNK, '[PAD]']\
  114. or (word_icc and not cand_icc):
  115. pred_result += word
  116. else:
  117. pred_result += candidates.lstrip("##")
  118. pred_result += ''.join(rest_words)
  119. return pred_result