inpainting_aot.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. from typing import List, Optional
  2. import numpy as np
  3. import os
  4. import shutil
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from .inpainting_lama_mpe import LamaMPEInpainter
  9. class AotInpainter(LamaMPEInpainter):
  10. _MODEL_MAPPING = {
  11. 'model': {
  12. 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/inpainting.ckpt',
  13. 'hash': '878d541c68648969bc1b042a6e997f3a58e49b6c07c5636ad55130736977149f',
  14. 'file': '.',
  15. },
  16. }
  17. def __init__(self, *args, **kwargs):
  18. os.makedirs(self.model_dir, exist_ok=True)
  19. if os.path.exists('inpainting.ckpt'):
  20. shutil.move('inpainting.ckpt', self._get_file_path('inpainting.ckpt'))
  21. super().__init__(*args, **kwargs)
  22. async def _load(self, device: str):
  23. self.model = AOTGenerator()
  24. sd = torch.load(self._get_file_path('inpainting.ckpt'), map_location='cpu')
  25. self.model.load_state_dict(sd['model'] if 'model' in sd else sd)
  26. self.model.eval()
  27. self.device = device
  28. if device.startswith('cuda') or device == 'mps':
  29. self.model.to(device)
  30. def relu_nf(x):
  31. return F.relu(x) * 1.7139588594436646
  32. def gelu_nf(x):
  33. return F.gelu(x) * 1.7015043497085571
  34. def silu_nf(x):
  35. return F.silu(x) * 1.7881293296813965
  36. class LambdaLayer(nn.Module):
  37. def __init__(self, f):
  38. super(LambdaLayer, self).__init__()
  39. self.f = f
  40. def forward(self, x):
  41. return self.f(x)
  42. class ScaledWSConv2d(nn.Conv2d):
  43. """2D Conv layer with Scaled Weight Standardization."""
  44. def __init__(self, in_channels, out_channels, kernel_size,
  45. stride=1, padding=0,
  46. dilation=1, groups=1, bias=True, gain=True,
  47. eps=1e-4):
  48. nn.Conv2d.__init__(self, in_channels, out_channels,
  49. kernel_size, stride,
  50. padding, dilation,
  51. groups, bias)
  52. #nn.init.kaiming_normal_(self.weight)
  53. if gain:
  54. self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
  55. else:
  56. self.gain = None
  57. # Epsilon, a small constant to avoid dividing by zero.
  58. self.eps = eps
  59. def get_weight(self):
  60. # Get Scaled WS weight OIHW;
  61. fan_in = np.prod(self.weight.shape[1:])
  62. var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True)
  63. scale = torch.rsqrt(torch.max(
  64. var * fan_in, torch.tensor(self.eps).to(var.device))) * self.gain.view_as(var).to(var.device)
  65. shift = mean * scale
  66. return self.weight * scale - shift
  67. def forward(self, x):
  68. return F.conv2d(x, self.get_weight(), self.bias,
  69. self.stride, self.padding,
  70. self.dilation, self.groups)
  71. class ScaledWSTransposeConv2d(nn.ConvTranspose2d):
  72. """2D Transpose Conv layer with Scaled Weight Standardization."""
  73. def __init__(self, in_channels: int,
  74. out_channels: int,
  75. kernel_size,
  76. stride = 1,
  77. padding = 0,
  78. output_padding = 0,
  79. groups: int = 1,
  80. bias: bool = True,
  81. dilation: int = 1,
  82. gain=True,
  83. eps=1e-4):
  84. nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, 'zeros')
  85. #nn.init.kaiming_normal_(self.weight)
  86. if gain:
  87. self.gain = nn.Parameter(torch.ones(self.in_channels, 1, 1, 1))
  88. else:
  89. self.gain = None
  90. # Epsilon, a small constant to avoid dividing by zero.
  91. self.eps = eps
  92. def get_weight(self):
  93. # Get Scaled WS weight OIHW;
  94. fan_in = np.prod(self.weight.shape[1:])
  95. var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True)
  96. scale = torch.rsqrt(torch.max(
  97. var * fan_in, torch.tensor(self.eps).to(var.device))) * self.gain.view_as(var).to(var.device)
  98. shift = mean * scale
  99. return self.weight * scale - shift
  100. def forward(self, x, output_size: Optional[List[int]] = None):
  101. output_padding = self._output_padding(
  102. input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)
  103. return F.conv_transpose2d(x, self.get_weight(), self.bias, self.stride, self.padding,
  104. output_padding, self.groups, self.dilation)
  105. class GatedWSConvPadded(nn.Module):
  106. def __init__(self, in_ch, out_ch, ks, stride = 1, dilation = 1):
  107. super(GatedWSConvPadded, self).__init__()
  108. self.in_ch = in_ch
  109. self.out_ch = out_ch
  110. self.padding = nn.ReflectionPad2d(((ks - 1) * dilation) // 2)
  111. self.conv = ScaledWSConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, dilation = dilation)
  112. self.conv_gate = ScaledWSConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, dilation = dilation)
  113. def forward(self, x):
  114. x = self.padding(x)
  115. signal = self.conv(x)
  116. gate = torch.sigmoid(self.conv_gate(x))
  117. return signal * gate * 1.8
  118. class GatedWSTransposeConvPadded(nn.Module):
  119. def __init__(self, in_ch, out_ch, ks, stride = 1):
  120. super(GatedWSTransposeConvPadded, self).__init__()
  121. self.in_ch = in_ch
  122. self.out_ch = out_ch
  123. self.conv = ScaledWSTransposeConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, padding = (ks - 1) // 2)
  124. self.conv_gate = ScaledWSTransposeConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, padding = (ks - 1) // 2)
  125. def forward(self, x):
  126. signal = self.conv(x)
  127. gate = torch.sigmoid(self.conv_gate(x))
  128. return signal * gate * 1.8
  129. class ResBlock(nn.Module):
  130. def __init__(self, ch, alpha = 0.2, beta = 1.0, dilation = 1):
  131. super(ResBlock, self).__init__()
  132. self.alpha = alpha
  133. self.beta = beta
  134. self.c1 = GatedWSConvPadded(ch, ch, 3, dilation = dilation)
  135. self.c2 = GatedWSConvPadded(ch, ch, 3, dilation = dilation)
  136. def forward(self, x):
  137. skip = x
  138. x = self.c1(relu_nf(x / self.beta))
  139. x = self.c2(relu_nf(x))
  140. x = x * self.alpha
  141. return x + skip
  142. def my_layer_norm(feat):
  143. mean = feat.mean((2, 3), keepdim=True)
  144. std = feat.std((2, 3), keepdim=True) + 1e-9
  145. feat = 2 * (feat - mean) / std - 1
  146. feat = 5 * feat
  147. return feat
  148. class AOTBlock(nn.Module):
  149. def __init__(self, dim, rates = [2, 4, 8, 16]):
  150. super(AOTBlock, self).__init__()
  151. self.rates = rates
  152. for i, rate in enumerate(rates):
  153. self.__setattr__(
  154. 'block{}'.format(str(i).zfill(2)),
  155. nn.Sequential(
  156. nn.ReflectionPad2d(rate),
  157. nn.Conv2d(dim, dim//4, 3, padding=0, dilation=rate),
  158. nn.ReLU(True)))
  159. self.fuse = nn.Sequential(
  160. nn.ReflectionPad2d(1),
  161. nn.Conv2d(dim, dim, 3, padding=0, dilation=1))
  162. self.gate = nn.Sequential(
  163. nn.ReflectionPad2d(1),
  164. nn.Conv2d(dim, dim, 3, padding=0, dilation=1))
  165. def forward(self, x):
  166. out = [self.__getattr__(f'block{str(i).zfill(2)}')(x) for i in range(len(self.rates))]
  167. out = torch.cat(out, 1)
  168. out = self.fuse(out)
  169. mask = my_layer_norm(self.gate(x))
  170. mask = torch.sigmoid(mask)
  171. return x * (1 - mask) + out * mask
  172. class ResBlockDis(nn.Module):
  173. def __init__(self, in_planes, planes, stride=1):
  174. super(ResBlockDis, self).__init__()
  175. self.bn1 = nn.InstanceNorm2d(in_planes)
  176. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3 if stride == 1 else 4, stride=stride, padding=1)
  177. self.bn2 = nn.InstanceNorm2d(planes)
  178. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1)
  179. self.planes = planes
  180. self.in_planes = in_planes
  181. self.stride = stride
  182. self.shortcut = nn.Sequential()
  183. if stride > 1:
  184. self.shortcut = nn.Sequential(nn.AvgPool2d(2, 2), nn.Conv2d(in_planes, planes, kernel_size=1))
  185. elif in_planes != planes and stride == 1:
  186. self.shortcut = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1))
  187. def forward(self, x):
  188. sc = self.shortcut(x)
  189. x = self.conv1(F.leaky_relu(self.bn1(x), 0.2))
  190. x = self.conv2(F.leaky_relu(self.bn2(x), 0.2))
  191. return sc + x
  192. from torch.nn.utils import spectral_norm
  193. class Discriminator(nn.Module):
  194. def __init__(self, in_ch = 3, in_planes = 64, blocks = [2, 2, 2], alpha = 0.2):
  195. super(Discriminator, self).__init__()
  196. self.in_planes = in_planes
  197. self.conv = nn.Sequential(
  198. spectral_norm(nn.Conv2d(in_ch, in_planes, 4, stride=2, padding=1, bias=False)),
  199. nn.LeakyReLU(0.2, inplace=True),
  200. spectral_norm(nn.Conv2d(in_planes, in_planes*2, 4, stride=2, padding=1, bias=False)),
  201. nn.LeakyReLU(0.2, inplace=True),
  202. spectral_norm(nn.Conv2d(in_planes*2, in_planes*4, 4, stride=2, padding=1, bias=False)),
  203. nn.LeakyReLU(0.2, inplace=True),
  204. spectral_norm(nn.Conv2d(in_planes*4, in_planes*8, 4, stride=1, padding=1, bias=False)),
  205. nn.LeakyReLU(0.2, inplace=True),
  206. nn.Conv2d(512, 1, 4, stride=1, padding=1)
  207. )
  208. def forward(self, x):
  209. x = self.conv(x)
  210. return x
  211. class AOTGenerator(nn.Module):
  212. def __init__(self, in_ch = 4, out_ch = 3, ch = 32, alpha = 0.0):
  213. super(AOTGenerator, self).__init__()
  214. self.head = nn.Sequential(
  215. GatedWSConvPadded(in_ch, ch, 3, stride = 1),
  216. LambdaLayer(relu_nf),
  217. GatedWSConvPadded(ch, ch * 2, 4, stride = 2),
  218. LambdaLayer(relu_nf),
  219. GatedWSConvPadded(ch * 2, ch * 4, 4, stride = 2),
  220. )
  221. self.body_conv = nn.Sequential(*[AOTBlock(ch * 4) for _ in range(10)])
  222. self.tail = nn.Sequential(
  223. GatedWSConvPadded(ch * 4, ch * 4, 3, 1),
  224. LambdaLayer(relu_nf),
  225. GatedWSConvPadded(ch * 4, ch * 4, 3, 1),
  226. LambdaLayer(relu_nf),
  227. GatedWSTransposeConvPadded(ch * 4, ch * 2, 4, 2),
  228. LambdaLayer(relu_nf),
  229. GatedWSTransposeConvPadded(ch * 2, ch, 4, 2),
  230. LambdaLayer(relu_nf),
  231. GatedWSConvPadded(ch, out_ch, 3, stride = 1),
  232. )
  233. def forward(self, img, mask):
  234. x = torch.cat([mask, img], dim = 1)
  235. x = self.head(x)
  236. conv = self.body_conv(x)
  237. x = self.tail(conv)
  238. if self.training:
  239. return x
  240. else:
  241. return torch.clip(x, -1, 1)
  242. def test():
  243. img = torch.randn(4, 3, 256, 256).cuda()
  244. mask = torch.randn(4, 1, 256, 256).cuda()
  245. net = AOTGenerator().cuda()
  246. y1 = net(img, mask)
  247. print(y1.shape)