db_utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. import cv2
  2. import numpy as np
  3. import pyclipper
  4. from shapely.geometry import Polygon
  5. from collections import namedtuple
  6. import warnings
  7. import torch
  8. warnings.filterwarnings('ignore')
  9. def iou_rotate(box_a, box_b, method='union'):
  10. rect_a = cv2.minAreaRect(box_a)
  11. rect_b = cv2.minAreaRect(box_b)
  12. r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b)
  13. if r1[0] == 0:
  14. return 0
  15. else:
  16. inter_area = cv2.contourArea(r1[1])
  17. area_a = cv2.contourArea(box_a)
  18. area_b = cv2.contourArea(box_b)
  19. union_area = area_a + area_b - inter_area
  20. if union_area == 0 or inter_area == 0:
  21. return 0
  22. if method == 'union':
  23. iou = inter_area / union_area
  24. elif method == 'intersection':
  25. iou = inter_area / min(area_a, area_b)
  26. else:
  27. raise NotImplementedError
  28. return iou
  29. class SegDetectorRepresenter():
  30. def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5):
  31. self.min_size = 3
  32. self.thresh = thresh
  33. self.box_thresh = box_thresh
  34. self.max_candidates = max_candidates
  35. self.unclip_ratio = unclip_ratio
  36. def __call__(self, batch, pred, is_output_polygon=False, height=None, width=None):
  37. '''
  38. batch: (image, polygons, ignore_tags
  39. batch: a dict produced by dataloaders.
  40. image: tensor of shape (N, C, H, W).
  41. polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
  42. ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
  43. shape: the original shape of images.
  44. filename: the original filenames of images.
  45. pred:
  46. binary: text region segmentation map, with shape (N, H, W)
  47. thresh: [if exists] thresh hold prediction with shape (N, H, W)
  48. thresh_binary: [if exists] binarized with threshold, (N, H, W)
  49. '''
  50. pred = pred[:, 0, :, :]
  51. segmentation = self.binarize(pred)
  52. boxes_batch = []
  53. scores_batch = []
  54. # print(pred.size())
  55. batch_size = pred.size(0) if isinstance(pred, torch.Tensor) else pred.shape[0]
  56. if height is None:
  57. height = pred.shape[1]
  58. if width is None:
  59. width = pred.shape[2]
  60. for batch_index in range(batch_size):
  61. if is_output_polygon:
  62. boxes, scores = self.polygons_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
  63. else:
  64. boxes, scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
  65. boxes_batch.append(boxes)
  66. scores_batch.append(scores)
  67. return boxes_batch, scores_batch
  68. def binarize(self, pred) -> np.ndarray:
  69. return pred > self.thresh
  70. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  71. '''
  72. _bitmap: single map with shape (H, W),
  73. whose values are binarized as {0, 1}
  74. '''
  75. assert len(_bitmap.shape) == 2
  76. bitmap = _bitmap.cpu().numpy() # The first channel
  77. pred = pred.cpu().detach().numpy()
  78. height, width = bitmap.shape
  79. boxes = []
  80. scores = []
  81. contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  82. for contour in contours[:self.max_candidates]:
  83. epsilon = 0.005 * cv2.arcLength(contour, True)
  84. approx = cv2.approxPolyDP(contour, epsilon, True)
  85. points = approx.reshape((-1, 2))
  86. if points.shape[0] < 4:
  87. continue
  88. # _, sside = self.get_mini_boxes(contour)
  89. # if sside < self.min_size:
  90. # continue
  91. score = self.box_score_fast(pred, contour.squeeze(1))
  92. if self.box_thresh > score:
  93. continue
  94. if points.shape[0] > 2:
  95. box = self.unclip(points, unclip_ratio=self.unclip_ratio)
  96. if len(box) > 1:
  97. continue
  98. else:
  99. continue
  100. box = box.reshape(-1, 2)
  101. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  102. if sside < self.min_size + 2:
  103. continue
  104. if not isinstance(dest_width, int):
  105. dest_width = dest_width.item()
  106. dest_height = dest_height.item()
  107. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  108. box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
  109. boxes.append(box)
  110. scores.append(score)
  111. return boxes, scores
  112. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  113. '''
  114. _bitmap: single map with shape (H, W),
  115. whose values are binarized as {0, 1}
  116. '''
  117. assert len(_bitmap.shape) == 2
  118. if isinstance(pred, torch.Tensor):
  119. bitmap = _bitmap.cpu().numpy() # The first channel
  120. pred = pred.cpu().detach().numpy()
  121. else:
  122. bitmap = _bitmap
  123. # cv2.imwrite('tmp.png', (bitmap*255).astype(np.uint8))
  124. height, width = bitmap.shape
  125. contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  126. num_contours = min(len(contours), self.max_candidates)
  127. boxes = np.zeros((num_contours, 4, 2), dtype=np.int64)
  128. scores = np.zeros((num_contours,), dtype=np.float32)
  129. for index in range(num_contours):
  130. contour = contours[index].squeeze(1)
  131. points, sside = self.get_mini_boxes(contour)
  132. # if sside < self.min_size:
  133. # continue
  134. if sside < 2:
  135. continue
  136. points = np.array(points)
  137. score = self.box_score_fast(pred, contour)
  138. # if self.box_thresh > score:
  139. # continue
  140. box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
  141. box, sside = self.get_mini_boxes(box)
  142. # if sside < 5:
  143. # continue
  144. box = np.array(box)
  145. if not isinstance(dest_width, int):
  146. dest_width = dest_width.item()
  147. dest_height = dest_height.item()
  148. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  149. box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
  150. boxes[index, :, :] = box.astype(np.int64)
  151. scores[index] = score
  152. return boxes, scores
  153. def unclip(self, box, unclip_ratio=1.5):
  154. poly = Polygon(box)
  155. distance = poly.area * unclip_ratio / poly.length
  156. offset = pyclipper.PyclipperOffset()
  157. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  158. expanded = np.array(offset.Execute(distance))
  159. return expanded
  160. def get_mini_boxes(self, contour):
  161. bounding_box = cv2.minAreaRect(contour)
  162. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  163. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  164. if points[1][1] > points[0][1]:
  165. index_1 = 0
  166. index_4 = 1
  167. else:
  168. index_1 = 1
  169. index_4 = 0
  170. if points[3][1] > points[2][1]:
  171. index_2 = 2
  172. index_3 = 3
  173. else:
  174. index_2 = 3
  175. index_3 = 2
  176. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  177. return box, min(bounding_box[1])
  178. def box_score_fast(self, bitmap, _box):
  179. h, w = bitmap.shape[:2]
  180. box = _box.copy()
  181. xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
  182. xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
  183. ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
  184. ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
  185. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  186. box[:, 0] = box[:, 0] - xmin
  187. box[:, 1] = box[:, 1] - ymin
  188. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  189. if bitmap.dtype == np.float16:
  190. bitmap = bitmap.astype(np.float32)
  191. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  192. class AverageMeter(object):
  193. """Computes and stores the average and current value"""
  194. def __init__(self):
  195. self.reset()
  196. def reset(self):
  197. self.val = 0
  198. self.avg = 0
  199. self.sum = 0
  200. self.count = 0
  201. def update(self, val, n=1):
  202. self.val = val
  203. self.sum += val * n
  204. self.count += n
  205. self.avg = self.sum / self.count
  206. return self
  207. class DetectionIoUEvaluator(object):
  208. def __init__(self, is_output_polygon=False, iou_constraint=0.5, area_precision_constraint=0.5):
  209. self.is_output_polygon = is_output_polygon
  210. self.iou_constraint = iou_constraint
  211. self.area_precision_constraint = area_precision_constraint
  212. def evaluate_image(self, gt, pred):
  213. def get_union(pD, pG):
  214. return Polygon(pD).union(Polygon(pG)).area
  215. def get_intersection_over_union(pD, pG):
  216. return get_intersection(pD, pG) / get_union(pD, pG)
  217. def get_intersection(pD, pG):
  218. return Polygon(pD).intersection(Polygon(pG)).area
  219. def compute_ap(confList, matchList, numGtCare):
  220. correct = 0
  221. AP = 0
  222. if len(confList) > 0:
  223. confList = np.array(confList)
  224. matchList = np.array(matchList)
  225. sorted_ind = np.argsort(-confList)
  226. confList = confList[sorted_ind]
  227. matchList = matchList[sorted_ind]
  228. for n in range(len(confList)):
  229. match = matchList[n]
  230. if match:
  231. correct += 1
  232. AP += float(correct) / (n + 1)
  233. if numGtCare > 0:
  234. AP /= numGtCare
  235. return AP
  236. perSampleMetrics = {}
  237. matchedSum = 0
  238. Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
  239. numGlobalCareGt = 0
  240. numGlobalCareDet = 0
  241. arrGlobalConfidences = []
  242. arrGlobalMatches = []
  243. recall = 0
  244. precision = 0
  245. hmean = 0
  246. detMatched = 0
  247. iouMat = np.empty([1, 1])
  248. gtPols = []
  249. detPols = []
  250. gtPolPoints = []
  251. detPolPoints = []
  252. # Array of Ground Truth Polygons' keys marked as don't Care
  253. gtDontCarePolsNum = []
  254. # Array of Detected Polygons' matched with a don't Care GT
  255. detDontCarePolsNum = []
  256. pairs = []
  257. detMatchedNums = []
  258. arrSampleConfidences = []
  259. arrSampleMatch = []
  260. evaluationLog = ""
  261. for n in range(len(gt)):
  262. points = gt[n]['points']
  263. # transcription = gt[n]['text']
  264. dontCare = gt[n]['ignore']
  265. if not Polygon(points).is_valid or not Polygon(points).is_simple:
  266. continue
  267. gtPol = points
  268. gtPols.append(gtPol)
  269. gtPolPoints.append(points)
  270. if dontCare:
  271. gtDontCarePolsNum.append(len(gtPols) - 1)
  272. evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(
  273. gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
  274. for n in range(len(pred)):
  275. points = pred[n]['points']
  276. if not Polygon(points).is_valid or not Polygon(points).is_simple:
  277. continue
  278. detPol = points
  279. detPols.append(detPol)
  280. detPolPoints.append(points)
  281. if len(gtDontCarePolsNum) > 0:
  282. for dontCarePol in gtDontCarePolsNum:
  283. dontCarePol = gtPols[dontCarePol]
  284. intersected_area = get_intersection(dontCarePol, detPol)
  285. pdDimensions = Polygon(detPol).area
  286. precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
  287. if (precision > self.area_precision_constraint):
  288. detDontCarePolsNum.append(len(detPols) - 1)
  289. break
  290. evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(
  291. detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
  292. if len(gtPols) > 0 and len(detPols) > 0:
  293. # Calculate IoU and precision matrixs
  294. outputShape = [len(gtPols), len(detPols)]
  295. iouMat = np.empty(outputShape)
  296. gtRectMat = np.zeros(len(gtPols), np.int8)
  297. detRectMat = np.zeros(len(detPols), np.int8)
  298. if self.is_output_polygon:
  299. for gtNum in range(len(gtPols)):
  300. for detNum in range(len(detPols)):
  301. pG = gtPols[gtNum]
  302. pD = detPols[detNum]
  303. iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
  304. else:
  305. # gtPols = np.float32(gtPols)
  306. # detPols = np.float32(detPols)
  307. for gtNum in range(len(gtPols)):
  308. for detNum in range(len(detPols)):
  309. pG = np.float32(gtPols[gtNum])
  310. pD = np.float32(detPols[detNum])
  311. iouMat[gtNum, detNum] = iou_rotate(pD, pG)
  312. for gtNum in range(len(gtPols)):
  313. for detNum in range(len(detPols)):
  314. if gtRectMat[gtNum] == 0 and detRectMat[
  315. detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
  316. if iouMat[gtNum, detNum] > self.iou_constraint:
  317. gtRectMat[gtNum] = 1
  318. detRectMat[detNum] = 1
  319. detMatched += 1
  320. pairs.append({'gt': gtNum, 'det': detNum})
  321. detMatchedNums.append(detNum)
  322. evaluationLog += "Match GT #" + \
  323. str(gtNum) + " with Det #" + str(detNum) + "\n"
  324. numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
  325. numDetCare = (len(detPols) - len(detDontCarePolsNum))
  326. if numGtCare == 0:
  327. recall = float(1)
  328. precision = float(0) if numDetCare > 0 else float(1)
  329. else:
  330. recall = float(detMatched) / numGtCare
  331. precision = 0 if numDetCare == 0 else float(
  332. detMatched) / numDetCare
  333. hmean = 0 if (precision + recall) == 0 else 2.0 * \
  334. precision * recall / (precision + recall)
  335. matchedSum += detMatched
  336. numGlobalCareGt += numGtCare
  337. numGlobalCareDet += numDetCare
  338. perSampleMetrics = {
  339. 'precision': precision,
  340. 'recall': recall,
  341. 'hmean': hmean,
  342. 'pairs': pairs,
  343. 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
  344. 'gtPolPoints': gtPolPoints,
  345. 'detPolPoints': detPolPoints,
  346. 'gtCare': numGtCare,
  347. 'detCare': numDetCare,
  348. 'gtDontCare': gtDontCarePolsNum,
  349. 'detDontCare': detDontCarePolsNum,
  350. 'detMatched': detMatched,
  351. 'evaluationLog': evaluationLog
  352. }
  353. return perSampleMetrics
  354. def combine_results(self, results):
  355. numGlobalCareGt = 0
  356. numGlobalCareDet = 0
  357. matchedSum = 0
  358. for result in results:
  359. numGlobalCareGt += result['gtCare']
  360. numGlobalCareDet += result['detCare']
  361. matchedSum += result['detMatched']
  362. methodRecall = 0 if numGlobalCareGt == 0 else float(
  363. matchedSum) / numGlobalCareGt
  364. methodPrecision = 0 if numGlobalCareDet == 0 else float(
  365. matchedSum) / numGlobalCareDet
  366. methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
  367. methodRecall * methodPrecision / (
  368. methodRecall + methodPrecision)
  369. methodMetrics = {'precision': methodPrecision,
  370. 'recall': methodRecall, 'hmean': methodHmean}
  371. return methodMetrics
  372. class QuadMetric():
  373. def __init__(self, is_output_polygon=False):
  374. self.is_output_polygon = is_output_polygon
  375. self.evaluator = DetectionIoUEvaluator(is_output_polygon=is_output_polygon)
  376. def measure(self, batch, output, box_thresh=0.6):
  377. '''
  378. batch: (image, polygons, ignore_tags
  379. batch: a dict produced by dataloaders.
  380. image: tensor of shape (N, C, H, W).
  381. polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
  382. ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
  383. shape: the original shape of images.
  384. filename: the original filenames of images.
  385. output: (polygons, ...)
  386. '''
  387. results = []
  388. gt_polyons_batch = batch['text_polys']
  389. ignore_tags_batch = batch['ignore_tags']
  390. pred_polygons_batch = np.array(output[0])
  391. pred_scores_batch = np.array(output[1])
  392. for polygons, pred_polygons, pred_scores, ignore_tags in zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch):
  393. gt = [dict(points=np.int64(polygons[i]), ignore=ignore_tags[i]) for i in range(len(polygons))]
  394. if self.is_output_polygon:
  395. pred = [dict(points=pred_polygons[i]) for i in range(len(pred_polygons))]
  396. else:
  397. pred = []
  398. # print(pred_polygons.shape)
  399. for i in range(pred_polygons.shape[0]):
  400. if pred_scores[i] >= box_thresh:
  401. # print(pred_polygons[i,:,:].tolist())
  402. pred.append(dict(points=pred_polygons[i, :, :].astype(np.int32)))
  403. # pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])]
  404. results.append(self.evaluator.evaluate_image(gt, pred))
  405. return results
  406. def validate_measure(self, batch, output, box_thresh=0.6):
  407. return self.measure(batch, output, box_thresh)
  408. def evaluate_measure(self, batch, output):
  409. return self.measure(batch, output), np.linspace(0, batch['image'].shape[0]).tolist()
  410. def gather_measure(self, raw_metrics):
  411. raw_metrics = [image_metrics
  412. for batch_metrics in raw_metrics
  413. for image_metrics in batch_metrics]
  414. result = self.evaluator.combine_results(raw_metrics)
  415. precision = AverageMeter()
  416. recall = AverageMeter()
  417. fmeasure = AverageMeter()
  418. precision.update(result['precision'], n=len(raw_metrics))
  419. recall.update(result['recall'], n=len(raw_metrics))
  420. fmeasure_score = 2 * precision.val * recall.val / (precision.val + recall.val + 1e-8)
  421. fmeasure.update(fmeasure_score)
  422. return {
  423. 'precision': precision,
  424. 'recall': recall,
  425. 'fmeasure': fmeasure
  426. }
  427. def shrink_polygon_py(polygon, shrink_ratio):
  428. """
  429. 对框进行缩放,返回去的比例为1/shrink_ratio 即可
  430. """
  431. cx = polygon[:, 0].mean()
  432. cy = polygon[:, 1].mean()
  433. polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio
  434. polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio
  435. return polygon
  436. def shrink_polygon_pyclipper(polygon, shrink_ratio):
  437. from shapely.geometry import Polygon
  438. import pyclipper
  439. polygon_shape = Polygon(polygon)
  440. distance = polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length
  441. subject = [tuple(l) for l in polygon]
  442. padding = pyclipper.PyclipperOffset()
  443. padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  444. shrunk = padding.Execute(-distance)
  445. if shrunk == []:
  446. shrunk = np.array(shrunk)
  447. else:
  448. shrunk = np.array(shrunk[0]).reshape(-1, 2)
  449. return shrunk
  450. class MakeShrinkMap():
  451. r'''
  452. Making binary mask from detection data with ICDAR format.
  453. Typically following the process of class `MakeICDARData`.
  454. '''
  455. def __init__(self, min_text_size=4, shrink_ratio=0.4, shrink_type='pyclipper'):
  456. shrink_func_dict = {'py': shrink_polygon_py, 'pyclipper': shrink_polygon_pyclipper}
  457. self.shrink_func = shrink_func_dict[shrink_type]
  458. self.min_text_size = min_text_size
  459. self.shrink_ratio = shrink_ratio
  460. def __call__(self, data: dict) -> dict:
  461. """
  462. 从scales中随机选择一个尺度,对图片和文本框进行缩放
  463. :param data: {'imgs':,'text_polys':,'texts':,'ignore_tags':}
  464. :return:
  465. """
  466. image = data['imgs']
  467. text_polys = data['text_polys']
  468. ignore_tags = data['ignore_tags']
  469. h, w = image.shape[:2]
  470. text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
  471. gt = np.zeros((h, w), dtype=np.float32)
  472. mask = np.ones((h, w), dtype=np.float32)
  473. for i in range(len(text_polys)):
  474. polygon = text_polys[i]
  475. height = max(polygon[:, 1]) - min(polygon[:, 1])
  476. width = max(polygon[:, 0]) - min(polygon[:, 0])
  477. if ignore_tags[i] or min(height, width) < self.min_text_size:
  478. cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
  479. ignore_tags[i] = True
  480. else:
  481. shrunk = self.shrink_func(polygon, self.shrink_ratio)
  482. if shrunk.size == 0:
  483. cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
  484. ignore_tags[i] = True
  485. continue
  486. cv2.fillPoly(gt, [shrunk.astype(np.int32)], 1)
  487. data['shrink_map'] = gt
  488. data['shrink_mask'] = mask
  489. return data
  490. def validate_polygons(self, polygons, ignore_tags, h, w):
  491. '''
  492. polygons (numpy.array, required): of shape (num_instances, num_points, 2)
  493. '''
  494. if len(polygons) == 0:
  495. return polygons, ignore_tags
  496. assert len(polygons) == len(ignore_tags)
  497. for polygon in polygons:
  498. polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
  499. polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
  500. for i in range(len(polygons)):
  501. area = self.polygon_area(polygons[i])
  502. if abs(area) < 1:
  503. ignore_tags[i] = True
  504. if area > 0:
  505. polygons[i] = polygons[i][::-1, :]
  506. return polygons, ignore_tags
  507. def polygon_area(self, polygon):
  508. return cv2.contourArea(polygon)
  509. class MakeBorderMap():
  510. def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7):
  511. self.shrink_ratio = shrink_ratio
  512. self.thresh_min = thresh_min
  513. self.thresh_max = thresh_max
  514. def __call__(self, data: dict) -> dict:
  515. """
  516. 从scales中随机选择一个尺度,对图片和文本框进行缩放
  517. :param data: {'imgs':,'text_polys':,'texts':,'ignore_tags':}
  518. :return:
  519. """
  520. im = data['imgs']
  521. text_polys = data['text_polys']
  522. ignore_tags = data['ignore_tags']
  523. canvas = np.zeros(im.shape[:2], dtype=np.float32)
  524. mask = np.zeros(im.shape[:2], dtype=np.float32)
  525. for i in range(len(text_polys)):
  526. if ignore_tags[i]:
  527. continue
  528. self.draw_border_map(text_polys[i], canvas, mask=mask)
  529. canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
  530. data['threshold_map'] = canvas
  531. data['threshold_mask'] = mask
  532. return data
  533. def draw_border_map(self, polygon, canvas, mask):
  534. polygon = np.array(polygon)
  535. assert polygon.ndim == 2
  536. assert polygon.shape[1] == 2
  537. polygon_shape = Polygon(polygon)
  538. if polygon_shape.area <= 0:
  539. return
  540. distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
  541. subject = [tuple(l) for l in polygon]
  542. padding = pyclipper.PyclipperOffset()
  543. padding.AddPath(subject, pyclipper.JT_ROUND,
  544. pyclipper.ET_CLOSEDPOLYGON)
  545. padded_polygon = np.array(padding.Execute(distance)[0])
  546. cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
  547. xmin = padded_polygon[:, 0].min()
  548. xmax = padded_polygon[:, 0].max()
  549. ymin = padded_polygon[:, 1].min()
  550. ymax = padded_polygon[:, 1].max()
  551. width = xmax - xmin + 1
  552. height = ymax - ymin + 1
  553. polygon[:, 0] = polygon[:, 0] - xmin
  554. polygon[:, 1] = polygon[:, 1] - ymin
  555. xs = np.broadcast_to(
  556. np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
  557. ys = np.broadcast_to(
  558. np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
  559. distance_map = np.zeros(
  560. (polygon.shape[0], height, width), dtype=np.float32)
  561. for i in range(polygon.shape[0]):
  562. j = (i + 1) % polygon.shape[0]
  563. absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
  564. distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
  565. distance_map = distance_map.min(axis=0)
  566. xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
  567. xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
  568. ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
  569. ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
  570. canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
  571. 1 - distance_map[
  572. ymin_valid - ymin:ymax_valid - ymax + height,
  573. xmin_valid - xmin:xmax_valid - xmax + width],
  574. canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
  575. def distance(self, xs, ys, point_1, point_2):
  576. '''
  577. compute the distance from point to a line
  578. ys: coordinates in the first axis
  579. xs: coordinates in the second axis
  580. point_1, point_2: (x, y), the end of the line
  581. '''
  582. height, width = xs.shape[:2]
  583. square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
  584. square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
  585. square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])
  586. cosin = (square_distance - square_distance_1 - square_distance_2) / (2 * np.sqrt(square_distance_1 * square_distance_2))
  587. square_sin = 1 - np.square(cosin)
  588. square_sin = np.nan_to_num(square_sin)
  589. result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance)
  590. result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0]
  591. return result
  592. def extend_line(self, point_1, point_2, result):
  593. ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))),
  594. int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio))))
  595. cv2.line(result, tuple(ex_point_1), tuple(point_1), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
  596. ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))),
  597. int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio))))
  598. cv2.line(result, tuple(ex_point_2), tuple(point_2), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
  599. return ex_point_1, ex_point_2