12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- from argparse import Namespace
- import os
- from os.path import join as pjoin
- from typing import Optional
- import cv2
- import torch
- from tools import (
- parse_face,
- match_histogram,
- )
- from utils.torch_helpers import make_image
- from utils.misc import stem
- def match_skin_histogram(
- imgs: torch.Tensor,
- sibling_img: torch.Tensor,
- spectral_sensitivity,
- im_sibling_dir: str,
- mask_dir: str,
- matched_hist_fn: Optional[str] = None,
- normalize=None, # normalize the range of the tensor
- ):
- """
- Extract the skin of the input and sibling images. Create a new input image by matching
- its histogram to the sibling.
- """
- # TODO: Currently only allows imgs of batch size 1
- im_sibling_dir = os.path.abspath(im_sibling_dir)
- mask_dir = os.path.abspath(mask_dir)
- img_np = make_image(imgs)[0]
- sibling_np = make_image(sibling_img)[0][...,::-1]
- # save img, sibling
- os.makedirs(im_sibling_dir, exist_ok=True)
- im_name, sibling_name = 'input.png', 'sibling.png'
- cv2.imwrite(pjoin(im_sibling_dir, im_name), img_np)
- cv2.imwrite(pjoin(im_sibling_dir, sibling_name), sibling_np)
- # face parsing
- parse_face.main(
- Namespace(in_dir=im_sibling_dir, out_dir=mask_dir, include_hair=False)
- )
- # match_histogram
- mh_args = match_histogram.parse_args(
- args=[
- pjoin(im_sibling_dir, im_name),
- pjoin(im_sibling_dir, sibling_name),
- ],
- namespace=Namespace(
- out=matched_hist_fn if matched_hist_fn else pjoin(im_sibling_dir, "match_histogram.png"),
- src_mask=pjoin(mask_dir, im_name),
- ref_mask=pjoin(mask_dir, sibling_name),
- spectral_sensitivity=spectral_sensitivity,
- )
- )
- matched_np = match_histogram.main(mh_args) / 255.0 # [0, 1]
- matched = torch.FloatTensor(matched_np).permute(2, 0, 1)[None,...] #BCHW
- if normalize is not None:
- matched = normalize(matched)
- return matched
|