match_skin_histogram.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from argparse import Namespace
  2. import os
  3. from os.path import join as pjoin
  4. from typing import Optional
  5. import cv2
  6. import torch
  7. from tools import (
  8. parse_face,
  9. match_histogram,
  10. )
  11. from utils.torch_helpers import make_image
  12. from utils.misc import stem
  13. def match_skin_histogram(
  14. imgs: torch.Tensor,
  15. sibling_img: torch.Tensor,
  16. spectral_sensitivity,
  17. im_sibling_dir: str,
  18. mask_dir: str,
  19. matched_hist_fn: Optional[str] = None,
  20. normalize=None, # normalize the range of the tensor
  21. ):
  22. """
  23. Extract the skin of the input and sibling images. Create a new input image by matching
  24. its histogram to the sibling.
  25. """
  26. # TODO: Currently only allows imgs of batch size 1
  27. im_sibling_dir = os.path.abspath(im_sibling_dir)
  28. mask_dir = os.path.abspath(mask_dir)
  29. img_np = make_image(imgs)[0]
  30. sibling_np = make_image(sibling_img)[0][...,::-1]
  31. # save img, sibling
  32. os.makedirs(im_sibling_dir, exist_ok=True)
  33. im_name, sibling_name = 'input.png', 'sibling.png'
  34. cv2.imwrite(pjoin(im_sibling_dir, im_name), img_np)
  35. cv2.imwrite(pjoin(im_sibling_dir, sibling_name), sibling_np)
  36. # face parsing
  37. parse_face.main(
  38. Namespace(in_dir=im_sibling_dir, out_dir=mask_dir, include_hair=False)
  39. )
  40. # match_histogram
  41. mh_args = match_histogram.parse_args(
  42. args=[
  43. pjoin(im_sibling_dir, im_name),
  44. pjoin(im_sibling_dir, sibling_name),
  45. ],
  46. namespace=Namespace(
  47. out=matched_hist_fn if matched_hist_fn else pjoin(im_sibling_dir, "match_histogram.png"),
  48. src_mask=pjoin(mask_dir, im_name),
  49. ref_mask=pjoin(mask_dir, sibling_name),
  50. spectral_sensitivity=spectral_sensitivity,
  51. )
  52. )
  53. matched_np = match_histogram.main(mh_args) / 255.0 # [0, 1]
  54. matched = torch.FloatTensor(matched_np).permute(2, 0, 1)[None,...] #BCHW
  55. if normalize is not None:
  56. matched = normalize(matched)
  57. return matched