123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- from argparse import ArgumentParser, Namespace
- from typing import (
- List,
- Tuple,
- )
- import numpy as np
- from PIL import Image
- import torch
- from torch import nn
- import torch.nn.functional as F
- from torchvision.transforms import (
- Compose,
- Grayscale,
- Resize,
- ToTensor,
- )
- from models.encoder import Encoder
- from models.encoder4editing import (
- get_latents as get_e4e_latents,
- setup_model as setup_e4e_model,
- )
- from utils.misc import (
- optional_string,
- iterable_to_str,
- stem,
- )
- class ColorEncoderArguments:
- def __init__(self):
- parser = ArgumentParser("Encode an image via a feed-forward encoder")
- self.add_arguments(parser)
- self.parser = parser
- @staticmethod
- def add_arguments(parser: ArgumentParser):
- parser.add_argument("--encoder_ckpt", default=None,
- help="encoder checkpoint path. initialize w with encoder output if specified")
- parser.add_argument("--encoder_size", type=int, default=256,
- help="Resize to this size to pass as input to the encoder")
- class InitializerArguments:
- @classmethod
- def add_arguments(cls, parser: ArgumentParser):
- ColorEncoderArguments.add_arguments(parser)
- cls.add_e4e_arguments(parser)
- parser.add_argument("--mix_layer_range", default=[10, 18], type=int, nargs=2,
- help="replace layers <start> to <end> in the e4e code by the color code")
- parser.add_argument("--init_latent", default=None, help="path to init wp")
- @staticmethod
- def to_string(args: Namespace):
- return (f"init{stem(args.init_latent).lstrip('0')[:10]}" if args.init_latent
- else f"init({iterable_to_str(args.mix_layer_range)})")
- #+ optional_string(args.init_noise > 0, f"-initN{args.init_noise}")
- @staticmethod
- def add_e4e_arguments(parser: ArgumentParser):
- parser.add_argument("--e4e_ckpt", default='checkpoint/e4e_ffhq_encode.pt',
- help="e4e checkpoint path.")
- parser.add_argument("--e4e_size", type=int, default=256,
- help="Resize to this size to pass as input to the e4e")
- def create_color_encoder(args: Namespace):
- encoder = Encoder(1, args.encoder_size, 512)
- ckpt = torch.load(args.encoder_ckpt)
- encoder.load_state_dict(ckpt["model"])
- return encoder
- def transform_input(img: Image):
- tsfm = Compose([
- Grayscale(),
- Resize(args.encoder_size),
- ToTensor(),
- ])
- return tsfm(img)
- def encode_color(imgs: torch.Tensor, args: Namespace) -> torch.Tensor:
- assert args.encoder_size is not None
- imgs = Resize(args.encoder_size)(imgs)
- color_encoder = create_color_encoder(args).to(imgs.device)
- color_encoder.eval()
- with torch.no_grad():
- latent = color_encoder(imgs)
- return latent.detach()
- def resize(imgs: torch.Tensor, size: int) -> torch.Tensor:
- return F.interpolate(imgs, size=size, mode='bilinear')
- class Initializer(nn.Module):
- def __init__(self, args: Namespace):
- super().__init__()
- self.path = None
- if args.init_latent is not None:
- self.path = args.init_latent
- return
- assert args.encoder_size is not None
- self.color_encoder = create_color_encoder(args)
- self.color_encoder.eval()
- self.color_encoder_size = args.encoder_size
- self.e4e, e4e_opts = setup_e4e_model(args.e4e_ckpt)
- assert 'cars_' not in e4e_opts.dataset_type
- self.e4e.decoder.eval()
- self.e4e.eval()
- self.e4e_size = args.e4e_size
- self.mix_layer_range = args.mix_layer_range
- def encode_color(self, imgs: torch.Tensor) -> torch.Tensor:
- """
- Get the color W code
- """
- imgs = resize(imgs, self.color_encoder_size)
- latent = self.color_encoder(imgs)
- return latent
- def encode_shape(self, imgs: torch.Tensor) -> torch.Tensor:
- imgs = resize(imgs, self.e4e_size)
- imgs = (imgs - 0.5) / 0.5
- if imgs.shape[1] == 1: # 1 channel
- imgs = imgs.repeat(1, 3, 1, 1)
- return get_e4e_latents(self.e4e, imgs)
- def load(self, device: torch.device):
- latent_np = np.load(self.path)
- return torch.tensor(latent_np, device=device)[None, ...]
- def forward(self, imgs: torch.Tensor) -> torch.Tensor:
- if self.path is not None:
- return self.load(imgs.device)
- shape_code = self.encode_shape(imgs)
- color_code = self.encode_color(imgs)
- # style mix
- latent = shape_code
- start, end = self.mix_layer_range
- latent[:, start:end] = color_code
- return latent
|