__init__.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import py3langid as langid
  2. from .common import *
  3. from .baidu import BaiduTranslator
  4. from .deepseek import DeepseekTranslator
  5. # from .google import GoogleTranslator
  6. from .youdao import YoudaoTranslator
  7. from .deepl import DeeplTranslator
  8. from .papago import PapagoTranslator
  9. from .caiyun import CaiyunTranslator
  10. from .chatgpt import GPT3Translator, GPT35TurboTranslator, GPT4Translator
  11. from .nllb import NLLBTranslator, NLLBBigTranslator
  12. from .sugoi import JparacrawlTranslator, JparacrawlBigTranslator, SugoiTranslator
  13. from .m2m100 import M2M100Translator, M2M100BigTranslator
  14. from .mbart50 import MBart50Translator
  15. from .selective import SelectiveOfflineTranslator, prepare as prepare_selective_translator
  16. from .none import NoneTranslator
  17. from .original import OriginalTranslator
  18. from .sakura import SakuraTranslator
  19. from .qwen2 import Qwen2Translator, Qwen2BigTranslator
  20. OFFLINE_TRANSLATORS = {
  21. 'offline': SelectiveOfflineTranslator,
  22. 'nllb': NLLBTranslator,
  23. 'nllb_big': NLLBBigTranslator,
  24. 'sugoi': SugoiTranslator,
  25. 'jparacrawl': JparacrawlTranslator,
  26. 'jparacrawl_big': JparacrawlBigTranslator,
  27. 'm2m100': M2M100Translator,
  28. 'm2m100_big': M2M100BigTranslator,
  29. 'mbart50': MBart50Translator,
  30. 'qwen2': Qwen2Translator,
  31. 'qwen2_big': Qwen2BigTranslator,
  32. }
  33. TRANSLATORS = {
  34. # 'google': GoogleTranslator,
  35. 'youdao': YoudaoTranslator,
  36. 'baidu': BaiduTranslator,
  37. 'deepl': DeeplTranslator,
  38. 'papago': PapagoTranslator,
  39. 'caiyun': CaiyunTranslator,
  40. 'gpt3': GPT3Translator,
  41. 'gpt3.5': GPT35TurboTranslator,
  42. 'gpt4': GPT4Translator,
  43. 'none': NoneTranslator,
  44. 'original': OriginalTranslator,
  45. 'sakura': SakuraTranslator,
  46. 'deepseek': DeepseekTranslator,
  47. **OFFLINE_TRANSLATORS,
  48. }
  49. translator_cache = {}
  50. def get_translator(key: str, *args, **kwargs) -> CommonTranslator:
  51. if key not in TRANSLATORS:
  52. raise ValueError(f'Could not find translator for: "{key}". Choose from the following: %s' % ','.join(TRANSLATORS))
  53. if not translator_cache.get(key):
  54. translator = TRANSLATORS[key]
  55. translator_cache[key] = translator(*args, **kwargs)
  56. return translator_cache[key]
  57. prepare_selective_translator(get_translator)
  58. # TODO: Refactor
  59. class TranslatorChain():
  60. def __init__(self, string: str):
  61. """
  62. Parses string in form 'trans1:lang1;trans2:lang2' into chains,
  63. which will be executed one after another when passed to the dispatch function.
  64. """
  65. if not string:
  66. raise Exception('Invalid translator chain')
  67. self.chain = []
  68. self.target_lang = None
  69. for g in string.split(';'):
  70. trans, lang = g.split(':')
  71. if trans not in TRANSLATORS:
  72. raise ValueError(f'Invalid choice: %s (choose from %s)' % (trans, ', '.join(map(repr, TRANSLATORS))))
  73. if lang not in VALID_LANGUAGES:
  74. raise ValueError(f'Invalid choice: %s (choose from %s)' % (lang, ', '.join(map(repr, VALID_LANGUAGES))))
  75. self.chain.append((trans, lang))
  76. self.translators, self.langs = list(zip(*self.chain))
  77. def has_offline(self) -> bool:
  78. """
  79. Returns True if the chain contains offline translators.
  80. """
  81. return any(translator in OFFLINE_TRANSLATORS for translator in self.translators)
  82. def __eq__(self, __o: object) -> bool:
  83. if type(__o) is str:
  84. return __o == self.translators[0]
  85. return super.__eq__(self, __o)
  86. async def prepare(chain: TranslatorChain):
  87. for key, tgt_lang in chain.chain:
  88. translator = get_translator(key)
  89. translator.supports_languages('auto', tgt_lang, fatal=True)
  90. if isinstance(translator, OfflineTranslator):
  91. await translator.download()
  92. # TODO: Optionally take in strings instead of TranslatorChain for simplicity
  93. async def dispatch(chain: TranslatorChain, queries: List[str], use_mtpe: bool = False, args = None, device: str = 'cpu') -> List[str]:
  94. if not queries:
  95. return queries
  96. if chain.target_lang is not None:
  97. text_lang = ISO_639_1_TO_VALID_LANGUAGES.get(langid.classify('\n'.join(queries))[0])
  98. translator = None
  99. for key, lang in chain.chain:
  100. if text_lang == lang:
  101. translator = get_translator(key)
  102. break
  103. if translator is None:
  104. translator = get_translator(chain.langs[0])
  105. if isinstance(translator, OfflineTranslator):
  106. await translator.load('auto', chain.target_lang, device)
  107. translator.parse_args(args)
  108. queries = await translator.translate('auto', chain.target_lang, queries, use_mtpe)
  109. return queries
  110. if args is not None:
  111. args['translations'] = {}
  112. for key, tgt_lang in chain.chain:
  113. translator = get_translator(key)
  114. if isinstance(translator, OfflineTranslator):
  115. await translator.load('auto', tgt_lang, device)
  116. translator.parse_args(args)
  117. queries = await translator.translate('auto', tgt_lang, queries, use_mtpe)
  118. if args is not None:
  119. args['translations'][tgt_lang] = queries
  120. return queries
  121. LANGDETECT_MAP = {
  122. 'zh-cn': 'CHS',
  123. 'zh-tw': 'CHT',
  124. 'cs': 'CSY',
  125. 'nl': 'NLD',
  126. 'en': 'ENG',
  127. 'fr': 'FRA',
  128. 'de': 'DEU',
  129. 'hu': 'HUN',
  130. 'it': 'ITA',
  131. 'ja': 'JPN',
  132. 'ko': 'KOR',
  133. 'pl': 'PLK',
  134. 'pt': 'PTB',
  135. 'ro': 'ROM',
  136. 'ru': 'RUS',
  137. 'es': 'ESP',
  138. 'tr': 'TRK',
  139. 'uk': 'UKR',
  140. 'vi': 'VIN',
  141. 'ar': 'ARA',
  142. 'hr': 'HRV',
  143. 'th': 'THA',
  144. 'id': 'IND',
  145. 'tl': 'FIL'
  146. }