prepare_landmarks_metfaces.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import argparse
  2. import glob
  3. import os
  4. import numpy as np
  5. import PIL
  6. import tqdm
  7. from PIL import Image, ImageOps
  8. try:
  9. import pyspng
  10. except ImportError:
  11. pyspng = None
  12. from mtcnn import MTCNN
  13. N_IMGS = 1
  14. def get_file_ext(fname):
  15. return os.path.splitext(fname)[1].lower()
  16. def main(args):
  17. detect_base_dir = os.path.join(args.save_dir, "detections")
  18. detect_res_dir = os.path.join(detect_base_dir, "results")
  19. os.makedirs(detect_res_dir, exist_ok=True)
  20. detector = MTCNN()
  21. sorted_f_list = sorted(list(glob.glob(os.path.join(args.data_dir, "*.png"))))
  22. print("\nsorted_f_list: ", len(sorted_f_list), sorted_f_list[:5], "\n")
  23. for i, f_path in tqdm.tqdm(enumerate(sorted_f_list), total=len(sorted_f_list)):
  24. basename = os.path.splitext(os.path.basename(f_path))[0]
  25. if pyspng is not None and get_file_ext(f_path) == ".png":
  26. with open(f_path, "rb") as fin:
  27. img = pyspng.load(fin.read())
  28. else:
  29. img = np.array(PIL.Image.open(f_path))
  30. if args.xflip == 1:
  31. tmp_f_path = f"{f_path.split('.png')[0]}_xflip.png"
  32. ImageOps.mirror(Image.fromarray(img)).save(tmp_f_path)
  33. else:
  34. text_path = f"{detect_res_dir}/{basename}.txt"
  35. result = detector.detect_faces(img)
  36. try:
  37. keypoints = result[0]["keypoints"]
  38. with open(text_path, "w") as f:
  39. for value in keypoints.values():
  40. f.write(f"{value[0]}\t{value[1]}\n")
  41. # print(f"File successfully written: {text_path}")
  42. except:
  43. if i == 0:
  44. mode = "w"
  45. else:
  46. mode = "a"
  47. with open(os.path.join(detect_base_dir, "fail_list.txt"), mode) as fail_f:
  48. fail_f.write(f"{os.path.basename(f_path)}\n")
  49. print("\n[fail] ", os.path.basename(f_path), "\n")
  50. if not os.path.exists(os.path.join(detect_base_dir, "fail_list.txt")):
  51. with open(os.path.join(detect_base_dir, "fail_list.txt"), "w") as fail_f:
  52. fail_f.write(f"\n")
  53. if __name__ == "__main__":
  54. parser = argparse.ArgumentParser(description="Get landmarks from images.")
  55. parser.add_argument("--data_dir", type=str, default=None, help="folder for metfaces")
  56. parser.add_argument("--save_dir", type=str, default=None, help="save dir")
  57. parser.add_argument("--xflip", type=int, default=0, help="Whether to do xflip.")
  58. args = parser.parse_args()
  59. main(args)