123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- import cv2
- import copy
- import torch
- import torch.nn as nn
- from .utils.yolov5_utils import fuse_conv_and_bn
- from .utils.weight_init import init_weights
- from .yolov5.yolo import load_yolov5_ckpt
- from .yolov5.common import C3, Conv
- TEXTDET_MASK = 0
- TEXTDET_DET = 1
- TEXTDET_INFERENCE = 2
- class double_conv_up_c3(nn.Module):
- def __init__(self, in_ch, mid_ch, out_ch, act=True):
- super(double_conv_up_c3, self).__init__()
- self.conv = nn.Sequential(
- C3(in_ch+mid_ch, mid_ch, act=act),
- nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True),
- )
- def forward(self, x):
- return self.conv(x)
- class double_conv_c3(nn.Module):
- def __init__(self, in_ch, out_ch, stride=1, act=True):
- super(double_conv_c3, self).__init__()
- if stride > 1:
- self.down = nn.AvgPool2d(2,stride=2) if stride > 1 else None
- self.conv = C3(in_ch, out_ch, act=act)
- def forward(self, x):
- if self.down is not None:
- x = self.down(x)
- x = self.conv(x)
- return x
- class UnetHead(nn.Module):
- def __init__(self, act=True) -> None:
- super(UnetHead, self).__init__()
- self.down_conv1 = double_conv_c3(512, 512, 2, act=act)
- self.upconv0 = double_conv_up_c3(0, 512, 256, act=act)
- self.upconv2 = double_conv_up_c3(256, 512, 256, act=act)
- self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
- self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
- self.upconv5 = double_conv_up_c3(64, 128, 64, act=act)
- self.upconv6 = nn.Sequential(
- nn.ConvTranspose2d(64, 1, kernel_size=4, stride = 2, padding=1, bias=False),
- nn.Sigmoid()
- )
- def forward(self, f160, f80, f40, f20, f3, forward_mode=TEXTDET_MASK):
- # input: 640@3
- d10 = self.down_conv1(f3) # 512@10
- u20 = self.upconv0(d10) # 256@10
- u40 = self.upconv2(torch.cat([f20, u20], dim = 1)) # 256@40
- if forward_mode == TEXTDET_DET:
- return f80, f40, u40
- else:
- u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
- u160 = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
- u320 = self.upconv5(torch.cat([f160, u160], dim = 1)) # 64@320
- mask = self.upconv6(u320)
- if forward_mode == TEXTDET_MASK:
- return mask
- else:
- return mask, [f80, f40, u40]
- def init_weight(self, init_func):
- self.apply(init_func)
- class DBHead(nn.Module):
- def __init__(self, in_channels, k = 50, shrink_with_sigmoid=True, act=True):
- super().__init__()
- self.k = k
- self.shrink_with_sigmoid = shrink_with_sigmoid
- self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
- self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
- self.conv = nn.Sequential(
- nn.Conv2d(128, in_channels, 1),
- nn.BatchNorm2d(in_channels),
- nn.ReLU(inplace=True)
- )
- self.binarize = nn.Sequential(
- nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
- nn.BatchNorm2d(in_channels // 4),
- nn.ReLU(inplace=True),
- nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
- nn.BatchNorm2d(in_channels // 4),
- nn.ReLU(inplace=True),
- nn.ConvTranspose2d(in_channels // 4, 1, 2, 2)
- )
- self.thresh = self._init_thresh(in_channels)
- def forward(self, f80, f40, u40, shrink_with_sigmoid=True, step_eval=False):
- shrink_with_sigmoid = self.shrink_with_sigmoid
- u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
- x = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
- x = self.conv(x)
- threshold_maps = self.thresh(x)
- x = self.binarize(x)
- shrink_maps = torch.sigmoid(x)
- if self.training:
- binary_maps = self.step_function(shrink_maps, threshold_maps)
- if shrink_with_sigmoid:
- return torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
- else:
- return torch.cat((shrink_maps, threshold_maps, binary_maps, x), dim=1)
- else:
- if step_eval:
- return self.step_function(shrink_maps, threshold_maps)
- else:
- return torch.cat((shrink_maps, threshold_maps), dim=1)
- def init_weight(self, init_func):
- self.apply(init_func)
- def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
- in_channels = inner_channels
- if serial:
- in_channels += 1
- self.thresh = nn.Sequential(
- nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
- nn.BatchNorm2d(inner_channels // 4),
- nn.ReLU(inplace=True),
- self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
- nn.BatchNorm2d(inner_channels // 4),
- nn.ReLU(inplace=True),
- self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
- nn.Sigmoid())
- return self.thresh
- def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
- if smooth:
- inter_out_channels = out_channels
- if out_channels == 1:
- inter_out_channels = in_channels
- module_list = [
- nn.Upsample(scale_factor=2, mode='nearest'),
- nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
- if out_channels == 1:
- module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
- return nn.Sequential(module_list)
- else:
- return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
- def step_function(self, x, y):
- return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
- class TextDetector(nn.Module):
- def __init__(self, weights, map_location='cpu', forward_mode=TEXTDET_MASK, act=True):
- super(TextDetector, self).__init__()
- yolov5s_backbone = load_yolov5_ckpt(weights=weights, map_location=map_location)
- yolov5s_backbone.eval()
- out_indices = [1, 3, 5, 7, 9]
- yolov5s_backbone.out_indices = out_indices
- yolov5s_backbone.model = yolov5s_backbone.model[:max(out_indices)+1]
- self.act = act
- self.seg_net = UnetHead(act=act)
- self.backbone = yolov5s_backbone
- self.dbnet = None
- self.forward_mode = forward_mode
- def train_mask(self):
- self.forward_mode = TEXTDET_MASK
- self.backbone.eval()
- self.seg_net.train()
- def initialize_db(self, unet_weights):
- self.dbnet = DBHead(64, act=self.act)
- self.seg_net.load_state_dict(torch.load(unet_weights, map_location='cpu')['weights'])
- self.dbnet.init_weight(init_weights)
- self.dbnet.upconv3 = copy.deepcopy(self.seg_net.upconv3)
- self.dbnet.upconv4 = copy.deepcopy(self.seg_net.upconv4)
- del self.seg_net.upconv3
- del self.seg_net.upconv4
- del self.seg_net.upconv5
- del self.seg_net.upconv6
- # del self.seg_net.conv_mask
- def train_db(self):
- self.forward_mode = TEXTDET_DET
- self.backbone.eval()
- self.seg_net.eval()
- self.dbnet.train()
- def forward(self, x):
- forward_mode = self.forward_mode
- with torch.no_grad():
- outs = self.backbone(x)
- if forward_mode == TEXTDET_MASK:
- return self.seg_net(*outs, forward_mode=forward_mode)
- elif forward_mode == TEXTDET_DET:
- with torch.no_grad():
- outs = self.seg_net(*outs, forward_mode=forward_mode)
- return self.dbnet(*outs)
- def get_base_det_models(model_path, device='cpu', half=False, act='leaky'):
- textdetector_dict = torch.load(model_path, map_location=device)
- blk_det = load_yolov5_ckpt(textdetector_dict['blk_det'], map_location=device)
- text_seg = UnetHead(act=act)
- text_seg.load_state_dict(textdetector_dict['text_seg'])
- text_det = DBHead(64, act=act)
- text_det.load_state_dict(textdetector_dict['text_det'])
- if half:
- return blk_det.eval().half(), text_seg.eval().half(), text_det.eval().half()
- return blk_det.eval().to(device), text_seg.eval().to(device), text_det.eval().to(device)
- class TextDetBase(nn.Module):
- def __init__(self, model_path, device='cpu', half=False, fuse=False, act='leaky'):
- super(TextDetBase, self).__init__()
- self.blk_det, self.text_seg, self.text_det = get_base_det_models(model_path, device, half, act=act)
- if fuse:
- self.fuse()
- def fuse(self):
- def _fuse(model):
- for m in model.modules():
- if isinstance(m, (Conv)) and hasattr(m, 'bn'):
- m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
- delattr(m, 'bn') # remove batchnorm
- m.forward = m.forward_fuse # update forward
- return model
- self.text_seg = _fuse(self.text_seg)
- self.text_det = _fuse(self.text_det)
- def forward(self, features):
- blks, features = self.blk_det(features, detect=True)
- mask, features = self.text_seg(*features, forward_mode=TEXTDET_INFERENCE)
- lines = self.text_det(*features, step_eval=False)
- return blks[0], mask, lines
- class TextDetBaseDNN:
- def __init__(self, input_size, model_path):
- self.input_size = input_size
- self.model = cv2.dnn.readNetFromONNX(model_path)
- self.uoln = self.model.getUnconnectedOutLayersNames()
- def __call__(self, im_in):
- blob = cv2.dnn.blobFromImage(im_in, scalefactor=1 / 255.0, size=(self.input_size, self.input_size))
- self.model.setInput(blob)
- blks, mask, lines_map = self.model.forward(self.uoln)
- return blks, mask, lines_map
|