model_48px_ctc.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. import os
  2. import math
  3. import shutil
  4. import cv2
  5. from typing import List, Tuple, Optional
  6. import numpy as np
  7. import einops
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from .common import OfflineOCR
  12. from ..utils import TextBlock, Quadrilateral, AvgMeter, chunks
  13. from ..utils.bubble import is_ignore
  14. class Model48pxCTCOCR(OfflineOCR):
  15. _MODEL_MAPPING = {
  16. 'model': {
  17. 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr-ctc.zip',
  18. 'hash': 'fc61c52f7a811bc72c54f6be85df814c6b60f63585175db27cb94a08e0c30101',
  19. 'archive': {
  20. 'ocr-ctc.ckpt': '.',
  21. 'alphabet-all-v5.txt': '.',
  22. },
  23. },
  24. }
  25. def __init__(self, *args, **kwargs):
  26. os.makedirs(self.model_dir, exist_ok=True)
  27. if os.path.exists('ocr-ctc.ckpt'):
  28. shutil.move('ocr-ctc.ckpt', self._get_file_path('ocr-ctc.ckpt'))
  29. if os.path.exists('alphabet-all-v5.txt'):
  30. shutil.move('alphabet-all-v5.txt', self._get_file_path('alphabet-all-v5.txt'))
  31. super().__init__(*args, **kwargs)
  32. async def _load(self, device: str):
  33. with open(self._get_file_path('alphabet-all-v5.txt'), 'r', encoding = 'utf-8') as fp:
  34. dictionary = [s[:-1] for s in fp.readlines()]
  35. self.model: OCR = OCR(dictionary, 768)
  36. sd = torch.load(self._get_file_path('ocr-ctc.ckpt'), map_location = 'cpu')
  37. sd = sd['model'] if 'model' in sd else sd
  38. del sd['encoders.layers.0.pe.pe']
  39. del sd['encoders.layers.1.pe.pe']
  40. del sd['encoders.layers.2.pe.pe']
  41. self.model.load_state_dict(sd, strict = False)
  42. self.model.eval()
  43. self.device = device
  44. if (device == 'cuda' or device == 'mps'):
  45. self.use_gpu = True
  46. else:
  47. self.use_gpu = False
  48. if self.use_gpu:
  49. self.model = self.model.to(device)
  50. async def _unload(self):
  51. del self.model
  52. async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args: dict, verbose: bool = False) -> List[TextBlock]:
  53. text_height = 48
  54. max_chunk_size = 16
  55. ignore_bubble = args.get('ignore_bubble', 0)
  56. quadrilaterals = list(self._generate_text_direction(textlines))
  57. region_imgs = [q.get_transformed_region(image, d, text_height) for q, d in quadrilaterals]
  58. out_regions = []
  59. perm = range(len(region_imgs))
  60. is_quadrilaterals = False
  61. if len(quadrilaterals) > 0:
  62. if isinstance(quadrilaterals[0][0], Quadrilateral):
  63. is_quadrilaterals = True
  64. # Sort regions based on width
  65. perm = sorted(range(len(region_imgs)), key = lambda x: region_imgs[x].shape[1])
  66. ix = 0
  67. for indices in chunks(perm, max_chunk_size):
  68. N = len(indices)
  69. widths = [region_imgs[i].shape[1] for i in indices]
  70. max_width = (4 * (max(widths) + 7) // 4) + 128
  71. region = np.zeros((N, text_height, max_width, 3), dtype = np.uint8)
  72. for i, idx in enumerate(indices):
  73. W = region_imgs[idx].shape[1]
  74. tmp = region_imgs[idx]
  75. # Determine whether to skip the text block, and return True to skip.
  76. if ignore_bubble >=1 and ignore_bubble <=50 and is_ignore(region_imgs[idx], ignore_bubble):
  77. ix+=1
  78. continue
  79. region[i, :, : W, :]=tmp
  80. if verbose:
  81. os.makedirs('result/ocrs/', exist_ok=True)
  82. if quadrilaterals[idx][1] == 'v':
  83. cv2.imwrite(f'result/ocrs/{ix}.png', cv2.rotate(cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR), cv2.ROTATE_90_CLOCKWISE))
  84. else:
  85. cv2.imwrite(f'result/ocrs/{ix}.png', cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR))
  86. ix += 1
  87. images = (torch.from_numpy(region).float() - 127.5) / 127.5
  88. images = einops.rearrange(images, 'N H W C -> N C H W')
  89. if self.use_gpu:
  90. images = images.to(self.device)
  91. with torch.inference_mode():
  92. texts = self.model.decode(images, widths, 0, verbose = verbose)
  93. for i, single_line in enumerate(texts):
  94. if not single_line:
  95. continue
  96. cur_texts = []
  97. total_fr = AvgMeter()
  98. total_fg = AvgMeter()
  99. total_fb = AvgMeter()
  100. total_br = AvgMeter()
  101. total_bg = AvgMeter()
  102. total_bb = AvgMeter()
  103. total_logprob = AvgMeter()
  104. for (chid, logprob, fr, fg, fb, br, bg, bb) in single_line:
  105. ch = self.model.dictionary[chid]
  106. if ch == '<SP>':
  107. ch = ' '
  108. cur_texts.append(ch)
  109. total_logprob(logprob)
  110. if ch != ' ':
  111. total_fr(int(fr * 255))
  112. total_fg(int(fg * 255))
  113. total_fb(int(fb * 255))
  114. total_br(int(br * 255))
  115. total_bg(int(bg * 255))
  116. total_bb(int(bb * 255))
  117. prob = np.exp(total_logprob())
  118. if prob < 0.5:
  119. continue
  120. txt = ''.join(cur_texts)
  121. fr = int(total_fr())
  122. fg = int(total_fg())
  123. fb = int(total_fb())
  124. br = int(total_br())
  125. bg = int(total_bg())
  126. bb = int(total_bb())
  127. self.logger.info(f'prob: {prob} {txt} fg: ({fr}, {fg}, {fb}) bg: ({br}, {bg}, {bb})')
  128. cur_region = quadrilaterals[indices[i]][0]
  129. if isinstance(cur_region, Quadrilateral):
  130. cur_region.text = txt
  131. cur_region.prob = prob
  132. cur_region.fg_r = fr
  133. cur_region.fg_g = fg
  134. cur_region.fg_b = fb
  135. cur_region.bg_r = br
  136. cur_region.bg_g = bg
  137. cur_region.bg_b = bb
  138. else:
  139. cur_region.text.append(txt)
  140. cur_region.update_font_colors(np.array([fr, fg, fb]), np.array([br, bg, bb]))
  141. out_regions.append(cur_region)
  142. if is_quadrilaterals:
  143. return out_regions
  144. return textlines
  145. class PositionalEncoding(nn.Module):
  146. def __init__(self, d_model, dropout=0.1, max_len=5000):
  147. super(PositionalEncoding, self).__init__()
  148. self.dropout = nn.Dropout(p=dropout)
  149. pe = torch.zeros(max_len, d_model)
  150. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
  151. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
  152. pe[:, 0::2] = torch.sin(position * div_term)
  153. pe[:, 1::2] = torch.cos(position * div_term)
  154. pe = pe.unsqueeze(0)
  155. self.register_buffer('pe', pe)
  156. def forward(self, x, offset = 0):
  157. x = x + self.pe[:, offset: offset + x.size(1), :]
  158. return x
  159. class CustomTransformerEncoderLayer(nn.Module):
  160. r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
  161. This standard encoder layer is based on the paper "Attention Is All You Need".
  162. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
  163. Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
  164. Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
  165. in a different way during application.
  166. Args:
  167. d_model: the number of expected features in the input (required).
  168. nhead: the number of heads in the multiheadattention models (required).
  169. dim_feedforward: the dimension of the feedforward network model (default=2048).
  170. dropout: the dropout value (default=0.1).
  171. activation: the activation function of intermediate layer, relu or gelu (default=relu).
  172. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  173. batch_first: If ``True``, then the input and output tensors are provided
  174. as (batch, seq, feature). Default: ``False``.
  175. norm_first: if ``True``, layer norm is done prior to attention and feedforward
  176. operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
  177. Examples::
  178. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  179. >>> src = torch.rand(10, 32, 512)
  180. >>> out = encoder_layer(src)
  181. Alternatively, when ``batch_first`` is ``True``:
  182. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
  183. >>> src = torch.rand(32, 10, 512)
  184. >>> out = encoder_layer(src)
  185. """
  186. __constants__ = ['batch_first', 'norm_first']
  187. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="gelu",
  188. layer_norm_eps=1e-5, batch_first=False, norm_first=False,
  189. device=None, dtype=None) -> None:
  190. factory_kwargs = {'device': device, 'dtype': dtype}
  191. super(CustomTransformerEncoderLayer, self).__init__()
  192. self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
  193. **factory_kwargs)
  194. # Implementation of Feedforward model
  195. self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
  196. self.dropout = nn.Dropout(dropout)
  197. self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
  198. self.norm_first = norm_first
  199. self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
  200. self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
  201. self.dropout1 = nn.Dropout(dropout)
  202. self.dropout2 = nn.Dropout(dropout)
  203. self.pe = PositionalEncoding(d_model, max_len = 2048)
  204. self.activation = F.gelu
  205. def __setstate__(self, state):
  206. if 'activation' not in state:
  207. state['activation'] = F.relu
  208. super(CustomTransformerEncoderLayer, self).__setstate__(state)
  209. def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, is_causal = None) -> torch.Tensor:
  210. r"""Pass the input through the encoder layer.
  211. Args:
  212. src: the sequence to the encoder layer (required).
  213. src_mask: the mask for the src sequence (optional).
  214. src_key_padding_mask: the mask for the src keys per batch (optional).
  215. Shape:
  216. see the docs in Transformer class.
  217. """
  218. # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
  219. x = src
  220. if self.norm_first:
  221. x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
  222. x = x + self._ff_block(self.norm2(x))
  223. else:
  224. x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
  225. x = self.norm2(x + self._ff_block(x))
  226. return x
  227. # self-attention block
  228. def _sa_block(self, x: torch.Tensor,
  229. attn_mask: Optional[torch.Tensor], key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
  230. x = self.self_attn(self.pe(x), self.pe(x), x, # no PE for value
  231. attn_mask=attn_mask,
  232. key_padding_mask=key_padding_mask,
  233. need_weights=False)[0]
  234. return self.dropout1(x)
  235. # feed forward block
  236. def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
  237. x = self.linear2(self.dropout(self.activation(self.linear1(x))))
  238. return self.dropout2(x)
  239. class ResNet(nn.Module):
  240. def __init__(self, input_channel, output_channel, block, layers):
  241. super(ResNet, self).__init__()
  242. self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
  243. self.inplanes = int(output_channel / 8)
  244. self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 8),
  245. kernel_size=3, stride=1, padding=1, bias=False)
  246. self.bn0_1 = nn.BatchNorm2d(int(output_channel / 8))
  247. self.conv0_2 = nn.Conv2d(int(output_channel / 8), self.inplanes,
  248. kernel_size=3, stride=1, padding=1, bias=False)
  249. self.maxpool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
  250. self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
  251. self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
  252. self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
  253. 0], kernel_size=3, stride=1, padding=1, bias=False)
  254. self.maxpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
  255. self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
  256. self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
  257. self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
  258. 1], kernel_size=3, stride=1, padding=1, bias=False)
  259. self.maxpool3 = nn.AvgPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
  260. self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
  261. self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
  262. self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
  263. 2], kernel_size=3, stride=1, padding=1, bias=False)
  264. self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
  265. self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
  266. self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
  267. 3], kernel_size=3, stride=(2, 1), padding=(1, 1), bias=False)
  268. self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
  269. self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
  270. 3], kernel_size=3, stride=1, padding=0, bias=False)
  271. self.bn4_3 = nn.BatchNorm2d(self.output_channel_block[3])
  272. def _make_layer(self, block, planes, blocks, stride=1):
  273. downsample = None
  274. if stride != 1 or self.inplanes != planes * block.expansion:
  275. downsample = nn.Sequential(
  276. nn.BatchNorm2d(self.inplanes),
  277. nn.Conv2d(self.inplanes, planes * block.expansion,
  278. kernel_size=1, stride=stride, bias=False),
  279. )
  280. layers = []
  281. layers.append(block(self.inplanes, planes, stride, downsample))
  282. self.inplanes = planes * block.expansion
  283. for i in range(1, blocks):
  284. layers.append(block(self.inplanes, planes))
  285. return nn.Sequential(*layers)
  286. def forward(self, x):
  287. x = self.conv0_1(x)
  288. x = self.bn0_1(x)
  289. x = F.relu(x)
  290. x = self.conv0_2(x)
  291. x = self.maxpool1(x)
  292. x = self.layer1(x)
  293. x = self.bn1(x)
  294. x = F.relu(x)
  295. x = self.conv1(x)
  296. x = self.maxpool2(x)
  297. x = self.layer2(x)
  298. x = self.bn2(x)
  299. x = F.relu(x)
  300. x = self.conv2(x)
  301. x = self.maxpool3(x)
  302. x = self.layer3(x)
  303. x = self.bn3(x)
  304. x = F.relu(x)
  305. x = self.conv3(x)
  306. x = self.layer4(x)
  307. x = self.bn4_1(x)
  308. x = F.relu(x)
  309. x = self.conv4_1(x)
  310. x = self.bn4_2(x)
  311. x = F.relu(x)
  312. x = self.conv4_2(x)
  313. x = self.bn4_3(x)
  314. return x
  315. class BasicBlock(nn.Module):
  316. expansion = 1
  317. def __init__(self, inplanes, planes, stride=1, downsample=None):
  318. super(BasicBlock, self).__init__()
  319. self.bn1 = nn.BatchNorm2d(inplanes)
  320. self.conv1 = self._conv3x3(inplanes, planes)
  321. self.bn2 = nn.BatchNorm2d(planes)
  322. self.conv2 = self._conv3x3(planes, planes)
  323. self.downsample = downsample
  324. self.stride = stride
  325. def _conv3x3(self, in_planes, out_planes, stride=1):
  326. "3x3 convolution with padding"
  327. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  328. padding=1, bias=False)
  329. def forward(self, x):
  330. residual = x
  331. out = self.bn1(x)
  332. out = F.relu(out)
  333. out = self.conv1(out)
  334. out = self.bn2(out)
  335. out = F.relu(out)
  336. out = self.conv2(out)
  337. if self.downsample is not None:
  338. residual = self.downsample(residual)
  339. return out + residual
  340. def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  341. """3x3 convolution with padding"""
  342. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  343. padding=dilation, groups=groups, bias=False, dilation=dilation)
  344. def conv1x1(in_planes, out_planes, stride=1):
  345. """1x1 convolution"""
  346. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  347. class ResNet_FeatureExtractor(nn.Module):
  348. """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
  349. def __init__(self, input_channel, output_channel=128):
  350. super(ResNet_FeatureExtractor, self).__init__()
  351. self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [4, 6, 8, 6, 3])
  352. def forward(self, input):
  353. return self.ConvNet(input)
  354. class OCR(nn.Module):
  355. def __init__(self, dictionary, max_len):
  356. super(OCR, self).__init__()
  357. self.max_len = max_len
  358. self.dictionary = dictionary
  359. self.dict_size = len(dictionary)
  360. self.backbone = ResNet_FeatureExtractor(3, 320)
  361. enc = CustomTransformerEncoderLayer(320, 8, 320 * 4, dropout=0.05, batch_first=True, norm_first=True)
  362. self.encoders = nn.TransformerEncoder(enc, 3)
  363. self.char_pred_norm = nn.Sequential(nn.LayerNorm(320), nn.Dropout(0.1), nn.GELU())
  364. self.char_pred = nn.Linear(320, self.dict_size)
  365. self.color_pred1 = nn.Sequential(nn.Linear(320, 6))
  366. def forward(self,
  367. img: torch.FloatTensor
  368. ):
  369. feats = self.backbone(img).squeeze(2)
  370. feats = self.encoders(feats.permute(0, 2, 1))
  371. pred_char_logits = self.char_pred(self.char_pred_norm(feats))
  372. pred_color_values = self.color_pred1(feats)
  373. return pred_char_logits, pred_color_values
  374. def decode(self, img: torch.Tensor, img_widths: List[int], blank, verbose = False) -> List[List[Tuple[str, float, int, int, int, int, int, int]]]:
  375. N, C, H, W = img.shape
  376. assert H == 48 and C == 3
  377. feats = self.backbone(img).squeeze(2)
  378. feats = self.encoders(feats.permute(0, 2, 1))
  379. pred_char_logits = self.char_pred(self.char_pred_norm(feats))
  380. pred_color_values = self.color_pred1(feats)
  381. return self.decode_ctc_top1(pred_char_logits, pred_color_values, blank, verbose = verbose)
  382. def decode_ctc_top1(self, pred_char_logits, pred_color_values, blank, verbose = False) -> List[List[Tuple[str, float, int, int, int, int, int, int]]]:
  383. pred_chars: List[List[Tuple[str, float, int, int, int, int, int, int]]] = []
  384. for _ in range(pred_char_logits.size(0)):
  385. pred_chars.append([])
  386. logprobs = pred_char_logits.log_softmax(2)
  387. _, preds_index = logprobs.max(2)
  388. preds_index = preds_index.cpu()
  389. pred_color_values = pred_color_values.cpu().clamp_(0, 1)
  390. for b in range(pred_char_logits.size(0)):
  391. # if verbose:
  392. # print('------------------------------')
  393. last_ch = blank
  394. for t in range(pred_char_logits.size(1)):
  395. pred_ch = preds_index[b, t]
  396. if pred_ch != last_ch and pred_ch != blank:
  397. lp = logprobs[b, t, pred_ch].item()
  398. # if verbose:
  399. # if lp < math.log(0.9):
  400. # top5 = torch.topk(logprobs[b, t], 5)
  401. # top5_idx = top5.indices
  402. # top5_val = top5.values
  403. # r = ''
  404. # for i in range(5):
  405. # r += f'{self.dictionary[top5_idx[i]]}: {math.exp(top5_val[i])}, '
  406. # print(r)
  407. # else:
  408. # print(f'{self.dictionary[pred_ch]}: {math.exp(lp)}')
  409. pred_chars[b].append((
  410. pred_ch,
  411. lp,
  412. pred_color_values[b, t][0].item(),
  413. pred_color_values[b, t][1].item(),
  414. pred_color_values[b, t][2].item(),
  415. pred_color_values[b, t][3].item(),
  416. pred_color_values[b, t][4].item(),
  417. pred_color_values[b, t][5].item()
  418. ))
  419. last_ch = pred_ch
  420. return pred_chars
  421. def eval_ocr(self, input_lengths, target_lengths, pred_char_logits, pred_color_values, gt_char_index, gt_color_values, blank, blank1):
  422. correct_char = 0
  423. total_char = 0
  424. color_diff = 0
  425. color_diff_dom = 0
  426. _, preds_index = pred_char_logits.max(2)
  427. pred_chars = torch.zeros_like(gt_char_index).cpu()
  428. for b in range(pred_char_logits.size(0)):
  429. last_ch = blank
  430. i = 0
  431. for t in range(input_lengths[b]):
  432. pred_ch = preds_index[b, t]
  433. if pred_ch != last_ch and pred_ch != blank:
  434. total_char += 1
  435. if gt_char_index[b, i] == pred_ch:
  436. correct_char += 1
  437. if pred_ch != blank1:
  438. color_diff += ((pred_color_values[b, t] - gt_color_values[b, i]).abs().mean() * 255.0).item()
  439. color_diff_dom += 1
  440. pred_chars[b, i] = pred_ch
  441. i += 1
  442. if i >= gt_color_values.size(1) or i >= gt_char_index.size(1):
  443. break
  444. last_ch = pred_ch
  445. return correct_char / (total_char + 1), color_diff / (color_diff_dom + 1), pred_chars
  446. def test2():
  447. with open('alphabet-all-v5.txt', 'r') as fp:
  448. dictionary = [s[:-1] for s in fp.readlines()]
  449. img = torch.randn(4, 3, 48, 1536)
  450. idx = torch.zeros(4, 32).long()
  451. mask = torch.zeros(4, 32).bool()
  452. model = OCR(dictionary, 1024)
  453. pred_char_logits, pred_color_values = model(img)
  454. print(pred_char_logits.shape, pred_color_values.shape)
  455. def test_inference():
  456. with torch.no_grad():
  457. with open('../SynthText/alphabet-all-v3.txt', 'r') as fp:
  458. dictionary = [s[:-1] for s in fp.readlines()]
  459. img = torch.zeros(1, 3, 32, 128)
  460. model = OCR(dictionary, 32)
  461. m = torch.load("ocr_ar_v2-3-test.ckpt", map_location='cpu')
  462. model.load_state_dict(m['model'])
  463. model.eval()
  464. (char_probs, _, _, _, _, _, _, _), _ = model.infer_beam(img, max_seq_length = 20)
  465. _, pred_chars_index = char_probs.max(2)
  466. pred_chars_index = pred_chars_index.squeeze_(0)
  467. seq = []
  468. for chid in pred_chars_index:
  469. ch = dictionary[chid]
  470. if ch == '<SP>':
  471. ch == ' '
  472. seq.append(ch)
  473. print(''.join(seq))
  474. if __name__ == "__main__":
  475. test2()