model_32px.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. import math
  2. from typing import List
  3. from collections import defaultdict
  4. import os
  5. import shutil
  6. import cv2
  7. import numpy as np
  8. import einops
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from .common import OfflineOCR
  13. from ..utils import TextBlock, Quadrilateral, chunks
  14. from ..utils.bubble import is_ignore
  15. class Model32pxOCR(OfflineOCR):
  16. _MODEL_MAPPING = {
  17. 'model': {
  18. 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr.zip',
  19. 'hash': '47405638b96fa2540a5ee841a4cd792f25062c09d9458a973362d40785f95d7a',
  20. 'archive': {
  21. 'ocr.ckpt': '.',
  22. 'alphabet-all-v5.txt': '.',
  23. },
  24. },
  25. }
  26. def __init__(self, *args, **kwargs):
  27. os.makedirs(self.model_dir, exist_ok=True)
  28. if os.path.exists('ocr.ckpt'):
  29. shutil.move('ocr.ckpt', self._get_file_path('ocr.ckpt'))
  30. if os.path.exists('alphabet-all-v5.txt'):
  31. shutil.move('alphabet-all-v5.txt', self._get_file_path('alphabet-all-v5.txt'))
  32. super().__init__(*args, **kwargs)
  33. async def _load(self, device: str):
  34. with open(self._get_file_path('alphabet-all-v5.txt'), 'r', encoding = 'utf-8') as fp:
  35. dictionary = [s[:-1] for s in fp.readlines()]
  36. self.model = OCR(dictionary, 768)
  37. sd = torch.load(self._get_file_path('ocr.ckpt'), map_location = 'cpu')
  38. self.model.load_state_dict(sd['model'] if 'model' in sd else sd)
  39. self.model.eval()
  40. self.device = device
  41. if (device == 'cuda' or device == 'mps'):
  42. self.use_gpu = True
  43. else:
  44. self.use_gpu = False
  45. if self.use_gpu:
  46. self.model = self.model.to(device)
  47. async def _unload(self):
  48. del self.model
  49. async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args: dict, verbose: bool = False) -> List[TextBlock]:
  50. text_height = 32
  51. max_chunk_size = 16
  52. ignore_bubble = args.get('ignore_bubble', 0)
  53. quadrilaterals = list(self._generate_text_direction(textlines))
  54. region_imgs = [q.get_transformed_region(image, d, text_height) for q, d in quadrilaterals]
  55. out_regions = []
  56. perm = range(len(region_imgs))
  57. is_quadrilaterals = False
  58. if len(quadrilaterals) > 0 and isinstance(quadrilaterals[0][0], Quadrilateral):
  59. perm = sorted(range(len(region_imgs)), key = lambda x: region_imgs[x].shape[1])
  60. is_quadrilaterals = True
  61. ix = 0
  62. for indices in chunks(perm, max_chunk_size):
  63. N = len(indices)
  64. widths = [region_imgs[i].shape[1] for i in indices]
  65. max_width = 4 * (max(widths) + 7) // 4
  66. region = np.zeros((N, text_height, max_width, 3), dtype = np.uint8)
  67. for i, idx in enumerate(indices):
  68. W = region_imgs[idx].shape[1]
  69. tmp = region_imgs[idx]
  70. # Determine whether to skip the text block, and return True to skip.
  71. if ignore_bubble >=1 and ignore_bubble <=50 and is_ignore(region_imgs[idx],ignore_bubble):
  72. ix+=1
  73. continue
  74. region[i, :, : W, :]=tmp
  75. if verbose:
  76. os.makedirs('result/ocrs/', exist_ok=True)
  77. if quadrilaterals[idx][1] == 'v':
  78. cv2.imwrite(f'result/ocrs/{ix}.png', cv2.rotate(cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR), cv2.ROTATE_90_CLOCKWISE))
  79. else:
  80. cv2.imwrite(f'result/ocrs/{ix}.png', cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR))
  81. ix += 1
  82. image_tensor = (torch.from_numpy(region).float() - 127.5) / 127.5
  83. image_tensor = einops.rearrange(image_tensor, 'N H W C -> N C H W')
  84. if self.use_gpu:
  85. image_tensor = image_tensor.to(self.device)
  86. with torch.no_grad():
  87. ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255)
  88. for i, (pred_chars_index, prob, fr, fg, fb, br, bg, bb) in enumerate(ret):
  89. if prob < 0.7:
  90. continue
  91. fr = (torch.clip(fr.view(-1), 0, 1).mean() * 255).long().item()
  92. fg = (torch.clip(fg.view(-1), 0, 1).mean() * 255).long().item()
  93. fb = (torch.clip(fb.view(-1), 0, 1).mean() * 255).long().item()
  94. br = (torch.clip(br.view(-1), 0, 1).mean() * 255).long().item()
  95. bg = (torch.clip(bg.view(-1), 0, 1).mean() * 255).long().item()
  96. bb = (torch.clip(bb.view(-1), 0, 1).mean() * 255).long().item()
  97. seq = []
  98. for chid in pred_chars_index:
  99. ch = self.model.dictionary[chid]
  100. if ch == '<S>':
  101. continue
  102. if ch == '</S>':
  103. break
  104. if ch == '<SP>':
  105. ch = ' '
  106. seq.append(ch)
  107. txt = ''.join(seq)
  108. self.logger.info(f'prob: {prob} {txt} fg: ({fr}, {fg}, {fb}) bg: ({br}, {bg}, {bb})')
  109. cur_region = quadrilaterals[indices[i]][0]
  110. if isinstance(cur_region, Quadrilateral):
  111. cur_region.text = txt
  112. cur_region.prob = prob
  113. cur_region.fg_r = fr
  114. cur_region.fg_g = fg
  115. cur_region.fg_b = fb
  116. cur_region.bg_r = br
  117. cur_region.bg_g = bg
  118. cur_region.bg_b = bb
  119. else:
  120. cur_region.text.append(txt)
  121. cur_region.update_font_colors(np.array([fr, fg, fb]), np.array([br, bg, bb]))
  122. out_regions.append(cur_region)
  123. if is_quadrilaterals:
  124. return out_regions
  125. return textlines
  126. class ResNet(nn.Module):
  127. def __init__(self, input_channel, output_channel, block, layers):
  128. super(ResNet, self).__init__()
  129. self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
  130. self.inplanes = int(output_channel / 8)
  131. self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 8),
  132. kernel_size=3, stride=1, padding=1, bias=False)
  133. self.bn0_1 = nn.BatchNorm2d(int(output_channel / 8))
  134. self.conv0_2 = nn.Conv2d(int(output_channel / 8), self.inplanes,
  135. kernel_size=3, stride=1, padding=1, bias=False)
  136. self.maxpool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
  137. self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
  138. self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
  139. self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
  140. 0], kernel_size=3, stride=1, padding=1, bias=False)
  141. self.maxpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
  142. self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
  143. self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
  144. self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
  145. 1], kernel_size=3, stride=1, padding=1, bias=False)
  146. self.maxpool3 = nn.AvgPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
  147. self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
  148. self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
  149. self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
  150. 2], kernel_size=3, stride=1, padding=1, bias=False)
  151. self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
  152. self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
  153. self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
  154. 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
  155. self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
  156. self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
  157. 3], kernel_size=2, stride=1, padding=0, bias=False)
  158. self.bn4_3 = nn.BatchNorm2d(self.output_channel_block[3])
  159. def _make_layer(self, block, planes, blocks, stride=1):
  160. downsample = None
  161. if stride != 1 or self.inplanes != planes * block.expansion:
  162. downsample = nn.Sequential(
  163. nn.BatchNorm2d(self.inplanes),
  164. nn.Conv2d(self.inplanes, planes * block.expansion,
  165. kernel_size=1, stride=stride, bias=False),
  166. )
  167. layers = []
  168. layers.append(block(self.inplanes, planes, stride, downsample))
  169. self.inplanes = planes * block.expansion
  170. for i in range(1, blocks):
  171. layers.append(block(self.inplanes, planes))
  172. return nn.Sequential(*layers)
  173. def forward(self, x):
  174. x = self.conv0_1(x)
  175. x = self.bn0_1(x)
  176. x = F.relu(x)
  177. x = self.conv0_2(x)
  178. x = self.maxpool1(x)
  179. x = self.layer1(x)
  180. x = self.bn1(x)
  181. x = F.relu(x)
  182. x = self.conv1(x)
  183. x = self.maxpool2(x)
  184. x = self.layer2(x)
  185. x = self.bn2(x)
  186. x = F.relu(x)
  187. x = self.conv2(x)
  188. x = self.maxpool3(x)
  189. x = self.layer3(x)
  190. x = self.bn3(x)
  191. x = F.relu(x)
  192. x = self.conv3(x)
  193. x = self.layer4(x)
  194. x = self.bn4_1(x)
  195. x = F.relu(x)
  196. x = self.conv4_1(x)
  197. x = self.bn4_2(x)
  198. x = F.relu(x)
  199. x = self.conv4_2(x)
  200. x = self.bn4_3(x)
  201. return x
  202. class BasicBlock(nn.Module):
  203. expansion = 1
  204. def __init__(self, inplanes, planes, stride=1, downsample=None):
  205. super(BasicBlock, self).__init__()
  206. self.bn1 = nn.BatchNorm2d(inplanes)
  207. self.conv1 = self._conv3x3(inplanes, planes)
  208. self.bn2 = nn.BatchNorm2d(planes)
  209. self.conv2 = self._conv3x3(planes, planes)
  210. self.downsample = downsample
  211. self.stride = stride
  212. def _conv3x3(self, in_planes, out_planes, stride=1):
  213. "3x3 convolution with padding"
  214. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  215. padding=1, bias=False)
  216. def forward(self, x):
  217. residual = x
  218. out = self.bn1(x)
  219. out = F.relu(out)
  220. out = self.conv1(out)
  221. out = self.bn2(out)
  222. out = F.relu(out)
  223. out = self.conv2(out)
  224. if self.downsample is not None:
  225. residual = self.downsample(residual)
  226. return out + residual
  227. def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  228. """3x3 convolution with padding"""
  229. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  230. padding=dilation, groups=groups, bias=False, dilation=dilation)
  231. def conv1x1(in_planes, out_planes, stride=1):
  232. """1x1 convolution"""
  233. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  234. class ResNet_FeatureExtractor(nn.Module):
  235. """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
  236. def __init__(self, input_channel, output_channel=128):
  237. super(ResNet_FeatureExtractor, self).__init__()
  238. self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [3, 6, 7, 5])
  239. def forward(self, input):
  240. return self.ConvNet(input)
  241. class PositionalEncoding(nn.Module):
  242. def __init__(self, d_model, dropout=0.1, max_len=5000):
  243. super(PositionalEncoding, self).__init__()
  244. self.dropout = nn.Dropout(p=dropout)
  245. pe = torch.zeros(max_len, d_model)
  246. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
  247. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
  248. pe[:, 0::2] = torch.sin(position * div_term)
  249. pe[:, 1::2] = torch.cos(position * div_term)
  250. pe = pe.unsqueeze(0).transpose(0, 1)
  251. self.register_buffer('pe', pe)
  252. def forward(self, x, offset = 0):
  253. x = x + self.pe[offset: offset + x.size(0), :]
  254. return x#self.dropout(x)
  255. def generate_square_subsequent_mask(sz):
  256. mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  257. mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  258. return mask
  259. class AddCoords(nn.Module):
  260. def __init__(self, with_r=False):
  261. super().__init__()
  262. self.with_r = with_r
  263. def forward(self, input_tensor):
  264. """
  265. Args:
  266. input_tensor: shape(batch, channel, x_dim, y_dim)
  267. """
  268. batch_size, _, x_dim, y_dim = input_tensor.size()
  269. xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
  270. yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
  271. xx_channel = xx_channel.float() / (x_dim - 1)
  272. yy_channel = yy_channel.float() / (y_dim - 1)
  273. xx_channel = xx_channel * 2 - 1
  274. yy_channel = yy_channel * 2 - 1
  275. xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
  276. yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
  277. ret = torch.cat([
  278. input_tensor,
  279. xx_channel.type_as(input_tensor),
  280. yy_channel.type_as(input_tensor)], dim=1)
  281. if self.with_r:
  282. rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))
  283. ret = torch.cat([ret, rr], dim=1)
  284. return ret
  285. class Beam:
  286. def __init__(self, char_seq = [], logprobs = []):
  287. # L
  288. if isinstance(char_seq, list):
  289. self.chars = torch.tensor(char_seq, dtype=torch.long)
  290. self.logprobs = torch.tensor(logprobs, dtype=torch.float32)
  291. else:
  292. self.chars = char_seq.clone()
  293. self.logprobs = logprobs.clone()
  294. def avg_logprob(self):
  295. return self.logprobs.mean().item()
  296. def sort_key(self):
  297. return -self.avg_logprob()
  298. def seq_end(self, end_tok):
  299. return self.chars.view(-1)[-1] == end_tok
  300. def extend(self, idx, logprob):
  301. return Beam(
  302. torch.cat([self.chars, idx.unsqueeze(0)], dim = -1),
  303. torch.cat([self.logprobs, logprob.unsqueeze(0)], dim = -1),
  304. )
  305. DECODE_BLOCK_LENGTH = 8
  306. class Hypothesis:
  307. def __init__(self, device, start_tok: int, end_tok: int, padding_tok: int, memory_idx: int, num_layers: int, embd_dim: int):
  308. self.device = device
  309. self.start_tok = start_tok
  310. self.end_tok = end_tok
  311. self.padding_tok = padding_tok
  312. self.memory_idx = memory_idx
  313. self.embd_size = embd_dim
  314. self.num_layers = num_layers
  315. # L, 1, E
  316. self.cached_activations = [torch.zeros(0, 1, self.embd_size).to(self.device)] * (num_layers + 1)
  317. self.out_idx = torch.LongTensor([start_tok]).to(self.device)
  318. self.out_logprobs = torch.FloatTensor([0]).to(self.device)
  319. self.length = 0
  320. def seq_end(self):
  321. return self.out_idx.view(-1)[-1] == self.end_tok
  322. def logprob(self):
  323. return self.out_logprobs.mean().item()
  324. def sort_key(self):
  325. return -self.logprob()
  326. def prob(self):
  327. return self.out_logprobs.mean().exp().item()
  328. def __len__(self):
  329. return self.length
  330. def extend(self, idx, logprob):
  331. ret = Hypothesis(self.device, self.start_tok, self.end_tok, self.padding_tok, self.memory_idx, self.num_layers, self.embd_size)
  332. ret.cached_activations = [item.clone() for item in self.cached_activations]
  333. ret.length = self.length + 1
  334. ret.out_idx = torch.cat([self.out_idx, torch.LongTensor([idx]).to(self.device)], dim = 0)
  335. ret.out_logprobs = torch.cat([self.out_logprobs, torch.FloatTensor([logprob]).to(self.device)], dim = 0)
  336. return ret
  337. def output(self):
  338. return self.cached_activations[-1]
  339. def next_token_batch(
  340. hyps: List[Hypothesis],
  341. memory: torch.Tensor, # S, K, E
  342. memory_mask: torch.BoolTensor,
  343. decoders: nn.TransformerDecoder,
  344. pe: PositionalEncoding,
  345. embd: nn.Embedding
  346. ):
  347. layer: nn.TransformerDecoderLayer
  348. N = len(hyps)
  349. # N
  350. last_toks = torch.stack([item.out_idx[-1] for item in hyps], dim = 0)
  351. # 1, N, E
  352. tgt: torch.FloatTensor = pe(embd(last_toks).unsqueeze_(0), offset = len(hyps[0]))
  353. # # L, N
  354. # out_idxs = torch.stack([item.out_idx for item in hyps], dim = 0).permute(1, 0)
  355. # # L, N, E
  356. # tgt2: torch.FloatTensor = pe(embd(out_idxs))
  357. # # 1, N, E
  358. # tgt_v2 = tgt2[-1, :, :].unsqueeze_(0)
  359. # print(((tgt_v1 - tgt_v2) ** 2).sum())
  360. # tgt = tgt_v2
  361. # S, N, E
  362. memory = torch.stack([memory[:, idx, :] for idx in [item.memory_idx for item in hyps]], dim = 1)
  363. for l, layer in enumerate(decoders.layers):
  364. # TODO: keys and values are recomputed every time
  365. # L - 1, N, E
  366. combined_activations = torch.cat([item.cached_activations[l] for item in hyps], dim = 1)
  367. # L, N, E
  368. combined_activations = torch.cat([combined_activations, tgt], dim = 0)
  369. for i in range(N):
  370. hyps[i].cached_activations[l] = combined_activations[:, i: i + 1, :]
  371. tgt2 = layer.self_attn(tgt, combined_activations, combined_activations)[0]
  372. tgt = tgt + layer.dropout1(tgt2)
  373. tgt = layer.norm1(tgt)
  374. tgt2 = layer.multihead_attn(tgt, memory, memory, key_padding_mask = memory_mask)[0]
  375. tgt = tgt + layer.dropout2(tgt2)
  376. tgt = layer.norm2(tgt)
  377. tgt2 = layer.linear2(layer.dropout(layer.activation(layer.linear1(tgt))))
  378. tgt = tgt + layer.dropout3(tgt2)
  379. # 1, N, E
  380. tgt = layer.norm3(tgt)
  381. #print(tgt[0, 0, 0])
  382. for i in range(N):
  383. hyps[i].cached_activations[decoders.num_layers] = torch.cat([hyps[i].cached_activations[decoders.num_layers], tgt[:, i: i + 1, :]], dim = 0)
  384. # N, E
  385. return tgt.squeeze_(0)
  386. class OCR(nn.Module):
  387. def __init__(self, dictionary, max_len):
  388. super(OCR, self).__init__()
  389. self.max_len = max_len
  390. self.dictionary = dictionary
  391. self.dict_size = len(dictionary)
  392. self.backbone = ResNet_FeatureExtractor(3, 320)
  393. encoder = nn.TransformerEncoderLayer(320, 4, dropout = 0.0)
  394. decoder = nn.TransformerDecoderLayer(320, 4, dropout = 0.0)
  395. self.encoders = nn.TransformerEncoder(encoder, 3)
  396. self.decoders = nn.TransformerDecoder(decoder, 2)
  397. self.pe = PositionalEncoding(320, max_len = max_len)
  398. self.embd = nn.Embedding(self.dict_size, 320)
  399. self.pred1 = nn.Sequential(nn.Linear(320, 320), nn.ReLU(), nn.Dropout(0.1))
  400. self.pred = nn.Linear(320, self.dict_size)
  401. self.pred.weight = self.embd.weight
  402. self.color_pred1 = nn.Sequential(nn.Linear(320, 64), nn.ReLU())
  403. self.fg_r_pred = nn.Linear(64, 1)
  404. self.fg_g_pred = nn.Linear(64, 1)
  405. self.fg_b_pred = nn.Linear(64, 1)
  406. self.bg_r_pred = nn.Linear(64, 1)
  407. self.bg_g_pred = nn.Linear(64, 1)
  408. self.bg_b_pred = nn.Linear(64, 1)
  409. def forward(self,
  410. img: torch.FloatTensor,
  411. char_idx: torch.LongTensor,
  412. mask: torch.BoolTensor,
  413. source_mask: torch.BoolTensor
  414. ):
  415. feats = self.backbone(img)
  416. feats = torch.einsum('n e h s -> s n e', feats)
  417. feats = self.pe(feats)
  418. memory = self.encoders(feats, src_key_padding_mask = source_mask)
  419. N, L = char_idx.shape
  420. char_embd = self.embd(char_idx)
  421. char_embd = torch.einsum('n t e -> t n e', char_embd)
  422. char_embd = self.pe(char_embd)
  423. casual_mask = generate_square_subsequent_mask(L).to(img.device)
  424. decoded = self.decoders(char_embd, memory, tgt_mask = casual_mask, tgt_key_padding_mask = mask, memory_key_padding_mask = source_mask)
  425. decoded = decoded.permute(1, 0, 2)
  426. pred_char_logits = self.pred(self.pred1(decoded))
  427. color_feats = self.color_pred1(decoded)
  428. return pred_char_logits, \
  429. self.fg_r_pred(color_feats), \
  430. self.fg_g_pred(color_feats), \
  431. self.fg_b_pred(color_feats), \
  432. self.bg_r_pred(color_feats), \
  433. self.bg_g_pred(color_feats), \
  434. self.bg_b_pred(color_feats)
  435. def infer_beam_batch(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384):
  436. N, C, H, W = img.shape
  437. assert H == 32 and C == 3
  438. feats = self.backbone(img)
  439. feats = torch.einsum('n e h s -> s n e', feats)
  440. valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths]
  441. input_mask = torch.zeros(N, feats.size(0), dtype = torch.bool).to(img.device)
  442. for i, l in enumerate(valid_feats_length):
  443. input_mask[i, l:] = True
  444. feats = self.pe(feats)
  445. memory = self.encoders(feats, src_key_padding_mask = input_mask)
  446. hypos = [Hypothesis(img.device, start_tok, end_tok, pad_tok, i, self.decoders.num_layers, 320) for i in range(N)]
  447. # N, E
  448. decoded = next_token_batch(hypos, memory, input_mask, self.decoders, self.pe, self.embd)
  449. # N, n_chars
  450. pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
  451. # N, k
  452. pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1)
  453. new_hypos = []
  454. finished_hypos = defaultdict(list)
  455. for i in range(N):
  456. for k in range(beams_k):
  457. new_hypos.append(hypos[i].extend(pred_chars_index[i, k], pred_chars_values[i, k]))
  458. hypos = new_hypos
  459. for _ in range(max_seq_length):
  460. # N * k, E
  461. decoded = next_token_batch(hypos, memory, torch.stack([input_mask[hyp.memory_idx] for hyp in hypos]) , self.decoders, self.pe, self.embd)
  462. # N * k, n_chars
  463. pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
  464. # N * k, k
  465. pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1)
  466. hypos_per_sample = defaultdict(list)
  467. h: Hypothesis
  468. for i, h in enumerate(hypos):
  469. for k in range(beams_k):
  470. hypos_per_sample[h.memory_idx].append(h.extend(pred_chars_index[i, k], pred_chars_values[i, k]))
  471. hypos = []
  472. # hypos_per_sample now contains N * k^2 hypos
  473. for i in hypos_per_sample.keys():
  474. cur_hypos: List[Hypothesis] = hypos_per_sample[i]
  475. cur_hypos = sorted(cur_hypos, key = lambda a: a.sort_key())[: beams_k + 1]
  476. #print(cur_hypos[0].out_idx[-1])
  477. to_added_hypos = []
  478. sample_done = False
  479. for h in cur_hypos:
  480. if h.seq_end():
  481. finished_hypos[i].append(h)
  482. if len(finished_hypos[i]) >= max_finished_hypos:
  483. sample_done = True
  484. break
  485. else:
  486. if len(to_added_hypos) < beams_k:
  487. to_added_hypos.append(h)
  488. if not sample_done:
  489. hypos.extend(to_added_hypos)
  490. if len(hypos) == 0:
  491. break
  492. # add remaining hypos to finished
  493. for i in range(N):
  494. if i not in finished_hypos:
  495. cur_hypos: List[Hypothesis] = hypos_per_sample[i]
  496. cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0]
  497. finished_hypos[i].append(cur_hypo)
  498. assert len(finished_hypos) == N
  499. result = []
  500. for i in range(N):
  501. cur_hypos = finished_hypos[i]
  502. cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0]
  503. decoded = cur_hypo.output()
  504. color_feats = self.color_pred1(decoded)
  505. fg_r, fg_g, fg_b, bg_r, bg_g, bg_b = self.fg_r_pred(color_feats), \
  506. self.fg_g_pred(color_feats), \
  507. self.fg_b_pred(color_feats), \
  508. self.bg_r_pred(color_feats), \
  509. self.bg_g_pred(color_feats), \
  510. self.bg_b_pred(color_feats)
  511. result.append((cur_hypo.out_idx, cur_hypo.prob(), fg_r, fg_g, fg_b, bg_r, bg_g, bg_b))
  512. return result
  513. def infer_beam(self, img: torch.FloatTensor, beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_seq_length = 384):
  514. N, C, H, W = img.shape
  515. assert H == 32 and N == 1 and C == 3
  516. feats = self.backbone(img)
  517. feats = torch.einsum('n e h s -> s n e', feats)
  518. feats = self.pe(feats)
  519. memory = self.encoders(feats)
  520. def run(tokens, add_start_tok = True, char_only = True):
  521. if add_start_tok:
  522. if isinstance(tokens, list):
  523. # N(=1), L
  524. tokens = torch.tensor([start_tok] + tokens, dtype = torch.long, device = img.device).unsqueeze_(0)
  525. else:
  526. # N, L
  527. tokens = torch.cat([torch.tensor([start_tok], dtype = torch.long, device = img.device), tokens], dim = -1).unsqueeze_(0)
  528. N, L = tokens.shape
  529. embd = self.embd(tokens)
  530. embd = torch.einsum('n t e -> t n e', embd)
  531. embd = self.pe(embd)
  532. casual_mask = generate_square_subsequent_mask(L).to(img.device)
  533. decoded = self.decoders(embd, memory, tgt_mask = casual_mask)
  534. decoded = decoded.permute(1, 0, 2)
  535. pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
  536. if char_only:
  537. return pred_char_logprob
  538. else:
  539. color_feats = self.color_pred1(decoded)
  540. return pred_char_logprob, \
  541. self.fg_r_pred(color_feats), \
  542. self.fg_g_pred(color_feats), \
  543. self.fg_b_pred(color_feats), \
  544. self.bg_r_pred(color_feats), \
  545. self.bg_g_pred(color_feats), \
  546. self.bg_b_pred(color_feats)
  547. # N, L, embd_size
  548. initial_char_logprob = run([])
  549. # N, L
  550. initial_pred_chars_values, initial_pred_chars_index = torch.topk(initial_char_logprob, beams_k, dim = 2)
  551. # beams_k, L
  552. initial_pred_chars_values = initial_pred_chars_values.squeeze(0).permute(1, 0)
  553. initial_pred_chars_index = initial_pred_chars_index.squeeze(0).permute(1, 0)
  554. beams = sorted([Beam(tok, logprob) for tok, logprob in zip(initial_pred_chars_index, initial_pred_chars_values)], key = lambda a: a.sort_key())
  555. for _ in range(max_seq_length):
  556. new_beams = []
  557. all_ended = True
  558. for beam in beams:
  559. if not beam.seq_end(end_tok):
  560. logprobs = run(beam.chars)
  561. pred_chars_values, pred_chars_index = torch.topk(logprobs, beams_k, dim = 2)
  562. # beams_k, L
  563. pred_chars_values = pred_chars_values.squeeze(0).permute(1, 0)
  564. pred_chars_index = pred_chars_index.squeeze(0).permute(1, 0)
  565. #print(pred_chars_index.view(-1)[-1])
  566. new_beams.extend([beam.extend(tok[-1], logprob[-1]) for tok, logprob in zip(pred_chars_index, pred_chars_values)])
  567. #new_beams.extend([Beam(tok, logprob) for tok, logprob in zip(pred_chars_index, pred_chars_values)]) # extend other top k
  568. all_ended = False
  569. else:
  570. new_beams.append(beam) # seq ended, add back to queue
  571. beams = sorted(new_beams, key = lambda a: a.sort_key())[: beams_k] # keep top k
  572. #print(beams[0].chars)
  573. if all_ended:
  574. break
  575. final_tokens = beams[0].chars[:-1]
  576. #print(beams[0].logprobs.mean().exp())
  577. return run(final_tokens, char_only = False), beams[0].logprobs.mean().exp().item()
  578. def test():
  579. with open('../SynthText/alphabet-all-v2.txt', 'r') as fp:
  580. dictionary = [s[:-1] for s in fp.readlines()]
  581. img = torch.randn(4, 3, 32, 1224)
  582. idx = torch.zeros(4, 32).long()
  583. mask = torch.zeros(4, 32).bool()
  584. model = ResNet_FeatureExtractor(3, 256)
  585. out = model(img)
  586. def test_inference():
  587. with torch.no_grad():
  588. with open('../SynthText/alphabet-all-v3.txt', 'r') as fp:
  589. dictionary = [s[:-1] for s in fp.readlines()]
  590. img = torch.zeros(1, 3, 32, 128)
  591. model = OCR(dictionary, 32)
  592. m = torch.load("ocr_ar_v2-3-test.ckpt", map_location='cpu')
  593. model.load_state_dict(m['model'])
  594. model.eval()
  595. (char_probs, _, _, _, _, _, _, _), _ = model.infer_beam(img, max_seq_length = 20)
  596. _, pred_chars_index = char_probs.max(2)
  597. pred_chars_index = pred_chars_index.squeeze_(0)
  598. seq = []
  599. for chid in pred_chars_index:
  600. ch = dictionary[chid]
  601. if ch == '<SP>':
  602. ch == ' '
  603. seq.append(ch)
  604. print(''.join(seq))
  605. if __name__ == "__main__":
  606. test()