ctd.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import os
  2. import shutil
  3. import numpy as np
  4. import einops
  5. from typing import Union, Tuple
  6. import cv2
  7. import torch
  8. from .ctd_utils.basemodel import TextDetBase, TextDetBaseDNN
  9. from .ctd_utils.utils.yolov5_utils import non_max_suppression
  10. from .ctd_utils.utils.db_utils import SegDetectorRepresenter
  11. from .ctd_utils.utils.imgproc_utils import letterbox
  12. from .ctd_utils.textmask import REFINEMASK_INPAINT, refine_mask
  13. from .common import OfflineDetector
  14. from ..utils import Quadrilateral, det_rearrange_forward
  15. def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True):
  16. if bgr2rgb:
  17. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  18. img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64)
  19. if to_tensor:
  20. img_in = img_in.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  21. img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255
  22. if to_tensor:
  23. img_in = torch.from_numpy(img_in).to(device)
  24. if half:
  25. img_in = img_in.half()
  26. return img_in, ratio, int(dw), int(dh)
  27. def postprocess_mask(img: Union[torch.Tensor, np.ndarray], thresh=None):
  28. # img = img.permute(1, 2, 0)
  29. if isinstance(img, torch.Tensor):
  30. img = img.squeeze_()
  31. if img.device != 'cpu':
  32. img = img.detach().cpu()
  33. img = img.numpy()
  34. else:
  35. img = img.squeeze()
  36. if thresh is not None:
  37. img = img > thresh
  38. img = img * 255
  39. # if isinstance(img, torch.Tensor):
  40. return img.astype(np.uint8)
  41. def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None):
  42. det = non_max_suppression(det, conf_thresh, nms_thresh)[0]
  43. # bbox = det[..., 0:4]
  44. if det.device != 'cpu':
  45. det = det.detach_().cpu().numpy()
  46. det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0]
  47. det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1]
  48. if sort_func is not None:
  49. det = sort_func(det)
  50. blines = det[..., 0:4].astype(np.int32)
  51. confs = np.round(det[..., 4], 3)
  52. cls = det[..., 5].astype(np.int32)
  53. return blines, cls, confs
  54. class ComicTextDetector(OfflineDetector):
  55. _MODEL_MAPPING = {
  56. 'model-cuda': {
  57. 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt',
  58. 'hash': '1f90fa60aeeb1eb82e2ac1167a66bf139a8a61b8780acd351ead55268540cccb',
  59. 'file': '.',
  60. },
  61. 'model-cpu': {
  62. 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt.onnx',
  63. 'hash': '1a86ace74961413cbd650002e7bb4dcec4980ffa21b2f19b86933372071d718f',
  64. 'file': '.',
  65. },
  66. }
  67. def __init__(self, *args, **kwargs):
  68. os.makedirs(self.model_dir, exist_ok=True)
  69. if os.path.exists('comictextdetector.pt'):
  70. shutil.move('comictextdetector.pt', self._get_file_path('comictextdetector.pt'))
  71. if os.path.exists('comictextdetector.pt.onnx'):
  72. shutil.move('comictextdetector.pt.onnx', self._get_file_path('comictextdetector.pt.onnx'))
  73. super().__init__(*args, **kwargs)
  74. async def _load(self, device: str, input_size=1024, half=False, nms_thresh=0.35, conf_thresh=0.4):
  75. self.device = device
  76. if self.device == 'cuda' or self.device == 'mps':
  77. self.model = TextDetBase(self._get_file_path('comictextdetector.pt'), device=self.device, act='leaky')
  78. self.model.to(self.device)
  79. self.backend = 'torch'
  80. else:
  81. model_path = self._get_file_path('comictextdetector.pt.onnx')
  82. self.model = cv2.dnn.readNetFromONNX(model_path)
  83. self.model = TextDetBaseDNN(input_size, model_path)
  84. self.backend = 'opencv'
  85. if isinstance(input_size, int):
  86. input_size = (input_size, input_size)
  87. self.input_size = input_size
  88. self.half = half
  89. self.conf_thresh = conf_thresh
  90. self.nms_thresh = nms_thresh
  91. self.seg_rep = SegDetectorRepresenter(thresh=0.3)
  92. async def _unload(self):
  93. del self.model
  94. def det_batch_forward_ctd(self, batch: np.ndarray, device: str) -> Tuple[np.ndarray, np.ndarray]:
  95. if isinstance(self.model, TextDetBase):
  96. batch = einops.rearrange(batch.astype(np.float32) / 255., 'n h w c -> n c h w')
  97. batch = torch.from_numpy(batch).to(device)
  98. _, mask, lines = self.model(batch)
  99. mask = mask.detach().cpu().numpy()
  100. lines = lines.detach().cpu().numpy()
  101. elif isinstance(self.model, TextDetBaseDNN):
  102. mask_lst, line_lst = [], []
  103. for b in batch:
  104. _, mask, lines = self.model(b)
  105. if mask.shape[1] == 2: # some version of opencv spit out reversed result
  106. tmp = mask
  107. mask = lines
  108. lines = tmp
  109. mask_lst.append(mask)
  110. line_lst.append(lines)
  111. lines, mask = np.concatenate(line_lst, 0), np.concatenate(mask_lst, 0)
  112. else:
  113. raise NotImplementedError
  114. return lines, mask
  115. @torch.no_grad()
  116. async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
  117. unclip_ratio: float, verbose: bool = False):
  118. # keep_undetected_mask = False
  119. # refine_mode = REFINEMASK_INPAINT
  120. im_h, im_w = image.shape[:2]
  121. lines_map, mask = det_rearrange_forward(image, self.det_batch_forward_ctd, self.input_size[0], 4, self.device, verbose)
  122. # blks = []
  123. # resize_ratio = [1, 1]
  124. if lines_map is None:
  125. img_in, ratio, dw, dh = preprocess_img(image, input_size=self.input_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')
  126. blks, mask, lines_map = self.model(img_in)
  127. if self.backend == 'opencv':
  128. if mask.shape[1] == 2: # some version of opencv spit out reversed result
  129. tmp = mask
  130. mask = lines_map
  131. lines_map = tmp
  132. mask = mask.squeeze()
  133. # resize_ratio = (im_w / (self.input_size[0] - dw), im_h / (self.input_size[1] - dh))
  134. # blks = postprocess_yolo(blks, self.conf_thresh, self.nms_thresh, resize_ratio)
  135. mask = mask[..., :mask.shape[0]-dh, :mask.shape[1]-dw]
  136. lines_map = lines_map[..., :lines_map.shape[2]-dh, :lines_map.shape[3]-dw]
  137. mask = postprocess_mask(mask)
  138. lines, scores = self.seg_rep(None, lines_map, height=im_h, width=im_w)
  139. box_thresh = 0.6
  140. idx = np.where(scores[0] > box_thresh)
  141. lines, scores = lines[0][idx], scores[0][idx]
  142. # map output to input img
  143. mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
  144. # if lines.size == 0:
  145. # lines = []
  146. # else:
  147. # lines = lines.astype(np.int32)
  148. # YOLO was used for finding bboxes which to order the lines into. This is now solved
  149. # through the textline merger, which seems to work more reliably.
  150. # The YOLO language detection seems unnecessary as it could never be as good as
  151. # using the OCR extracted string directly.
  152. # Doing it for increasing the textline merge accuracy doesn't really work either,
  153. # as the merge could be postponed until after the OCR finishes.
  154. textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(lines, scores)]
  155. mask_refined = refine_mask(image, mask, textlines, refine_mode=None)
  156. return textlines, mask_refined, None
  157. # blk_list = group_output(blks, lines, im_w, im_h, mask)
  158. # mask_refined = refine_mask(image, mask, blk_list, refine_mode=refine_mode)
  159. # if keep_undetected_mask:
  160. # mask_refined = refine_undetected_mask(image, mask, mask_refined, blk_list, refine_mode=refine_mode)
  161. # return blk_list, mask, mask_refined