nllb.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. from typing import List
  3. import py3langid as langid
  4. from .common import OfflineTranslator
  5. # https://github.com/facebookresearch/flores/blob/main/flores200/README.md
  6. ISO_639_1_TO_FLORES_200 = {
  7. 'zh': 'zho_Hans',
  8. 'ja': 'jpn_Jpan',
  9. 'en': 'eng_Latn',
  10. 'kn': 'kor_Hang',
  11. 'cs': 'ces_Latn',
  12. 'nl': 'nld_Latn',
  13. 'fr': 'fra_Latn',
  14. 'de': 'deu_Latn',
  15. 'hu': 'hun_Latn',
  16. 'it': 'ita_Latn',
  17. 'pl': 'pol_Latn',
  18. 'pt': 'por_Latn',
  19. 'ro': 'ron_Latn',
  20. 'ru': 'rus_Cyrl',
  21. 'es': 'spa_Latn',
  22. 'tr': 'tur_Latn',
  23. 'uk': 'ukr_Cyrl',
  24. 'vi': 'vie_Latn',
  25. 'ar': 'arb_Arab',
  26. 'sr': 'srp_Cyrl',
  27. 'hr': 'hrv_Latn',
  28. 'th': 'tha_Thai',
  29. 'id': 'ind_Latn'
  30. }
  31. class NLLBTranslator(OfflineTranslator):
  32. _LANGUAGE_CODE_MAP = {
  33. 'CHS': 'zho_Hans',
  34. 'CHT': 'zho_Hant',
  35. 'JPN': 'jpn_Jpan',
  36. 'ENG': 'eng_Latn',
  37. 'KOR': 'kor_Hang',
  38. 'CSY': 'ces_Latn',
  39. 'NLD': 'nld_Latn',
  40. 'FRA': 'fra_Latn',
  41. 'DEU': 'deu_Latn',
  42. 'HUN': 'hun_Latn',
  43. 'ITA': 'ita_Latn',
  44. 'PLK': 'pol_Latn',
  45. 'PTB': 'por_Latn',
  46. 'ROM': 'ron_Latn',
  47. 'RUS': 'rus_Cyrl',
  48. 'ESP': 'spa_Latn',
  49. 'TRK': 'tur_Latn',
  50. 'UKR': 'Ukrainian',
  51. 'VIN': 'vie_Latn',
  52. 'ARA': 'arb_Arab',
  53. 'SRP': 'srp_Cyrl',
  54. 'HRV': 'hrv_Latn',
  55. 'THA': 'tha_Thai',
  56. 'IND': 'ind_Latn'
  57. }
  58. _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb')
  59. _TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-600M'
  60. async def _load(self, from_lang: str, to_lang: str, device: str):
  61. from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
  62. if ':' not in device:
  63. device += ':0'
  64. self.device = device
  65. self.model = AutoModelForSeq2SeqLM.from_pretrained(self._TRANSLATOR_MODEL)
  66. self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL)
  67. async def _unload(self):
  68. del self.model
  69. del self.tokenizer
  70. async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
  71. if from_lang == 'auto':
  72. detected_lang = langid.classify('\n'.join(queries))[0]
  73. target_lang = self._map_detected_lang_to_translator(detected_lang)
  74. if target_lang == None:
  75. self.logger.warn('Could not detect language from over all sentence. Will try per sentence.')
  76. else:
  77. from_lang = target_lang
  78. return [self._translate_sentence(from_lang, to_lang, query) for query in queries]
  79. def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str:
  80. from transformers import pipeline
  81. if not self.is_loaded():
  82. return ''
  83. if from_lang == 'auto':
  84. detected_lang = langid.classify(query)[0]
  85. from_lang = self._map_detected_lang_to_translator(detected_lang)
  86. if from_lang == None:
  87. self.logger.warn(f'NLLB Translation Failed. Could not detect language (Or language not supported for text: {query})')
  88. return ''
  89. translator = pipeline('translation',
  90. device=self.device,
  91. model=self.model,
  92. tokenizer=self.tokenizer,
  93. src_lang=from_lang,
  94. tgt_lang=to_lang,
  95. max_length = 512,
  96. )
  97. result = translator(query)[0]['translation_text']
  98. return result
  99. def _map_detected_lang_to_translator(self, lang):
  100. if not lang in ISO_639_1_TO_FLORES_200:
  101. return None
  102. return ISO_639_1_TO_FLORES_200[lang]
  103. async def _download(self):
  104. import huggingface_hub
  105. # do not download msgpack and h5 files as they are not needed to run the model
  106. huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"])
  107. def _check_downloaded(self) -> bool:
  108. import huggingface_hub
  109. return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None
  110. class NLLBBigTranslator(NLLBTranslator):
  111. _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb_big')
  112. _TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-1.3B'