model_manga_ocr.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import itertools
  2. import math
  3. from typing import Callable, List, Set, Optional, Tuple, Union
  4. from collections import defaultdict, Counter
  5. import os
  6. import shutil
  7. import cv2
  8. from PIL import Image
  9. import numpy as np
  10. import einops
  11. import networkx as nx
  12. from shapely.geometry import Polygon
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from manga_ocr import MangaOcr
  17. from .xpos_relative_position import XPOS
  18. from .common import OfflineOCR
  19. from .model_48px import OCR
  20. from ..textline_merge import split_text_region
  21. from ..utils import TextBlock, Quadrilateral, quadrilateral_can_merge_region, chunks
  22. from ..utils.generic import AvgMeter
  23. from ..utils.bubble import is_ignore
  24. async def merge_bboxes(bboxes: List[Quadrilateral], width: int, height: int):
  25. # step 1: divide into multiple text region candidates
  26. G = nx.Graph()
  27. for i, box in enumerate(bboxes):
  28. G.add_node(i, box=box)
  29. for ((u, ubox), (v, vbox)) in itertools.combinations(enumerate(bboxes), 2):
  30. # if quadrilateral_can_merge_region_coarse(ubox, vbox):
  31. if quadrilateral_can_merge_region(ubox, vbox, aspect_ratio_tol=1.3, font_size_ratio_tol=2,
  32. char_gap_tolerance=1, char_gap_tolerance2=3):
  33. G.add_edge(u, v)
  34. # step 2: postprocess - further split each region
  35. region_indices: List[Set[int]] = []
  36. for node_set in nx.algorithms.components.connected_components(G):
  37. region_indices.extend(split_text_region(bboxes, node_set, width, height))
  38. # step 3: return regions
  39. merge_box = []
  40. merge_idx = []
  41. for node_set in region_indices:
  42. # for node_set in nx.algorithms.components.connected_components(G):
  43. nodes = list(node_set)
  44. txtlns: List[Quadrilateral] = np.array(bboxes)[nodes]
  45. # majority vote for direction
  46. dirs = [box.direction for box in txtlns]
  47. majority_dir_top_2 = Counter(dirs).most_common(2)
  48. if len(majority_dir_top_2) == 1 :
  49. majority_dir = majority_dir_top_2[0][0]
  50. elif majority_dir_top_2[0][1] == majority_dir_top_2[1][1] : # if top 2 have the same counts
  51. max_aspect_ratio = -100
  52. for box in txtlns :
  53. if box.aspect_ratio > max_aspect_ratio :
  54. max_aspect_ratio = box.aspect_ratio
  55. majority_dir = box.direction
  56. if 1.0 / box.aspect_ratio > max_aspect_ratio :
  57. max_aspect_ratio = 1.0 / box.aspect_ratio
  58. majority_dir = box.direction
  59. else :
  60. majority_dir = majority_dir_top_2[0][0]
  61. # sort textlines
  62. if majority_dir == 'h':
  63. nodes = sorted(nodes, key=lambda x: bboxes[x].centroid[1])
  64. elif majority_dir == 'v':
  65. nodes = sorted(nodes, key=lambda x: -bboxes[x].centroid[0])
  66. txtlns = np.array(bboxes)[nodes]
  67. # yield overall bbox and sorted indices
  68. merge_box.append(txtlns)
  69. merge_idx.append(nodes)
  70. return_box = []
  71. for bbox in merge_box:
  72. if len(bbox) == 1:
  73. return_box.append(bbox[0])
  74. else:
  75. prob = [q.prob for q in bbox]
  76. prob = sum(prob)/len(prob)
  77. base_box = bbox[0]
  78. for box in bbox[1:]:
  79. min_rect = np.array(Polygon([*base_box.pts, *box.pts]).minimum_rotated_rectangle.exterior.coords[:4])
  80. base_box = Quadrilateral(min_rect, '', prob)
  81. return_box.append(base_box)
  82. return return_box, merge_idx
  83. class ModelMangaOCR(OfflineOCR):
  84. _MODEL_MAPPING = {
  85. 'model': {
  86. 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr_ar_48px.ckpt',
  87. 'hash': '29daa46d080818bb4ab239a518a88338cbccff8f901bef8c9db191a7cb97671d',
  88. },
  89. 'dict': {
  90. 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/alphabet-all-v7.txt',
  91. 'hash': 'f5722368146aa0fbcc9f4726866e4efc3203318ebb66c811d8cbbe915576538a',
  92. },
  93. }
  94. def __init__(self, *args, **kwargs):
  95. os.makedirs(self.model_dir, exist_ok=True)
  96. if os.path.exists('ocr_ar_48px.ckpt'):
  97. shutil.move('ocr_ar_48px.ckpt', self._get_file_path('ocr_ar_48px.ckpt'))
  98. if os.path.exists('alphabet-all-v7.txt'):
  99. shutil.move('alphabet-all-v7.txt', self._get_file_path('alphabet-all-v7.txt'))
  100. super().__init__(*args, **kwargs)
  101. async def _load(self, device: str):
  102. with open(self._get_file_path('alphabet-all-v7.txt'), 'r', encoding = 'utf-8') as fp:
  103. dictionary = [s[:-1] for s in fp.readlines()]
  104. self.model = OCR(dictionary, 768)
  105. self.mocr = MangaOcr()
  106. sd = torch.load(self._get_file_path('ocr_ar_48px.ckpt'))
  107. self.model.load_state_dict(sd)
  108. self.model.eval()
  109. self.device = device
  110. if (device == 'cuda' or device == 'mps'):
  111. self.use_gpu = True
  112. else:
  113. self.use_gpu = False
  114. if self.use_gpu:
  115. self.model = self.model.to(device)
  116. async def _unload(self):
  117. del self.model
  118. del self.mocr
  119. async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args: dict, verbose: bool = False, ignore_bubble: int = 0) -> List[TextBlock]:
  120. text_height = 48
  121. max_chunk_size = 16
  122. quadrilaterals = list(self._generate_text_direction(textlines))
  123. region_imgs = [q.get_transformed_region(image, d, text_height) for q, d in quadrilaterals]
  124. perm = range(len(region_imgs))
  125. is_quadrilaterals = False
  126. if len(quadrilaterals) > 0 and isinstance(quadrilaterals[0][0], Quadrilateral):
  127. perm = sorted(range(len(region_imgs)), key = lambda x: region_imgs[x].shape[1])
  128. is_quadrilaterals = True
  129. texts = {}
  130. if args['use_mocr_merge']:
  131. merged_textlines, merged_idx = await merge_bboxes(textlines, image.shape[1], image.shape[0])
  132. merged_quadrilaterals = list(self._generate_text_direction(merged_textlines))
  133. else:
  134. merged_idx = [[i] for i in range(len(region_imgs))]
  135. merged_quadrilaterals = quadrilaterals
  136. merged_region_imgs = []
  137. for q, d in merged_quadrilaterals:
  138. if d == 'h':
  139. merged_text_height = q.aabb.w
  140. merged_d = 'h'
  141. elif d == 'v':
  142. merged_text_height = q.aabb.h
  143. merged_d = 'h'
  144. merged_region_imgs.append(q.get_transformed_region(image, merged_d, merged_text_height))
  145. for idx in range(len(merged_region_imgs)):
  146. texts[idx] = self.mocr(Image.fromarray(merged_region_imgs[idx]))
  147. ix = 0
  148. out_regions = {}
  149. for indices in chunks(perm, max_chunk_size):
  150. N = len(indices)
  151. widths = [region_imgs[i].shape[1] for i in indices]
  152. max_width = 4 * (max(widths) + 7) // 4
  153. region = np.zeros((N, text_height, max_width, 3), dtype = np.uint8)
  154. idx_keys = []
  155. for i, idx in enumerate(indices):
  156. idx_keys.append(idx)
  157. W = region_imgs[idx].shape[1]
  158. tmp = region_imgs[idx]
  159. region[i, :, : W, :]=tmp
  160. if verbose:
  161. os.makedirs('result/ocrs/', exist_ok=True)
  162. if quadrilaterals[idx][1] == 'v':
  163. cv2.imwrite(f'result/ocrs/{ix}.png', cv2.rotate(cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR), cv2.ROTATE_90_CLOCKWISE))
  164. else:
  165. cv2.imwrite(f'result/ocrs/{ix}.png', cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR))
  166. ix += 1
  167. image_tensor = (torch.from_numpy(region).float() - 127.5) / 127.5
  168. image_tensor = einops.rearrange(image_tensor, 'N H W C -> N C H W')
  169. if self.use_gpu:
  170. image_tensor = image_tensor.to(self.device)
  171. with torch.no_grad():
  172. ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255)
  173. for i, (pred_chars_index, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred) in enumerate(ret):
  174. if prob < 0.2:
  175. continue
  176. has_fg = (fg_ind_pred[:, 1] > fg_ind_pred[:, 0])
  177. has_bg = (bg_ind_pred[:, 1] > bg_ind_pred[:, 0])
  178. fr = AvgMeter()
  179. fg = AvgMeter()
  180. fb = AvgMeter()
  181. br = AvgMeter()
  182. bg = AvgMeter()
  183. bb = AvgMeter()
  184. for chid, c_fg, c_bg, h_fg, h_bg in zip(pred_chars_index, fg_pred, bg_pred, has_fg, has_bg) :
  185. ch = self.model.dictionary[chid]
  186. if ch == '<S>':
  187. continue
  188. if ch == '</S>':
  189. break
  190. if h_fg.item() :
  191. fr(int(c_fg[0] * 255))
  192. fg(int(c_fg[1] * 255))
  193. fb(int(c_fg[2] * 255))
  194. if h_bg.item() :
  195. br(int(c_bg[0] * 255))
  196. bg(int(c_bg[1] * 255))
  197. bb(int(c_bg[2] * 255))
  198. else :
  199. br(int(c_fg[0] * 255))
  200. bg(int(c_fg[1] * 255))
  201. bb(int(c_fg[2] * 255))
  202. fr = min(max(int(fr()), 0), 255)
  203. fg = min(max(int(fg()), 0), 255)
  204. fb = min(max(int(fb()), 0), 255)
  205. br = min(max(int(br()), 0), 255)
  206. bg = min(max(int(bg()), 0), 255)
  207. bb = min(max(int(bb()), 0), 255)
  208. cur_region = quadrilaterals[indices[i]][0]
  209. if isinstance(cur_region, Quadrilateral):
  210. cur_region.prob = prob
  211. cur_region.fg_r = fr
  212. cur_region.fg_g = fg
  213. cur_region.fg_b = fb
  214. cur_region.bg_r = br
  215. cur_region.bg_g = bg
  216. cur_region.bg_b = bb
  217. else:
  218. cur_region.update_font_colors(np.array([fr, fg, fb]), np.array([br, bg, bb]))
  219. out_regions[idx_keys[i]] = cur_region
  220. output_regions = []
  221. for i, nodes in enumerate(merged_idx):
  222. total_logprobs = 0
  223. total_area = 0
  224. fg_r = []
  225. fg_g = []
  226. fg_b = []
  227. bg_r = []
  228. bg_g = []
  229. bg_b = []
  230. for idx in nodes:
  231. if idx not in out_regions:
  232. continue
  233. total_logprobs += np.log(out_regions[idx].prob) * out_regions[idx].area
  234. total_area += out_regions[idx].area
  235. fg_r.append(out_regions[idx].fg_r)
  236. fg_g.append(out_regions[idx].fg_g)
  237. fg_b.append(out_regions[idx].fg_b)
  238. bg_r.append(out_regions[idx].bg_r)
  239. bg_g.append(out_regions[idx].bg_g)
  240. bg_b.append(out_regions[idx].bg_b)
  241. total_logprobs /= total_area
  242. prob = np.exp(total_logprobs)
  243. fr = round(np.mean(fg_r))
  244. fg = round(np.mean(fg_g))
  245. fb = round(np.mean(fg_b))
  246. br = round(np.mean(bg_r))
  247. bg = round(np.mean(bg_g))
  248. bb = round(np.mean(bg_b))
  249. txt = texts[i]
  250. self.logger.info(f'prob: {prob} {txt} fg: ({fr}, {fg}, {fb}) bg: ({br}, {bg}, {bb})')
  251. cur_region = merged_quadrilaterals[i][0]
  252. if isinstance(cur_region, Quadrilateral):
  253. cur_region.text = txt
  254. cur_region.prob = prob
  255. cur_region.fg_r = fr
  256. cur_region.fg_g = fg
  257. cur_region.fg_b = fb
  258. cur_region.bg_r = br
  259. cur_region.bg_g = bg
  260. cur_region.bg_b = bb
  261. else:
  262. cur_region.text.append(txt)
  263. cur_region.update_font_colors(np.array([fr, fg, fb]), np.array([br, bg, bb]))
  264. output_regions.append(cur_region)
  265. if is_quadrilaterals:
  266. return output_regions
  267. return textlines