projector.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from argparse import Namespace
  2. import os
  3. from os.path import join as pjoin
  4. import random
  5. import sys
  6. from typing import (
  7. Iterable,
  8. Optional,
  9. )
  10. import cv2
  11. import numpy as np
  12. from PIL import Image
  13. import torch
  14. from torch.utils.tensorboard import SummaryWriter
  15. from torchvision.transforms import (
  16. Compose,
  17. Grayscale,
  18. Resize,
  19. ToTensor,
  20. Normalize,
  21. )
  22. from losses.joint_loss import JointLoss
  23. from model import Generator
  24. from tools.initialize import Initializer
  25. from tools.match_skin_histogram import match_skin_histogram
  26. from utils.projector_arguments import ProjectorArguments
  27. from utils import torch_helpers as th
  28. from utils.torch_helpers import make_image
  29. from utils.misc import stem
  30. from utils.optimize import Optimizer
  31. from models.degrade import (
  32. Degrade,
  33. Downsample,
  34. )
  35. def set_random_seed(seed: int):
  36. # FIXME (xuanluo): this setup still allows randomness somehow
  37. torch.manual_seed(seed)
  38. random.seed(seed)
  39. np.random.seed(seed)
  40. def read_images(paths: str, max_size: Optional[int] = None):
  41. transform = Compose(
  42. [
  43. Grayscale(),
  44. ToTensor(),
  45. ]
  46. )
  47. imgs = []
  48. for path in paths:
  49. img = Image.open(path)
  50. if max_size is not None and img.width > max_size:
  51. img = img.resize((max_size, max_size))
  52. img = transform(img)
  53. imgs.append(img)
  54. imgs = torch.stack(imgs, 0)
  55. return imgs
  56. def normalize(img: torch.Tensor, mean=0.5, std=0.5):
  57. """[0, 1] -> [-1, 1]"""
  58. return (img - mean) / std
  59. def create_generator(args: Namespace, device: torch.device):
  60. generator = Generator(args.generator_size, 512, 8)
  61. generator.load_state_dict(torch.load(args.ckpt)['g_ema'], strict=False)
  62. generator.eval()
  63. generator = generator.to(device)
  64. return generator
  65. def save(
  66. path_prefixes: Iterable[str],
  67. imgs: torch.Tensor, # BCHW
  68. latents: torch.Tensor,
  69. noises: torch.Tensor,
  70. imgs_rand: Optional[torch.Tensor] = None,
  71. ):
  72. assert len(path_prefixes) == len(imgs) and len(latents) == len(path_prefixes)
  73. if imgs_rand is not None:
  74. assert len(imgs) == len(imgs_rand)
  75. imgs_arr = make_image(imgs)
  76. for path_prefix, img, latent, noise in zip(path_prefixes, imgs_arr, latents, noises):
  77. os.makedirs(os.path.dirname(path_prefix), exist_ok=True)
  78. cv2.imwrite(path_prefix + ".png", img[...,::-1])
  79. torch.save({"latent": latent.detach().cpu(), "noise": noise.detach().cpu()},
  80. path_prefix + ".pt")
  81. if imgs_rand is not None:
  82. imgs_arr = make_image(imgs_rand)
  83. for path_prefix, img in zip(path_prefixes, imgs_arr):
  84. cv2.imwrite(path_prefix + "-rand.png", img[...,::-1])
  85. def main(args):
  86. opt_str = ProjectorArguments.to_string(args)
  87. print(opt_str)
  88. if args.rand_seed is not None:
  89. set_random_seed(args.rand_seed)
  90. device = th.device()
  91. # read inputs. TODO imgs_orig has channel 1
  92. imgs_orig = read_images([args.input], max_size=args.generator_size).to(device)
  93. imgs = normalize(imgs_orig) # actually this will be overwritten by the histogram matching result
  94. # initialize
  95. with torch.no_grad():
  96. init = Initializer(args).to(device)
  97. latent_init = init(imgs_orig)
  98. # create generator
  99. generator = create_generator(args, device)
  100. # init noises
  101. with torch.no_grad():
  102. noises_init = generator.make_noise()
  103. # create a new input by matching the input's histogram to the sibling image
  104. with torch.no_grad():
  105. sibling, _, sibling_rgbs = generator([latent_init], input_is_latent=True, noise=noises_init)
  106. mh_dir = pjoin(args.results_dir, stem(args.input))
  107. imgs = match_skin_histogram(
  108. imgs, sibling,
  109. args.spectral_sensitivity,
  110. pjoin(mh_dir, "input_sibling"),
  111. pjoin(mh_dir, "skin_mask"),
  112. matched_hist_fn=mh_dir.rstrip(os.sep) + f"_{args.spectral_sensitivity}.png",
  113. normalize=normalize,
  114. ).to(device)
  115. torch.cuda.empty_cache()
  116. # TODO imgs has channel 3
  117. degrade = Degrade(args).to(device)
  118. rgb_levels = generator.get_latent_size(args.coarse_min) // 2 + len(args.wplus_step) - 1
  119. criterion = JointLoss(
  120. args, imgs,
  121. sibling=sibling.detach(), sibling_rgbs=sibling_rgbs[:rgb_levels]).to(device)
  122. # save initialization
  123. save(
  124. [pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}-init")],
  125. sibling, latent_init, noises_init,
  126. )
  127. writer = SummaryWriter(pjoin(args.log_dir, f"{stem(args.input)}/{opt_str}"))
  128. # start optimize
  129. latent, noises = Optimizer.optimize(generator, criterion, degrade, imgs, latent_init, noises_init, args, writer=writer)
  130. # generate output
  131. img_out, _, _ = generator([latent], input_is_latent=True, noise=noises)
  132. img_out_rand_noise, _, _ = generator([latent], input_is_latent=True)
  133. # save output
  134. save(
  135. [pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}")],
  136. img_out, latent, noises,
  137. imgs_rand=img_out_rand_noise
  138. )
  139. def parse_args():
  140. return ProjectorArguments().parse()
  141. if __name__ == "__main__":
  142. sys.exit(main(parse_args()))