123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692 |
- import math
- from typing import List
- from collections import defaultdict
- import os
- import shutil
- import cv2
- import numpy as np
- import einops
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .common import OfflineOCR
- from ..utils import TextBlock, Quadrilateral, chunks
- from ..utils.bubble import is_ignore
- class Model32pxOCR(OfflineOCR):
- _MODEL_MAPPING = {
- 'model': {
- 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr.zip',
- 'hash': '47405638b96fa2540a5ee841a4cd792f25062c09d9458a973362d40785f95d7a',
- 'archive': {
- 'ocr.ckpt': '.',
- 'alphabet-all-v5.txt': '.',
- },
- },
- }
- def __init__(self, *args, **kwargs):
- os.makedirs(self.model_dir, exist_ok=True)
- if os.path.exists('ocr.ckpt'):
- shutil.move('ocr.ckpt', self._get_file_path('ocr.ckpt'))
- if os.path.exists('alphabet-all-v5.txt'):
- shutil.move('alphabet-all-v5.txt', self._get_file_path('alphabet-all-v5.txt'))
- super().__init__(*args, **kwargs)
- async def _load(self, device: str):
- with open(self._get_file_path('alphabet-all-v5.txt'), 'r', encoding = 'utf-8') as fp:
- dictionary = [s[:-1] for s in fp.readlines()]
- self.model = OCR(dictionary, 768)
- sd = torch.load(self._get_file_path('ocr.ckpt'), map_location = 'cpu')
- self.model.load_state_dict(sd['model'] if 'model' in sd else sd)
- self.model.eval()
- self.device = device
- if (device == 'cuda' or device == 'mps'):
- self.use_gpu = True
- else:
- self.use_gpu = False
- if self.use_gpu:
- self.model = self.model.to(device)
- async def _unload(self):
- del self.model
- async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args: dict, verbose: bool = False) -> List[TextBlock]:
- text_height = 32
- max_chunk_size = 16
- ignore_bubble = args.get('ignore_bubble', 0)
- quadrilaterals = list(self._generate_text_direction(textlines))
- region_imgs = [q.get_transformed_region(image, d, text_height) for q, d in quadrilaterals]
- out_regions = []
- perm = range(len(region_imgs))
- is_quadrilaterals = False
- if len(quadrilaterals) > 0 and isinstance(quadrilaterals[0][0], Quadrilateral):
- perm = sorted(range(len(region_imgs)), key = lambda x: region_imgs[x].shape[1])
- is_quadrilaterals = True
- ix = 0
- for indices in chunks(perm, max_chunk_size):
- N = len(indices)
- widths = [region_imgs[i].shape[1] for i in indices]
- max_width = 4 * (max(widths) + 7) // 4
- region = np.zeros((N, text_height, max_width, 3), dtype = np.uint8)
- for i, idx in enumerate(indices):
- W = region_imgs[idx].shape[1]
- tmp = region_imgs[idx]
- # Determine whether to skip the text block, and return True to skip.
- if ignore_bubble >=1 and ignore_bubble <=50 and is_ignore(region_imgs[idx],ignore_bubble):
- ix+=1
- continue
- region[i, :, : W, :]=tmp
- if verbose:
- os.makedirs('result/ocrs/', exist_ok=True)
- if quadrilaterals[idx][1] == 'v':
- cv2.imwrite(f'result/ocrs/{ix}.png', cv2.rotate(cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR), cv2.ROTATE_90_CLOCKWISE))
- else:
- cv2.imwrite(f'result/ocrs/{ix}.png', cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR))
- ix += 1
- image_tensor = (torch.from_numpy(region).float() - 127.5) / 127.5
- image_tensor = einops.rearrange(image_tensor, 'N H W C -> N C H W')
- if self.use_gpu:
- image_tensor = image_tensor.to(self.device)
- with torch.no_grad():
- ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255)
- for i, (pred_chars_index, prob, fr, fg, fb, br, bg, bb) in enumerate(ret):
- if prob < 0.7:
- continue
- fr = (torch.clip(fr.view(-1), 0, 1).mean() * 255).long().item()
- fg = (torch.clip(fg.view(-1), 0, 1).mean() * 255).long().item()
- fb = (torch.clip(fb.view(-1), 0, 1).mean() * 255).long().item()
- br = (torch.clip(br.view(-1), 0, 1).mean() * 255).long().item()
- bg = (torch.clip(bg.view(-1), 0, 1).mean() * 255).long().item()
- bb = (torch.clip(bb.view(-1), 0, 1).mean() * 255).long().item()
- seq = []
- for chid in pred_chars_index:
- ch = self.model.dictionary[chid]
- if ch == '<S>':
- continue
- if ch == '</S>':
- break
- if ch == '<SP>':
- ch = ' '
- seq.append(ch)
- txt = ''.join(seq)
- self.logger.info(f'prob: {prob} {txt} fg: ({fr}, {fg}, {fb}) bg: ({br}, {bg}, {bb})')
- cur_region = quadrilaterals[indices[i]][0]
- if isinstance(cur_region, Quadrilateral):
- cur_region.text = txt
- cur_region.prob = prob
- cur_region.fg_r = fr
- cur_region.fg_g = fg
- cur_region.fg_b = fb
- cur_region.bg_r = br
- cur_region.bg_g = bg
- cur_region.bg_b = bb
- else:
- cur_region.text.append(txt)
- cur_region.update_font_colors(np.array([fr, fg, fb]), np.array([br, bg, bb]))
- out_regions.append(cur_region)
- if is_quadrilaterals:
- return out_regions
- return textlines
- class ResNet(nn.Module):
- def __init__(self, input_channel, output_channel, block, layers):
- super(ResNet, self).__init__()
- self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
- self.inplanes = int(output_channel / 8)
- self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 8),
- kernel_size=3, stride=1, padding=1, bias=False)
- self.bn0_1 = nn.BatchNorm2d(int(output_channel / 8))
- self.conv0_2 = nn.Conv2d(int(output_channel / 8), self.inplanes,
- kernel_size=3, stride=1, padding=1, bias=False)
- self.maxpool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
- self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
- self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
- self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
- 0], kernel_size=3, stride=1, padding=1, bias=False)
- self.maxpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
- self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
- self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
- self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
- 1], kernel_size=3, stride=1, padding=1, bias=False)
- self.maxpool3 = nn.AvgPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
- self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
- self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
- self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
- 2], kernel_size=3, stride=1, padding=1, bias=False)
- self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
- self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
- self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
- 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
- self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
- self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
- 3], kernel_size=2, stride=1, padding=0, bias=False)
- self.bn4_3 = nn.BatchNorm2d(self.output_channel_block[3])
- def _make_layer(self, block, planes, blocks, stride=1):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- nn.BatchNorm2d(self.inplanes),
- nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
- )
- layers = []
- layers.append(block(self.inplanes, planes, stride, downsample))
- self.inplanes = planes * block.expansion
- for i in range(1, blocks):
- layers.append(block(self.inplanes, planes))
- return nn.Sequential(*layers)
- def forward(self, x):
- x = self.conv0_1(x)
- x = self.bn0_1(x)
- x = F.relu(x)
- x = self.conv0_2(x)
- x = self.maxpool1(x)
- x = self.layer1(x)
- x = self.bn1(x)
- x = F.relu(x)
- x = self.conv1(x)
- x = self.maxpool2(x)
- x = self.layer2(x)
- x = self.bn2(x)
- x = F.relu(x)
- x = self.conv2(x)
- x = self.maxpool3(x)
- x = self.layer3(x)
- x = self.bn3(x)
- x = F.relu(x)
- x = self.conv3(x)
- x = self.layer4(x)
- x = self.bn4_1(x)
- x = F.relu(x)
- x = self.conv4_1(x)
- x = self.bn4_2(x)
- x = F.relu(x)
- x = self.conv4_2(x)
- x = self.bn4_3(x)
- return x
- class BasicBlock(nn.Module):
- expansion = 1
- def __init__(self, inplanes, planes, stride=1, downsample=None):
- super(BasicBlock, self).__init__()
- self.bn1 = nn.BatchNorm2d(inplanes)
- self.conv1 = self._conv3x3(inplanes, planes)
- self.bn2 = nn.BatchNorm2d(planes)
- self.conv2 = self._conv3x3(planes, planes)
- self.downsample = downsample
- self.stride = stride
- def _conv3x3(self, in_planes, out_planes, stride=1):
- "3x3 convolution with padding"
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=1, bias=False)
- def forward(self, x):
- residual = x
- out = self.bn1(x)
- out = F.relu(out)
- out = self.conv1(out)
- out = self.bn2(out)
- out = F.relu(out)
- out = self.conv2(out)
- if self.downsample is not None:
- residual = self.downsample(residual)
- return out + residual
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
- """3x3 convolution with padding"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=dilation, groups=groups, bias=False, dilation=dilation)
- def conv1x1(in_planes, out_planes, stride=1):
- """1x1 convolution"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
- class ResNet_FeatureExtractor(nn.Module):
- """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
- def __init__(self, input_channel, output_channel=128):
- super(ResNet_FeatureExtractor, self).__init__()
- self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [3, 6, 7, 5])
- def forward(self, input):
- return self.ConvNet(input)
- class PositionalEncoding(nn.Module):
- def __init__(self, d_model, dropout=0.1, max_len=5000):
- super(PositionalEncoding, self).__init__()
- self.dropout = nn.Dropout(p=dropout)
- pe = torch.zeros(max_len, d_model)
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(0).transpose(0, 1)
- self.register_buffer('pe', pe)
- def forward(self, x, offset = 0):
- x = x + self.pe[offset: offset + x.size(0), :]
- return x#self.dropout(x)
- def generate_square_subsequent_mask(sz):
- mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
- mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
- return mask
- class AddCoords(nn.Module):
- def __init__(self, with_r=False):
- super().__init__()
- self.with_r = with_r
- def forward(self, input_tensor):
- """
- Args:
- input_tensor: shape(batch, channel, x_dim, y_dim)
- """
- batch_size, _, x_dim, y_dim = input_tensor.size()
- xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
- yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
- xx_channel = xx_channel.float() / (x_dim - 1)
- yy_channel = yy_channel.float() / (y_dim - 1)
- xx_channel = xx_channel * 2 - 1
- yy_channel = yy_channel * 2 - 1
- xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
- yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
- ret = torch.cat([
- input_tensor,
- xx_channel.type_as(input_tensor),
- yy_channel.type_as(input_tensor)], dim=1)
- if self.with_r:
- rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))
- ret = torch.cat([ret, rr], dim=1)
- return ret
- class Beam:
- def __init__(self, char_seq = [], logprobs = []):
- # L
- if isinstance(char_seq, list):
- self.chars = torch.tensor(char_seq, dtype=torch.long)
- self.logprobs = torch.tensor(logprobs, dtype=torch.float32)
- else:
- self.chars = char_seq.clone()
- self.logprobs = logprobs.clone()
- def avg_logprob(self):
- return self.logprobs.mean().item()
- def sort_key(self):
- return -self.avg_logprob()
- def seq_end(self, end_tok):
- return self.chars.view(-1)[-1] == end_tok
- def extend(self, idx, logprob):
- return Beam(
- torch.cat([self.chars, idx.unsqueeze(0)], dim = -1),
- torch.cat([self.logprobs, logprob.unsqueeze(0)], dim = -1),
- )
- DECODE_BLOCK_LENGTH = 8
- class Hypothesis:
- def __init__(self, device, start_tok: int, end_tok: int, padding_tok: int, memory_idx: int, num_layers: int, embd_dim: int):
- self.device = device
- self.start_tok = start_tok
- self.end_tok = end_tok
- self.padding_tok = padding_tok
- self.memory_idx = memory_idx
- self.embd_size = embd_dim
- self.num_layers = num_layers
- # L, 1, E
- self.cached_activations = [torch.zeros(0, 1, self.embd_size).to(self.device)] * (num_layers + 1)
- self.out_idx = torch.LongTensor([start_tok]).to(self.device)
- self.out_logprobs = torch.FloatTensor([0]).to(self.device)
- self.length = 0
- def seq_end(self):
- return self.out_idx.view(-1)[-1] == self.end_tok
- def logprob(self):
- return self.out_logprobs.mean().item()
- def sort_key(self):
- return -self.logprob()
- def prob(self):
- return self.out_logprobs.mean().exp().item()
- def __len__(self):
- return self.length
- def extend(self, idx, logprob):
- ret = Hypothesis(self.device, self.start_tok, self.end_tok, self.padding_tok, self.memory_idx, self.num_layers, self.embd_size)
- ret.cached_activations = [item.clone() for item in self.cached_activations]
- ret.length = self.length + 1
- ret.out_idx = torch.cat([self.out_idx, torch.LongTensor([idx]).to(self.device)], dim = 0)
- ret.out_logprobs = torch.cat([self.out_logprobs, torch.FloatTensor([logprob]).to(self.device)], dim = 0)
- return ret
- def output(self):
- return self.cached_activations[-1]
- def next_token_batch(
- hyps: List[Hypothesis],
- memory: torch.Tensor, # S, K, E
- memory_mask: torch.BoolTensor,
- decoders: nn.TransformerDecoder,
- pe: PositionalEncoding,
- embd: nn.Embedding
- ):
- layer: nn.TransformerDecoderLayer
- N = len(hyps)
- # N
- last_toks = torch.stack([item.out_idx[-1] for item in hyps], dim = 0)
- # 1, N, E
- tgt: torch.FloatTensor = pe(embd(last_toks).unsqueeze_(0), offset = len(hyps[0]))
- # # L, N
- # out_idxs = torch.stack([item.out_idx for item in hyps], dim = 0).permute(1, 0)
- # # L, N, E
- # tgt2: torch.FloatTensor = pe(embd(out_idxs))
- # # 1, N, E
- # tgt_v2 = tgt2[-1, :, :].unsqueeze_(0)
- # print(((tgt_v1 - tgt_v2) ** 2).sum())
- # tgt = tgt_v2
- # S, N, E
- memory = torch.stack([memory[:, idx, :] for idx in [item.memory_idx for item in hyps]], dim = 1)
- for l, layer in enumerate(decoders.layers):
- # TODO: keys and values are recomputed every time
- # L - 1, N, E
- combined_activations = torch.cat([item.cached_activations[l] for item in hyps], dim = 1)
- # L, N, E
- combined_activations = torch.cat([combined_activations, tgt], dim = 0)
- for i in range(N):
- hyps[i].cached_activations[l] = combined_activations[:, i: i + 1, :]
- tgt2 = layer.self_attn(tgt, combined_activations, combined_activations)[0]
- tgt = tgt + layer.dropout1(tgt2)
- tgt = layer.norm1(tgt)
- tgt2 = layer.multihead_attn(tgt, memory, memory, key_padding_mask = memory_mask)[0]
- tgt = tgt + layer.dropout2(tgt2)
- tgt = layer.norm2(tgt)
- tgt2 = layer.linear2(layer.dropout(layer.activation(layer.linear1(tgt))))
- tgt = tgt + layer.dropout3(tgt2)
- # 1, N, E
- tgt = layer.norm3(tgt)
- #print(tgt[0, 0, 0])
- for i in range(N):
- hyps[i].cached_activations[decoders.num_layers] = torch.cat([hyps[i].cached_activations[decoders.num_layers], tgt[:, i: i + 1, :]], dim = 0)
- # N, E
- return tgt.squeeze_(0)
- class OCR(nn.Module):
- def __init__(self, dictionary, max_len):
- super(OCR, self).__init__()
- self.max_len = max_len
- self.dictionary = dictionary
- self.dict_size = len(dictionary)
- self.backbone = ResNet_FeatureExtractor(3, 320)
- encoder = nn.TransformerEncoderLayer(320, 4, dropout = 0.0)
- decoder = nn.TransformerDecoderLayer(320, 4, dropout = 0.0)
- self.encoders = nn.TransformerEncoder(encoder, 3)
- self.decoders = nn.TransformerDecoder(decoder, 2)
- self.pe = PositionalEncoding(320, max_len = max_len)
- self.embd = nn.Embedding(self.dict_size, 320)
- self.pred1 = nn.Sequential(nn.Linear(320, 320), nn.ReLU(), nn.Dropout(0.1))
- self.pred = nn.Linear(320, self.dict_size)
- self.pred.weight = self.embd.weight
- self.color_pred1 = nn.Sequential(nn.Linear(320, 64), nn.ReLU())
- self.fg_r_pred = nn.Linear(64, 1)
- self.fg_g_pred = nn.Linear(64, 1)
- self.fg_b_pred = nn.Linear(64, 1)
- self.bg_r_pred = nn.Linear(64, 1)
- self.bg_g_pred = nn.Linear(64, 1)
- self.bg_b_pred = nn.Linear(64, 1)
- def forward(self,
- img: torch.FloatTensor,
- char_idx: torch.LongTensor,
- mask: torch.BoolTensor,
- source_mask: torch.BoolTensor
- ):
- feats = self.backbone(img)
- feats = torch.einsum('n e h s -> s n e', feats)
- feats = self.pe(feats)
- memory = self.encoders(feats, src_key_padding_mask = source_mask)
- N, L = char_idx.shape
- char_embd = self.embd(char_idx)
- char_embd = torch.einsum('n t e -> t n e', char_embd)
- char_embd = self.pe(char_embd)
- casual_mask = generate_square_subsequent_mask(L).to(img.device)
- decoded = self.decoders(char_embd, memory, tgt_mask = casual_mask, tgt_key_padding_mask = mask, memory_key_padding_mask = source_mask)
- decoded = decoded.permute(1, 0, 2)
- pred_char_logits = self.pred(self.pred1(decoded))
- color_feats = self.color_pred1(decoded)
- return pred_char_logits, \
- self.fg_r_pred(color_feats), \
- self.fg_g_pred(color_feats), \
- self.fg_b_pred(color_feats), \
- self.bg_r_pred(color_feats), \
- self.bg_g_pred(color_feats), \
- self.bg_b_pred(color_feats)
- def infer_beam_batch(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384):
- N, C, H, W = img.shape
- assert H == 32 and C == 3
- feats = self.backbone(img)
- feats = torch.einsum('n e h s -> s n e', feats)
- valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths]
- input_mask = torch.zeros(N, feats.size(0), dtype = torch.bool).to(img.device)
- for i, l in enumerate(valid_feats_length):
- input_mask[i, l:] = True
- feats = self.pe(feats)
- memory = self.encoders(feats, src_key_padding_mask = input_mask)
- hypos = [Hypothesis(img.device, start_tok, end_tok, pad_tok, i, self.decoders.num_layers, 320) for i in range(N)]
- # N, E
- decoded = next_token_batch(hypos, memory, input_mask, self.decoders, self.pe, self.embd)
- # N, n_chars
- pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
- # N, k
- pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1)
- new_hypos = []
- finished_hypos = defaultdict(list)
- for i in range(N):
- for k in range(beams_k):
- new_hypos.append(hypos[i].extend(pred_chars_index[i, k], pred_chars_values[i, k]))
- hypos = new_hypos
- for _ in range(max_seq_length):
- # N * k, E
- decoded = next_token_batch(hypos, memory, torch.stack([input_mask[hyp.memory_idx] for hyp in hypos]) , self.decoders, self.pe, self.embd)
- # N * k, n_chars
- pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
- # N * k, k
- pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1)
- hypos_per_sample = defaultdict(list)
- h: Hypothesis
- for i, h in enumerate(hypos):
- for k in range(beams_k):
- hypos_per_sample[h.memory_idx].append(h.extend(pred_chars_index[i, k], pred_chars_values[i, k]))
- hypos = []
- # hypos_per_sample now contains N * k^2 hypos
- for i in hypos_per_sample.keys():
- cur_hypos: List[Hypothesis] = hypos_per_sample[i]
- cur_hypos = sorted(cur_hypos, key = lambda a: a.sort_key())[: beams_k + 1]
- #print(cur_hypos[0].out_idx[-1])
- to_added_hypos = []
- sample_done = False
- for h in cur_hypos:
- if h.seq_end():
- finished_hypos[i].append(h)
- if len(finished_hypos[i]) >= max_finished_hypos:
- sample_done = True
- break
- else:
- if len(to_added_hypos) < beams_k:
- to_added_hypos.append(h)
- if not sample_done:
- hypos.extend(to_added_hypos)
- if len(hypos) == 0:
- break
- # add remaining hypos to finished
- for i in range(N):
- if i not in finished_hypos:
- cur_hypos: List[Hypothesis] = hypos_per_sample[i]
- cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0]
- finished_hypos[i].append(cur_hypo)
- assert len(finished_hypos) == N
- result = []
- for i in range(N):
- cur_hypos = finished_hypos[i]
- cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0]
- decoded = cur_hypo.output()
- color_feats = self.color_pred1(decoded)
- fg_r, fg_g, fg_b, bg_r, bg_g, bg_b = self.fg_r_pred(color_feats), \
- self.fg_g_pred(color_feats), \
- self.fg_b_pred(color_feats), \
- self.bg_r_pred(color_feats), \
- self.bg_g_pred(color_feats), \
- self.bg_b_pred(color_feats)
- result.append((cur_hypo.out_idx, cur_hypo.prob(), fg_r, fg_g, fg_b, bg_r, bg_g, bg_b))
- return result
- def infer_beam(self, img: torch.FloatTensor, beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_seq_length = 384):
- N, C, H, W = img.shape
- assert H == 32 and N == 1 and C == 3
- feats = self.backbone(img)
- feats = torch.einsum('n e h s -> s n e', feats)
- feats = self.pe(feats)
- memory = self.encoders(feats)
- def run(tokens, add_start_tok = True, char_only = True):
- if add_start_tok:
- if isinstance(tokens, list):
- # N(=1), L
- tokens = torch.tensor([start_tok] + tokens, dtype = torch.long, device = img.device).unsqueeze_(0)
- else:
- # N, L
- tokens = torch.cat([torch.tensor([start_tok], dtype = torch.long, device = img.device), tokens], dim = -1).unsqueeze_(0)
- N, L = tokens.shape
- embd = self.embd(tokens)
- embd = torch.einsum('n t e -> t n e', embd)
- embd = self.pe(embd)
- casual_mask = generate_square_subsequent_mask(L).to(img.device)
- decoded = self.decoders(embd, memory, tgt_mask = casual_mask)
- decoded = decoded.permute(1, 0, 2)
- pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
- if char_only:
- return pred_char_logprob
- else:
- color_feats = self.color_pred1(decoded)
- return pred_char_logprob, \
- self.fg_r_pred(color_feats), \
- self.fg_g_pred(color_feats), \
- self.fg_b_pred(color_feats), \
- self.bg_r_pred(color_feats), \
- self.bg_g_pred(color_feats), \
- self.bg_b_pred(color_feats)
- # N, L, embd_size
- initial_char_logprob = run([])
- # N, L
- initial_pred_chars_values, initial_pred_chars_index = torch.topk(initial_char_logprob, beams_k, dim = 2)
- # beams_k, L
- initial_pred_chars_values = initial_pred_chars_values.squeeze(0).permute(1, 0)
- initial_pred_chars_index = initial_pred_chars_index.squeeze(0).permute(1, 0)
- beams = sorted([Beam(tok, logprob) for tok, logprob in zip(initial_pred_chars_index, initial_pred_chars_values)], key = lambda a: a.sort_key())
- for _ in range(max_seq_length):
- new_beams = []
- all_ended = True
- for beam in beams:
- if not beam.seq_end(end_tok):
- logprobs = run(beam.chars)
- pred_chars_values, pred_chars_index = torch.topk(logprobs, beams_k, dim = 2)
- # beams_k, L
- pred_chars_values = pred_chars_values.squeeze(0).permute(1, 0)
- pred_chars_index = pred_chars_index.squeeze(0).permute(1, 0)
- #print(pred_chars_index.view(-1)[-1])
- new_beams.extend([beam.extend(tok[-1], logprob[-1]) for tok, logprob in zip(pred_chars_index, pred_chars_values)])
- #new_beams.extend([Beam(tok, logprob) for tok, logprob in zip(pred_chars_index, pred_chars_values)]) # extend other top k
- all_ended = False
- else:
- new_beams.append(beam) # seq ended, add back to queue
- beams = sorted(new_beams, key = lambda a: a.sort_key())[: beams_k] # keep top k
- #print(beams[0].chars)
- if all_ended:
- break
- final_tokens = beams[0].chars[:-1]
- #print(beams[0].logprobs.mean().exp())
- return run(final_tokens, char_only = False), beams[0].logprobs.mean().exp().item()
- def test():
- with open('../SynthText/alphabet-all-v2.txt', 'r') as fp:
- dictionary = [s[:-1] for s in fp.readlines()]
- img = torch.randn(4, 3, 32, 1224)
- idx = torch.zeros(4, 32).long()
- mask = torch.zeros(4, 32).bool()
- model = ResNet_FeatureExtractor(3, 256)
- out = model(img)
- def test_inference():
- with torch.no_grad():
- with open('../SynthText/alphabet-all-v3.txt', 'r') as fp:
- dictionary = [s[:-1] for s in fp.readlines()]
- img = torch.zeros(1, 3, 32, 128)
- model = OCR(dictionary, 32)
- m = torch.load("ocr_ar_v2-3-test.ckpt", map_location='cpu')
- model.load_state_dict(m['model'])
- model.eval()
- (char_probs, _, _, _, _, _, _, _), _ = model.infer_beam(img, max_seq_length = 20)
- _, pred_chars_index = char_probs.max(2)
- pred_chars_index = pred_chars_index.squeeze_(0)
- seq = []
- for chid in pred_chars_index:
- ch = dictionary[chid]
- if ch == '<SP>':
- ch == ' '
- seq.append(ch)
- print(''.join(seq))
- if __name__ == "__main__":
- test()
|