text_mask_utils.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. from typing import Tuple, List
  2. import numpy as np
  3. import cv2
  4. import math
  5. from tqdm import tqdm
  6. from shapely.geometry import Polygon
  7. # from sklearn.mixture import BayesianGaussianMixture
  8. # from functools import reduce
  9. # from collections import defaultdict
  10. # from scipy.optimize import linear_sum_assignment
  11. from ..utils import Quadrilateral, image_resize
  12. COLOR_RANGE_SIGMA = 1.5 # how many stddev away is considered the same color
  13. def save_rgb(fn, img):
  14. if len(img.shape) == 3 and img.shape[2] == 3:
  15. cv2.imwrite(fn, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
  16. else:
  17. cv2.imwrite(fn, img)
  18. def area_overlap(x1, y1, w1, h1, x2, y2, w2, h2): # returns None if rectangles don't intersect
  19. x_overlap = max(0, min(x1 + w1, x2 + w2) - max(x1, x2))
  20. y_overlap = max(0, min(y1 + h1, y2 + h2) - max(y1, y2))
  21. return x_overlap * y_overlap
  22. def dist(x1, y1, x2, y2):
  23. return math.sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2))
  24. def rect_distance(x1, y1, x1b, y1b, x2, y2, x2b, y2b):
  25. left = x2b < x1
  26. right = x1b < x2
  27. bottom = y2b < y1
  28. top = y1b < y2
  29. if top and left:
  30. return dist(x1, y1b, x2b, y2)
  31. elif left and bottom:
  32. return dist(x1, y1, x2b, y2b)
  33. elif bottom and right:
  34. return dist(x1b, y1, x2, y2b)
  35. elif right and top:
  36. return dist(x1b, y1b, x2, y2)
  37. elif left:
  38. return x1 - x2b
  39. elif right:
  40. return x2 - x1b
  41. elif bottom:
  42. return y1 - y2b
  43. elif top:
  44. return y2 - y1b
  45. else: # rectangles intersect
  46. return 0
  47. def extend_rect(x, y, w, h, max_x, max_y, extend_size):
  48. x1 = max(x - extend_size, 0)
  49. y1 = max(y - extend_size, 0)
  50. w1 = min(w + extend_size * 2, max_x - x1 - 1)
  51. h1 = min(h + extend_size * 2, max_y - y1 - 1)
  52. return x1, y1, w1, h1
  53. def complete_mask_fill(text_lines: List[Tuple[int, int, int, int]]):
  54. for (x, y, w, h) in text_lines:
  55. final_mask = cv2.rectangle(final_mask, (x, y), (x + w, y + h), (255), -1)
  56. return final_mask
  57. from pydensecrf.utils import compute_unary, unary_from_softmax
  58. import pydensecrf.densecrf as dcrf
  59. def refine_mask(rgbimg, rawmask):
  60. if len(rawmask.shape) == 2:
  61. rawmask = rawmask[:, :, None]
  62. mask_softmax = np.concatenate([cv2.bitwise_not(rawmask)[:, :, None], rawmask], axis=2)
  63. mask_softmax = mask_softmax.astype(np.float32) / 255.0
  64. n_classes = 2
  65. feat_first = mask_softmax.transpose((2, 0, 1)).reshape((n_classes,-1))
  66. unary = unary_from_softmax(feat_first)
  67. unary = np.ascontiguousarray(unary)
  68. d = dcrf.DenseCRF2D(rgbimg.shape[1], rgbimg.shape[0], n_classes)
  69. d.setUnaryEnergy(unary)
  70. d.addPairwiseGaussian(sxy=1, compat=3, kernel=dcrf.DIAG_KERNEL,
  71. normalization=dcrf.NO_NORMALIZATION)
  72. d.addPairwiseBilateral(sxy=23, srgb=7, rgbim=rgbimg,
  73. compat=20,
  74. kernel=dcrf.DIAG_KERNEL,
  75. normalization=dcrf.NO_NORMALIZATION)
  76. Q = d.inference(5)
  77. res = np.argmax(Q, axis=0).reshape((rgbimg.shape[0], rgbimg.shape[1]))
  78. crf_mask = np.array(res * 255, dtype=np.uint8)
  79. return crf_mask
  80. def complete_mask(img: np.ndarray, mask: np.ndarray, textlines: List[Quadrilateral], keep_threshold = 1e-2, dilation_offset = 0,kernel_size=3):
  81. bboxes = [txtln.aabb.xywh for txtln in textlines]
  82. polys = [Polygon(txtln.pts) for txtln in textlines]
  83. for (x, y, w, h) in bboxes:
  84. cv2.rectangle(mask, (x, y), (x + w, y + h), (0), 1)
  85. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask)
  86. M = len(textlines)
  87. textline_ccs = [np.zeros_like(mask) for _ in range(M)]
  88. iinfo = np.iinfo(labels.dtype)
  89. textline_rects = np.full(shape = (M, 4), fill_value = [iinfo.max, iinfo.max, iinfo.min, iinfo.min], dtype = labels.dtype)
  90. ratio_mat = np.zeros(shape = (num_labels, M), dtype = np.float32)
  91. dist_mat = np.zeros(shape = (num_labels, M), dtype = np.float32)
  92. valid = False
  93. for label in range(1, num_labels):
  94. # skip area too small
  95. if stats[label, cv2.CC_STAT_AREA] <= 9:
  96. continue
  97. x1 = stats[label, cv2.CC_STAT_LEFT]
  98. y1 = stats[label, cv2.CC_STAT_TOP]
  99. w1 = stats[label, cv2.CC_STAT_WIDTH]
  100. h1 = stats[label, cv2.CC_STAT_HEIGHT]
  101. area1 = stats[label, cv2.CC_STAT_AREA]
  102. cc_pts = np.array([[x1, y1], [x1 + w1, y1], [x1 + w1, y1 + h1], [x1, y1 + h1]])
  103. cc_poly = Polygon(cc_pts)
  104. for tl_idx in range(M):
  105. area2 = polys[tl_idx].area
  106. overlapping_area = polys[tl_idx].intersection(cc_poly).area
  107. ratio_mat[label, tl_idx] = overlapping_area / min(area1, area2)
  108. dist_mat[label, tl_idx] = polys[tl_idx].distance(cc_poly.centroid)
  109. # print(textlines[tl_idx].pts, cc_pts, '->', overlapping_area, min(area1, area2), '=', overlapping_area / min(area1, area2), '|', polys[tl_idx].distance(cc_poly))
  110. avg = np.argmax(ratio_mat[label])
  111. # print(avg, 'overlap:', ratio_mat[label, avg], '<=', keep_threshold)
  112. area2 = polys[avg].area
  113. if area1 >= area2:
  114. continue
  115. if ratio_mat[label, avg] <= keep_threshold:
  116. avg = np.argmin(dist_mat[label])
  117. area2 = polys[avg].area
  118. unit = max(min([textlines[avg].font_size, w1, h1]), 10)
  119. # print("unit", unit, textlines[avg].font_size, w1, h1)
  120. # if area1 < 0.4 * w1 * h1:
  121. # # ccs is probably angled
  122. # unit /= 2
  123. # if avg == 0:
  124. # print('no intersect', area1, '>=', area2, dist_mat[label, avg], '>=', 0.5 * unit)
  125. if dist_mat[label, avg] >= 0.5 * unit:
  126. # print(dist_mat[label])
  127. # print('CONTINUE')
  128. continue
  129. textline_ccs[avg][y1:y1+h1, x1:x1+w1][labels[y1:y1+h1, x1:x1+w1] == label] = 255
  130. # if avg == 0:
  131. # print(avg)
  132. # cv2.imshow('ccs', image_resize(textline_ccs[avg], height = 800))
  133. # cv2.waitKey(0)
  134. textline_rects[avg, 0] = min(textline_rects[avg, 0], x1)
  135. textline_rects[avg, 1] = min(textline_rects[avg, 1], y1)
  136. textline_rects[avg, 2] = max(textline_rects[avg, 2], x1 + w1)
  137. textline_rects[avg, 3] = max(textline_rects[avg, 3], y1 + h1)
  138. valid = True
  139. if not valid:
  140. return None
  141. # tblr to xywh
  142. textline_rects[:, 2] -= textline_rects[:, 0]
  143. textline_rects[:, 3] -= textline_rects[:, 1]
  144. final_mask = np.zeros_like(mask)
  145. img = cv2.bilateralFilter(img, 17, 80, 80)
  146. for i, cc in enumerate(tqdm(textline_ccs, '[mask]')):
  147. x1, y1, w1, h1 = textline_rects[i]
  148. text_size = min(w1, h1, textlines[i].font_size)
  149. x1, y1, w1, h1 = extend_rect(x1, y1, w1, h1, img.shape[1], img.shape[0], int(text_size * 0.1))
  150. # TODO: Need to think of better way to determine dilate_size.
  151. dilate_size = max((int((text_size + dilation_offset) * 0.3) // 2) * 2 + 1, 3)
  152. # print(textlines[i].font_size, min(w1, h1), dilate_size)
  153. kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_size, dilate_size))
  154. cc_region = np.ascontiguousarray(cc[y1: y1 + h1, x1: x1 + w1])
  155. if cc_region.size == 0:
  156. continue
  157. # cv2.imshow('cc before', image_resize(cc_region, height = 800))
  158. img_region = np.ascontiguousarray(img[y1: y1 + h1, x1: x1 + w1])
  159. # cv2.imshow('img', image_resize(img_region, height = 800))
  160. cc_region = refine_mask(img_region, cc_region)
  161. # cv2.imshow('cc after', image_resize(cc_region, height = 800))
  162. # cv2.waitKey(0)
  163. cc[y1: y1 + h1, x1: x1 + w1] = cc_region
  164. # cc = cv2.dilate(cc, kern)
  165. x2, y2, w2, h2 = extend_rect(x1, y1, w1, h1, img.shape[1], img.shape[0], -(-dilate_size // 2))
  166. cc[y2:y2+h2, x2:x2+w2] = cv2.dilate(cc[y2:y2+h2, x2:x2+w2], kern)
  167. final_mask[y2:y2+h2, x2:x2+w2] = cv2.bitwise_or(final_mask[y2:y2+h2, x2:x2+w2], cc[y2:y2+h2, x2:x2+w2])
  168. kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
  169. # for (x, y, w, h) in text_lines:
  170. # final_mask = cv2.rectangle(final_mask, (x, y), (x + w, y + h), (255), -1)
  171. return cv2.dilate(final_mask, kern)
  172. def unsharp(image):
  173. gaussian_3 = cv2.GaussianBlur(image, (3, 3), 2.0)
  174. return cv2.addWeighted(image, 1.5, gaussian_3, -0.5, 0, image)