inference.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. import os
  2. import stat
  3. import sys
  4. import tempfile
  5. import re
  6. import torch
  7. import shutil
  8. import filecmp
  9. from abc import ABC, abstractmethod
  10. from functools import cached_property
  11. from .generic import (
  12. BASE_PATH,
  13. Context,
  14. download_url_with_progressbar,
  15. prompt_yes_no,
  16. replace_prefix,
  17. get_digest,
  18. get_filename_from_url,
  19. )
  20. from .log import get_logger
  21. class InfererModule(ABC):
  22. def __init__(self):
  23. self.logger = get_logger(self.__class__.__name__)
  24. super().__init__()
  25. def parse_args(self, args: Context):
  26. """May be overwritten by super classes to parse commandline arguments"""
  27. pass
  28. # class InfererModuleManager(ABC):
  29. # _KEY = ''
  30. # _VARIANTS = []
  31. # def __init__(self):
  32. # self.onstart: Callable = None
  33. # self.onfinish: Callable = None
  34. # def validate(self):
  35. # """
  36. # Throws exception if a
  37. # """
  38. # ...
  39. # async def prepare(self):
  40. # ...
  41. # async def dispatch(self):
  42. # ...
  43. class ModelVerificationException(Exception):
  44. pass
  45. class InvalidModelMappingException(ValueError):
  46. def __init__(self, cls: str, map_key: str, error_msg: str):
  47. error = f'[{cls}->{map_key}] Invalid _MODEL_MAPPING - {error_msg}'
  48. super().__init__(error)
  49. class ModelWrapper(ABC):
  50. r"""
  51. A class that provides a unified interface for downloading models and making forward passes.
  52. All model inferer classes should extend it.
  53. Download specifications can be made through overwriting the `_MODEL_MAPPING` property.
  54. ```python
  55. _MODEL_MAPPTING = {
  56. 'model_id': {
  57. **PARAMETERS
  58. },
  59. ...
  60. }
  61. ```
  62. Parameters:
  63. model_id - Used for temporary caches and debug messages
  64. url - A direct download url
  65. hash - Hash of downloaded file, Can be obtained upon ModelVerificationException
  66. file - File download destination, If set to '.' the filename will be inferred
  67. from the url (fallback is `model_id` value)
  68. archive - Dict that contains all files/folders that are to be extracted from
  69. the downloaded archive and their destinations, Mutually exclusive with `file`
  70. executables - List of files that need to have the executable flag set
  71. """
  72. _MODEL_DIR = os.path.join(BASE_PATH, 'models')
  73. _MODEL_SUB_DIR = ''
  74. _MODEL_MAPPING = {}
  75. _KEY = ''
  76. def __init__(self):
  77. os.makedirs(self.model_dir, exist_ok=True)
  78. self._key = self._KEY or self.__class__.__name__
  79. self._loaded = False
  80. self._check_for_malformed_model_mapping()
  81. self._downloaded = self._check_downloaded()
  82. def is_loaded(self) -> bool:
  83. return self._loaded
  84. def is_downloaded(self) -> bool:
  85. return self._downloaded
  86. @property
  87. def model_dir(self):
  88. return os.path.join(self._MODEL_DIR, self._MODEL_SUB_DIR)
  89. def _get_file_path(self, *args) -> str:
  90. return os.path.join(self.model_dir, *args)
  91. def _get_used_gpu_memory(self) -> bool:
  92. '''
  93. Gets the total amount of GPU memory used by model (Can be used in the future
  94. to determine whether a model should be loaded into vram or ram or automatically choose a model size).
  95. TODO: Use together with `--use-cuda-limited` flag to enforce stricter memory checks
  96. '''
  97. return torch.cuda.mem_get_info()
  98. def _check_for_malformed_model_mapping(self):
  99. for map_key, mapping in self._MODEL_MAPPING.items():
  100. if 'url' not in mapping:
  101. raise InvalidModelMappingException(self._key, map_key, 'Missing url property')
  102. elif not re.search(r'^https?://', mapping['url']):
  103. raise InvalidModelMappingException(self._key, map_key, 'Malformed url property: "%s"' % mapping['url'])
  104. if 'file' not in mapping and 'archive' not in mapping:
  105. mapping['file'] = '.'
  106. elif 'file' in mapping and 'archive' in mapping:
  107. raise InvalidModelMappingException(self._key, map_key, 'Properties file and archive are mutually exclusive')
  108. async def _download_file(self, url: str, path: str):
  109. print(f' -- Downloading: "{url}"')
  110. download_url_with_progressbar(url, path)
  111. async def _verify_file(self, sha256_pre_calculated: str, path: str):
  112. print(f' -- Verifying: "{path}"')
  113. sha256_calculated = get_digest(path).lower()
  114. sha256_pre_calculated = sha256_pre_calculated.lower()
  115. if sha256_calculated != sha256_pre_calculated:
  116. self._on_verify_failure(sha256_calculated, sha256_pre_calculated)
  117. else:
  118. print(' -- Verifying: OK!')
  119. def _on_verify_failure(self, sha256_calculated: str, sha256_pre_calculated: str):
  120. print(f' -- Mismatch between downloaded and created hash: "{sha256_calculated}" <-> "{sha256_pre_calculated}"')
  121. raise ModelVerificationException()
  122. @cached_property
  123. def _temp_working_directory(self):
  124. p = os.path.join(tempfile.gettempdir(), 'manga-image-translator', self._key.lower())
  125. os.makedirs(p, exist_ok=True)
  126. return p
  127. async def download(self, force=False):
  128. '''
  129. Downloads required models.
  130. '''
  131. if force or not self.is_downloaded():
  132. while True:
  133. try:
  134. await self._download()
  135. self._downloaded = True
  136. break
  137. except ModelVerificationException:
  138. if not prompt_yes_no('Failed to verify signature. Do you want to restart the download?', default=True):
  139. print('Aborting.', end='')
  140. raise KeyboardInterrupt()
  141. async def _download(self):
  142. '''
  143. Downloads models as defined in `_MODEL_MAPPING`. Can be overwritten (together
  144. with `_check_downloaded`) to implement unconventional download logic.
  145. '''
  146. print(f'\nDownloading models into {self.model_dir}\n')
  147. for map_key, mapping in self._MODEL_MAPPING.items():
  148. if self._check_downloaded_map(map_key):
  149. print(f' -- Skipping {map_key} as it\'s already downloaded')
  150. continue
  151. is_archive = 'archive' in mapping
  152. if is_archive:
  153. download_path = os.path.join(self._temp_working_directory, map_key, '')
  154. else:
  155. download_path = self._get_file_path(mapping['file'])
  156. if not os.path.basename(download_path):
  157. os.makedirs(download_path, exist_ok=True)
  158. if os.path.basename(download_path) in ('', '.'):
  159. download_path = os.path.join(download_path, get_filename_from_url(mapping['url'], map_key))
  160. if not is_archive:
  161. download_path += '.part'
  162. if 'hash' in mapping:
  163. downloaded = False
  164. if os.path.isfile(download_path):
  165. try:
  166. print(' -- Found existing file')
  167. await self._verify_file(mapping['hash'], download_path)
  168. downloaded = True
  169. except ModelVerificationException:
  170. print(' -- Resuming interrupted download')
  171. if not downloaded:
  172. await self._download_file(mapping['url'], download_path)
  173. await self._verify_file(mapping['hash'], download_path)
  174. else:
  175. await self._download_file(mapping['url'], download_path)
  176. if download_path.endswith('.part'):
  177. p = download_path[:len(download_path)-5]
  178. shutil.move(download_path, p)
  179. download_path = p
  180. if is_archive:
  181. extracted_path = os.path.join(os.path.dirname(download_path), 'extracted')
  182. print(f' -- Extracting files')
  183. shutil.unpack_archive(download_path, extracted_path)
  184. def get_real_archive_files():
  185. archive_files = []
  186. for root, dirs, files in os.walk(extracted_path):
  187. for name in files:
  188. file_path = replace_prefix(os.path.join(root, name), extracted_path, '')
  189. archive_files.append(file_path)
  190. return archive_files
  191. # Move every specified file from archive to destination
  192. for orig, dest in mapping['archive'].items():
  193. p1 = os.path.join(extracted_path, orig)
  194. if os.path.exists(p1):
  195. p2 = self._get_file_path(dest)
  196. if os.path.basename(p2) in ('', '.'):
  197. p2 = os.path.join(p2, os.path.basename(p1))
  198. if os.path.isfile(p2):
  199. if filecmp.cmp(p1, p2):
  200. continue
  201. raise InvalidModelMappingException(self._key, map_key, 'File "{orig}" already exists at "{dest}"')
  202. os.makedirs(os.path.dirname(p2), exist_ok=True)
  203. shutil.move(p1, p2)
  204. else:
  205. raise InvalidModelMappingException(self._key, map_key, f'File "{orig}" does not exist within archive' +
  206. '\nAvailable files:\n%s' % '\n'.join(get_real_archive_files()))
  207. if len(mapping['archive']) == 0:
  208. raise InvalidModelMappingException(self._key, map_key, 'No archive files specified' +
  209. '\nAvailable files:\n%s' % '\n'.join(get_real_archive_files()))
  210. self._grant_execute_permissions(map_key)
  211. # Remove temporary files
  212. try:
  213. os.remove(download_path)
  214. shutil.rmtree(extracted_path)
  215. except Exception:
  216. pass
  217. print()
  218. self._on_download_finished(map_key)
  219. def _on_download_finished(self, map_key):
  220. '''
  221. Can be overwritten to further process the downloaded files
  222. '''
  223. pass
  224. def _check_downloaded(self) -> bool:
  225. '''
  226. Scans filesystem for required files as defined in `_MODEL_MAPPING`.
  227. Returns `False` if files should be redownloaded.
  228. '''
  229. for map_key in self._MODEL_MAPPING:
  230. if not self._check_downloaded_map(map_key):
  231. return False
  232. return True
  233. def _check_downloaded_map(self, map_key: str) -> str:
  234. mapping = self._MODEL_MAPPING[map_key]
  235. if 'file' in mapping:
  236. path = mapping['file']
  237. if os.path.basename(path) in ('.', ''):
  238. path = os.path.join(path, get_filename_from_url(mapping['url'], map_key))
  239. if not os.path.exists(self._get_file_path(path)):
  240. return False
  241. elif 'archive' in mapping:
  242. for orig, dest in mapping['archive'].items():
  243. if os.path.basename(dest) in ('', '.'):
  244. dest = os.path.join(dest, os.path.basename(orig[:-1] if orig.endswith('/') else orig))
  245. if not os.path.exists(self._get_file_path(dest)):
  246. return False
  247. self._grant_execute_permissions(map_key)
  248. return True
  249. def _grant_execute_permissions(self, map_key: str):
  250. mapping = self._MODEL_MAPPING[map_key]
  251. if sys.platform == 'linux':
  252. # Grant permission to executables
  253. for file in mapping.get('executables', []):
  254. p = self._get_file_path(file)
  255. if os.path.basename(p) in ('', '.'):
  256. p = os.path.join(p, file)
  257. if not os.path.isfile(p):
  258. raise InvalidModelMappingException(self._key, map_key, f'File "{file}" does not exist')
  259. if not os.access(p, os.X_OK):
  260. os.chmod(p, os.stat(p).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
  261. async def reload(self, device: str, *args, **kwargs):
  262. await self.unload()
  263. await self.load(*args, **kwargs, device=device)
  264. async def load(self, device: str, *args, **kwargs):
  265. '''
  266. Loads models into memory. Has to be called before `forward`.
  267. '''
  268. if not self.is_downloaded():
  269. await self.download()
  270. if not self.is_loaded():
  271. await self._load(*args, **kwargs, device=device)
  272. self._loaded = True
  273. async def unload(self):
  274. if self.is_loaded():
  275. await self._unload()
  276. self._loaded = False
  277. async def infer(self, *args, **kwargs):
  278. '''
  279. Makes a forward pass through the network.
  280. '''
  281. if not self.is_loaded():
  282. raise Exception(f'{self._key}: Tried to forward pass without having loaded the model.')
  283. return await self._infer(*args, **kwargs)
  284. @abstractmethod
  285. async def _load(self, device: str, *args, **kwargs):
  286. pass
  287. @abstractmethod
  288. async def _unload(self):
  289. pass
  290. @abstractmethod
  291. async def _infer(self, *args, **kwargs):
  292. pass