crop_head.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import face_alignment
  2. import os
  3. import cv2
  4. import skimage.transform as trans
  5. import argparse
  6. import torch
  7. import numpy as np
  8. import tqdm
  9. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  10. def get_affine(src):
  11. dst = np.array([[87, 59],
  12. [137, 59],
  13. [112, 120]], dtype=np.float32)
  14. tform = trans.SimilarityTransform()
  15. tform.estimate(src, dst)
  16. M = tform.params[0:2, :]
  17. return M
  18. def affine_align_img(img, M, crop_size=224):
  19. warped = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
  20. return warped
  21. def affine_align_3landmarks(landmarks, M):
  22. new_landmarks = np.concatenate([landmarks, np.ones((3, 1))], 1)
  23. affined_landmarks = np.matmul(new_landmarks, M.transpose())
  24. return affined_landmarks
  25. def get_eyes_mouths(landmark):
  26. three_points = np.zeros((3, 2))
  27. three_points[0] = landmark[36:42].mean(0)
  28. three_points[1] = landmark[42:48].mean(0)
  29. three_points[2] = landmark[60:68].mean(0)
  30. return three_points
  31. def get_mouth_bias(three_points):
  32. bias = np.array([112, 120]) - three_points[2]
  33. return bias
  34. def align_folder(folder_path, folder_save_path):
  35. fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device)
  36. preds = fa.get_landmarks_from_directory(folder_path)
  37. sumpoints = 0
  38. three_points_list = []
  39. for img in tqdm.tqdm(preds.keys(), desc='preprocessing..'):
  40. pred_points = np.array(preds[img])
  41. if pred_points is None or len(pred_points.shape) != 3:
  42. print('preprocessing failed')
  43. return False
  44. else:
  45. num_faces, size, _ = pred_points.shape
  46. if num_faces == 1 and size == 68:
  47. three_points = get_eyes_mouths(pred_points[0])
  48. sumpoints += three_points
  49. three_points_list.append(three_points)
  50. else:
  51. print('preprocessing failed')
  52. return False
  53. avg_points = sumpoints / len(preds)
  54. M = get_affine(avg_points)
  55. p_bias = None
  56. for i, img_pth in tqdm.tqdm(enumerate(preds.keys()), desc='affine and save'):
  57. three_points = three_points_list[i]
  58. affined_3landmarks = affine_align_3landmarks(three_points, M)
  59. bias = get_mouth_bias(affined_3landmarks)
  60. if p_bias is None:
  61. bias = bias
  62. else:
  63. bias = p_bias * 0.2 + bias * 0.8
  64. p_bias = bias
  65. M_i = M.copy()
  66. M_i[:, 2] = M[:, 2] + bias
  67. img = cv2.imread(img_pth)
  68. wrapped = affine_align_img(img, M_i)
  69. img_save_path = os.path.join(folder_save_path, img_pth.split('/')[-1])
  70. cv2.imwrite(img_save_path, wrapped)
  71. print('cropped files saved at {}'.format(folder_save_path))
  72. def main():
  73. parser = argparse.ArgumentParser()
  74. parser.add_argument('--folder_path', help='the folder which needs processing')
  75. args = parser.parse_args()
  76. if os.path.isdir(args.folder_path):
  77. home_path = '/'.join(args.folder_path.split('/')[:-1])
  78. save_img_path = os.path.join(home_path, args.folder_path.split('/')[-1] + '_cropped')
  79. os.makedirs(save_img_path, exist_ok=True)
  80. align_folder(args.folder_path, save_img_path)
  81. if __name__ == '__main__':
  82. main()