m2m100.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import os
  2. import ctranslate2
  3. import sentencepiece as spm
  4. from typing import List
  5. from .common import OfflineTranslator
  6. # Adapted from:
  7. # https://gist.github.com/ymoslem/a414a0ead0d3e50f4d7ff7110b1d1c0d
  8. # https://github.com/ymoslem/DesktopTranslator
  9. class M2M100Translator(OfflineTranslator):
  10. # Refer to https://github.com/ymoslem/DesktopTranslator/blob/main/utils/m2m_languages.json
  11. # other languages can be added as well
  12. _LANGUAGE_CODE_MAP = {
  13. 'CHS': '__zh__',
  14. 'CHT': '__zh__',
  15. 'CSY': '__cs__',
  16. 'NLD': '__nl__',
  17. 'ENG': '__en__',
  18. 'FRA': '__fr__',
  19. 'DEU': '__de__',
  20. 'HUN': '__hu__',
  21. 'ITA': '__it__',
  22. 'JPN': '__ja__',
  23. 'KOR': '__ko__',
  24. 'PLK': '__pl__',
  25. 'PTB': '__pt__',
  26. 'ROM': '__ro__',
  27. 'RUS': '__ru__',
  28. 'ESP': '__es__',
  29. 'TRK': '__tr__',
  30. 'UKR': '__uk__',
  31. 'VIN': '__vi__',
  32. 'ARA': '__ar__',
  33. 'SRP': '__sr__',
  34. 'HRV': '__hr__',
  35. 'THA': '__th__',
  36. 'IND': '__id__'
  37. }
  38. _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_SUB_DIR, 'm2m_100')
  39. _CT2_MODEL_DIR = 'm2m100_418m'
  40. _MODEL_MAPPING = {
  41. 'models': {
  42. 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/m2m100_418m_ct2.zip',
  43. 'hash': '8a9cd0e00505a7879f26e5a1b396b447bc29967783a1e17e8df5eecb0c13d1c3',
  44. 'archive': {
  45. 'm2m100_418m/': '.',
  46. },
  47. },
  48. }
  49. async def _load(self, from_lang: str, to_lang: str, device: str):
  50. self.load_params = {
  51. 'from_lang': from_lang,
  52. 'to_lang': to_lang,
  53. 'device': device,
  54. }
  55. self.model = ctranslate2.Translator(
  56. model_path=self._get_file_path(self._CT2_MODEL_DIR),
  57. device=device,
  58. device_index=0,
  59. )
  60. self.model.load_model()
  61. self.sentence_piece_processor = spm.SentencePieceProcessor(model_file=self._get_file_path(self._CT2_MODEL_DIR, 'sentencepiece.model'))
  62. async def _unload(self):
  63. self.model.unload_model()
  64. del self.model
  65. del self.sentence_piece_processor
  66. async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
  67. queries_tokenized = self.tokenize(queries, from_lang)
  68. translated_tokenized = self.model.translate_batch(
  69. source=queries_tokenized,
  70. target_prefix=[[to_lang]] * len(queries),
  71. beam_size=5,
  72. max_batch_size=1024,
  73. return_alternatives=False,
  74. disable_unk=True,
  75. replace_unknowns=True,
  76. repetition_penalty=3,
  77. )
  78. translated = self.detokenize(list(map(lambda t: t[0]['tokens'], translated_tokenized)), to_lang)
  79. return translated
  80. def tokenize(self, queries, lang):
  81. sp = self.sentence_piece_processor
  82. if isinstance(queries, list):
  83. return sp.encode(queries, out_type=str)
  84. else:
  85. return [sp.encode(queries, out_type=str)]
  86. def detokenize(self, queries, lang):
  87. sp = self.sentence_piece_processor
  88. translation = sp.decode(queries)
  89. prefix_len = len(lang) + 1
  90. translation = [''.join(query)[prefix_len:] for query in translation]
  91. return translation
  92. class M2M100BigTranslator(M2M100Translator):
  93. _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_SUB_DIR, 'm2m_100')
  94. _CT2_MODEL_DIR = 'm2m100_12b'
  95. _MODEL_MAPPING = {
  96. 'models': {
  97. 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/m2m100_12b_ct2.zip',
  98. 'hash': '742d5380c2837affd3680339145d37fc78f537ad633958347b76e9be9c577662',
  99. 'archive': {
  100. 'm2m100_12b/': '.',
  101. },
  102. },
  103. }