123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480 |
- # -*- coding: utf-8 -*-
- """
- @author:XuMing(xuming624@qq.com)
- @description: error word detector
- """
- import os
- import re
- from codecs import open
- import numpy as np
- from loguru import logger
- from pycorrector.proper_corrector import ProperCorrector
- from pycorrector.utils.get_file import get_file
- from pycorrector.utils.text_utils import uniform, is_alphabet_string, is_chinese_string
- from pycorrector.utils.tokenizer import Tokenizer, split_text_into_sentences_by_symbol
- pwd_path = os.path.abspath(os.path.dirname(__file__))
- # -----用户目录,存储模型文件-----
- USER_DATA_DIR = os.environ.get('PYCORRECTOR_DATA_DIR', os.path.expanduser('~/.pycorrector/datasets'))
- os.makedirs(USER_DATA_DIR, exist_ok=True)
- language_model_path = os.path.join(USER_DATA_DIR, 'zh_giga.no_cna_cmn.prune01244.klm')
- # -----词典文件路径-----
- # 通用分词词典文件 format: 词语 词频
- word_freq_path = os.path.join(pwd_path, 'data/word_freq.txt')
- # 五笔笔画字典
- stroke_path = os.path.join(pwd_path, 'data/stroke.txt')
- # 知名人名词典 format: 词语 词频
- person_name_path = os.path.join(pwd_path, 'data/person_name.txt')
- # 地名词典 format: 词语 词频
- place_name_path = os.path.join(pwd_path, 'data/place_name.txt')
- # 专名词典,包括成语、俗语、专业领域词等 format: 词语
- proper_name_path = os.path.join(pwd_path, 'data/proper_name.txt')
- # 停用词
- stopwords_path = os.path.join(pwd_path, 'data/stopwords.txt')
- class ErrorType:
- confusion = 'confusion'
- word = 'word'
- char = 'char'
- proper = 'proper' # 专名纠错,包括成语纠错、人名纠错等
- class Detector:
- pretrained_language_models = {
- # 语言模型 2.95GB
- 'zh_giga.no_cna_cmn.prune01244.klm':
- 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
- # 人民日报训练语言模型 148MB
- 'people2014_corpus_chars.klm':
- 'https://github.com/shibing624/pycorrector/releases/download/1.0.0/people2014_corpus_chars.klm',
- # 人民日报训练语言模型(tiny) 20MB
- 'people_chars_lm.klm':
- 'https://github.com/shibing624/pycorrector/releases/download/0.4.3/people_chars_lm.klm',
- }
- def __init__(
- self,
- language_model_path=language_model_path,
- word_freq_path=word_freq_path,
- custom_word_freq_path='',
- custom_confusion_path_or_dict='',
- person_name_path=person_name_path,
- place_name_path=place_name_path,
- stopwords_path=stopwords_path,
- proper_name_path=proper_name_path,
- stroke_path=stroke_path
- ):
- self.name = 'detector'
- self.language_model_path = language_model_path
- self.word_freq_path = word_freq_path
- self.custom_word_freq_path = custom_word_freq_path
- self.custom_confusion_path_or_dict = custom_confusion_path_or_dict
- self.person_name_path = person_name_path
- self.place_name_path = place_name_path
- self.stopwords_path = stopwords_path
- self.is_char_error_detect = True
- self.is_word_error_detect = True
- self.initialized_detector = False
- self.lm = None
- self.word_freq = None
- self.custom_confusion = None
- self.custom_word_freq = None
- self.person_names = None
- self.place_names = None
- self.stopwords = None
- self.tokenizer = None
- self.proper_corrector = None
- self.proper_name_path = proper_name_path
- self.stroke_path = stroke_path
- def _initialize_detector(self):
- try:
- import kenlm
- except ImportError:
- raise ImportError(
- 'pycorrector dependencies are not fully installed, '
- 'they are required for statistical language model.'
- 'Please use "pip install kenlm" to install it.'
- 'if you are Win, Please install kenlm in cgwin.'
- )
- if not os.path.exists(self.language_model_path):
- filename = self.pretrained_language_models.get(
- self.language_model_path, 'zh_giga.no_cna_cmn.prune01244.klm'
- )
- url = self.pretrained_language_models.get(filename)
- self.language_model_path = get_file(
- filename, url, extract=True,
- cache_dir='~',
- cache_subdir=USER_DATA_DIR,
- verbose=1
- )
- self.lm = kenlm.Model(self.language_model_path)
- # 词、频数dict
- self.word_freq = self.load_word_freq_dict(self.word_freq_path)
- # 自定义混淆集
- if isinstance(self.custom_confusion_path_or_dict, dict):
- self.custom_confusion = self.custom_confusion_path_or_dict
- for k, v in self.custom_confusion.items():
- self.word_freq[v] = self.word_freq.get(v, 1)
- elif isinstance(self.custom_confusion_path_or_dict, str):
- self.custom_confusion = self._get_custom_confusion_dict(self.custom_confusion_path_or_dict)
- else:
- raise ValueError('custom_confusion_path_or_dict must be dict or str.')
- # 自定义切词词典
- self.custom_word_freq = self.load_word_freq_dict(self.custom_word_freq_path)
- self.person_names = self.load_word_freq_dict(self.person_name_path)
- self.place_names = self.load_word_freq_dict(self.place_name_path)
- self.stopwords = self.load_word_freq_dict(self.stopwords_path)
- # 合并切词词典及自定义词典
- self.custom_word_freq.update(self.person_names)
- self.custom_word_freq.update(self.place_names)
- self.custom_word_freq.update(self.stopwords)
- self.word_freq.update(self.custom_word_freq)
- self.tokenizer = Tokenizer(
- dict_path=self.word_freq_path,
- custom_word_freq_dict=self.custom_word_freq,
- custom_confusion_dict=self.custom_confusion
- )
- self.proper_corrector = ProperCorrector(
- proper_name_path=self.proper_name_path,
- stroke_path=self.stroke_path
- )
- self.initialized_detector = True
- def check_detector_initialized(self):
- if not self.initialized_detector:
- self._initialize_detector()
- @staticmethod
- def load_word_freq_dict(path):
- """
- 加载切词词典
- :param path:
- :return:
- """
- word_freq = {}
- if path:
- if not os.path.exists(path):
- logger.warning('file not found.%s' % path)
- return word_freq
- else:
- with open(path, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.strip()
- if line.startswith('#'):
- continue
- info = line.split()
- if len(info) < 1:
- continue
- word = info[0]
- # 取词频,默认1
- freq = int(info[1]) if len(info) > 1 else 1
- word_freq[word] = freq
- return word_freq
- def _get_custom_confusion_dict(self, path):
- """
- 取自定义困惑集
- :param path:
- :return: dict, {variant: origin}, eg: {"交通先行": "交通限行"}
- """
- confusion = {}
- if path:
- if not os.path.exists(path):
- logger.warning('file not found.%s' % path)
- return confusion
- else:
- with open(path, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.strip()
- if line.startswith('#'):
- continue
- info = line.split()
- if len(info) < 2:
- continue
- variant = info[0]
- origin = info[1]
- freq = int(info[2]) if len(info) > 2 else 1
- self.word_freq[origin] = freq
- confusion[variant] = origin
- return confusion
- def set_language_model_path(self, path):
- self.check_detector_initialized()
- import kenlm
- self.lm = kenlm.Model(path)
- logger.debug('Loaded language model: %s' % path)
- def set_custom_confusion_path_or_dict(self, data):
- self.check_detector_initialized()
- if isinstance(data, dict):
- self.custom_confusion = data
- for k, v in self.custom_confusion.items():
- self.word_freq[v] = self.word_freq.get(v, 1)
- elif isinstance(data, str):
- self.custom_confusion = self._get_custom_confusion_dict(data)
- else:
- raise ValueError('custom_confusion_path_or_dict must be dict or str.')
- logger.debug('Loaded confusion size: %d' % len(self.custom_confusion))
- def set_custom_word_freq(self, path):
- self.check_detector_initialized()
- word_freqs = self.load_word_freq_dict(path)
- # 合并字典
- self.custom_word_freq.update(word_freqs)
- # 合并切词词典及自定义词典
- self.word_freq.update(self.custom_word_freq)
- self.tokenizer = Tokenizer(dict_path=self.word_freq_path, custom_word_freq_dict=self.custom_word_freq,
- custom_confusion_dict=self.custom_confusion)
- for k, v in word_freqs.items():
- self.set_word_frequency(k, v)
- logger.debug('Loaded custom word path: %s, size: %d' % (path, len(word_freqs)))
- def enable_char_error(self, enable=True):
- """
- is open char error detect
- :param enable:
- :return:
- """
- self.is_char_error_detect = enable
- def enable_word_error(self, enable=True):
- """
- is open word error detect
- :param enable:
- :return:
- """
- self.is_word_error_detect = enable
- def ngram_score(self, chars):
- """
- 取n元文法得分
- :param chars: list, 以词或字切分
- :return:
- """
- self.check_detector_initialized()
- return self.lm.score(' '.join(chars), bos=False, eos=False)
- def ppl_score(self, words):
- """
- 取语言模型困惑度得分,越小句子越通顺
- :param words: list, 以词或字切分
- :return:
- """
- self.check_detector_initialized()
- return self.lm.perplexity(' '.join(words))
- def word_frequency(self, word):
- """
- 取词在样本中的词频
- :param word:
- :return: dict
- """
- self.check_detector_initialized()
- return self.word_freq.get(word, 0)
- def set_word_frequency(self, word, num):
- """
- 更新在样本中的词频
- """
- self.check_detector_initialized()
- self.word_freq[word] = num
- return self.word_freq
- @staticmethod
- def _check_contain_error(maybe_err, maybe_errors):
- """
- 检测错误集合(maybe_errors)是否已经包含该错误位置(maybe_err)
- :param maybe_err: [error_word, begin_pos, end_pos, error_type]
- :param maybe_errors: list
- :return: bool
- """
- error_word_idx = 0
- begin_idx = 1
- end_idx = 2
- for err in maybe_errors:
- if maybe_err[error_word_idx] in err[error_word_idx] and maybe_err[begin_idx] >= err[begin_idx] and \
- maybe_err[end_idx] <= err[end_idx]:
- return True
- return False
- def _add_maybe_error_item(self, maybe_err, maybe_errors):
- """
- 新增错误
- :param maybe_err:
- :param maybe_errors:
- :return:
- """
- if maybe_err not in maybe_errors and not self._check_contain_error(maybe_err, maybe_errors):
- maybe_errors.append(maybe_err)
- @staticmethod
- def _get_maybe_error_index(scores, ratio=0.6745, threshold=2):
- """
- 取疑似错字的位置,通过平均绝对离差(MAD)
- :param scores: np.array
- :param ratio: 正态分布表参数
- :param threshold: 阈值越小,得到疑似错别字越多
- :return: 全部疑似错误字的index: list
- """
- result = []
- scores = np.array(scores)
- if len(scores.shape) == 1:
- scores = scores[:, None]
- median = np.median(scores, axis=0) # get median of all scores
- margin_median = np.abs(scores - median).flatten() # deviation from the median
- # 平均绝对离差值
- med_abs_deviation = np.median(margin_median)
- if med_abs_deviation == 0:
- return result
- y_score = ratio * margin_median / med_abs_deviation
- # 打平
- scores = scores.flatten()
- maybe_error_indices = np.where((y_score > threshold) & (scores < median))
- # 取全部疑似错误字的index
- result = [int(i) for i in maybe_error_indices[0]]
- return result
- @staticmethod
- def _get_maybe_error_index_by_stddev(scores, n=2):
- """
- 取疑似错字的位置,通过平均值上下n倍标准差之间属于正常点
- :param scores: list, float
- :param n: n倍
- :return: 全部疑似错误字的index: list
- """
- std = np.std(scores, ddof=1)
- mean = np.mean(scores)
- down_limit = mean - n * std
- upper_limit = mean + n * std
- maybe_error_indices = np.where((scores > upper_limit) | (scores < down_limit))
- # 取全部疑似错误字的index
- result = list(maybe_error_indices[0])
- return result
- @staticmethod
- def is_filter_token(token):
- """
- 是否为需过滤字词
- :param token: 字词
- :return: bool
- """
- result = False
- # pass blank
- if not token.strip():
- result = True
- # pass num
- if token.isdigit():
- result = True
- # pass alpha
- if is_alphabet_string(token.lower()):
- result = True
- # pass not chinese
- if not is_chinese_string(token):
- result = True
- return result
- def _detect(self, sentence, start_idx=0, **kwargs):
- """
- 检测句子中的疑似错误字词,包括[词、位置、错误类型]
- 检测逻辑:
- 1. 自定义混淆集
- 2. 专名错误检测
- 3. 词错误
- 4. 字错误
- :param sentence:
- :param start_idx:
- :return: list[list], [error_word, begin_pos, end_pos, error_type]
- """
- maybe_errors = []
- # 初始化
- self.check_detector_initialized()
- # 1. 自定义混淆集加入疑似错误词典
- for confuse in self.custom_confusion:
- for i in re.finditer(confuse, sentence):
- maybe_err = [confuse, i.span()[0] + start_idx, i.span()[1] + start_idx, ErrorType.confusion]
- self._add_maybe_error_item(maybe_err, maybe_errors)
- # 2. 专名错误检测
- proper_details = self.proper_corrector.correct(sentence, start_idx=start_idx, **kwargs)['errors']
- for error_word, corrected_word, begin_idx in proper_details:
- end_idx = begin_idx + len(error_word)
- maybe_err = [error_word, begin_idx, end_idx, ErrorType.proper]
- self._add_maybe_error_item(maybe_err, maybe_errors)
- # 3. 词错误
- if self.is_word_error_detect:
- tokens = self.tokenizer.tokenize(sentence)
- # 未登录词加入疑似错误词典
- for token, begin_idx, end_idx in tokens:
- # pass filter word
- if self.is_filter_token(token):
- continue
- # pass in dict
- if token in self.word_freq:
- continue
- maybe_err = [token, begin_idx + start_idx, end_idx + start_idx, ErrorType.word]
- self._add_maybe_error_item(maybe_err, maybe_errors)
- # 4. 字错误,语言模型检测疑似错误字
- if self.is_char_error_detect:
- try:
- ngram_avg_scores = []
- for n in [2, 3]:
- scores = []
- for i in range(len(sentence) - n + 1):
- word = sentence[i:i + n]
- score = self.ngram_score(list(word))
- scores.append(score)
- if not scores:
- continue
- # 移动窗口补全得分
- for _ in range(n - 1):
- scores.insert(0, scores[0])
- scores.append(scores[-1])
- avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]
- ngram_avg_scores.append(avg_scores)
- if ngram_avg_scores:
- # 取拼接后的n-gram平均得分
- sent_scores = list(np.average(np.array(ngram_avg_scores), axis=0))
- # 取疑似错字信息
- for i in self._get_maybe_error_index(sent_scores):
- token = sentence[i]
- # pass filter word
- if self.is_filter_token(token):
- continue
- # pass in stop word dict
- if token in self.stopwords:
- continue
- # token, begin_idx, end_idx, error_type
- maybe_err = [token, i + start_idx, i + start_idx + 1, ErrorType.char]
- self._add_maybe_error_item(maybe_err, maybe_errors)
- except IndexError as ie:
- logger.warning("index error, sentence:" + sentence + str(ie))
- except Exception as e:
- logger.warning("detect error, sentence:" + sentence + str(e))
- return sorted(maybe_errors, key=lambda k: k[1], reverse=False), proper_details
- def detect(self, sentence):
- """
- 文本错误检测
- :param sentence: 句子
- :return: 错误index
- """
- maybe_errors = []
- if not sentence.strip():
- return maybe_errors
- # 文本归一化
- sentence = uniform(sentence)
- # 文本切分为句子
- short_sents = split_text_into_sentences_by_symbol(sentence)
- for sent, idx in short_sents:
- maybe_errors += self._detect(sent, idx)[0]
- return maybe_errors
|