face3d_helper.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import os
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. from scipy.io import loadmat
  6. from deep_3drecon.deep_3drecon_models.bfm import perspective_projection
  7. class Face3DHelper(nn.Module):
  8. def __init__(self, bfm_dir='deep_3drecon/BFM', keypoint_mode='lm68', use_gpu=True):
  9. super().__init__()
  10. self.keypoint_mode = keypoint_mode # lm68 | mediapipe
  11. self.bfm_dir = bfm_dir
  12. self.load_3dmm()
  13. if use_gpu: self.to("cuda")
  14. def load_3dmm(self):
  15. model = loadmat(os.path.join(self.bfm_dir, "BFM_model_front.mat"))
  16. self.register_buffer('mean_shape',torch.from_numpy(model['meanshape'].transpose()).float()) # mean face shape. [3*N, 1], N=35709, xyz=3, ==> 3*N=107127
  17. mean_shape = self.mean_shape.reshape([-1, 3])
  18. # re-center
  19. mean_shape = mean_shape - torch.mean(mean_shape, dim=0, keepdims=True)
  20. self.mean_shape = mean_shape.reshape([-1, 1])
  21. self.register_buffer('id_base',torch.from_numpy(model['idBase']).float()) # identity basis. [3*N,80], we have 80 eigen faces for identity
  22. self.register_buffer('exp_base',torch.from_numpy(model['exBase']).float()) # expression basis. [3*N,64], we have 64 eigen faces for expression
  23. self.register_buffer('mean_texure',torch.from_numpy(model['meantex'].transpose()).float()) # mean face texture. [3*N,1] (0-255)
  24. self.register_buffer('tex_base',torch.from_numpy(model['texBase']).float()) # texture basis. [3*N,80], rgb=3
  25. self.register_buffer('point_buf',torch.from_numpy(model['point_buf']).float()) # triangle indices for each vertex that lies in. starts from 1. [N,8] (1-F)
  26. self.register_buffer('face_buf',torch.from_numpy(model['tri']).float()) # vertex indices in each triangle. starts from 1. [F,3] (1-N)
  27. if self.keypoint_mode == 'mediapipe':
  28. self.register_buffer('key_points', torch.from_numpy(np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)))
  29. unmatch_mask = self.key_points < 0
  30. self.key_points[unmatch_mask] = 0
  31. else:
  32. self.register_buffer('key_points',torch.from_numpy(model['keypoints'].squeeze().astype(np.int_)).long()) # vertex indices of 68 facial landmarks. starts from 1. [68,1]
  33. self.register_buffer('key_mean_shape',self.mean_shape.reshape([-1,3])[self.key_points,:])
  34. self.register_buffer('key_id_base', self.id_base.reshape([-1,3,80])[self.key_points, :, :].reshape([-1,80]))
  35. self.register_buffer('key_exp_base', self.exp_base.reshape([-1,3,64])[self.key_points, :, :].reshape([-1,64]))
  36. self.key_id_base_np = self.key_id_base.cpu().numpy()
  37. self.key_exp_base_np = self.key_exp_base.cpu().numpy()
  38. self.register_buffer('persc_proj', torch.tensor(perspective_projection(focal=1015, center=112)))
  39. def split_coeff(self, coeff):
  40. """
  41. coeff: Tensor[B, T, c=257] or [T, c=257]
  42. """
  43. ret_dict = {
  44. 'identity': coeff[..., :80], # identity, [b, t, c=80]
  45. 'expression': coeff[..., 80:144], # expression, [b, t, c=80]
  46. 'texture': coeff[..., 144:224], # texture, [b, t, c=80]
  47. 'euler': coeff[..., 224:227], # euler euler for pose, [b, t, c=3]
  48. 'translation': coeff[..., 254:257], # translation, [b, t, c=3]
  49. 'gamma': coeff[..., 227:254] # lighting, [b, t, c=27]
  50. }
  51. return ret_dict
  52. def reconstruct_face_mesh(self, id_coeff, exp_coeff):
  53. """
  54. Generate a pose-independent 3D face mesh!
  55. id_coeff: Tensor[T, c=80]
  56. exp_coeff: Tensor[T, c=64]
  57. """
  58. id_coeff = id_coeff.to(self.key_id_base.device)
  59. exp_coeff = exp_coeff.to(self.key_id_base.device)
  60. mean_face = self.mean_shape.squeeze().reshape([1, -1]) # [3N, 1] ==> [1, 3N]
  61. id_base, exp_base = self.id_base, self.exp_base # [3*N, C]
  62. identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
  63. expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
  64. face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
  65. face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
  66. # re-centering the face with mean_xyz, so the face will be in [-1, 1]
  67. # mean_xyz = self.mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
  68. # face_mesh = face - mean_xyz.unsqueeze(0) # [t,N,3]
  69. return face
  70. def reconstruct_cano_lm3d(self, id_coeff, exp_coeff):
  71. """
  72. Generate 3D landmark with keypoint base!
  73. id_coeff: Tensor[T, c=80]
  74. exp_coeff: Tensor[T, c=64]
  75. """
  76. id_coeff = id_coeff.to(self.key_id_base.device)
  77. exp_coeff = exp_coeff.to(self.key_id_base.device)
  78. mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
  79. id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
  80. identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
  81. expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
  82. face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
  83. face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
  84. # re-centering the face with mean_xyz, so the face will be in [-1, 1]
  85. # mean_xyz = self.key_mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
  86. # lm3d = face - mean_xyz.unsqueeze(0) # [t,N,3]
  87. return face
  88. def reconstruct_lm3d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
  89. """
  90. Generate 3D landmark with keypoint base!
  91. id_coeff: Tensor[T, c=80]
  92. exp_coeff: Tensor[T, c=64]
  93. """
  94. id_coeff = id_coeff.to(self.key_id_base.device)
  95. exp_coeff = exp_coeff.to(self.key_id_base.device)
  96. mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
  97. id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
  98. identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
  99. expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
  100. face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
  101. face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
  102. # re-centering the face with mean_xyz, so the face will be in [-1, 1]
  103. rot = self.compute_rotation(euler)
  104. # transform
  105. lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
  106. # to camera
  107. if to_camera:
  108. lm3d[...,-1] = 10 - lm3d[...,-1]
  109. return lm3d
  110. def reconstruct_lm2d_nerf(self, id_coeff, exp_coeff, euler, trans):
  111. lm2d = self.reconstruct_lm2d(id_coeff, exp_coeff, euler, trans, to_camera=False)
  112. lm2d[..., 0] = 1 - lm2d[..., 0]
  113. lm2d[..., 1] = 1 - lm2d[..., 1]
  114. return lm2d
  115. def reconstruct_lm2d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
  116. """
  117. Generate 3D landmark with keypoint base!
  118. id_coeff: Tensor[T, c=80]
  119. exp_coeff: Tensor[T, c=64]
  120. """
  121. is_btc_flag = True if id_coeff.ndim == 3 else False
  122. if is_btc_flag:
  123. b,t,_ = id_coeff.shape
  124. id_coeff = id_coeff.reshape([b*t,-1])
  125. exp_coeff = exp_coeff.reshape([b*t,-1])
  126. euler = euler.reshape([b*t,-1])
  127. trans = trans.reshape([b*t,-1])
  128. id_coeff = id_coeff.to(self.key_id_base.device)
  129. exp_coeff = exp_coeff.to(self.key_id_base.device)
  130. mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
  131. id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
  132. identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
  133. expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
  134. face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
  135. face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
  136. # re-centering the face with mean_xyz, so the face will be in [-1, 1]
  137. rot = self.compute_rotation(euler)
  138. # transform
  139. lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
  140. # to camera
  141. if to_camera:
  142. lm3d[...,-1] = 10 - lm3d[...,-1]
  143. # to image_plane
  144. lm3d = lm3d @ self.persc_proj
  145. lm2d = lm3d[..., :2] / lm3d[..., 2:]
  146. # flip
  147. lm2d[..., 1] = 224 - lm2d[..., 1]
  148. lm2d /= 224
  149. if is_btc_flag:
  150. return lm2d.reshape([b,t,-1,2])
  151. return lm2d
  152. def compute_rotation(self, euler):
  153. """
  154. Return:
  155. rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
  156. Parameters:
  157. euler -- torch.tensor, size (B, 3), radian
  158. """
  159. batch_size = euler.shape[0]
  160. euler = euler.to(self.key_id_base.device)
  161. ones = torch.ones([batch_size, 1]).to(self.key_id_base.device)
  162. zeros = torch.zeros([batch_size, 1]).to(self.key_id_base.device)
  163. x, y, z = euler[:, :1], euler[:, 1:2], euler[:, 2:],
  164. rot_x = torch.cat([
  165. ones, zeros, zeros,
  166. zeros, torch.cos(x), -torch.sin(x),
  167. zeros, torch.sin(x), torch.cos(x)
  168. ], dim=1).reshape([batch_size, 3, 3])
  169. rot_y = torch.cat([
  170. torch.cos(y), zeros, torch.sin(y),
  171. zeros, ones, zeros,
  172. -torch.sin(y), zeros, torch.cos(y)
  173. ], dim=1).reshape([batch_size, 3, 3])
  174. rot_z = torch.cat([
  175. torch.cos(z), -torch.sin(z), zeros,
  176. torch.sin(z), torch.cos(z), zeros,
  177. zeros, zeros, ones
  178. ], dim=1).reshape([batch_size, 3, 3])
  179. rot = rot_z @ rot_y @ rot_x
  180. return rot.permute(0, 2, 1)
  181. def reconstruct_idexp_lm3d(self, id_coeff, exp_coeff):
  182. """
  183. Generate 3D landmark with keypoint base!
  184. id_coeff: Tensor[T, c=80]
  185. exp_coeff: Tensor[T, c=64]
  186. """
  187. id_coeff = id_coeff.to(self.key_id_base.device)
  188. exp_coeff = exp_coeff.to(self.key_id_base.device)
  189. id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
  190. identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
  191. expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
  192. face = identity_diff_face + expression_diff_face # [t,3N]
  193. face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
  194. lm3d = face * 10
  195. return lm3d
  196. def reconstruct_idexp_lm3d_np(self, id_coeff, exp_coeff):
  197. """
  198. Generate 3D landmark with keypoint base!
  199. id_coeff: Tensor[T, c=80]
  200. exp_coeff: Tensor[T, c=64]
  201. """
  202. id_base, exp_base = self.key_id_base_np, self.key_exp_base_np # [3*68, C]
  203. identity_diff_face = np.dot(id_coeff, id_base.T) # [t,c],[c,3*68] ==> [t,3*68]
  204. expression_diff_face = np.dot(exp_coeff, exp_base.T) # [t,c],[c,3*68] ==> [t,3*68]
  205. face = identity_diff_face + expression_diff_face # [t,3N]
  206. face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
  207. lm3d = face * 10
  208. return lm3d
  209. def get_eye_mouth_lm_from_lm3d(self, lm3d):
  210. eye_lm = lm3d[:, 17:48] # [T, 31, 3]
  211. mouth_lm = lm3d[:, 48:68] # [T, 20, 3]
  212. return eye_lm, mouth_lm
  213. def get_eye_mouth_lm_from_lm3d_batch(self, lm3d):
  214. eye_lm = lm3d[:, :, 17:48] # [T, 31, 3]
  215. mouth_lm = lm3d[:, :, 48:68] # [T, 20, 3]
  216. return eye_lm, mouth_lm
  217. def close_mouth_for_idexp_lm3d(self, idexp_lm3d, freeze_as_first_frame=True):
  218. idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
  219. num_frames = idexp_lm3d.shape[0]
  220. eps = 0.0
  221. # [n_landmarks=68,xyz=3], x 代表左右,y代表上下,z代表深度
  222. idexp_lm3d[:,49:54, 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 + eps * 2
  223. idexp_lm3d[:,range(59,54,-1), 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 - eps * 2
  224. idexp_lm3d[:,61:64, 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 + eps
  225. idexp_lm3d[:,range(67,64,-1), 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 - eps
  226. idexp_lm3d[:,49:54, 1] += (0.03 - idexp_lm3d[:,49:54, 1].mean(dim=1) + idexp_lm3d[:,61:64, 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
  227. idexp_lm3d[:,range(59,54,-1), 1] += (-0.03 - idexp_lm3d[:,range(59,54,-1), 1].mean(dim=1) + idexp_lm3d[:,range(67,64,-1), 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
  228. if freeze_as_first_frame:
  229. idexp_lm3d[:, 48:68,] = idexp_lm3d[0, 48:68].unsqueeze(0).clone().repeat([num_frames, 1,1])*0
  230. return idexp_lm3d.cpu()
  231. def close_eyes_for_idexp_lm3d(self, idexp_lm3d):
  232. idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
  233. eps = 0.003
  234. idexp_lm3d[:,37:39, 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 + eps
  235. idexp_lm3d[:,range(41,39,-1), 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 - eps
  236. idexp_lm3d[:,43:45, 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 + eps
  237. idexp_lm3d[:,range(47,45,-1), 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 - eps
  238. return idexp_lm3d
  239. if __name__ == '__main__':
  240. import cv2
  241. font = cv2.FONT_HERSHEY_SIMPLEX
  242. face_mesh_helper = Face3DHelper('deep_3drecon/BFM')
  243. coeff_npy = 'data/coeff_fit_mp/crop_nana_003_coeff_fit_mp.npy'
  244. coeff_dict = np.load(coeff_npy, allow_pickle=True).tolist()
  245. lm3d = face_mesh_helper.reconstruct_lm2d(torch.tensor(coeff_dict['id']).cuda(), torch.tensor(coeff_dict['exp']).cuda(), torch.tensor(coeff_dict['euler']).cuda(), torch.tensor(coeff_dict['trans']).cuda() )
  246. WH = 512
  247. lm3d = (lm3d * WH).cpu().int().numpy()
  248. eye_idx = list(range(36,48))
  249. mouth_idx = list(range(48,68))
  250. import imageio
  251. debug_name = 'debug_lm3d.mp4'
  252. writer = imageio.get_writer(debug_name, fps=25)
  253. for i_img in range(len(lm3d)):
  254. lm2d = lm3d[i_img ,:, :2] # [68, 2]
  255. img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
  256. for i in range(len(lm2d)):
  257. x, y = lm2d[i]
  258. if i in eye_idx:
  259. color = (0,0,255)
  260. elif i in mouth_idx:
  261. color = (0,255,0)
  262. else:
  263. color = (255,0,0)
  264. img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1)
  265. img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0))
  266. writer.append_data(img)
  267. writer.close()