detector.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description: error word detector
  5. """
  6. import os
  7. import re
  8. from codecs import open
  9. import numpy as np
  10. from loguru import logger
  11. from pycorrector.proper_corrector import ProperCorrector
  12. from pycorrector.utils.get_file import get_file
  13. from pycorrector.utils.text_utils import uniform, is_alphabet_string, is_chinese_string
  14. from pycorrector.utils.tokenizer import Tokenizer, split_text_into_sentences_by_symbol
  15. pwd_path = os.path.abspath(os.path.dirname(__file__))
  16. # -----用户目录,存储模型文件-----
  17. USER_DATA_DIR = os.environ.get('PYCORRECTOR_DATA_DIR', os.path.expanduser('~/.pycorrector/datasets'))
  18. os.makedirs(USER_DATA_DIR, exist_ok=True)
  19. language_model_path = os.path.join(USER_DATA_DIR, 'zh_giga.no_cna_cmn.prune01244.klm')
  20. # -----词典文件路径-----
  21. # 通用分词词典文件 format: 词语 词频
  22. word_freq_path = os.path.join(pwd_path, 'data/word_freq.txt')
  23. # 五笔笔画字典
  24. stroke_path = os.path.join(pwd_path, 'data/stroke.txt')
  25. # 知名人名词典 format: 词语 词频
  26. person_name_path = os.path.join(pwd_path, 'data/person_name.txt')
  27. # 地名词典 format: 词语 词频
  28. place_name_path = os.path.join(pwd_path, 'data/place_name.txt')
  29. # 专名词典,包括成语、俗语、专业领域词等 format: 词语
  30. proper_name_path = os.path.join(pwd_path, 'data/proper_name.txt')
  31. # 停用词
  32. stopwords_path = os.path.join(pwd_path, 'data/stopwords.txt')
  33. class ErrorType:
  34. confusion = 'confusion'
  35. word = 'word'
  36. char = 'char'
  37. proper = 'proper' # 专名纠错,包括成语纠错、人名纠错等
  38. class Detector:
  39. pretrained_language_models = {
  40. # 语言模型 2.95GB
  41. 'zh_giga.no_cna_cmn.prune01244.klm':
  42. 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
  43. # 人民日报训练语言模型 148MB
  44. 'people2014_corpus_chars.klm':
  45. 'https://github.com/shibing624/pycorrector/releases/download/1.0.0/people2014_corpus_chars.klm',
  46. # 人民日报训练语言模型(tiny) 20MB
  47. 'people_chars_lm.klm':
  48. 'https://github.com/shibing624/pycorrector/releases/download/0.4.3/people_chars_lm.klm',
  49. }
  50. def __init__(
  51. self,
  52. language_model_path=language_model_path,
  53. word_freq_path=word_freq_path,
  54. custom_word_freq_path='',
  55. custom_confusion_path_or_dict='',
  56. person_name_path=person_name_path,
  57. place_name_path=place_name_path,
  58. stopwords_path=stopwords_path,
  59. proper_name_path=proper_name_path,
  60. stroke_path=stroke_path
  61. ):
  62. self.name = 'detector'
  63. self.language_model_path = language_model_path
  64. self.word_freq_path = word_freq_path
  65. self.custom_word_freq_path = custom_word_freq_path
  66. self.custom_confusion_path_or_dict = custom_confusion_path_or_dict
  67. self.person_name_path = person_name_path
  68. self.place_name_path = place_name_path
  69. self.stopwords_path = stopwords_path
  70. self.is_char_error_detect = True
  71. self.is_word_error_detect = True
  72. self.initialized_detector = False
  73. self.lm = None
  74. self.word_freq = None
  75. self.custom_confusion = None
  76. self.custom_word_freq = None
  77. self.person_names = None
  78. self.place_names = None
  79. self.stopwords = None
  80. self.tokenizer = None
  81. self.proper_corrector = None
  82. self.proper_name_path = proper_name_path
  83. self.stroke_path = stroke_path
  84. def _initialize_detector(self):
  85. try:
  86. import kenlm
  87. except ImportError:
  88. raise ImportError(
  89. 'pycorrector dependencies are not fully installed, '
  90. 'they are required for statistical language model.'
  91. 'Please use "pip install kenlm" to install it.'
  92. 'if you are Win, Please install kenlm in cgwin.'
  93. )
  94. if not os.path.exists(self.language_model_path):
  95. filename = self.pretrained_language_models.get(
  96. self.language_model_path, 'zh_giga.no_cna_cmn.prune01244.klm'
  97. )
  98. url = self.pretrained_language_models.get(filename)
  99. self.language_model_path = get_file(
  100. filename, url, extract=True,
  101. cache_dir='~',
  102. cache_subdir=USER_DATA_DIR,
  103. verbose=1
  104. )
  105. self.lm = kenlm.Model(self.language_model_path)
  106. # 词、频数dict
  107. self.word_freq = self.load_word_freq_dict(self.word_freq_path)
  108. # 自定义混淆集
  109. if isinstance(self.custom_confusion_path_or_dict, dict):
  110. self.custom_confusion = self.custom_confusion_path_or_dict
  111. for k, v in self.custom_confusion.items():
  112. self.word_freq[v] = self.word_freq.get(v, 1)
  113. elif isinstance(self.custom_confusion_path_or_dict, str):
  114. self.custom_confusion = self._get_custom_confusion_dict(self.custom_confusion_path_or_dict)
  115. else:
  116. raise ValueError('custom_confusion_path_or_dict must be dict or str.')
  117. # 自定义切词词典
  118. self.custom_word_freq = self.load_word_freq_dict(self.custom_word_freq_path)
  119. self.person_names = self.load_word_freq_dict(self.person_name_path)
  120. self.place_names = self.load_word_freq_dict(self.place_name_path)
  121. self.stopwords = self.load_word_freq_dict(self.stopwords_path)
  122. # 合并切词词典及自定义词典
  123. self.custom_word_freq.update(self.person_names)
  124. self.custom_word_freq.update(self.place_names)
  125. self.custom_word_freq.update(self.stopwords)
  126. self.word_freq.update(self.custom_word_freq)
  127. self.tokenizer = Tokenizer(
  128. dict_path=self.word_freq_path,
  129. custom_word_freq_dict=self.custom_word_freq,
  130. custom_confusion_dict=self.custom_confusion
  131. )
  132. self.proper_corrector = ProperCorrector(
  133. proper_name_path=self.proper_name_path,
  134. stroke_path=self.stroke_path
  135. )
  136. self.initialized_detector = True
  137. def check_detector_initialized(self):
  138. if not self.initialized_detector:
  139. self._initialize_detector()
  140. @staticmethod
  141. def load_word_freq_dict(path):
  142. """
  143. 加载切词词典
  144. :param path:
  145. :return:
  146. """
  147. word_freq = {}
  148. if path:
  149. if not os.path.exists(path):
  150. logger.warning('file not found.%s' % path)
  151. return word_freq
  152. else:
  153. with open(path, 'r', encoding='utf-8') as f:
  154. for line in f:
  155. line = line.strip()
  156. if line.startswith('#'):
  157. continue
  158. info = line.split()
  159. if len(info) < 1:
  160. continue
  161. word = info[0]
  162. # 取词频,默认1
  163. freq = int(info[1]) if len(info) > 1 else 1
  164. word_freq[word] = freq
  165. return word_freq
  166. def _get_custom_confusion_dict(self, path):
  167. """
  168. 取自定义困惑集
  169. :param path:
  170. :return: dict, {variant: origin}, eg: {"交通先行": "交通限行"}
  171. """
  172. confusion = {}
  173. if path:
  174. if not os.path.exists(path):
  175. logger.warning('file not found.%s' % path)
  176. return confusion
  177. else:
  178. with open(path, 'r', encoding='utf-8') as f:
  179. for line in f:
  180. line = line.strip()
  181. if line.startswith('#'):
  182. continue
  183. info = line.split()
  184. if len(info) < 2:
  185. continue
  186. variant = info[0]
  187. origin = info[1]
  188. freq = int(info[2]) if len(info) > 2 else 1
  189. self.word_freq[origin] = freq
  190. confusion[variant] = origin
  191. return confusion
  192. def set_language_model_path(self, path):
  193. self.check_detector_initialized()
  194. import kenlm
  195. self.lm = kenlm.Model(path)
  196. logger.debug('Loaded language model: %s' % path)
  197. def set_custom_confusion_path_or_dict(self, data):
  198. self.check_detector_initialized()
  199. if isinstance(data, dict):
  200. self.custom_confusion = data
  201. for k, v in self.custom_confusion.items():
  202. self.word_freq[v] = self.word_freq.get(v, 1)
  203. elif isinstance(data, str):
  204. self.custom_confusion = self._get_custom_confusion_dict(data)
  205. else:
  206. raise ValueError('custom_confusion_path_or_dict must be dict or str.')
  207. logger.debug('Loaded confusion size: %d' % len(self.custom_confusion))
  208. def set_custom_word_freq(self, path):
  209. self.check_detector_initialized()
  210. word_freqs = self.load_word_freq_dict(path)
  211. # 合并字典
  212. self.custom_word_freq.update(word_freqs)
  213. # 合并切词词典及自定义词典
  214. self.word_freq.update(self.custom_word_freq)
  215. self.tokenizer = Tokenizer(dict_path=self.word_freq_path, custom_word_freq_dict=self.custom_word_freq,
  216. custom_confusion_dict=self.custom_confusion)
  217. for k, v in word_freqs.items():
  218. self.set_word_frequency(k, v)
  219. logger.debug('Loaded custom word path: %s, size: %d' % (path, len(word_freqs)))
  220. def enable_char_error(self, enable=True):
  221. """
  222. is open char error detect
  223. :param enable:
  224. :return:
  225. """
  226. self.is_char_error_detect = enable
  227. def enable_word_error(self, enable=True):
  228. """
  229. is open word error detect
  230. :param enable:
  231. :return:
  232. """
  233. self.is_word_error_detect = enable
  234. def ngram_score(self, chars):
  235. """
  236. 取n元文法得分
  237. :param chars: list, 以词或字切分
  238. :return:
  239. """
  240. self.check_detector_initialized()
  241. return self.lm.score(' '.join(chars), bos=False, eos=False)
  242. def ppl_score(self, words):
  243. """
  244. 取语言模型困惑度得分,越小句子越通顺
  245. :param words: list, 以词或字切分
  246. :return:
  247. """
  248. self.check_detector_initialized()
  249. return self.lm.perplexity(' '.join(words))
  250. def word_frequency(self, word):
  251. """
  252. 取词在样本中的词频
  253. :param word:
  254. :return: dict
  255. """
  256. self.check_detector_initialized()
  257. return self.word_freq.get(word, 0)
  258. def set_word_frequency(self, word, num):
  259. """
  260. 更新在样本中的词频
  261. """
  262. self.check_detector_initialized()
  263. self.word_freq[word] = num
  264. return self.word_freq
  265. @staticmethod
  266. def _check_contain_error(maybe_err, maybe_errors):
  267. """
  268. 检测错误集合(maybe_errors)是否已经包含该错误位置(maybe_err)
  269. :param maybe_err: [error_word, begin_pos, end_pos, error_type]
  270. :param maybe_errors: list
  271. :return: bool
  272. """
  273. error_word_idx = 0
  274. begin_idx = 1
  275. end_idx = 2
  276. for err in maybe_errors:
  277. if maybe_err[error_word_idx] in err[error_word_idx] and maybe_err[begin_idx] >= err[begin_idx] and \
  278. maybe_err[end_idx] <= err[end_idx]:
  279. return True
  280. return False
  281. def _add_maybe_error_item(self, maybe_err, maybe_errors):
  282. """
  283. 新增错误
  284. :param maybe_err:
  285. :param maybe_errors:
  286. :return:
  287. """
  288. if maybe_err not in maybe_errors and not self._check_contain_error(maybe_err, maybe_errors):
  289. maybe_errors.append(maybe_err)
  290. @staticmethod
  291. def _get_maybe_error_index(scores, ratio=0.6745, threshold=2):
  292. """
  293. 取疑似错字的位置,通过平均绝对离差(MAD)
  294. :param scores: np.array
  295. :param ratio: 正态分布表参数
  296. :param threshold: 阈值越小,得到疑似错别字越多
  297. :return: 全部疑似错误字的index: list
  298. """
  299. result = []
  300. scores = np.array(scores)
  301. if len(scores.shape) == 1:
  302. scores = scores[:, None]
  303. median = np.median(scores, axis=0) # get median of all scores
  304. margin_median = np.abs(scores - median).flatten() # deviation from the median
  305. # 平均绝对离差值
  306. med_abs_deviation = np.median(margin_median)
  307. if med_abs_deviation == 0:
  308. return result
  309. y_score = ratio * margin_median / med_abs_deviation
  310. # 打平
  311. scores = scores.flatten()
  312. maybe_error_indices = np.where((y_score > threshold) & (scores < median))
  313. # 取全部疑似错误字的index
  314. result = [int(i) for i in maybe_error_indices[0]]
  315. return result
  316. @staticmethod
  317. def _get_maybe_error_index_by_stddev(scores, n=2):
  318. """
  319. 取疑似错字的位置,通过平均值上下n倍标准差之间属于正常点
  320. :param scores: list, float
  321. :param n: n倍
  322. :return: 全部疑似错误字的index: list
  323. """
  324. std = np.std(scores, ddof=1)
  325. mean = np.mean(scores)
  326. down_limit = mean - n * std
  327. upper_limit = mean + n * std
  328. maybe_error_indices = np.where((scores > upper_limit) | (scores < down_limit))
  329. # 取全部疑似错误字的index
  330. result = list(maybe_error_indices[0])
  331. return result
  332. @staticmethod
  333. def is_filter_token(token):
  334. """
  335. 是否为需过滤字词
  336. :param token: 字词
  337. :return: bool
  338. """
  339. result = False
  340. # pass blank
  341. if not token.strip():
  342. result = True
  343. # pass num
  344. if token.isdigit():
  345. result = True
  346. # pass alpha
  347. if is_alphabet_string(token.lower()):
  348. result = True
  349. # pass not chinese
  350. if not is_chinese_string(token):
  351. result = True
  352. return result
  353. def _detect(self, sentence, start_idx=0, **kwargs):
  354. """
  355. 检测句子中的疑似错误字词,包括[词、位置、错误类型]
  356. 检测逻辑:
  357. 1. 自定义混淆集
  358. 2. 专名错误检测
  359. 3. 词错误
  360. 4. 字错误
  361. :param sentence:
  362. :param start_idx:
  363. :return: list[list], [error_word, begin_pos, end_pos, error_type]
  364. """
  365. maybe_errors = []
  366. # 初始化
  367. self.check_detector_initialized()
  368. # 1. 自定义混淆集加入疑似错误词典
  369. for confuse in self.custom_confusion:
  370. for i in re.finditer(confuse, sentence):
  371. maybe_err = [confuse, i.span()[0] + start_idx, i.span()[1] + start_idx, ErrorType.confusion]
  372. self._add_maybe_error_item(maybe_err, maybe_errors)
  373. # 2. 专名错误检测
  374. proper_details = self.proper_corrector.correct(sentence, start_idx=start_idx, **kwargs)['errors']
  375. for error_word, corrected_word, begin_idx in proper_details:
  376. end_idx = begin_idx + len(error_word)
  377. maybe_err = [error_word, begin_idx, end_idx, ErrorType.proper]
  378. self._add_maybe_error_item(maybe_err, maybe_errors)
  379. # 3. 词错误
  380. if self.is_word_error_detect:
  381. tokens = self.tokenizer.tokenize(sentence)
  382. # 未登录词加入疑似错误词典
  383. for token, begin_idx, end_idx in tokens:
  384. # pass filter word
  385. if self.is_filter_token(token):
  386. continue
  387. # pass in dict
  388. if token in self.word_freq:
  389. continue
  390. maybe_err = [token, begin_idx + start_idx, end_idx + start_idx, ErrorType.word]
  391. self._add_maybe_error_item(maybe_err, maybe_errors)
  392. # 4. 字错误,语言模型检测疑似错误字
  393. if self.is_char_error_detect:
  394. try:
  395. ngram_avg_scores = []
  396. for n in [2, 3]:
  397. scores = []
  398. for i in range(len(sentence) - n + 1):
  399. word = sentence[i:i + n]
  400. score = self.ngram_score(list(word))
  401. scores.append(score)
  402. if not scores:
  403. continue
  404. # 移动窗口补全得分
  405. for _ in range(n - 1):
  406. scores.insert(0, scores[0])
  407. scores.append(scores[-1])
  408. avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]
  409. ngram_avg_scores.append(avg_scores)
  410. if ngram_avg_scores:
  411. # 取拼接后的n-gram平均得分
  412. sent_scores = list(np.average(np.array(ngram_avg_scores), axis=0))
  413. # 取疑似错字信息
  414. for i in self._get_maybe_error_index(sent_scores):
  415. token = sentence[i]
  416. # pass filter word
  417. if self.is_filter_token(token):
  418. continue
  419. # pass in stop word dict
  420. if token in self.stopwords:
  421. continue
  422. # token, begin_idx, end_idx, error_type
  423. maybe_err = [token, i + start_idx, i + start_idx + 1, ErrorType.char]
  424. self._add_maybe_error_item(maybe_err, maybe_errors)
  425. except IndexError as ie:
  426. logger.warning("index error, sentence:" + sentence + str(ie))
  427. except Exception as e:
  428. logger.warning("detect error, sentence:" + sentence + str(e))
  429. return sorted(maybe_errors, key=lambda k: k[1], reverse=False), proper_details
  430. def detect(self, sentence):
  431. """
  432. 文本错误检测
  433. :param sentence: 句子
  434. :return: 错误index
  435. """
  436. maybe_errors = []
  437. if not sentence.strip():
  438. return maybe_errors
  439. # 文本归一化
  440. sentence = uniform(sentence)
  441. # 文本切分为句子
  442. short_sents = split_text_into_sentences_by_symbol(sentence)
  443. for sent, idx in short_sents:
  444. maybe_errors += self._detect(sent, idx)[0]
  445. return maybe_errors