inpainting_attn.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. from typing import List, Optional
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import numpy as np
  6. def relu_nf(x):
  7. return F.relu(x) * 1.7139588594436646
  8. def gelu_nf(x):
  9. return F.gelu(x) * 1.7015043497085571
  10. def silu_nf(x):
  11. return F.silu(x) * 1.7881293296813965
  12. class LambdaLayer(nn.Module):
  13. def __init__(self, f):
  14. super(LambdaLayer, self).__init__()
  15. self.f = f
  16. def forward(self, x):
  17. return self.f(x)
  18. class ScaledWSConv2d(nn.Conv2d):
  19. """2D Conv layer with Scaled Weight Standardization."""
  20. def __init__(self, in_channels, out_channels, kernel_size,
  21. stride=1, padding=0,
  22. dilation=1, groups=1, bias=True, gain=True,
  23. eps=1e-4):
  24. nn.Conv2d.__init__(self, in_channels, out_channels,
  25. kernel_size, stride,
  26. padding, dilation,
  27. groups, bias)
  28. #nn.init.kaiming_normal_(self.weight)
  29. if gain:
  30. self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
  31. else:
  32. self.gain = None
  33. # Epsilon, a small constant to avoid dividing by zero.
  34. self.eps = eps
  35. def get_weight(self):
  36. # Get Scaled WS weight OIHW;
  37. fan_in = np.prod(self.weight.shape[1:])
  38. var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True)
  39. scale = torch.rsqrt(torch.max(
  40. var * fan_in, torch.tensor(self.eps).to(var.device))) * self.gain.view_as(var).to(var.device)
  41. shift = mean * scale
  42. return self.weight * scale - shift
  43. def forward(self, x):
  44. return F.conv2d(x, self.get_weight(), self.bias,
  45. self.stride, self.padding,
  46. self.dilation, self.groups)
  47. class ScaledWSTransposeConv2d(nn.ConvTranspose2d):
  48. """2D Transpose Conv layer with Scaled Weight Standardization."""
  49. def __init__(self, in_channels: int,
  50. out_channels: int,
  51. kernel_size,
  52. stride = 1,
  53. padding = 0,
  54. output_padding = 0,
  55. groups: int = 1,
  56. bias: bool = True,
  57. dilation: int = 1,
  58. gain=True,
  59. eps=1e-4):
  60. nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, 'zeros')
  61. #nn.init.kaiming_normal_(self.weight)
  62. if gain:
  63. self.gain = nn.Parameter(torch.ones(self.in_channels, 1, 1, 1))
  64. else:
  65. self.gain = None
  66. # Epsilon, a small constant to avoid dividing by zero.
  67. self.eps = eps
  68. def get_weight(self):
  69. # Get Scaled WS weight OIHW;
  70. fan_in = np.prod(self.weight.shape[1:])
  71. var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True)
  72. scale = torch.rsqrt(torch.max(
  73. var * fan_in, torch.tensor(self.eps).to(var.device))) * self.gain.view_as(var).to(var.device)
  74. shift = mean * scale
  75. return self.weight * scale - shift
  76. def forward(self, x, output_size: Optional[List[int]] = None):
  77. output_padding = self._output_padding(
  78. input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)
  79. return F.conv_transpose2d(x, self.get_weight(), self.bias, self.stride, self.padding,
  80. output_padding, self.groups, self.dilation)
  81. class GatedWSConvPadded(nn.Module):
  82. def __init__(self, in_ch, out_ch, ks, stride = 1, dilation = 1):
  83. super(GatedWSConvPadded, self).__init__()
  84. self.in_ch = in_ch
  85. self.out_ch = out_ch
  86. self.padding = nn.ReflectionPad2d((ks - 1) // 2)
  87. self.conv = ScaledWSConv2d(in_ch, out_ch, kernel_size = ks, stride = stride)
  88. self.conv_gate = ScaledWSConv2d(in_ch, out_ch, kernel_size = ks, stride = stride)
  89. def forward(self, x):
  90. x = self.padding(x)
  91. signal = self.conv(x)
  92. gate = torch.sigmoid(self.conv_gate(x))
  93. return signal * gate * 1.8
  94. class GatedWSTransposeConvPadded(nn.Module):
  95. def __init__(self, in_ch, out_ch, ks, stride = 1):
  96. super(GatedWSTransposeConvPadded, self).__init__()
  97. self.in_ch = in_ch
  98. self.out_ch = out_ch
  99. self.conv = ScaledWSTransposeConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, padding = (ks - 1) // 2)
  100. self.conv_gate = ScaledWSTransposeConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, padding = (ks - 1) // 2)
  101. def forward(self, x):
  102. signal = self.conv(x)
  103. gate = torch.sigmoid(self.conv_gate(x))
  104. return signal * gate * 1.8
  105. class ResBlock(nn.Module):
  106. def __init__(self, ch, alpha = 0.2, beta = 1.0, dilation = 1):
  107. super(ResBlock, self).__init__()
  108. self.alpha = alpha
  109. self.beta = beta
  110. self.c1 = GatedWSConvPadded(ch, ch, 3, dilation = dilation)
  111. self.c2 = GatedWSConvPadded(ch, ch, 3, dilation = dilation)
  112. def forward(self, x):
  113. skip = x
  114. x = self.c1(relu_nf(x / self.beta))
  115. x = self.c2(relu_nf(x))
  116. x = x * self.alpha
  117. return x + skip
  118. # from https://github.com/SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting
  119. class GlobalAttention(nn.Module):
  120. """ Self attention Layer"""
  121. def __init__(self, in_dim):
  122. super(GlobalAttention, self).__init__()
  123. self.channel_in = in_dim
  124. self.query_conv = ScaledWSConv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
  125. self.key_conv = ScaledWSConv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
  126. self.value_conv = ScaledWSConv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
  127. self.softmax = nn.Softmax(dim=-1) #
  128. self.rate = 1
  129. self.gamma = nn.parameter.Parameter(torch.tensor([1.0], requires_grad=True), requires_grad=True)
  130. def forward(self, a, b, c):
  131. m_batchsize, C, height, width = a.size() # B, C, H, W
  132. c = F.interpolate(c, size=(height, width), mode='nearest')
  133. proj_query = self.query_conv(a).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B, C, N -> B N C
  134. proj_key = self.key_conv(b).view(m_batchsize, -1, width * height) # B, C, N
  135. feature_similarity = torch.bmm(proj_query, proj_key) # B, N, N
  136. mask = c.view(m_batchsize, -1, width * height) # B, C, N
  137. mask = mask.repeat(1, height * width, 1).permute(0, 2, 1) # B, 1, H, W -> B, C, H, W // B
  138. feature_pruning = feature_similarity * mask
  139. attention = self.softmax(feature_pruning) # B, N, C
  140. feature_pruning = torch.bmm(self.value_conv(a).view(m_batchsize, -1, width * height),
  141. attention.permute(0, 2, 1)) # -. B, C, N
  142. out = feature_pruning.view(m_batchsize, C, height, width) # B, C, H, W
  143. out = a * c + self.gamma * (1.0 - c) * out
  144. return out
  145. class CoarseGenerator(nn.Module):
  146. def __init__(self, in_ch = 4, out_ch = 3, ch = 32, alpha = 0.2):
  147. super(CoarseGenerator, self).__init__()
  148. self.head = nn.Sequential(
  149. GatedWSConvPadded(in_ch, ch, 3, stride = 1),
  150. LambdaLayer(relu_nf),
  151. GatedWSConvPadded(ch, ch * 2, 4, stride = 2),
  152. LambdaLayer(relu_nf),
  153. GatedWSConvPadded(ch * 2, ch * 4, 4, stride = 2),
  154. )
  155. self.beta = 1.0
  156. self.alpha = alpha
  157. self.body_conv = []
  158. self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta))
  159. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  160. self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta))
  161. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  162. self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 2))
  163. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  164. self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 4))
  165. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  166. self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 8))
  167. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  168. self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 16))
  169. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  170. self.body_conv = nn.Sequential(*self.body_conv)
  171. self.tail = nn.Sequential(
  172. LambdaLayer(relu_nf),
  173. GatedWSConvPadded(ch * 8, ch * 8, 3, 1),
  174. LambdaLayer(relu_nf),
  175. GatedWSConvPadded(ch * 8, ch * 4, 3, 1),
  176. LambdaLayer(relu_nf),
  177. GatedWSConvPadded(ch * 4, ch * 4, 3, 1),
  178. LambdaLayer(relu_nf),
  179. GatedWSConvPadded(ch * 4, ch * 4, 3, 1),
  180. LambdaLayer(relu_nf),
  181. GatedWSTransposeConvPadded(ch * 4, ch * 2, 4, 2),
  182. LambdaLayer(relu_nf),
  183. GatedWSTransposeConvPadded(ch * 2, ch, 4, 2),
  184. LambdaLayer(relu_nf),
  185. GatedWSConvPadded(ch, out_ch, 3, stride = 1),
  186. )
  187. self.beta = 1.0
  188. self.body_attn_1 = ResBlock(ch * 4, self.alpha, self.beta)
  189. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  190. self.body_attn_2 = ResBlock(ch * 4, self.alpha, self.beta)
  191. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  192. self.body_attn_3 = ResBlock(ch * 4, self.alpha, self.beta)
  193. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  194. self.body_attn_4 = ResBlock(ch * 4, self.alpha, self.beta)
  195. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  196. self.body_attn_attn = GlobalAttention(in_dim = ch * 4)
  197. self.body_attn_5 = ResBlock(ch * 4, self.alpha, self.beta)
  198. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  199. self.body_attn_6 = ResBlock(ch * 4, self.alpha, self.beta)
  200. self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
  201. def forward(self, img, mask):
  202. x = torch.cat([mask, img], dim = 1)
  203. x = self.head(x)
  204. attn = self.body_attn_1(x)
  205. attn = self.body_attn_2(attn)
  206. attn = self.body_attn_3(attn)
  207. attn = self.body_attn_4(attn)
  208. attn = self.body_attn_attn(attn, attn, mask)
  209. attn = self.body_attn_5(attn)
  210. attn = self.body_attn_6(attn)
  211. conv = self.body_conv(x)
  212. x = self.tail(torch.cat([conv, attn], dim = 1))
  213. return torch.clip(x, -1, 1)
  214. class InpaintingVanilla(nn.Module):
  215. def __init__(self):
  216. super(InpaintingVanilla, self).__init__()
  217. self.coarse_generator = CoarseGenerator(4, 3, 32)
  218. def forward(self, x, mask):
  219. x_stage1 = self.coarse_generator(x, mask)
  220. return x_stage1