123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- import os
- import shutil
- import numpy as np
- import einops
- from typing import Union, Tuple
- import cv2
- import torch
- from .ctd_utils.basemodel import TextDetBase, TextDetBaseDNN
- from .ctd_utils.utils.yolov5_utils import non_max_suppression
- from .ctd_utils.utils.db_utils import SegDetectorRepresenter
- from .ctd_utils.utils.imgproc_utils import letterbox
- from .ctd_utils.textmask import REFINEMASK_INPAINT, refine_mask
- from .common import OfflineDetector
- from ..utils import Quadrilateral, det_rearrange_forward
- def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True):
- if bgr2rgb:
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64)
- if to_tensor:
- img_in = img_in.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
- img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255
- if to_tensor:
- img_in = torch.from_numpy(img_in).to(device)
- if half:
- img_in = img_in.half()
- return img_in, ratio, int(dw), int(dh)
- def postprocess_mask(img: Union[torch.Tensor, np.ndarray], thresh=None):
- # img = img.permute(1, 2, 0)
- if isinstance(img, torch.Tensor):
- img = img.squeeze_()
- if img.device != 'cpu':
- img = img.detach().cpu()
- img = img.numpy()
- else:
- img = img.squeeze()
- if thresh is not None:
- img = img > thresh
- img = img * 255
- # if isinstance(img, torch.Tensor):
- return img.astype(np.uint8)
- def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None):
- det = non_max_suppression(det, conf_thresh, nms_thresh)[0]
- # bbox = det[..., 0:4]
- if det.device != 'cpu':
- det = det.detach_().cpu().numpy()
- det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0]
- det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1]
- if sort_func is not None:
- det = sort_func(det)
- blines = det[..., 0:4].astype(np.int32)
- confs = np.round(det[..., 4], 3)
- cls = det[..., 5].astype(np.int32)
- return blines, cls, confs
- class ComicTextDetector(OfflineDetector):
- _MODEL_MAPPING = {
- 'model-cuda': {
- 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt',
- 'hash': '1f90fa60aeeb1eb82e2ac1167a66bf139a8a61b8780acd351ead55268540cccb',
- 'file': '.',
- },
- 'model-cpu': {
- 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt.onnx',
- 'hash': '1a86ace74961413cbd650002e7bb4dcec4980ffa21b2f19b86933372071d718f',
- 'file': '.',
- },
- }
- def __init__(self, *args, **kwargs):
- os.makedirs(self.model_dir, exist_ok=True)
- if os.path.exists('comictextdetector.pt'):
- shutil.move('comictextdetector.pt', self._get_file_path('comictextdetector.pt'))
- if os.path.exists('comictextdetector.pt.onnx'):
- shutil.move('comictextdetector.pt.onnx', self._get_file_path('comictextdetector.pt.onnx'))
- super().__init__(*args, **kwargs)
- async def _load(self, device: str, input_size=1024, half=False, nms_thresh=0.35, conf_thresh=0.4):
- self.device = device
- if self.device == 'cuda' or self.device == 'mps':
- self.model = TextDetBase(self._get_file_path('comictextdetector.pt'), device=self.device, act='leaky')
- self.model.to(self.device)
- self.backend = 'torch'
- else:
- model_path = self._get_file_path('comictextdetector.pt.onnx')
- self.model = cv2.dnn.readNetFromONNX(model_path)
- self.model = TextDetBaseDNN(input_size, model_path)
- self.backend = 'opencv'
- if isinstance(input_size, int):
- input_size = (input_size, input_size)
- self.input_size = input_size
- self.half = half
- self.conf_thresh = conf_thresh
- self.nms_thresh = nms_thresh
- self.seg_rep = SegDetectorRepresenter(thresh=0.3)
- async def _unload(self):
- del self.model
- def det_batch_forward_ctd(self, batch: np.ndarray, device: str) -> Tuple[np.ndarray, np.ndarray]:
- if isinstance(self.model, TextDetBase):
- batch = einops.rearrange(batch.astype(np.float32) / 255., 'n h w c -> n c h w')
- batch = torch.from_numpy(batch).to(device)
- _, mask, lines = self.model(batch)
- mask = mask.detach().cpu().numpy()
- lines = lines.detach().cpu().numpy()
- elif isinstance(self.model, TextDetBaseDNN):
- mask_lst, line_lst = [], []
- for b in batch:
- _, mask, lines = self.model(b)
- if mask.shape[1] == 2: # some version of opencv spit out reversed result
- tmp = mask
- mask = lines
- lines = tmp
- mask_lst.append(mask)
- line_lst.append(lines)
- lines, mask = np.concatenate(line_lst, 0), np.concatenate(mask_lst, 0)
- else:
- raise NotImplementedError
- return lines, mask
- @torch.no_grad()
- async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
- unclip_ratio: float, verbose: bool = False):
- # keep_undetected_mask = False
- # refine_mode = REFINEMASK_INPAINT
- im_h, im_w = image.shape[:2]
- lines_map, mask = det_rearrange_forward(image, self.det_batch_forward_ctd, self.input_size[0], 4, self.device, verbose)
- # blks = []
- # resize_ratio = [1, 1]
- if lines_map is None:
- img_in, ratio, dw, dh = preprocess_img(image, input_size=self.input_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')
- blks, mask, lines_map = self.model(img_in)
- if self.backend == 'opencv':
- if mask.shape[1] == 2: # some version of opencv spit out reversed result
- tmp = mask
- mask = lines_map
- lines_map = tmp
- mask = mask.squeeze()
- # resize_ratio = (im_w / (self.input_size[0] - dw), im_h / (self.input_size[1] - dh))
- # blks = postprocess_yolo(blks, self.conf_thresh, self.nms_thresh, resize_ratio)
- mask = mask[..., :mask.shape[0]-dh, :mask.shape[1]-dw]
- lines_map = lines_map[..., :lines_map.shape[2]-dh, :lines_map.shape[3]-dw]
- mask = postprocess_mask(mask)
- lines, scores = self.seg_rep(None, lines_map, height=im_h, width=im_w)
- box_thresh = 0.6
- idx = np.where(scores[0] > box_thresh)
- lines, scores = lines[0][idx], scores[0][idx]
- # map output to input img
- mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
- # if lines.size == 0:
- # lines = []
- # else:
- # lines = lines.astype(np.int32)
- # YOLO was used for finding bboxes which to order the lines into. This is now solved
- # through the textline merger, which seems to work more reliably.
- # The YOLO language detection seems unnecessary as it could never be as good as
- # using the OCR extracted string directly.
- # Doing it for increasing the textline merge accuracy doesn't really work either,
- # as the merge could be postponed until after the OCR finishes.
- textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(lines, scores)]
- mask_refined = refine_mask(image, mask, textlines, refine_mode=None)
- return textlines, mask_refined, None
- # blk_list = group_output(blks, lines, im_w, im_h, mask)
- # mask_refined = refine_mask(image, mask, blk_list, refine_mode=refine_mode)
- # if keep_undetected_mask:
- # mask_refined = refine_undetected_mask(image, mask, mask_refined, blk_list, refine_mode=refine_mode)
- # return blk_list, mask, mask_refined
|