123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- import os
- from typing import List
- import py3langid as langid
- from .common import OfflineTranslator
- # https://github.com/facebookresearch/flores/blob/main/flores200/README.md
- ISO_639_1_TO_FLORES_200 = {
- 'zh': 'zho_Hans',
- 'ja': 'jpn_Jpan',
- 'en': 'eng_Latn',
- 'kn': 'kor_Hang',
- 'cs': 'ces_Latn',
- 'nl': 'nld_Latn',
- 'fr': 'fra_Latn',
- 'de': 'deu_Latn',
- 'hu': 'hun_Latn',
- 'it': 'ita_Latn',
- 'pl': 'pol_Latn',
- 'pt': 'por_Latn',
- 'ro': 'ron_Latn',
- 'ru': 'rus_Cyrl',
- 'es': 'spa_Latn',
- 'tr': 'tur_Latn',
- 'uk': 'ukr_Cyrl',
- 'vi': 'vie_Latn',
- 'ar': 'arb_Arab',
- 'sr': 'srp_Cyrl',
- 'hr': 'hrv_Latn',
- 'th': 'tha_Thai',
- 'id': 'ind_Latn'
- }
- class NLLBTranslator(OfflineTranslator):
- _LANGUAGE_CODE_MAP = {
- 'CHS': 'zho_Hans',
- 'CHT': 'zho_Hant',
- 'JPN': 'jpn_Jpan',
- 'ENG': 'eng_Latn',
- 'KOR': 'kor_Hang',
- 'CSY': 'ces_Latn',
- 'NLD': 'nld_Latn',
- 'FRA': 'fra_Latn',
- 'DEU': 'deu_Latn',
- 'HUN': 'hun_Latn',
- 'ITA': 'ita_Latn',
- 'PLK': 'pol_Latn',
- 'PTB': 'por_Latn',
- 'ROM': 'ron_Latn',
- 'RUS': 'rus_Cyrl',
- 'ESP': 'spa_Latn',
- 'TRK': 'tur_Latn',
- 'UKR': 'Ukrainian',
- 'VIN': 'vie_Latn',
- 'ARA': 'arb_Arab',
- 'SRP': 'srp_Cyrl',
- 'HRV': 'hrv_Latn',
- 'THA': 'tha_Thai',
- 'IND': 'ind_Latn'
- }
- _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb')
- _TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-600M'
- async def _load(self, from_lang: str, to_lang: str, device: str):
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
- if ':' not in device:
- device += ':0'
- self.device = device
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self._TRANSLATOR_MODEL)
- self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL)
- async def _unload(self):
- del self.model
- del self.tokenizer
- async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
- if from_lang == 'auto':
- detected_lang = langid.classify('\n'.join(queries))[0]
- target_lang = self._map_detected_lang_to_translator(detected_lang)
- if target_lang == None:
- self.logger.warn('Could not detect language from over all sentence. Will try per sentence.')
- else:
- from_lang = target_lang
- return [self._translate_sentence(from_lang, to_lang, query) for query in queries]
- def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str:
- from transformers import pipeline
- if not self.is_loaded():
- return ''
- if from_lang == 'auto':
- detected_lang = langid.classify(query)[0]
- from_lang = self._map_detected_lang_to_translator(detected_lang)
- if from_lang == None:
- self.logger.warn(f'NLLB Translation Failed. Could not detect language (Or language not supported for text: {query})')
- return ''
- translator = pipeline('translation',
- device=self.device,
- model=self.model,
- tokenizer=self.tokenizer,
- src_lang=from_lang,
- tgt_lang=to_lang,
- max_length = 512,
- )
- result = translator(query)[0]['translation_text']
- return result
- def _map_detected_lang_to_translator(self, lang):
- if not lang in ISO_639_1_TO_FLORES_200:
- return None
- return ISO_639_1_TO_FLORES_200[lang]
- async def _download(self):
- import huggingface_hub
- # do not download msgpack and h5 files as they are not needed to run the model
- huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"])
- def _check_downloaded(self) -> bool:
- import huggingface_hub
- return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None
- class NLLBBigTranslator(NLLBTranslator):
- _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb_big')
- _TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-1.3B'
|