DBNet_resnet101.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchvision.models import resnet101
  5. import DBHead
  6. import einops
  7. class ImageMultiheadSelfAttention(nn.Module):
  8. def __init__(self, planes):
  9. super(ImageMultiheadSelfAttention, self).__init__()
  10. self.attn = nn.MultiheadAttention(planes, 8)
  11. def forward(self, x):
  12. res = x
  13. n, c, h, w = x.shape
  14. x = einops.rearrange(x, 'n c h w -> (h w) n c')
  15. x = self.attn(x, x, x)[0]
  16. x = einops.rearrange(x, '(h w) n c -> n c h w', n = n, c = c, h = h, w = w)
  17. return res + x
  18. class double_conv(nn.Module):
  19. def __init__(self, in_ch, mid_ch, out_ch, stride = 1, planes = 256):
  20. super(double_conv, self).__init__()
  21. self.planes = planes
  22. # down = None
  23. # if stride > 1:
  24. # down = nn.Sequential(
  25. # nn.AvgPool2d(2, 2),
  26. # nn.Conv2d(in_ch + mid_ch, self.planes * Bottleneck.expansion, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(self.planes * Bottleneck.expansion)
  27. # )
  28. self.down = None
  29. if stride > 1:
  30. self.down = nn.AvgPool2d(2,stride=2)
  31. self.conv = nn.Sequential(
  32. nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=3, padding=1, stride = 1, bias=False),
  33. nn.BatchNorm2d(mid_ch),
  34. nn.ReLU(inplace=True),
  35. #Bottleneck(mid_ch, self.planes, stride, down, 2, 1, avd = True, norm_layer = nn.BatchNorm2d),
  36. nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride = 1, padding=1, bias=False),
  37. nn.BatchNorm2d(out_ch),
  38. nn.ReLU(inplace=True),
  39. )
  40. def forward(self, x):
  41. if self.down is not None:
  42. x = self.down(x)
  43. x = self.conv(x)
  44. return x
  45. class double_conv_up(nn.Module):
  46. def __init__(self, in_ch, mid_ch, out_ch, stride = 1, planes = 256):
  47. super(double_conv_up, self).__init__()
  48. self.planes = planes
  49. # down = None
  50. # if stride > 1:
  51. # down = nn.Sequential(
  52. # nn.AvgPool2d(2, 2),
  53. # nn.Conv2d(in_ch + mid_ch, self.planes * Bottleneck.expansion, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(self.planes * Bottleneck.expansion)
  54. # )
  55. self.down = None
  56. if stride > 1:
  57. self.down = nn.AvgPool2d(2,stride=2)
  58. self.conv = nn.Sequential(
  59. nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=3, padding=1, stride = 1, bias=False),
  60. nn.BatchNorm2d(mid_ch),
  61. nn.ReLU(inplace=True),
  62. #Bottleneck(mid_ch, self.planes, stride, down, 2, 1, avd = True, norm_layer = nn.BatchNorm2d),
  63. nn.Conv2d(mid_ch, mid_ch, kernel_size=3, stride = 1, padding=1, bias=False),
  64. nn.BatchNorm2d(mid_ch),
  65. nn.ReLU(inplace=True),
  66. nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False),
  67. nn.BatchNorm2d(out_ch),
  68. nn.ReLU(inplace=True),
  69. )
  70. def forward(self, x):
  71. if self.down is not None:
  72. x = self.down(x)
  73. x = self.conv(x)
  74. return x
  75. class TextDetection(nn.Module):
  76. def __init__(self, pretrained=None):
  77. super(TextDetection, self).__init__()
  78. self.backbone = resnet101(pretrained=True if pretrained else False)
  79. self.conv_db = DBHead.DBHead(64, 0)
  80. self.conv_mask = nn.Sequential(
  81. nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
  82. nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
  83. nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
  84. nn.Conv2d(32, 1, kernel_size=1),
  85. nn.Sigmoid()
  86. )
  87. self.down_conv1 = double_conv(0, 512, 512, 2)
  88. self.down_conv2 = double_conv(0, 512, 512, 2)
  89. self.down_conv3 = double_conv(0, 512, 512, 2)
  90. self.upconv1 = double_conv_up(0, 512, 256)
  91. self.upconv2 = double_conv_up(256, 512, 256)
  92. self.upconv3 = double_conv_up(256, 512, 256)
  93. self.upconv4 = double_conv_up(256, 512, 256, planes = 128)
  94. self.upconv5 = double_conv_up(256, 256, 128, planes = 64)
  95. self.upconv6 = double_conv_up(128, 128, 64, planes = 32)
  96. self.upconv7 = double_conv_up(64, 64, 64, planes = 16)
  97. self.proj_h4 = nn.Conv2d(64 * 4, 64, 1)
  98. self.proj_h8 = nn.Conv2d(128 * 4, 128, 1)
  99. self.proj_h16 = nn.Conv2d(256 * 4, 256, 1)
  100. self.proj_h32 = nn.Conv2d(512 * 4, 512, 1)
  101. def forward(self, x):
  102. x = self.backbone.conv1(x)
  103. x = self.backbone.bn1(x)
  104. x = self.backbone.relu(x)
  105. x = self.backbone.maxpool(x) # 64@384
  106. h4 = self.backbone.layer1(x) # 64@384
  107. h8 = self.backbone.layer2(h4) # 128@192
  108. h16 = self.backbone.layer3(h8) # 256@96
  109. h32 = self.backbone.layer4(h16) # 512@48
  110. h4 = self.proj_h4(h4)
  111. h8 = self.proj_h8(h8)
  112. h16 = self.proj_h16(h16)
  113. h32 = self.proj_h32(h32)
  114. h64 = self.down_conv1(h32) # 512@24
  115. h128 = self.down_conv2(h64) # 512@12
  116. h256 = self.down_conv3(h128) # 512@6
  117. up256 = self.upconv1(h256) # 128@12
  118. up128 = self.upconv2(torch.cat([up256, h128], dim = 1)) # 64@24
  119. up64 = self.upconv3(torch.cat([up128, h64], dim = 1)) # 128@48
  120. up32 = self.upconv4(torch.cat([up64, h32], dim = 1)) # 64@96
  121. up16 = self.upconv5(torch.cat([up32, h16], dim = 1)) # 128@192
  122. up8 = self.upconv6(torch.cat([up16, h8], dim = 1)) # 64@384
  123. up4 = self.upconv7(torch.cat([up8, h4], dim = 1)) # 64@768
  124. return self.conv_db(up8), self.conv_mask(up4)
  125. if __name__ == '__main__':
  126. device = torch.device("cuda:0")
  127. net = TextDetection().to(device)
  128. img = torch.randn(2, 3, 1024, 1024).to(device)
  129. db, seg = net(img)
  130. print(db.shape)
  131. print(seg.shape)