unet1d.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from collections import OrderedDict
  2. import torch
  3. import torch.nn as nn
  4. class UNet1d(nn.Module):
  5. def __init__(self, in_channels=3, out_channels=1, init_features=128, multi=None):
  6. super(UNet1d, self).__init__()
  7. if multi is None:
  8. multi = [1, 2, 2, 4]
  9. features = init_features
  10. self.encoder1 = UNet1d._block(in_channels, features * multi[0], name="enc1")
  11. self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
  12. self.encoder2 = UNet1d._block(features * multi[0], features * multi[1], name="enc2")
  13. self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
  14. self.encoder3 = UNet1d._block(features * multi[1], features * multi[2], name="enc3")
  15. self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
  16. self.encoder4 = UNet1d._block(features * multi[2], features * multi[3], name="enc4")
  17. self.pool4 = nn.MaxPool1d(kernel_size=2, stride=2)
  18. self.bottleneck = UNet1d._block(features * multi[3], features * multi[3], name="bottleneck")
  19. self.upconv4 = nn.ConvTranspose1d(
  20. features * multi[3], features * multi[3], kernel_size=2, stride=2
  21. )
  22. self.decoder4 = UNet1d._block((features * multi[3]) * 2, features * multi[3], name="dec4")
  23. self.upconv3 = nn.ConvTranspose1d(
  24. features * multi[3], features * multi[2], kernel_size=2, stride=2
  25. )
  26. self.decoder3 = UNet1d._block((features * multi[2]) * 2, features * multi[2], name="dec3")
  27. self.upconv2 = nn.ConvTranspose1d(
  28. features * multi[2], features * multi[1], kernel_size=2, stride=2
  29. )
  30. self.decoder2 = UNet1d._block((features * multi[1]) * 2, features * multi[1], name="dec2")
  31. self.upconv1 = nn.ConvTranspose1d(
  32. features * multi[1], features * multi[0], kernel_size=2, stride=2
  33. )
  34. self.decoder1 = UNet1d._block(features * multi[0] * 2, features * multi[0], name="dec1")
  35. self.conv = nn.Conv1d(
  36. in_channels=features * multi[0], out_channels=out_channels, kernel_size=1
  37. )
  38. def forward(self, x, nonpadding=None):
  39. if nonpadding is None:
  40. nonpadding = torch.ones_like(x)[:, :, :1]
  41. enc1 = self.encoder1(x.transpose(1, 2)) * nonpadding.transpose(1, 2)
  42. enc2 = self.encoder2(self.pool1(enc1))
  43. enc3 = self.encoder3(self.pool2(enc2))
  44. enc4 = self.encoder4(self.pool3(enc3))
  45. bottleneck = self.bottleneck(self.pool4(enc4))
  46. dec4 = self.upconv4(bottleneck)
  47. dec4 = torch.cat((dec4, enc4), dim=1)
  48. dec4 = self.decoder4(dec4)
  49. dec3 = self.upconv3(dec4)
  50. dec3 = torch.cat((dec3, enc3), dim=1)
  51. dec3 = self.decoder3(dec3)
  52. dec2 = self.upconv2(dec3)
  53. dec2 = torch.cat((dec2, enc2), dim=1)
  54. dec2 = self.decoder2(dec2)
  55. dec1 = self.upconv1(dec2)
  56. dec1 = torch.cat((dec1, enc1), dim=1)
  57. dec1 = self.decoder1(dec1)
  58. return self.conv(dec1).transpose(1, 2) * nonpadding
  59. @staticmethod
  60. def _block(in_channels, features, name):
  61. return nn.Sequential(
  62. OrderedDict(
  63. [
  64. (
  65. name + "conv1",
  66. nn.Conv1d(
  67. in_channels=in_channels,
  68. out_channels=features,
  69. kernel_size=5,
  70. padding=2,
  71. bias=False,
  72. ),
  73. ),
  74. (name + "norm1", nn.GroupNorm(4, features)),
  75. (name + "tanh1", nn.Tanh()),
  76. (
  77. name + "conv2",
  78. nn.Conv1d(
  79. in_channels=features,
  80. out_channels=features,
  81. kernel_size=5,
  82. padding=2,
  83. bias=False,
  84. ),
  85. ),
  86. (name + "norm2", nn.GroupNorm(4, features)),
  87. (name + "tanh2", nn.Tanh()),
  88. ]
  89. )
  90. )
  91. class UNet2d(nn.Module):
  92. def __init__(self, in_channels=3, out_channels=1, init_features=32, multi=None):
  93. super(UNet2d, self).__init__()
  94. features = init_features
  95. self.encoder1 = UNet2d._block(in_channels, features * multi[0], name="enc1")
  96. self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
  97. self.encoder2 = UNet2d._block(features * multi[0], features * multi[1], name="enc2")
  98. self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
  99. self.encoder3 = UNet2d._block(features * multi[1], features * multi[2], name="enc3")
  100. self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
  101. self.encoder4 = UNet2d._block(features * multi[2], features * multi[3], name="enc4")
  102. self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
  103. self.bottleneck = UNet2d._block(features * multi[3], features * multi[3], name="bottleneck")
  104. self.upconv4 = nn.ConvTranspose2d(
  105. features * multi[3], features * multi[3], kernel_size=2, stride=2
  106. )
  107. self.decoder4 = UNet2d._block((features * multi[3]) * 2, features * multi[3], name="dec4")
  108. self.upconv3 = nn.ConvTranspose2d(
  109. features * multi[3], features * multi[2], kernel_size=2, stride=2
  110. )
  111. self.decoder3 = UNet2d._block((features * multi[2]) * 2, features * multi[2], name="dec3")
  112. self.upconv2 = nn.ConvTranspose2d(
  113. features * multi[2], features * multi[1], kernel_size=2, stride=2
  114. )
  115. self.decoder2 = UNet2d._block((features * multi[1]) * 2, features * multi[1], name="dec2")
  116. self.upconv1 = nn.ConvTranspose2d(
  117. features * multi[1], features * multi[0], kernel_size=2, stride=2
  118. )
  119. self.decoder1 = UNet2d._block(features * multi[0] * 2, features * multi[0], name="dec1")
  120. self.conv = nn.Conv2d(
  121. in_channels=features * multi[0], out_channels=out_channels, kernel_size=1
  122. )
  123. def forward(self, x):
  124. enc1 = self.encoder1(x)
  125. enc2 = self.encoder2(self.pool1(enc1))
  126. enc3 = self.encoder3(self.pool2(enc2))
  127. enc4 = self.encoder4(self.pool3(enc3))
  128. bottleneck = self.bottleneck(self.pool4(enc4))
  129. dec4 = self.upconv4(bottleneck)
  130. dec4 = torch.cat((dec4, enc4), dim=1)
  131. dec4 = self.decoder4(dec4)
  132. dec3 = self.upconv3(dec4)
  133. dec3 = torch.cat((dec3, enc3), dim=1)
  134. dec3 = self.decoder3(dec3)
  135. dec2 = self.upconv2(dec3)
  136. dec2 = torch.cat((dec2, enc2), dim=1)
  137. dec2 = self.decoder2(dec2)
  138. dec1 = self.upconv1(dec2)
  139. dec1 = torch.cat((dec1, enc1), dim=1)
  140. dec1 = self.decoder1(dec1)
  141. x = self.conv(dec1)
  142. return x
  143. @staticmethod
  144. def _block(in_channels, features, name):
  145. return nn.Sequential(
  146. OrderedDict(
  147. [
  148. (
  149. name + "conv1",
  150. nn.Conv2d(
  151. in_channels=in_channels,
  152. out_channels=features,
  153. kernel_size=3,
  154. padding=1,
  155. bias=False,
  156. ),
  157. ),
  158. (name + "norm1", nn.GroupNorm(4, features)),
  159. (name + "tanh1", nn.Tanh()),
  160. (
  161. name + "conv2",
  162. nn.Conv2d(
  163. in_channels=features,
  164. out_channels=features,
  165. kernel_size=3,
  166. padding=1,
  167. bias=False,
  168. ),
  169. ),
  170. (name + "norm2", nn.GroupNorm(4, features)),
  171. (name + "tanh2", nn.Tanh()),
  172. (name + "conv3", nn.Conv2d(
  173. in_channels=features,
  174. out_channels=features,
  175. kernel_size=1,
  176. padding=0,
  177. bias=True,
  178. )),
  179. ]
  180. )
  181. )