import os from typing import List import py3langid as langid from .common import OfflineTranslator # https://github.com/facebookresearch/flores/blob/main/flores200/README.md ISO_639_1_TO_FLORES_200 = { 'zh': 'zho_Hans', 'ja': 'jpn_Jpan', 'en': 'eng_Latn', 'kn': 'kor_Hang', 'cs': 'ces_Latn', 'nl': 'nld_Latn', 'fr': 'fra_Latn', 'de': 'deu_Latn', 'hu': 'hun_Latn', 'it': 'ita_Latn', 'pl': 'pol_Latn', 'pt': 'por_Latn', 'ro': 'ron_Latn', 'ru': 'rus_Cyrl', 'es': 'spa_Latn', 'tr': 'tur_Latn', 'uk': 'ukr_Cyrl', 'vi': 'vie_Latn', 'ar': 'arb_Arab', 'sr': 'srp_Cyrl', 'hr': 'hrv_Latn', 'th': 'tha_Thai', 'id': 'ind_Latn' } class NLLBTranslator(OfflineTranslator): _LANGUAGE_CODE_MAP = { 'CHS': 'zho_Hans', 'CHT': 'zho_Hant', 'JPN': 'jpn_Jpan', 'ENG': 'eng_Latn', 'KOR': 'kor_Hang', 'CSY': 'ces_Latn', 'NLD': 'nld_Latn', 'FRA': 'fra_Latn', 'DEU': 'deu_Latn', 'HUN': 'hun_Latn', 'ITA': 'ita_Latn', 'PLK': 'pol_Latn', 'PTB': 'por_Latn', 'ROM': 'ron_Latn', 'RUS': 'rus_Cyrl', 'ESP': 'spa_Latn', 'TRK': 'tur_Latn', 'UKR': 'Ukrainian', 'VIN': 'vie_Latn', 'ARA': 'arb_Arab', 'SRP': 'srp_Cyrl', 'HRV': 'hrv_Latn', 'THA': 'tha_Thai', 'IND': 'ind_Latn' } _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb') _TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-600M' async def _load(self, from_lang: str, to_lang: str, device: str): from transformers import AutoTokenizer, AutoModelForSeq2SeqLM if ':' not in device: device += ':0' self.device = device self.model = AutoModelForSeq2SeqLM.from_pretrained(self._TRANSLATOR_MODEL) self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL) async def _unload(self): del self.model del self.tokenizer async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]: if from_lang == 'auto': detected_lang = langid.classify('\n'.join(queries))[0] target_lang = self._map_detected_lang_to_translator(detected_lang) if target_lang == None: self.logger.warn('Could not detect language from over all sentence. Will try per sentence.') else: from_lang = target_lang return [self._translate_sentence(from_lang, to_lang, query) for query in queries] def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str: from transformers import pipeline if not self.is_loaded(): return '' if from_lang == 'auto': detected_lang = langid.classify(query)[0] from_lang = self._map_detected_lang_to_translator(detected_lang) if from_lang == None: self.logger.warn(f'NLLB Translation Failed. Could not detect language (Or language not supported for text: {query})') return '' translator = pipeline('translation', device=self.device, model=self.model, tokenizer=self.tokenizer, src_lang=from_lang, tgt_lang=to_lang, max_length = 512, ) result = translator(query)[0]['translation_text'] return result def _map_detected_lang_to_translator(self, lang): if not lang in ISO_639_1_TO_FLORES_200: return None return ISO_639_1_TO_FLORES_200[lang] async def _download(self): import huggingface_hub # do not download msgpack and h5 files as they are not needed to run the model huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"]) def _check_downloaded(self) -> bool: import huggingface_hub return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None class NLLBBigTranslator(NLLBTranslator): _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb_big') _TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-1.3B'