initialize.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from argparse import ArgumentParser, Namespace
  2. from typing import (
  3. List,
  4. Tuple,
  5. )
  6. import numpy as np
  7. from PIL import Image
  8. import torch
  9. from torch import nn
  10. import torch.nn.functional as F
  11. from torchvision.transforms import (
  12. Compose,
  13. Grayscale,
  14. Resize,
  15. ToTensor,
  16. )
  17. from models.encoder import Encoder
  18. from models.encoder4editing import (
  19. get_latents as get_e4e_latents,
  20. setup_model as setup_e4e_model,
  21. )
  22. from utils.misc import (
  23. optional_string,
  24. iterable_to_str,
  25. stem,
  26. )
  27. class ColorEncoderArguments:
  28. def __init__(self):
  29. parser = ArgumentParser("Encode an image via a feed-forward encoder")
  30. self.add_arguments(parser)
  31. self.parser = parser
  32. @staticmethod
  33. def add_arguments(parser: ArgumentParser):
  34. parser.add_argument("--encoder_ckpt", default=None,
  35. help="encoder checkpoint path. initialize w with encoder output if specified")
  36. parser.add_argument("--encoder_size", type=int, default=256,
  37. help="Resize to this size to pass as input to the encoder")
  38. class InitializerArguments:
  39. @classmethod
  40. def add_arguments(cls, parser: ArgumentParser):
  41. ColorEncoderArguments.add_arguments(parser)
  42. cls.add_e4e_arguments(parser)
  43. parser.add_argument("--mix_layer_range", default=[10, 18], type=int, nargs=2,
  44. help="replace layers <start> to <end> in the e4e code by the color code")
  45. parser.add_argument("--init_latent", default=None, help="path to init wp")
  46. @staticmethod
  47. def to_string(args: Namespace):
  48. return (f"init{stem(args.init_latent).lstrip('0')[:10]}" if args.init_latent
  49. else f"init({iterable_to_str(args.mix_layer_range)})")
  50. #+ optional_string(args.init_noise > 0, f"-initN{args.init_noise}")
  51. @staticmethod
  52. def add_e4e_arguments(parser: ArgumentParser):
  53. parser.add_argument("--e4e_ckpt", default='checkpoint/e4e_ffhq_encode.pt',
  54. help="e4e checkpoint path.")
  55. parser.add_argument("--e4e_size", type=int, default=256,
  56. help="Resize to this size to pass as input to the e4e")
  57. def create_color_encoder(args: Namespace):
  58. encoder = Encoder(1, args.encoder_size, 512)
  59. ckpt = torch.load(args.encoder_ckpt)
  60. encoder.load_state_dict(ckpt["model"])
  61. return encoder
  62. def transform_input(img: Image):
  63. tsfm = Compose([
  64. Grayscale(),
  65. Resize(args.encoder_size),
  66. ToTensor(),
  67. ])
  68. return tsfm(img)
  69. def encode_color(imgs: torch.Tensor, args: Namespace) -> torch.Tensor:
  70. assert args.encoder_size is not None
  71. imgs = Resize(args.encoder_size)(imgs)
  72. color_encoder = create_color_encoder(args).to(imgs.device)
  73. color_encoder.eval()
  74. with torch.no_grad():
  75. latent = color_encoder(imgs)
  76. return latent.detach()
  77. def resize(imgs: torch.Tensor, size: int) -> torch.Tensor:
  78. return F.interpolate(imgs, size=size, mode='bilinear')
  79. class Initializer(nn.Module):
  80. def __init__(self, args: Namespace):
  81. super().__init__()
  82. self.path = None
  83. if args.init_latent is not None:
  84. self.path = args.init_latent
  85. return
  86. assert args.encoder_size is not None
  87. self.color_encoder = create_color_encoder(args)
  88. self.color_encoder.eval()
  89. self.color_encoder_size = args.encoder_size
  90. self.e4e, e4e_opts = setup_e4e_model(args.e4e_ckpt)
  91. assert 'cars_' not in e4e_opts.dataset_type
  92. self.e4e.decoder.eval()
  93. self.e4e.eval()
  94. self.e4e_size = args.e4e_size
  95. self.mix_layer_range = args.mix_layer_range
  96. def encode_color(self, imgs: torch.Tensor) -> torch.Tensor:
  97. """
  98. Get the color W code
  99. """
  100. imgs = resize(imgs, self.color_encoder_size)
  101. latent = self.color_encoder(imgs)
  102. return latent
  103. def encode_shape(self, imgs: torch.Tensor) -> torch.Tensor:
  104. imgs = resize(imgs, self.e4e_size)
  105. imgs = (imgs - 0.5) / 0.5
  106. if imgs.shape[1] == 1: # 1 channel
  107. imgs = imgs.repeat(1, 3, 1, 1)
  108. return get_e4e_latents(self.e4e, imgs)
  109. def load(self, device: torch.device):
  110. latent_np = np.load(self.path)
  111. return torch.tensor(latent_np, device=device)[None, ...]
  112. def forward(self, imgs: torch.Tensor) -> torch.Tensor:
  113. if self.path is not None:
  114. return self.load(imgs.device)
  115. shape_code = self.encode_shape(imgs)
  116. color_code = self.encode_color(imgs)
  117. # style mix
  118. latent = shape_code
  119. start, end = self.mix_layer_range
  120. latent[:, start:end] = color_code
  121. return latent