from typing import List, Optional import numpy as np import os import shutil import torch import torch.nn as nn import torch.nn.functional as F from .inpainting_lama_mpe import LamaMPEInpainter class AotInpainter(LamaMPEInpainter): _MODEL_MAPPING = { 'model': { 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/inpainting.ckpt', 'hash': '878d541c68648969bc1b042a6e997f3a58e49b6c07c5636ad55130736977149f', 'file': '.', }, } def __init__(self, *args, **kwargs): os.makedirs(self.model_dir, exist_ok=True) if os.path.exists('inpainting.ckpt'): shutil.move('inpainting.ckpt', self._get_file_path('inpainting.ckpt')) super().__init__(*args, **kwargs) async def _load(self, device: str): self.model = AOTGenerator() sd = torch.load(self._get_file_path('inpainting.ckpt'), map_location='cpu') self.model.load_state_dict(sd['model'] if 'model' in sd else sd) self.model.eval() self.device = device if device.startswith('cuda') or device == 'mps': self.model.to(device) def relu_nf(x): return F.relu(x) * 1.7139588594436646 def gelu_nf(x): return F.gelu(x) * 1.7015043497085571 def silu_nf(x): return F.silu(x) * 1.7881293296813965 class LambdaLayer(nn.Module): def __init__(self, f): super(LambdaLayer, self).__init__() self.f = f def forward(self, x): return self.f(x) class ScaledWSConv2d(nn.Conv2d): """2D Conv layer with Scaled Weight Standardization.""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, gain=True, eps=1e-4): nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) #nn.init.kaiming_normal_(self.weight) if gain: self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) else: self.gain = None # Epsilon, a small constant to avoid dividing by zero. self.eps = eps def get_weight(self): # Get Scaled WS weight OIHW; fan_in = np.prod(self.weight.shape[1:]) var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True) scale = torch.rsqrt(torch.max( var * fan_in, torch.tensor(self.eps).to(var.device))) * self.gain.view_as(var).to(var.device) shift = mean * scale return self.weight * scale - shift def forward(self, x): return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) class ScaledWSTransposeConv2d(nn.ConvTranspose2d): """2D Transpose Conv layer with Scaled Weight Standardization.""" def __init__(self, in_channels: int, out_channels: int, kernel_size, stride = 1, padding = 0, output_padding = 0, groups: int = 1, bias: bool = True, dilation: int = 1, gain=True, eps=1e-4): nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, 'zeros') #nn.init.kaiming_normal_(self.weight) if gain: self.gain = nn.Parameter(torch.ones(self.in_channels, 1, 1, 1)) else: self.gain = None # Epsilon, a small constant to avoid dividing by zero. self.eps = eps def get_weight(self): # Get Scaled WS weight OIHW; fan_in = np.prod(self.weight.shape[1:]) var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True) scale = torch.rsqrt(torch.max( var * fan_in, torch.tensor(self.eps).to(var.device))) * self.gain.view_as(var).to(var.device) shift = mean * scale return self.weight * scale - shift def forward(self, x, output_size: Optional[List[int]] = None): output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) return F.conv_transpose2d(x, self.get_weight(), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class GatedWSConvPadded(nn.Module): def __init__(self, in_ch, out_ch, ks, stride = 1, dilation = 1): super(GatedWSConvPadded, self).__init__() self.in_ch = in_ch self.out_ch = out_ch self.padding = nn.ReflectionPad2d(((ks - 1) * dilation) // 2) self.conv = ScaledWSConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, dilation = dilation) self.conv_gate = ScaledWSConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, dilation = dilation) def forward(self, x): x = self.padding(x) signal = self.conv(x) gate = torch.sigmoid(self.conv_gate(x)) return signal * gate * 1.8 class GatedWSTransposeConvPadded(nn.Module): def __init__(self, in_ch, out_ch, ks, stride = 1): super(GatedWSTransposeConvPadded, self).__init__() self.in_ch = in_ch self.out_ch = out_ch self.conv = ScaledWSTransposeConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, padding = (ks - 1) // 2) self.conv_gate = ScaledWSTransposeConv2d(in_ch, out_ch, kernel_size = ks, stride = stride, padding = (ks - 1) // 2) def forward(self, x): signal = self.conv(x) gate = torch.sigmoid(self.conv_gate(x)) return signal * gate * 1.8 class ResBlock(nn.Module): def __init__(self, ch, alpha = 0.2, beta = 1.0, dilation = 1): super(ResBlock, self).__init__() self.alpha = alpha self.beta = beta self.c1 = GatedWSConvPadded(ch, ch, 3, dilation = dilation) self.c2 = GatedWSConvPadded(ch, ch, 3, dilation = dilation) def forward(self, x): skip = x x = self.c1(relu_nf(x / self.beta)) x = self.c2(relu_nf(x)) x = x * self.alpha return x + skip def my_layer_norm(feat): mean = feat.mean((2, 3), keepdim=True) std = feat.std((2, 3), keepdim=True) + 1e-9 feat = 2 * (feat - mean) / std - 1 feat = 5 * feat return feat class AOTBlock(nn.Module): def __init__(self, dim, rates = [2, 4, 8, 16]): super(AOTBlock, self).__init__() self.rates = rates for i, rate in enumerate(rates): self.__setattr__( 'block{}'.format(str(i).zfill(2)), nn.Sequential( nn.ReflectionPad2d(rate), nn.Conv2d(dim, dim//4, 3, padding=0, dilation=rate), nn.ReLU(True))) self.fuse = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, 3, padding=0, dilation=1)) self.gate = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, 3, padding=0, dilation=1)) def forward(self, x): out = [self.__getattr__(f'block{str(i).zfill(2)}')(x) for i in range(len(self.rates))] out = torch.cat(out, 1) out = self.fuse(out) mask = my_layer_norm(self.gate(x)) mask = torch.sigmoid(mask) return x * (1 - mask) + out * mask class ResBlockDis(nn.Module): def __init__(self, in_planes, planes, stride=1): super(ResBlockDis, self).__init__() self.bn1 = nn.InstanceNorm2d(in_planes) self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3 if stride == 1 else 4, stride=stride, padding=1) self.bn2 = nn.InstanceNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1) self.planes = planes self.in_planes = in_planes self.stride = stride self.shortcut = nn.Sequential() if stride > 1: self.shortcut = nn.Sequential(nn.AvgPool2d(2, 2), nn.Conv2d(in_planes, planes, kernel_size=1)) elif in_planes != planes and stride == 1: self.shortcut = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1)) def forward(self, x): sc = self.shortcut(x) x = self.conv1(F.leaky_relu(self.bn1(x), 0.2)) x = self.conv2(F.leaky_relu(self.bn2(x), 0.2)) return sc + x from torch.nn.utils import spectral_norm class Discriminator(nn.Module): def __init__(self, in_ch = 3, in_planes = 64, blocks = [2, 2, 2], alpha = 0.2): super(Discriminator, self).__init__() self.in_planes = in_planes self.conv = nn.Sequential( spectral_norm(nn.Conv2d(in_ch, in_planes, 4, stride=2, padding=1, bias=False)), nn.LeakyReLU(0.2, inplace=True), spectral_norm(nn.Conv2d(in_planes, in_planes*2, 4, stride=2, padding=1, bias=False)), nn.LeakyReLU(0.2, inplace=True), spectral_norm(nn.Conv2d(in_planes*2, in_planes*4, 4, stride=2, padding=1, bias=False)), nn.LeakyReLU(0.2, inplace=True), spectral_norm(nn.Conv2d(in_planes*4, in_planes*8, 4, stride=1, padding=1, bias=False)), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(512, 1, 4, stride=1, padding=1) ) def forward(self, x): x = self.conv(x) return x class AOTGenerator(nn.Module): def __init__(self, in_ch = 4, out_ch = 3, ch = 32, alpha = 0.0): super(AOTGenerator, self).__init__() self.head = nn.Sequential( GatedWSConvPadded(in_ch, ch, 3, stride = 1), LambdaLayer(relu_nf), GatedWSConvPadded(ch, ch * 2, 4, stride = 2), LambdaLayer(relu_nf), GatedWSConvPadded(ch * 2, ch * 4, 4, stride = 2), ) self.body_conv = nn.Sequential(*[AOTBlock(ch * 4) for _ in range(10)]) self.tail = nn.Sequential( GatedWSConvPadded(ch * 4, ch * 4, 3, 1), LambdaLayer(relu_nf), GatedWSConvPadded(ch * 4, ch * 4, 3, 1), LambdaLayer(relu_nf), GatedWSTransposeConvPadded(ch * 4, ch * 2, 4, 2), LambdaLayer(relu_nf), GatedWSTransposeConvPadded(ch * 2, ch, 4, 2), LambdaLayer(relu_nf), GatedWSConvPadded(ch, out_ch, 3, stride = 1), ) def forward(self, img, mask): x = torch.cat([mask, img], dim = 1) x = self.head(x) conv = self.body_conv(x) x = self.tail(conv) if self.training: return x else: return torch.clip(x, -1, 1) def test(): img = torch.randn(4, 3, 256, 256).cuda() mask = torch.randn(4, 1, 256, 256).cuda() net = AOTGenerator().cuda() y1 = net(img, mask) print(y1.shape)