123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- import re
- import time
- import asyncio
- from typing import List, Tuple
- from abc import abstractmethod
- from ..utils import InfererModule, ModelWrapper, repeating_sequence, is_valuable_text
- try:
- import readline
- except Exception:
- readline = None
- VALID_LANGUAGES = {
- 'CHS': 'Chinese (Simplified)',
- 'CHT': 'Chinese (Traditional)',
- 'CSY': 'Czech',
- 'NLD': 'Dutch',
- 'ENG': 'English',
- 'FRA': 'French',
- 'DEU': 'German',
- 'HUN': 'Hungarian',
- 'ITA': 'Italian',
- 'JPN': 'Japanese',
- 'KOR': 'Korean',
- 'PLK': 'Polish',
- 'PTB': 'Portuguese (Brazil)',
- 'ROM': 'Romanian',
- 'RUS': 'Russian',
- 'ESP': 'Spanish',
- 'TRK': 'Turkish',
- 'UKR': 'Ukrainian',
- 'VIN': 'Vietnamese',
- 'ARA': 'Arabic',
- 'CNR': 'Montenegrin',
- 'SRP': 'Serbian',
- 'HRV': 'Croatian',
- 'THA': 'Thai',
- 'IND': 'Indonesian',
- 'FIL': 'Filipino (Tagalog)'
- }
- ISO_639_1_TO_VALID_LANGUAGES = {
- 'zh': 'CHS',
- 'ja': 'JPN',
- 'en': 'ENG',
- 'ko': 'KOR',
- 'vi': 'VIN',
- 'cs': 'CSY',
- 'nl': 'NLD',
- 'fr': 'FRA',
- 'de': 'DEU',
- 'hu': 'HUN',
- 'it': 'ITA',
- 'pl': 'PLK',
- 'pt': 'PTB',
- 'ro': 'ROM',
- 'ru': 'RUS',
- 'es': 'ESP',
- 'tr': 'TRK',
- 'uk': 'UKR',
- 'vi': 'VIN',
- 'ar': 'ARA',
- 'cnr': 'CNR',
- 'sr': 'SRP',
- 'hr': 'HRV',
- 'th': 'THA',
- 'id': 'IND',
- 'tl': 'FIL'
- }
- class InvalidServerResponse(Exception):
- pass
- class MissingAPIKeyException(Exception):
- pass
- class LanguageUnsupportedException(Exception):
- def __init__(self, language_code: str, translator: str = None, supported_languages: List[str] = None):
- error = 'Language not supported for %s: "%s"' % (translator if translator else 'chosen translator', language_code)
- if supported_languages:
- error += '. Supported languages: "%s"' % ','.join(supported_languages)
- super().__init__(error)
- class MTPEAdapter():
- async def dispatch(self, queries: List[str], translations: List[str]) -> List[str]:
- # TODO: Make it work in windows (e.g. through os.startfile)
- if not readline:
- print('MTPE is currently only supported on linux')
- return translations
- new_translations = []
- print('Running Machine Translation Post Editing (MTPE)')
- for i, (query, translation) in enumerate(zip(queries, translations)):
- print(f'\n[{i + 1}/{len(queries)}] {query}:')
- readline.set_startup_hook(lambda: readline.insert_text(translation.replace('\n', '\\n')))
- new_translation = ''
- try:
- new_translation = input(' -> ').replace('\\n', '\n')
- finally:
- readline.set_startup_hook()
- new_translations.append(new_translation)
- print()
- return new_translations
- class CommonTranslator(InfererModule):
- # Translator has to support all languages listed in here. The language codes will be resolved into
- # _LANGUAGE_CODE_MAP[lang_code] automatically if _LANGUAGE_CODE_MAP is a dict.
- # If it is a list it will simply return the language code as is.
- _LANGUAGE_CODE_MAP = {}
- # The amount of repeats upon detecting an invalid translation.
- # Use with _is_translation_invalid and _modify_invalid_translation_query.
- _INVALID_REPEAT_COUNT = 0
- # Will sleep for the rest of the minute if the request count is over this number.
- _MAX_REQUESTS_PER_MINUTE = -1
- def __init__(self):
- super().__init__()
- self.mtpe_adapter = MTPEAdapter()
- self._last_request_ts = 0
- def supports_languages(self, from_lang: str, to_lang: str, fatal: bool = False) -> bool:
- supported_src_languages = ['auto'] + list(self._LANGUAGE_CODE_MAP)
- supported_tgt_languages = list(self._LANGUAGE_CODE_MAP)
- if from_lang not in supported_src_languages:
- if fatal:
- raise LanguageUnsupportedException(from_lang, self.__class__.__name__, supported_src_languages)
- return False
- if to_lang not in supported_tgt_languages:
- if fatal:
- raise LanguageUnsupportedException(to_lang, self.__class__.__name__, supported_tgt_languages)
- return False
- return True
- def parse_language_codes(self, from_lang: str, to_lang: str, fatal: bool = False) -> Tuple[str, str]:
- if not self.supports_languages(from_lang, to_lang, fatal):
- return None, None
- if type(self._LANGUAGE_CODE_MAP) is list:
- return from_lang, to_lang
- _from_lang = self._LANGUAGE_CODE_MAP.get(from_lang) if from_lang != 'auto' else 'auto'
- _to_lang = self._LANGUAGE_CODE_MAP.get(to_lang)
- return _from_lang, _to_lang
- async def translate(self, from_lang: str, to_lang: str, queries: List[str], use_mtpe: bool = False) -> List[str]:
- """
- Translates list of queries of one language into another.
- """
- if to_lang not in VALID_LANGUAGES:
- raise ValueError('Invalid language code: "%s". Choose from the following: %s' % (to_lang, ', '.join(VALID_LANGUAGES)))
- if from_lang not in VALID_LANGUAGES and from_lang != 'auto':
- raise ValueError('Invalid language code: "%s". Choose from the following: auto, %s' % (from_lang, ', '.join(VALID_LANGUAGES)))
- self.logger.info(f'Translating into {VALID_LANGUAGES[to_lang]}')
- if from_lang == to_lang:
- return queries
- # Dont translate queries without text
- query_indices = []
- final_translations = []
- for i, query in enumerate(queries):
- if not is_valuable_text(query):
- final_translations.append(queries[i])
- else:
- final_translations.append(None)
- query_indices.append(i)
- queries = [queries[i] for i in query_indices]
- translations = [''] * len(queries)
- untranslated_indices = list(range(len(queries)))
- for i in range(1 + self._INVALID_REPEAT_COUNT): # Repeat until all translations are considered valid
- if i > 0:
- self.logger.warn(f'Repeating because of invalid translation. Attempt: {i+1}')
- await asyncio.sleep(0.1)
- # Sleep if speed is over the ratelimit
- await self._ratelimit_sleep()
- # Translate
- _translations = await self._translate(*self.parse_language_codes(from_lang, to_lang, fatal=True), queries)
- # Extend returned translations list to have the same size as queries
- if len(_translations) < len(queries):
- _translations.extend([''] * (len(queries) - len(_translations)))
- elif len(_translations) > len(queries):
- _translations = _translations[:len(queries)]
- # Only overwrite yet untranslated indices
- for j in untranslated_indices:
- translations[j] = _translations[j]
- if self._INVALID_REPEAT_COUNT == 0:
- break
- new_untranslated_indices = []
- for j in untranslated_indices:
- q, t = queries[j], translations[j]
- # Repeat invalid translations with slightly modified queries
- if self._is_translation_invalid(q, t):
- new_untranslated_indices.append(j)
- queries[j] = self._modify_invalid_translation_query(q, t)
- untranslated_indices = new_untranslated_indices
- if not untranslated_indices:
- break
- translations = [self._clean_translation_output(q, r, to_lang) for q, r in zip(queries, translations)]
- if to_lang == 'ARA':
- import arabic_reshaper
- translations = [arabic_reshaper.reshape(t) for t in translations]
- if use_mtpe:
- translations = await self.mtpe_adapter.dispatch(queries, translations)
- # Merge with the queries without text
- for i, trans in enumerate(translations):
- final_translations[query_indices[i]] = trans
- self.logger.info(f'{i}: {queries[i]} => {trans}')
- return final_translations
- @abstractmethod
- async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
- pass
- async def _ratelimit_sleep(self):
- if self._MAX_REQUESTS_PER_MINUTE > 0:
- now = time.time()
- ratelimit_timeout = self._last_request_ts + 60 / self._MAX_REQUESTS_PER_MINUTE
- if ratelimit_timeout > now:
- self.logger.info(f'Ratelimit sleep: {(ratelimit_timeout-now):.2f}s')
- await asyncio.sleep(ratelimit_timeout-now)
- self._last_request_ts = time.time()
- def _is_translation_invalid(self, query: str, trans: str) -> bool:
- if not trans and query:
- return True
- if not query or not trans:
- return False
- query_symbols_count = len(set(query))
- trans_symbols_count = len(set(trans))
- if query_symbols_count > 6 and trans_symbols_count < 6 and trans_symbols_count < 0.25 * len(trans):
- return True
- return False
- def _modify_invalid_translation_query(self, query: str, trans: str) -> str:
- """
- Can be overwritten if _INVALID_REPEAT_COUNT was set. It modifies the query
- for the next translation attempt.
- """
- return query
- def _clean_translation_output(self, query: str, trans: str, to_lang: str) -> str:
- """
- Tries to spot and skim down invalid translations.
- """
- if not query or not trans:
- return ''
- # ' ' -> ' '
- trans = re.sub(r'\s+', r' ', trans)
- # 'text.text' -> 'text. text'
- trans = re.sub(r'(?<![.,;!?])([.,;!?])(?=\w)', r'\1 ', trans)
- # ' ! ! . . ' -> ' !!.. '
- trans = re.sub(r'([.,;!?])\s+(?=[.,;!?]|$)', r'\1', trans)
- if to_lang != 'ARA':
- # 'text .' -> 'text.'
- trans = re.sub(r'(?<=[.,;!?\w])\s+([.,;!?])', r'\1', trans)
- # ' ... text' -> ' ...text'
- trans = re.sub(r'((?:\s|^)\.+)\s+(?=\w)', r'\1', trans)
- seq = repeating_sequence(trans.lower())
- # 'aaaaaaaaaaaaa' -> 'aaaaaa'
- if len(trans) < len(query) and len(seq) < 0.5 * len(trans):
- # Shrink sequence to length of original query
- trans = seq * max(len(query) // len(seq), 1)
- # Transfer capitalization of query to translation
- nTrans = ''
- for i in range(min(len(trans), len(query))):
- nTrans += trans[i].upper() if query[i].isupper() else trans[i]
- trans = nTrans
- # words = text.split()
- # elements = list(set(words))
- # if len(elements) / len(words) < 0.1:
- # words = words[:int(len(words) / 1.75)]
- # text = ' '.join(words)
- # # For words that appear more then four times consecutively, remove the excess
- # for el in elements:
- # el = re.escape(el)
- # text = re.sub(r'(?: ' + el + r'){4} (' + el + r' )+', ' ', text)
- return trans
- class OfflineTranslator(CommonTranslator, ModelWrapper):
- _MODEL_SUB_DIR = 'translators'
- async def _translate(self, *args, **kwargs):
- return await self.infer(*args, **kwargs)
- @abstractmethod
- async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
- pass
- async def load(self, from_lang: str, to_lang: str, device: str):
- return await super().load(device, *self.parse_language_codes(from_lang, to_lang))
- @abstractmethod
- async def _load(self, from_lang: str, to_lang: str, device: str):
- pass
- async def reload(self, from_lang: str, to_lang: str, device: str):
- return await super().reload(device, from_lang, to_lang)
|