basemodel.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import cv2
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. from .utils.yolov5_utils import fuse_conv_and_bn
  6. from .utils.weight_init import init_weights
  7. from .yolov5.yolo import load_yolov5_ckpt
  8. from .yolov5.common import C3, Conv
  9. TEXTDET_MASK = 0
  10. TEXTDET_DET = 1
  11. TEXTDET_INFERENCE = 2
  12. class double_conv_up_c3(nn.Module):
  13. def __init__(self, in_ch, mid_ch, out_ch, act=True):
  14. super(double_conv_up_c3, self).__init__()
  15. self.conv = nn.Sequential(
  16. C3(in_ch+mid_ch, mid_ch, act=act),
  17. nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False),
  18. nn.BatchNorm2d(out_ch),
  19. nn.ReLU(inplace=True),
  20. )
  21. def forward(self, x):
  22. return self.conv(x)
  23. class double_conv_c3(nn.Module):
  24. def __init__(self, in_ch, out_ch, stride=1, act=True):
  25. super(double_conv_c3, self).__init__()
  26. if stride > 1:
  27. self.down = nn.AvgPool2d(2,stride=2) if stride > 1 else None
  28. self.conv = C3(in_ch, out_ch, act=act)
  29. def forward(self, x):
  30. if self.down is not None:
  31. x = self.down(x)
  32. x = self.conv(x)
  33. return x
  34. class UnetHead(nn.Module):
  35. def __init__(self, act=True) -> None:
  36. super(UnetHead, self).__init__()
  37. self.down_conv1 = double_conv_c3(512, 512, 2, act=act)
  38. self.upconv0 = double_conv_up_c3(0, 512, 256, act=act)
  39. self.upconv2 = double_conv_up_c3(256, 512, 256, act=act)
  40. self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
  41. self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
  42. self.upconv5 = double_conv_up_c3(64, 128, 64, act=act)
  43. self.upconv6 = nn.Sequential(
  44. nn.ConvTranspose2d(64, 1, kernel_size=4, stride = 2, padding=1, bias=False),
  45. nn.Sigmoid()
  46. )
  47. def forward(self, f160, f80, f40, f20, f3, forward_mode=TEXTDET_MASK):
  48. # input: 640@3
  49. d10 = self.down_conv1(f3) # 512@10
  50. u20 = self.upconv0(d10) # 256@10
  51. u40 = self.upconv2(torch.cat([f20, u20], dim = 1)) # 256@40
  52. if forward_mode == TEXTDET_DET:
  53. return f80, f40, u40
  54. else:
  55. u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
  56. u160 = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
  57. u320 = self.upconv5(torch.cat([f160, u160], dim = 1)) # 64@320
  58. mask = self.upconv6(u320)
  59. if forward_mode == TEXTDET_MASK:
  60. return mask
  61. else:
  62. return mask, [f80, f40, u40]
  63. def init_weight(self, init_func):
  64. self.apply(init_func)
  65. class DBHead(nn.Module):
  66. def __init__(self, in_channels, k = 50, shrink_with_sigmoid=True, act=True):
  67. super().__init__()
  68. self.k = k
  69. self.shrink_with_sigmoid = shrink_with_sigmoid
  70. self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
  71. self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
  72. self.conv = nn.Sequential(
  73. nn.Conv2d(128, in_channels, 1),
  74. nn.BatchNorm2d(in_channels),
  75. nn.ReLU(inplace=True)
  76. )
  77. self.binarize = nn.Sequential(
  78. nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
  79. nn.BatchNorm2d(in_channels // 4),
  80. nn.ReLU(inplace=True),
  81. nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
  82. nn.BatchNorm2d(in_channels // 4),
  83. nn.ReLU(inplace=True),
  84. nn.ConvTranspose2d(in_channels // 4, 1, 2, 2)
  85. )
  86. self.thresh = self._init_thresh(in_channels)
  87. def forward(self, f80, f40, u40, shrink_with_sigmoid=True, step_eval=False):
  88. shrink_with_sigmoid = self.shrink_with_sigmoid
  89. u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
  90. x = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
  91. x = self.conv(x)
  92. threshold_maps = self.thresh(x)
  93. x = self.binarize(x)
  94. shrink_maps = torch.sigmoid(x)
  95. if self.training:
  96. binary_maps = self.step_function(shrink_maps, threshold_maps)
  97. if shrink_with_sigmoid:
  98. return torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
  99. else:
  100. return torch.cat((shrink_maps, threshold_maps, binary_maps, x), dim=1)
  101. else:
  102. if step_eval:
  103. return self.step_function(shrink_maps, threshold_maps)
  104. else:
  105. return torch.cat((shrink_maps, threshold_maps), dim=1)
  106. def init_weight(self, init_func):
  107. self.apply(init_func)
  108. def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
  109. in_channels = inner_channels
  110. if serial:
  111. in_channels += 1
  112. self.thresh = nn.Sequential(
  113. nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
  114. nn.BatchNorm2d(inner_channels // 4),
  115. nn.ReLU(inplace=True),
  116. self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
  117. nn.BatchNorm2d(inner_channels // 4),
  118. nn.ReLU(inplace=True),
  119. self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
  120. nn.Sigmoid())
  121. return self.thresh
  122. def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
  123. if smooth:
  124. inter_out_channels = out_channels
  125. if out_channels == 1:
  126. inter_out_channels = in_channels
  127. module_list = [
  128. nn.Upsample(scale_factor=2, mode='nearest'),
  129. nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
  130. if out_channels == 1:
  131. module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
  132. return nn.Sequential(module_list)
  133. else:
  134. return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
  135. def step_function(self, x, y):
  136. return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
  137. class TextDetector(nn.Module):
  138. def __init__(self, weights, map_location='cpu', forward_mode=TEXTDET_MASK, act=True):
  139. super(TextDetector, self).__init__()
  140. yolov5s_backbone = load_yolov5_ckpt(weights=weights, map_location=map_location)
  141. yolov5s_backbone.eval()
  142. out_indices = [1, 3, 5, 7, 9]
  143. yolov5s_backbone.out_indices = out_indices
  144. yolov5s_backbone.model = yolov5s_backbone.model[:max(out_indices)+1]
  145. self.act = act
  146. self.seg_net = UnetHead(act=act)
  147. self.backbone = yolov5s_backbone
  148. self.dbnet = None
  149. self.forward_mode = forward_mode
  150. def train_mask(self):
  151. self.forward_mode = TEXTDET_MASK
  152. self.backbone.eval()
  153. self.seg_net.train()
  154. def initialize_db(self, unet_weights):
  155. self.dbnet = DBHead(64, act=self.act)
  156. self.seg_net.load_state_dict(torch.load(unet_weights, map_location='cpu')['weights'])
  157. self.dbnet.init_weight(init_weights)
  158. self.dbnet.upconv3 = copy.deepcopy(self.seg_net.upconv3)
  159. self.dbnet.upconv4 = copy.deepcopy(self.seg_net.upconv4)
  160. del self.seg_net.upconv3
  161. del self.seg_net.upconv4
  162. del self.seg_net.upconv5
  163. del self.seg_net.upconv6
  164. # del self.seg_net.conv_mask
  165. def train_db(self):
  166. self.forward_mode = TEXTDET_DET
  167. self.backbone.eval()
  168. self.seg_net.eval()
  169. self.dbnet.train()
  170. def forward(self, x):
  171. forward_mode = self.forward_mode
  172. with torch.no_grad():
  173. outs = self.backbone(x)
  174. if forward_mode == TEXTDET_MASK:
  175. return self.seg_net(*outs, forward_mode=forward_mode)
  176. elif forward_mode == TEXTDET_DET:
  177. with torch.no_grad():
  178. outs = self.seg_net(*outs, forward_mode=forward_mode)
  179. return self.dbnet(*outs)
  180. def get_base_det_models(model_path, device='cpu', half=False, act='leaky'):
  181. textdetector_dict = torch.load(model_path, map_location=device)
  182. blk_det = load_yolov5_ckpt(textdetector_dict['blk_det'], map_location=device)
  183. text_seg = UnetHead(act=act)
  184. text_seg.load_state_dict(textdetector_dict['text_seg'])
  185. text_det = DBHead(64, act=act)
  186. text_det.load_state_dict(textdetector_dict['text_det'])
  187. if half:
  188. return blk_det.eval().half(), text_seg.eval().half(), text_det.eval().half()
  189. return blk_det.eval().to(device), text_seg.eval().to(device), text_det.eval().to(device)
  190. class TextDetBase(nn.Module):
  191. def __init__(self, model_path, device='cpu', half=False, fuse=False, act='leaky'):
  192. super(TextDetBase, self).__init__()
  193. self.blk_det, self.text_seg, self.text_det = get_base_det_models(model_path, device, half, act=act)
  194. if fuse:
  195. self.fuse()
  196. def fuse(self):
  197. def _fuse(model):
  198. for m in model.modules():
  199. if isinstance(m, (Conv)) and hasattr(m, 'bn'):
  200. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  201. delattr(m, 'bn') # remove batchnorm
  202. m.forward = m.forward_fuse # update forward
  203. return model
  204. self.text_seg = _fuse(self.text_seg)
  205. self.text_det = _fuse(self.text_det)
  206. def forward(self, features):
  207. blks, features = self.blk_det(features, detect=True)
  208. mask, features = self.text_seg(*features, forward_mode=TEXTDET_INFERENCE)
  209. lines = self.text_det(*features, step_eval=False)
  210. return blks[0], mask, lines
  211. class TextDetBaseDNN:
  212. def __init__(self, input_size, model_path):
  213. self.input_size = input_size
  214. self.model = cv2.dnn.readNetFromONNX(model_path)
  215. self.uoln = self.model.getUnconnectedOutLayersNames()
  216. def __call__(self, im_in):
  217. blob = cv2.dnn.blobFromImage(im_in, scalefactor=1 / 255.0, size=(self.input_size, self.input_size))
  218. self.model.setInput(blob)
  219. blks, mask, lines_map = self.model.forward(self.uoln)
  220. return blks, mask, lines_map