textmask.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from typing import List
  2. import cv2
  3. import numpy as np
  4. from .utils.imgproc_utils import union_area, enlarge_window
  5. from ...utils import TextBlock, Quadrilateral
  6. WHITE = (255, 255, 255)
  7. BLACK = (0, 0, 0)
  8. LANG_ENG = 0
  9. LANG_JPN = 1
  10. REFINEMASK_INPAINT = 0
  11. REFINEMASK_ANNOTATION = 1
  12. def get_topk_color(color_list, bins, k=3, color_var=10, bin_tol=0.001):
  13. idx = np.argsort(bins * -1)
  14. color_list, bins = color_list[idx], bins[idx]
  15. top_colors = [color_list[0]]
  16. bin_tol = np.sum(bins) * bin_tol
  17. if len(color_list) > 1:
  18. for color, bin in zip(color_list[1:], bins[1:]):
  19. if np.abs(np.array(top_colors) - color).min() > color_var:
  20. top_colors.append(color)
  21. if len(top_colors) >= k or bin < bin_tol:
  22. break
  23. return top_colors
  24. def minxor_thresh(threshed, mask, dilate=False):
  25. neg_threshed = 255 - threshed
  26. e_size = 1
  27. if dilate:
  28. element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
  29. neg_threshed = cv2.dilate(neg_threshed, element, iterations=1)
  30. threshed = cv2.dilate(threshed, element, iterations=1)
  31. neg_xor_sum = cv2.bitwise_xor(neg_threshed, mask).sum()
  32. xor_sum = cv2.bitwise_xor(threshed, mask).sum()
  33. if neg_xor_sum < xor_sum:
  34. return neg_threshed, neg_xor_sum
  35. else:
  36. return threshed, xor_sum
  37. def get_otsuthresh_masklist(img, pred_mask, per_channel=False) -> List[np.ndarray]:
  38. channels = [img[..., 0], img[..., 1], img[..., 2]]
  39. mask_list = []
  40. for c in channels:
  41. _, threshed = cv2.threshold(c, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
  42. threshed, xor_sum = minxor_thresh(threshed, pred_mask, dilate=False)
  43. mask_list.append([threshed, xor_sum])
  44. mask_list.sort(key=lambda x: x[1])
  45. if per_channel:
  46. return mask_list
  47. else:
  48. return [mask_list[0]]
  49. def get_topk_masklist(im_grey, pred_mask):
  50. if len(im_grey.shape) == 3 and im_grey.shape[-1] == 3:
  51. im_grey = cv2.cvtColor(im_grey, cv2.COLOR_BGR2GRAY)
  52. msk = np.ascontiguousarray(pred_mask)
  53. candidate_grey_px = im_grey[np.where(cv2.erode(msk, np.ones((3,3), np.uint8), iterations=1) > 127)]
  54. bin, his = np.histogram(candidate_grey_px, bins=255)
  55. topk_color = get_topk_color(his, bin, color_var=10, k=3)
  56. color_range = 30
  57. mask_list = list()
  58. for ii, color in enumerate(topk_color):
  59. c_top = min(color+color_range, 255)
  60. c_bottom = c_top - 2 * color_range
  61. threshed = cv2.inRange(im_grey, c_bottom, c_top)
  62. threshed, xor_sum = minxor_thresh(threshed, msk)
  63. mask_list.append([threshed, xor_sum])
  64. return mask_list
  65. def merge_mask_list(mask_list, pred_mask, blk: Quadrilateral = None, pred_thresh=30, text_window=None, filter_with_lines=False, refine_mode=REFINEMASK_INPAINT):
  66. mask_list.sort(key=lambda x: x[1])
  67. linemask = None
  68. if blk is not None and filter_with_lines:
  69. linemask = np.zeros_like(pred_mask)
  70. lines = blk.pts.astype(np.int64)
  71. for line in lines:
  72. line[..., 0] -= text_window[0]
  73. line[..., 1] -= text_window[1]
  74. cv2.fillPoly(linemask, [line], 255)
  75. linemask = cv2.dilate(linemask, np.ones((3, 3), np.uint8), iterations=3)
  76. if pred_thresh > 0:
  77. e_size = 1
  78. element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
  79. pred_mask = cv2.erode(pred_mask, element, iterations=1)
  80. _, pred_mask = cv2.threshold(pred_mask, 60, 255, cv2.THRESH_BINARY)
  81. connectivity = 8
  82. mask_merged = np.zeros_like(pred_mask)
  83. for ii, (candidate_mask, xor_sum) in enumerate(mask_list):
  84. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(candidate_mask, connectivity, cv2.CV_16U)
  85. for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
  86. if label_index != 0: # skip background label
  87. x, y, w, h, area = stat
  88. if w * h < 3:
  89. continue
  90. x1, y1, x2, y2 = x, y, x+w, y+h
  91. label_local = labels[y1: y2, x1: x2]
  92. label_coordinates = np.where(label_local==label_index)
  93. tmp_merged = np.zeros_like(label_local, np.uint8)
  94. tmp_merged[label_coordinates] = 255
  95. tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged)
  96. xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum()
  97. xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum()
  98. if xor_merged < xor_origin:
  99. mask_merged[y1: y2, x1: x2] = tmp_merged
  100. if refine_mode == REFINEMASK_INPAINT:
  101. mask_merged = cv2.dilate(mask_merged, np.ones((5, 5), np.uint8), iterations=1)
  102. # fill holes
  103. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(255-mask_merged, connectivity, cv2.CV_16U)
  104. sorted_area = np.sort(stats[:, -1])
  105. if len(sorted_area) > 1:
  106. area_thresh = sorted_area[-2]
  107. else:
  108. area_thresh = sorted_area[-1]
  109. for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
  110. x, y, w, h, area = stat
  111. if area < area_thresh:
  112. x1, y1, x2, y2 = x, y, x+w, y+h
  113. label_local = labels[y1: y2, x1: x2]
  114. label_coordinates = np.where(label_local==label_index)
  115. tmp_merged = np.zeros_like(label_local, np.uint8)
  116. tmp_merged[label_coordinates] = 255
  117. tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged)
  118. xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum()
  119. xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum()
  120. if xor_merged < xor_origin:
  121. mask_merged[y1: y2, x1: x2] = tmp_merged
  122. return mask_merged
  123. def refine_undetected_mask(img: np.ndarray, mask_pred: np.ndarray, mask_refined: np.ndarray, blk_list: List[TextBlock], refine_mode=REFINEMASK_INPAINT):
  124. mask_pred[np.where(mask_refined > 30)] = 0
  125. _, pred_mask_t = cv2.threshold(mask_pred, 30, 255, cv2.THRESH_BINARY)
  126. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(pred_mask_t, 4, cv2.CV_16U)
  127. valid_labels = np.where(stats[:, -1] > 50)[0]
  128. seg_blk_list = []
  129. if len(valid_labels) > 0:
  130. for lab_index in valid_labels[1:]:
  131. x, y, w, h, area = stats[lab_index]
  132. bx1, by1 = x, y
  133. bx2, by2 = x+w, y+h
  134. bbox = [bx1, by1, bx2, by2]
  135. bbox_score = -1
  136. for blk in blk_list:
  137. bbox_s = union_area(blk.xyxy, bbox)
  138. if bbox_s > bbox_score:
  139. bbox_score = bbox_s
  140. if bbox_score / w / h < 0.5:
  141. seg_blk_list.append(TextBlock(bbox))
  142. if len(seg_blk_list) > 0:
  143. mask_refined = cv2.bitwise_or(mask_refined, refine_mask(img, mask_pred, seg_blk_list, refine_mode=refine_mode))
  144. return mask_refined
  145. def refine_mask(img: np.ndarray, pred_mask: np.ndarray, blk_list: List[Quadrilateral], refine_mode: int = REFINEMASK_INPAINT) -> np.ndarray:
  146. mask_refined = np.zeros_like(pred_mask)
  147. for blk in blk_list:
  148. bx1, by1, bx2, by2 = enlarge_window(blk.xyxy, img.shape[1], img.shape[0])
  149. im = np.ascontiguousarray(img[by1: by2, bx1: bx2])
  150. msk = np.ascontiguousarray(pred_mask[by1: by2, bx1: bx2])
  151. mask_list = get_topk_masklist(im, msk)
  152. mask_list += get_otsuthresh_masklist(im, msk, per_channel=False)
  153. mask_merged = merge_mask_list(mask_list, msk, blk=blk, text_window=[bx1, by1, bx2, by2], refine_mode=refine_mode)
  154. mask_refined[by1: by2, bx1: bx2] = cv2.bitwise_or(mask_refined[by1: by2, bx1: bx2], mask_merged)
  155. # cv2.imshow('im', im)
  156. # cv2.imshow('msk', msk)
  157. # cv2.imshow('mask_refined', mask_refined[by1: by2, bx1: bx2])
  158. # cv2.waitKey(0)
  159. return mask_refined