__init__.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import numpy as np
  2. from typing import List
  3. from .common import CommonOCR, OfflineOCR
  4. from .model_32px import Model32pxOCR
  5. from .model_48px import Model48pxOCR
  6. from .model_48px_ctc import Model48pxCTCOCR
  7. from .model_manga_ocr import ModelMangaOCR
  8. from ..utils import Quadrilateral
  9. OCRS = {
  10. '32px': Model32pxOCR,
  11. '48px': Model48pxOCR,
  12. '48px_ctc': Model48pxCTCOCR,
  13. 'mocr': ModelMangaOCR,
  14. }
  15. ocr_cache = {}
  16. def get_ocr(key: str, *args, **kwargs) -> CommonOCR:
  17. if key not in OCRS:
  18. raise ValueError(f'Could not find OCR for: "{key}". Choose from the following: %s' % ','.join(OCRS))
  19. if not ocr_cache.get(key):
  20. ocr = OCRS[key]
  21. ocr_cache[key] = ocr(*args, **kwargs)
  22. return ocr_cache[key]
  23. async def prepare(ocr_key: str, device: str = 'cpu'):
  24. ocr = get_ocr(ocr_key)
  25. if isinstance(ocr, OfflineOCR):
  26. await ocr.download()
  27. await ocr.load(device)
  28. async def dispatch(ocr_key: str, image: np.ndarray, regions: List[Quadrilateral], args = None, device: str = 'cpu', verbose: bool = False) -> List[Quadrilateral]:
  29. ocr = get_ocr(ocr_key)
  30. if isinstance(ocr, OfflineOCR):
  31. await ocr.load(device)
  32. args = args or {}
  33. return await ocr.recognize(image, regions, args, verbose)