1234567891011121314151617181920212223242526272829303132333435363738 |
- import numpy as np
- from typing import List
- from .common import CommonOCR, OfflineOCR
- from .model_32px import Model32pxOCR
- from .model_48px import Model48pxOCR
- from .model_48px_ctc import Model48pxCTCOCR
- from .model_manga_ocr import ModelMangaOCR
- from ..utils import Quadrilateral
- OCRS = {
- '32px': Model32pxOCR,
- '48px': Model48pxOCR,
- '48px_ctc': Model48pxCTCOCR,
- 'mocr': ModelMangaOCR,
- }
- ocr_cache = {}
- def get_ocr(key: str, *args, **kwargs) -> CommonOCR:
- if key not in OCRS:
- raise ValueError(f'Could not find OCR for: "{key}". Choose from the following: %s' % ','.join(OCRS))
- if not ocr_cache.get(key):
- ocr = OCRS[key]
- ocr_cache[key] = ocr(*args, **kwargs)
- return ocr_cache[key]
- async def prepare(ocr_key: str, device: str = 'cpu'):
- ocr = get_ocr(ocr_key)
- if isinstance(ocr, OfflineOCR):
- await ocr.download()
- await ocr.load(device)
- async def dispatch(ocr_key: str, image: np.ndarray, regions: List[Quadrilateral], args = None, device: str = 'cpu', verbose: bool = False) -> List[Quadrilateral]:
- ocr = get_ocr(ocr_key)
- if isinstance(ocr, OfflineOCR):
- await ocr.load(device)
- args = args or {}
- return await ocr.recognize(image, regions, args, verbose)
|