corrector.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description: corrector with pinyin and stroke
  5. """
  6. import operator
  7. import os
  8. from codecs import open
  9. from typing import List
  10. import pypinyin
  11. from loguru import logger
  12. from pycorrector.detector import Detector, ErrorType
  13. from pycorrector.utils.math_utils import edit_distance_word
  14. from pycorrector.utils.text_utils import is_chinese_string
  15. from pycorrector.utils.tokenizer import segment, split_text_into_sentences_by_symbol
  16. pwd_path = os.path.abspath(os.path.dirname(__file__))
  17. # 中文常用字符集
  18. common_char_path = os.path.join(pwd_path, 'data/common_char_set.txt')
  19. # 同音字
  20. same_pinyin_path = os.path.join(pwd_path, 'data/same_pinyin.txt')
  21. # 形似字
  22. same_stroke_path = os.path.join(pwd_path, 'data/same_stroke.txt')
  23. class Corrector(Detector):
  24. def __init__(
  25. self,
  26. common_char_path=common_char_path,
  27. same_pinyin_path=same_pinyin_path,
  28. same_stroke_path=same_stroke_path,
  29. **kwargs,
  30. ):
  31. super(Corrector, self).__init__(**kwargs)
  32. self.name = 'kenlm_corrector'
  33. self.common_char_path = common_char_path
  34. self.same_pinyin_path = same_pinyin_path
  35. self.same_stroke_path = same_stroke_path
  36. self.initialized_corrector = False
  37. self.cn_char_set = None
  38. self.same_pinyin = None
  39. self.same_stroke = None
  40. @staticmethod
  41. def load_set_file(path):
  42. words = set()
  43. with open(path, 'r', encoding='utf-8') as f:
  44. for w in f:
  45. w = w.strip()
  46. if w.startswith('#'):
  47. continue
  48. if w:
  49. words.add(w)
  50. return words
  51. @staticmethod
  52. def load_same_pinyin(path, sep='\t'):
  53. """
  54. 加载同音字
  55. :param path:
  56. :param sep:
  57. :return:
  58. """
  59. result = dict()
  60. if not os.path.exists(path):
  61. logger.warning(f"file not exists: {path}")
  62. return result
  63. with open(path, 'r', encoding='utf-8') as f:
  64. for line in f:
  65. line = line.strip()
  66. if line.startswith('#'):
  67. continue
  68. parts = line.split(sep)
  69. if parts and len(parts) > 2:
  70. key_char = parts[0]
  71. same_pron_same_tone = set(list(parts[1]))
  72. same_pron_diff_tone = set(list(parts[2]))
  73. value = same_pron_same_tone.union(same_pron_diff_tone)
  74. if key_char and value:
  75. result[key_char] = value
  76. return result
  77. @staticmethod
  78. def load_same_stroke(path, sep='\t'):
  79. """
  80. 加载形似字
  81. :param path:
  82. :param sep:
  83. :return:
  84. """
  85. result = dict()
  86. if not os.path.exists(path):
  87. logger.warning(f"file not exists: {path}")
  88. return result
  89. with open(path, 'r', encoding='utf-8') as f:
  90. for line in f:
  91. line = line.strip()
  92. if line.startswith('#'):
  93. continue
  94. parts = line.split(sep)
  95. if parts and len(parts) > 1:
  96. for i, c in enumerate(parts):
  97. exist = result.get(c, set())
  98. current = set(list(parts[:i] + parts[i + 1:]))
  99. result[c] = exist.union(current)
  100. return result
  101. def _initialize_corrector(self):
  102. # chinese common char
  103. self.cn_char_set = self.load_set_file(self.common_char_path)
  104. # same pinyin
  105. self.same_pinyin = self.load_same_pinyin(self.same_pinyin_path)
  106. # same stroke
  107. self.same_stroke = self.load_same_stroke(self.same_stroke_path)
  108. self.initialized_corrector = True
  109. def check_corrector_initialized(self):
  110. if not self.initialized_corrector:
  111. self._initialize_corrector()
  112. def get_same_pinyin(self, char):
  113. """
  114. 取同音字
  115. :param char:
  116. :return:
  117. """
  118. self.check_corrector_initialized()
  119. return self.same_pinyin.get(char, set())
  120. def get_same_stroke(self, char):
  121. """
  122. 取形似字
  123. :param char:
  124. :return:
  125. """
  126. self.check_corrector_initialized()
  127. return self.same_stroke.get(char, set())
  128. def known(self, words):
  129. """
  130. 取得词序列中属于常用词部分
  131. :param words:
  132. :return:
  133. """
  134. self.check_detector_initialized()
  135. return set(word for word in words if word in self.word_freq)
  136. def _confusion_char_set(self, c):
  137. return self.get_same_pinyin(c).union(self.get_same_stroke(c))
  138. def _confusion_word_set(self, word):
  139. confusion_word_set = set()
  140. candidate_words = list(self.known(edit_distance_word(word, self.cn_char_set)))
  141. for candidate_word in candidate_words:
  142. if pypinyin.lazy_pinyin(candidate_word) == pypinyin.lazy_pinyin(word):
  143. # same pinyin
  144. confusion_word_set.add(candidate_word)
  145. return confusion_word_set
  146. def _confusion_custom_set(self, word):
  147. confusion_word_set = set()
  148. if word in self.custom_confusion:
  149. confusion_word_set = {self.custom_confusion[word]}
  150. return confusion_word_set
  151. def generate_items(self, word, fragment=1):
  152. """
  153. 生成纠错候选集
  154. :param word:
  155. :param fragment: 分段
  156. :return:
  157. """
  158. self.check_corrector_initialized()
  159. # 1字
  160. candidates_1 = []
  161. # 2字
  162. candidates_2 = []
  163. # 多于2字
  164. candidates_3 = []
  165. # same pinyin word
  166. candidates_1.extend(self._confusion_word_set(word))
  167. # custom confusion word
  168. candidates_1.extend(self._confusion_custom_set(word))
  169. # get similarity char
  170. if len(word) == 1:
  171. # sim one char
  172. confusion = [i for i in self._confusion_char_set(word[0]) if i]
  173. candidates_1.extend(confusion)
  174. if len(word) == 2:
  175. # sim first char
  176. confusion_first = [i for i in self._confusion_char_set(word[0]) if i]
  177. candidates_2.extend([i + word[1] for i in confusion_first])
  178. # sim last char
  179. confusion_last = [i for i in self._confusion_char_set(word[1]) if i]
  180. candidates_2.extend([word[0] + i for i in confusion_last])
  181. # both change, sim char
  182. candidates_2.extend([i + j for i in confusion_first for j in confusion_last if i + j])
  183. # sim word
  184. # candidates_2.extend([i for i in self._confusion_word_set(word) if i])
  185. if len(word) > 2:
  186. # sim mid char
  187. confusion = [word[0] + i + word[2:] for i in self._confusion_char_set(word[1])]
  188. candidates_3.extend(confusion)
  189. # sim first word
  190. confusion_word = [i + word[-1] for i in self._confusion_word_set(word[:-1])]
  191. candidates_3.extend(confusion_word)
  192. # sim last word
  193. confusion_word = [word[0] + i for i in self._confusion_word_set(word[1:])]
  194. candidates_3.extend(confusion_word)
  195. # add all confusion word list
  196. confusion_word_set = set(candidates_1 + candidates_2 + candidates_3)
  197. confusion_word_list = [item for item in confusion_word_set if is_chinese_string(item)]
  198. confusion_sorted = sorted(confusion_word_list, key=lambda k: self.word_frequency(k), reverse=True)
  199. return confusion_sorted[:len(confusion_word_list) // fragment + 1]
  200. def get_lm_correct_item(self, cur_item, candidates, before_sent, after_sent, threshold=57.0, cut_type='char'):
  201. """
  202. 通过语言模型纠正字词错误
  203. :param cur_item: 当前词
  204. :param candidates: 候选词
  205. :param before_sent: 前半部分句子
  206. :param after_sent: 后半部分句子
  207. :param threshold: ppl阈值, 原始字词替换后大于该ppl值则认为是错误
  208. :param cut_type: 切词方式, 字粒度
  209. :return: str, correct item, 正确的字词
  210. """
  211. result = cur_item
  212. if cur_item not in candidates:
  213. candidates.append(cur_item)
  214. ppl_scores = {i: self.ppl_score(segment(before_sent + i + after_sent, cut_type=cut_type)) for i in candidates}
  215. sorted_ppl_scores = sorted(ppl_scores.items(), key=lambda d: d[1])
  216. # 增加正确字词的修正范围,减少误纠
  217. top_items = []
  218. top_score = 0.0
  219. for i, v in enumerate(sorted_ppl_scores):
  220. v_word = v[0]
  221. v_score = v[1]
  222. if i == 0:
  223. top_score = v_score
  224. top_items.append(v_word)
  225. # 通过阈值修正范围
  226. elif v_score < top_score + threshold:
  227. top_items.append(v_word)
  228. else:
  229. break
  230. if cur_item not in top_items:
  231. result = top_items[0]
  232. return result
  233. def correct(
  234. self,
  235. sentence: str,
  236. include_symbol: bool = True,
  237. num_fragment: int = 1,
  238. threshold: float = 57.0,
  239. **kwargs
  240. ):
  241. """
  242. 单条文本纠错
  243. 纠错逻辑:
  244. 1. 自定义混淆集
  245. 2. 专名错误
  246. 3. 字词错误
  247. :param sentence: str, query 文本
  248. :param include_symbol: bool, 是否包含标点符号
  249. :param num_fragment: 纠错候选集分段数, 1 / (num_fragment + 1)
  250. :param threshold: 语言模型纠错ppl阈值
  251. :param kwargs: ...
  252. :return: {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
  253. """
  254. corrected_sentence = ''
  255. details = []
  256. self.check_corrector_initialized()
  257. # 按标点符号切分短句
  258. short_sents = split_text_into_sentences_by_symbol(sentence, include_symbol=include_symbol)
  259. for sent, idx in short_sents:
  260. # 检错
  261. maybe_errors, proper_details = self._detect(sent, idx, **kwargs)
  262. for cur_item, begin_idx, end_idx, err_type in maybe_errors:
  263. # 纠错,逐个处理
  264. before_sent = sent[:(begin_idx - idx)]
  265. after_sent = sent[(end_idx - idx):]
  266. # 困惑集中指定的词,直接取结果
  267. if err_type == ErrorType.confusion:
  268. corrected_item = self.custom_confusion[cur_item]
  269. elif err_type == ErrorType.proper:
  270. # 专名错误 proper_details format: (error_word, corrected_word, begin_idx, end_idx)
  271. corrected_item = [i[1] for i in proper_details if cur_item == i[0] and begin_idx == i[2]][0]
  272. else:
  273. # 字词错误,找所有可能正确的词
  274. candidates = self.generate_items(cur_item, fragment=num_fragment)
  275. if not candidates:
  276. continue
  277. corrected_item = self.get_lm_correct_item(
  278. cur_item,
  279. candidates,
  280. before_sent,
  281. after_sent,
  282. threshold=threshold
  283. )
  284. # output
  285. if corrected_item != cur_item:
  286. sent = before_sent + corrected_item + after_sent
  287. details.append((cur_item, corrected_item, begin_idx))
  288. corrected_sentence += sent
  289. details = sorted(details, key=operator.itemgetter(2))
  290. return {'source': sentence, 'target': corrected_sentence, 'errors': details}
  291. def correct_batch(self, sentences: List[str], **kwargs):
  292. """
  293. 批量句子纠错
  294. :param sentences: 句子文本列表
  295. :param kwargs: 其他参数
  296. :return: list of {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
  297. """
  298. return [self.correct(s, **kwargs) for s in sentences]