common.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import re
  2. import time
  3. import asyncio
  4. from typing import List, Tuple
  5. from abc import abstractmethod
  6. from ..utils import InfererModule, ModelWrapper, repeating_sequence, is_valuable_text
  7. try:
  8. import readline
  9. except Exception:
  10. readline = None
  11. VALID_LANGUAGES = {
  12. 'CHS': 'Chinese (Simplified)',
  13. 'CHT': 'Chinese (Traditional)',
  14. 'CSY': 'Czech',
  15. 'NLD': 'Dutch',
  16. 'ENG': 'English',
  17. 'FRA': 'French',
  18. 'DEU': 'German',
  19. 'HUN': 'Hungarian',
  20. 'ITA': 'Italian',
  21. 'JPN': 'Japanese',
  22. 'KOR': 'Korean',
  23. 'PLK': 'Polish',
  24. 'PTB': 'Portuguese (Brazil)',
  25. 'ROM': 'Romanian',
  26. 'RUS': 'Russian',
  27. 'ESP': 'Spanish',
  28. 'TRK': 'Turkish',
  29. 'UKR': 'Ukrainian',
  30. 'VIN': 'Vietnamese',
  31. 'ARA': 'Arabic',
  32. 'CNR': 'Montenegrin',
  33. 'SRP': 'Serbian',
  34. 'HRV': 'Croatian',
  35. 'THA': 'Thai',
  36. 'IND': 'Indonesian',
  37. 'FIL': 'Filipino (Tagalog)'
  38. }
  39. ISO_639_1_TO_VALID_LANGUAGES = {
  40. 'zh': 'CHS',
  41. 'ja': 'JPN',
  42. 'en': 'ENG',
  43. 'ko': 'KOR',
  44. 'vi': 'VIN',
  45. 'cs': 'CSY',
  46. 'nl': 'NLD',
  47. 'fr': 'FRA',
  48. 'de': 'DEU',
  49. 'hu': 'HUN',
  50. 'it': 'ITA',
  51. 'pl': 'PLK',
  52. 'pt': 'PTB',
  53. 'ro': 'ROM',
  54. 'ru': 'RUS',
  55. 'es': 'ESP',
  56. 'tr': 'TRK',
  57. 'uk': 'UKR',
  58. 'vi': 'VIN',
  59. 'ar': 'ARA',
  60. 'cnr': 'CNR',
  61. 'sr': 'SRP',
  62. 'hr': 'HRV',
  63. 'th': 'THA',
  64. 'id': 'IND',
  65. 'tl': 'FIL'
  66. }
  67. class InvalidServerResponse(Exception):
  68. pass
  69. class MissingAPIKeyException(Exception):
  70. pass
  71. class LanguageUnsupportedException(Exception):
  72. def __init__(self, language_code: str, translator: str = None, supported_languages: List[str] = None):
  73. error = 'Language not supported for %s: "%s"' % (translator if translator else 'chosen translator', language_code)
  74. if supported_languages:
  75. error += '. Supported languages: "%s"' % ','.join(supported_languages)
  76. super().__init__(error)
  77. class MTPEAdapter():
  78. async def dispatch(self, queries: List[str], translations: List[str]) -> List[str]:
  79. # TODO: Make it work in windows (e.g. through os.startfile)
  80. if not readline:
  81. print('MTPE is currently only supported on linux')
  82. return translations
  83. new_translations = []
  84. print('Running Machine Translation Post Editing (MTPE)')
  85. for i, (query, translation) in enumerate(zip(queries, translations)):
  86. print(f'\n[{i + 1}/{len(queries)}] {query}:')
  87. readline.set_startup_hook(lambda: readline.insert_text(translation.replace('\n', '\\n')))
  88. new_translation = ''
  89. try:
  90. new_translation = input(' -> ').replace('\\n', '\n')
  91. finally:
  92. readline.set_startup_hook()
  93. new_translations.append(new_translation)
  94. print()
  95. return new_translations
  96. class CommonTranslator(InfererModule):
  97. # Translator has to support all languages listed in here. The language codes will be resolved into
  98. # _LANGUAGE_CODE_MAP[lang_code] automatically if _LANGUAGE_CODE_MAP is a dict.
  99. # If it is a list it will simply return the language code as is.
  100. _LANGUAGE_CODE_MAP = {}
  101. # The amount of repeats upon detecting an invalid translation.
  102. # Use with _is_translation_invalid and _modify_invalid_translation_query.
  103. _INVALID_REPEAT_COUNT = 0
  104. # Will sleep for the rest of the minute if the request count is over this number.
  105. _MAX_REQUESTS_PER_MINUTE = -1
  106. def __init__(self):
  107. super().__init__()
  108. self.mtpe_adapter = MTPEAdapter()
  109. self._last_request_ts = 0
  110. def supports_languages(self, from_lang: str, to_lang: str, fatal: bool = False) -> bool:
  111. supported_src_languages = ['auto'] + list(self._LANGUAGE_CODE_MAP)
  112. supported_tgt_languages = list(self._LANGUAGE_CODE_MAP)
  113. if from_lang not in supported_src_languages:
  114. if fatal:
  115. raise LanguageUnsupportedException(from_lang, self.__class__.__name__, supported_src_languages)
  116. return False
  117. if to_lang not in supported_tgt_languages:
  118. if fatal:
  119. raise LanguageUnsupportedException(to_lang, self.__class__.__name__, supported_tgt_languages)
  120. return False
  121. return True
  122. def parse_language_codes(self, from_lang: str, to_lang: str, fatal: bool = False) -> Tuple[str, str]:
  123. if not self.supports_languages(from_lang, to_lang, fatal):
  124. return None, None
  125. if type(self._LANGUAGE_CODE_MAP) is list:
  126. return from_lang, to_lang
  127. _from_lang = self._LANGUAGE_CODE_MAP.get(from_lang) if from_lang != 'auto' else 'auto'
  128. _to_lang = self._LANGUAGE_CODE_MAP.get(to_lang)
  129. return _from_lang, _to_lang
  130. async def translate(self, from_lang: str, to_lang: str, queries: List[str], use_mtpe: bool = False) -> List[str]:
  131. """
  132. Translates list of queries of one language into another.
  133. """
  134. if to_lang not in VALID_LANGUAGES:
  135. raise ValueError('Invalid language code: "%s". Choose from the following: %s' % (to_lang, ', '.join(VALID_LANGUAGES)))
  136. if from_lang not in VALID_LANGUAGES and from_lang != 'auto':
  137. raise ValueError('Invalid language code: "%s". Choose from the following: auto, %s' % (from_lang, ', '.join(VALID_LANGUAGES)))
  138. self.logger.info(f'Translating into {VALID_LANGUAGES[to_lang]}')
  139. if from_lang == to_lang:
  140. return queries
  141. # Dont translate queries without text
  142. query_indices = []
  143. final_translations = []
  144. for i, query in enumerate(queries):
  145. if not is_valuable_text(query):
  146. final_translations.append(queries[i])
  147. else:
  148. final_translations.append(None)
  149. query_indices.append(i)
  150. queries = [queries[i] for i in query_indices]
  151. translations = [''] * len(queries)
  152. untranslated_indices = list(range(len(queries)))
  153. for i in range(1 + self._INVALID_REPEAT_COUNT): # Repeat until all translations are considered valid
  154. if i > 0:
  155. self.logger.warn(f'Repeating because of invalid translation. Attempt: {i+1}')
  156. await asyncio.sleep(0.1)
  157. # Sleep if speed is over the ratelimit
  158. await self._ratelimit_sleep()
  159. # Translate
  160. _translations = await self._translate(*self.parse_language_codes(from_lang, to_lang, fatal=True), queries)
  161. # Extend returned translations list to have the same size as queries
  162. if len(_translations) < len(queries):
  163. _translations.extend([''] * (len(queries) - len(_translations)))
  164. elif len(_translations) > len(queries):
  165. _translations = _translations[:len(queries)]
  166. # Only overwrite yet untranslated indices
  167. for j in untranslated_indices:
  168. translations[j] = _translations[j]
  169. if self._INVALID_REPEAT_COUNT == 0:
  170. break
  171. new_untranslated_indices = []
  172. for j in untranslated_indices:
  173. q, t = queries[j], translations[j]
  174. # Repeat invalid translations with slightly modified queries
  175. if self._is_translation_invalid(q, t):
  176. new_untranslated_indices.append(j)
  177. queries[j] = self._modify_invalid_translation_query(q, t)
  178. untranslated_indices = new_untranslated_indices
  179. if not untranslated_indices:
  180. break
  181. translations = [self._clean_translation_output(q, r, to_lang) for q, r in zip(queries, translations)]
  182. if to_lang == 'ARA':
  183. import arabic_reshaper
  184. translations = [arabic_reshaper.reshape(t) for t in translations]
  185. if use_mtpe:
  186. translations = await self.mtpe_adapter.dispatch(queries, translations)
  187. # Merge with the queries without text
  188. for i, trans in enumerate(translations):
  189. final_translations[query_indices[i]] = trans
  190. self.logger.info(f'{i}: {queries[i]} => {trans}')
  191. return final_translations
  192. @abstractmethod
  193. async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
  194. pass
  195. async def _ratelimit_sleep(self):
  196. if self._MAX_REQUESTS_PER_MINUTE > 0:
  197. now = time.time()
  198. ratelimit_timeout = self._last_request_ts + 60 / self._MAX_REQUESTS_PER_MINUTE
  199. if ratelimit_timeout > now:
  200. self.logger.info(f'Ratelimit sleep: {(ratelimit_timeout-now):.2f}s')
  201. await asyncio.sleep(ratelimit_timeout-now)
  202. self._last_request_ts = time.time()
  203. def _is_translation_invalid(self, query: str, trans: str) -> bool:
  204. if not trans and query:
  205. return True
  206. if not query or not trans:
  207. return False
  208. query_symbols_count = len(set(query))
  209. trans_symbols_count = len(set(trans))
  210. if query_symbols_count > 6 and trans_symbols_count < 6 and trans_symbols_count < 0.25 * len(trans):
  211. return True
  212. return False
  213. def _modify_invalid_translation_query(self, query: str, trans: str) -> str:
  214. """
  215. Can be overwritten if _INVALID_REPEAT_COUNT was set. It modifies the query
  216. for the next translation attempt.
  217. """
  218. return query
  219. def _clean_translation_output(self, query: str, trans: str, to_lang: str) -> str:
  220. """
  221. Tries to spot and skim down invalid translations.
  222. """
  223. if not query or not trans:
  224. return ''
  225. # ' ' -> ' '
  226. trans = re.sub(r'\s+', r' ', trans)
  227. # 'text.text' -> 'text. text'
  228. trans = re.sub(r'(?<![.,;!?])([.,;!?])(?=\w)', r'\1 ', trans)
  229. # ' ! ! . . ' -> ' !!.. '
  230. trans = re.sub(r'([.,;!?])\s+(?=[.,;!?]|$)', r'\1', trans)
  231. if to_lang != 'ARA':
  232. # 'text .' -> 'text.'
  233. trans = re.sub(r'(?<=[.,;!?\w])\s+([.,;!?])', r'\1', trans)
  234. # ' ... text' -> ' ...text'
  235. trans = re.sub(r'((?:\s|^)\.+)\s+(?=\w)', r'\1', trans)
  236. seq = repeating_sequence(trans.lower())
  237. # 'aaaaaaaaaaaaa' -> 'aaaaaa'
  238. if len(trans) < len(query) and len(seq) < 0.5 * len(trans):
  239. # Shrink sequence to length of original query
  240. trans = seq * max(len(query) // len(seq), 1)
  241. # Transfer capitalization of query to translation
  242. nTrans = ''
  243. for i in range(min(len(trans), len(query))):
  244. nTrans += trans[i].upper() if query[i].isupper() else trans[i]
  245. trans = nTrans
  246. # words = text.split()
  247. # elements = list(set(words))
  248. # if len(elements) / len(words) < 0.1:
  249. # words = words[:int(len(words) / 1.75)]
  250. # text = ' '.join(words)
  251. # # For words that appear more then four times consecutively, remove the excess
  252. # for el in elements:
  253. # el = re.escape(el)
  254. # text = re.sub(r'(?: ' + el + r'){4} (' + el + r' )+', ' ', text)
  255. return trans
  256. class OfflineTranslator(CommonTranslator, ModelWrapper):
  257. _MODEL_SUB_DIR = 'translators'
  258. async def _translate(self, *args, **kwargs):
  259. return await self.infer(*args, **kwargs)
  260. @abstractmethod
  261. async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
  262. pass
  263. async def load(self, from_lang: str, to_lang: str, device: str):
  264. return await super().load(device, *self.parse_language_codes(from_lang, to_lang))
  265. @abstractmethod
  266. async def _load(self, from_lang: str, to_lang: str, device: str):
  267. pass
  268. async def reload(self, from_lang: str, to_lang: str, device: str):
  269. return await super().reload(device, from_lang, to_lang)