123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327 |
- # -*- coding: utf-8 -*-
- """
- @author:XuMing(xuming624@qq.com)
- @description: corrector with pinyin and stroke
- """
- import operator
- import os
- from codecs import open
- from typing import List
- import pypinyin
- from loguru import logger
- from pycorrector.detector import Detector, ErrorType
- from pycorrector.utils.math_utils import edit_distance_word
- from pycorrector.utils.text_utils import is_chinese_string
- from pycorrector.utils.tokenizer import segment, split_text_into_sentences_by_symbol
- pwd_path = os.path.abspath(os.path.dirname(__file__))
- # 中文常用字符集
- common_char_path = os.path.join(pwd_path, 'data/common_char_set.txt')
- # 同音字
- same_pinyin_path = os.path.join(pwd_path, 'data/same_pinyin.txt')
- # 形似字
- same_stroke_path = os.path.join(pwd_path, 'data/same_stroke.txt')
- class Corrector(Detector):
- def __init__(
- self,
- common_char_path=common_char_path,
- same_pinyin_path=same_pinyin_path,
- same_stroke_path=same_stroke_path,
- **kwargs,
- ):
- super(Corrector, self).__init__(**kwargs)
- self.name = 'kenlm_corrector'
- self.common_char_path = common_char_path
- self.same_pinyin_path = same_pinyin_path
- self.same_stroke_path = same_stroke_path
- self.initialized_corrector = False
- self.cn_char_set = None
- self.same_pinyin = None
- self.same_stroke = None
- @staticmethod
- def load_set_file(path):
- words = set()
- with open(path, 'r', encoding='utf-8') as f:
- for w in f:
- w = w.strip()
- if w.startswith('#'):
- continue
- if w:
- words.add(w)
- return words
- @staticmethod
- def load_same_pinyin(path, sep='\t'):
- """
- 加载同音字
- :param path:
- :param sep:
- :return:
- """
- result = dict()
- if not os.path.exists(path):
- logger.warning(f"file not exists: {path}")
- return result
- with open(path, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.strip()
- if line.startswith('#'):
- continue
- parts = line.split(sep)
- if parts and len(parts) > 2:
- key_char = parts[0]
- same_pron_same_tone = set(list(parts[1]))
- same_pron_diff_tone = set(list(parts[2]))
- value = same_pron_same_tone.union(same_pron_diff_tone)
- if key_char and value:
- result[key_char] = value
- return result
- @staticmethod
- def load_same_stroke(path, sep='\t'):
- """
- 加载形似字
- :param path:
- :param sep:
- :return:
- """
- result = dict()
- if not os.path.exists(path):
- logger.warning(f"file not exists: {path}")
- return result
- with open(path, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.strip()
- if line.startswith('#'):
- continue
- parts = line.split(sep)
- if parts and len(parts) > 1:
- for i, c in enumerate(parts):
- exist = result.get(c, set())
- current = set(list(parts[:i] + parts[i + 1:]))
- result[c] = exist.union(current)
- return result
- def _initialize_corrector(self):
- # chinese common char
- self.cn_char_set = self.load_set_file(self.common_char_path)
- # same pinyin
- self.same_pinyin = self.load_same_pinyin(self.same_pinyin_path)
- # same stroke
- self.same_stroke = self.load_same_stroke(self.same_stroke_path)
- self.initialized_corrector = True
- def check_corrector_initialized(self):
- if not self.initialized_corrector:
- self._initialize_corrector()
- def get_same_pinyin(self, char):
- """
- 取同音字
- :param char:
- :return:
- """
- self.check_corrector_initialized()
- return self.same_pinyin.get(char, set())
- def get_same_stroke(self, char):
- """
- 取形似字
- :param char:
- :return:
- """
- self.check_corrector_initialized()
- return self.same_stroke.get(char, set())
- def known(self, words):
- """
- 取得词序列中属于常用词部分
- :param words:
- :return:
- """
- self.check_detector_initialized()
- return set(word for word in words if word in self.word_freq)
- def _confusion_char_set(self, c):
- return self.get_same_pinyin(c).union(self.get_same_stroke(c))
- def _confusion_word_set(self, word):
- confusion_word_set = set()
- candidate_words = list(self.known(edit_distance_word(word, self.cn_char_set)))
- for candidate_word in candidate_words:
- if pypinyin.lazy_pinyin(candidate_word) == pypinyin.lazy_pinyin(word):
- # same pinyin
- confusion_word_set.add(candidate_word)
- return confusion_word_set
- def _confusion_custom_set(self, word):
- confusion_word_set = set()
- if word in self.custom_confusion:
- confusion_word_set = {self.custom_confusion[word]}
- return confusion_word_set
- def generate_items(self, word, fragment=1):
- """
- 生成纠错候选集
- :param word:
- :param fragment: 分段
- :return:
- """
- self.check_corrector_initialized()
- # 1字
- candidates_1 = []
- # 2字
- candidates_2 = []
- # 多于2字
- candidates_3 = []
- # same pinyin word
- candidates_1.extend(self._confusion_word_set(word))
- # custom confusion word
- candidates_1.extend(self._confusion_custom_set(word))
- # get similarity char
- if len(word) == 1:
- # sim one char
- confusion = [i for i in self._confusion_char_set(word[0]) if i]
- candidates_1.extend(confusion)
- if len(word) == 2:
- # sim first char
- confusion_first = [i for i in self._confusion_char_set(word[0]) if i]
- candidates_2.extend([i + word[1] for i in confusion_first])
- # sim last char
- confusion_last = [i for i in self._confusion_char_set(word[1]) if i]
- candidates_2.extend([word[0] + i for i in confusion_last])
- # both change, sim char
- candidates_2.extend([i + j for i in confusion_first for j in confusion_last if i + j])
- # sim word
- # candidates_2.extend([i for i in self._confusion_word_set(word) if i])
- if len(word) > 2:
- # sim mid char
- confusion = [word[0] + i + word[2:] for i in self._confusion_char_set(word[1])]
- candidates_3.extend(confusion)
- # sim first word
- confusion_word = [i + word[-1] for i in self._confusion_word_set(word[:-1])]
- candidates_3.extend(confusion_word)
- # sim last word
- confusion_word = [word[0] + i for i in self._confusion_word_set(word[1:])]
- candidates_3.extend(confusion_word)
- # add all confusion word list
- confusion_word_set = set(candidates_1 + candidates_2 + candidates_3)
- confusion_word_list = [item for item in confusion_word_set if is_chinese_string(item)]
- confusion_sorted = sorted(confusion_word_list, key=lambda k: self.word_frequency(k), reverse=True)
- return confusion_sorted[:len(confusion_word_list) // fragment + 1]
- def get_lm_correct_item(self, cur_item, candidates, before_sent, after_sent, threshold=57.0, cut_type='char'):
- """
- 通过语言模型纠正字词错误
- :param cur_item: 当前词
- :param candidates: 候选词
- :param before_sent: 前半部分句子
- :param after_sent: 后半部分句子
- :param threshold: ppl阈值, 原始字词替换后大于该ppl值则认为是错误
- :param cut_type: 切词方式, 字粒度
- :return: str, correct item, 正确的字词
- """
- result = cur_item
- if cur_item not in candidates:
- candidates.append(cur_item)
- ppl_scores = {i: self.ppl_score(segment(before_sent + i + after_sent, cut_type=cut_type)) for i in candidates}
- sorted_ppl_scores = sorted(ppl_scores.items(), key=lambda d: d[1])
- # 增加正确字词的修正范围,减少误纠
- top_items = []
- top_score = 0.0
- for i, v in enumerate(sorted_ppl_scores):
- v_word = v[0]
- v_score = v[1]
- if i == 0:
- top_score = v_score
- top_items.append(v_word)
- # 通过阈值修正范围
- elif v_score < top_score + threshold:
- top_items.append(v_word)
- else:
- break
- if cur_item not in top_items:
- result = top_items[0]
- return result
- def correct(
- self,
- sentence: str,
- include_symbol: bool = True,
- num_fragment: int = 1,
- threshold: float = 57.0,
- **kwargs
- ):
- """
- 单条文本纠错
- 纠错逻辑:
- 1. 自定义混淆集
- 2. 专名错误
- 3. 字词错误
- :param sentence: str, query 文本
- :param include_symbol: bool, 是否包含标点符号
- :param num_fragment: 纠错候选集分段数, 1 / (num_fragment + 1)
- :param threshold: 语言模型纠错ppl阈值
- :param kwargs: ...
- :return: {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
- """
- corrected_sentence = ''
- details = []
- self.check_corrector_initialized()
- # 按标点符号切分短句
- short_sents = split_text_into_sentences_by_symbol(sentence, include_symbol=include_symbol)
- for sent, idx in short_sents:
- # 检错
- maybe_errors, proper_details = self._detect(sent, idx, **kwargs)
- for cur_item, begin_idx, end_idx, err_type in maybe_errors:
- # 纠错,逐个处理
- before_sent = sent[:(begin_idx - idx)]
- after_sent = sent[(end_idx - idx):]
- # 困惑集中指定的词,直接取结果
- if err_type == ErrorType.confusion:
- corrected_item = self.custom_confusion[cur_item]
- elif err_type == ErrorType.proper:
- # 专名错误 proper_details format: (error_word, corrected_word, begin_idx, end_idx)
- corrected_item = [i[1] for i in proper_details if cur_item == i[0] and begin_idx == i[2]][0]
- else:
- # 字词错误,找所有可能正确的词
- candidates = self.generate_items(cur_item, fragment=num_fragment)
- if not candidates:
- continue
- corrected_item = self.get_lm_correct_item(
- cur_item,
- candidates,
- before_sent,
- after_sent,
- threshold=threshold
- )
- # output
- if corrected_item != cur_item:
- sent = before_sent + corrected_item + after_sent
- details.append((cur_item, corrected_item, begin_idx))
- corrected_sentence += sent
- details = sorted(details, key=operator.itemgetter(2))
- return {'source': sentence, 'target': corrected_sentence, 'errors': details}
- def correct_batch(self, sentences: List[str], **kwargs):
- """
- 批量句子纠错
- :param sentences: 句子文本列表
- :param kwargs: 其他参数
- :return: list of {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
- """
- return [self.correct(s, **kwargs) for s in sentences]
|