bfm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. """This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
  2. """
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from scipy.io import loadmat
  7. import os
  8. # from utils.commons.tensor_utils import convert_like
  9. def perspective_projection(focal, center):
  10. # return p.T (N, 3) @ (3, 3)
  11. return np.array([
  12. focal, 0, center,
  13. 0, focal, center,
  14. 0, 0, 1
  15. ]).reshape([3, 3]).astype(np.float32).transpose() # 注意这里的transpose!
  16. class SH:
  17. def __init__(self):
  18. self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
  19. self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
  20. class ParametricFaceModel:
  21. def __init__(self,
  22. bfm_folder='./BFM',
  23. recenter=True,
  24. camera_distance=10.,
  25. init_lit=np.array([
  26. 0.8, 0, 0, 0, 0, 0, 0, 0, 0
  27. ]),
  28. focal=1015.,
  29. center=112.,
  30. is_train=True,
  31. default_name='BFM_model_front.mat',
  32. keypoint_mode='mediapipe'):
  33. model = loadmat(os.path.join(bfm_folder, default_name))
  34. # mean face shape. [3*N,1]
  35. self.mean_shape = model['meanshape'].astype(np.float32)
  36. # identity basis. [3*N,80]
  37. self.id_base = model['idBase'].astype(np.float32)
  38. # expression basis. [3*N,64]
  39. self.exp_base = model['exBase'].astype(np.float32)
  40. # mean face texture. [3*N,1] (0-255)
  41. self.mean_tex = model['meantex'].astype(np.float32)
  42. # texture basis. [3*N,80]
  43. self.tex_base = model['texBase'].astype(np.float32)
  44. # face indices for each vertex that lies in. starts from 0. [N,8]
  45. self.point_buf = model['point_buf'].astype(np.int64) - 1
  46. # vertex indices for each face. starts from 0. [F,3]
  47. self.face_buf = model['tri'].astype(np.int64) - 1
  48. # vertex indices for 68 landmarks. starts from 0. [68,1]
  49. if keypoint_mode == 'mediapipe':
  50. self.keypoints = np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)
  51. unmatch_mask = self.keypoints < 0
  52. self.keypoints[unmatch_mask] = 0
  53. else:
  54. self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
  55. if is_train:
  56. # vertex indices for small face region to compute photometric error. starts from 0.
  57. self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
  58. # vertex indices for each face from small face region. starts from 0. [f,3]
  59. self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
  60. # vertex indices for pre-defined skin region to compute reflectance loss
  61. self.skin_mask = np.squeeze(model['skinmask'])
  62. if recenter:
  63. mean_shape = self.mean_shape.reshape([-1, 3])
  64. mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
  65. self.mean_shape = mean_shape.reshape([-1, 1])
  66. self.key_mean_shape = self.mean_shape.reshape([-1, 3])[self.keypoints, :].reshape([-1, 3])
  67. self.key_id_base = self.id_base.reshape([-1, 3,80])[self.keypoints, :].reshape([-1, 80])
  68. self.key_exp_base = self.exp_base.reshape([-1, 3, 64])[self.keypoints, :].reshape([-1, 64])
  69. self.focal = focal
  70. self.center = center
  71. self.persc_proj = perspective_projection(focal, center)
  72. self.device = 'cpu'
  73. self.camera_distance = camera_distance
  74. self.SH = SH()
  75. self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
  76. self.initialized = False
  77. def to(self, device):
  78. self.device = device
  79. for key, value in self.__dict__.items():
  80. if type(value).__module__ == np.__name__:
  81. setattr(self, key, torch.tensor(value).to(device))
  82. self.initialized = True
  83. return self
  84. def compute_shape(self, id_coeff, exp_coeff):
  85. """
  86. Return:
  87. face_shape -- torch.tensor, size (B, N, 3)
  88. Parameters:
  89. id_coeff -- torch.tensor, size (B, 80), identity coeffs
  90. exp_coeff -- torch.tensor, size (B, 64), expression coeffs
  91. """
  92. batch_size = id_coeff.shape[0]
  93. id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
  94. exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
  95. face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
  96. return face_shape.reshape([batch_size, -1, 3])
  97. def compute_key_shape(self, id_coeff, exp_coeff):
  98. """
  99. Return:
  100. face_shape -- torch.tensor, size (B, N, 3)
  101. Parameters:
  102. id_coeff -- torch.tensor, size (B, 80), identity coeffs
  103. exp_coeff -- torch.tensor, size (B, 64), expression coeffs
  104. """
  105. batch_size = id_coeff.shape[0]
  106. id_part = torch.einsum('ij,aj->ai', self.key_id_base, id_coeff)
  107. exp_part = torch.einsum('ij,aj->ai', self.key_exp_base, exp_coeff)
  108. face_shape = id_part + exp_part + self.key_mean_shape.reshape([1, -1])
  109. return face_shape.reshape([batch_size, -1, 3])
  110. def compute_texture(self, tex_coeff, normalize=True):
  111. """
  112. Return:
  113. face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
  114. Parameters:
  115. tex_coeff -- torch.tensor, size (B, 80)
  116. """
  117. batch_size = tex_coeff.shape[0]
  118. face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
  119. if normalize:
  120. face_texture = face_texture / 255.
  121. return face_texture.reshape([batch_size, -1, 3])
  122. def compute_norm(self, face_shape):
  123. """
  124. Return:
  125. vertex_norm -- torch.tensor, size (B, N, 3)
  126. Parameters:
  127. face_shape -- torch.tensor, size (B, N, 3)
  128. """
  129. v1 = face_shape[:, self.face_buf[:, 0]]
  130. v2 = face_shape[:, self.face_buf[:, 1]]
  131. v3 = face_shape[:, self.face_buf[:, 2]]
  132. e1 = v1 - v2
  133. e2 = v2 - v3
  134. face_norm = torch.cross(e1, e2, dim=-1)
  135. face_norm = F.normalize(face_norm, dim=-1, p=2)
  136. face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
  137. vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
  138. vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
  139. return vertex_norm
  140. def compute_color(self, face_texture, face_norm, gamma):
  141. """
  142. Return:
  143. face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
  144. Parameters:
  145. face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
  146. face_norm -- torch.tensor, size (B, N, 3), rotated face normal
  147. gamma -- torch.tensor, size (B, 27), SH coeffs
  148. """
  149. batch_size = gamma.shape[0]
  150. v_num = face_texture.shape[1]
  151. a, c = self.SH.a, self.SH.c
  152. gamma = gamma.reshape([batch_size, 3, 9])
  153. gamma = gamma + self.init_lit
  154. gamma = gamma.permute(0, 2, 1)
  155. Y = torch.cat([
  156. a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
  157. -a[1] * c[1] * face_norm[..., 1:2],
  158. a[1] * c[1] * face_norm[..., 2:],
  159. -a[1] * c[1] * face_norm[..., :1],
  160. a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
  161. -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
  162. 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
  163. -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
  164. 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
  165. ], dim=-1)
  166. r = Y @ gamma[..., :1]
  167. g = Y @ gamma[..., 1:2]
  168. b = Y @ gamma[..., 2:]
  169. face_color = torch.cat([r, g, b], dim=-1) * face_texture
  170. return face_color
  171. @staticmethod
  172. def compute_rotation(angles, device='cpu'):
  173. """
  174. Return:
  175. rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
  176. Parameters:
  177. angles -- torch.tensor, size (B, 3), radian
  178. """
  179. batch_size = angles.shape[0]
  180. angles = angles.to(device)
  181. ones = torch.ones([batch_size, 1]).to(device)
  182. zeros = torch.zeros([batch_size, 1]).to(device)
  183. x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
  184. rot_x = torch.cat([
  185. ones, zeros, zeros,
  186. zeros, torch.cos(x), -torch.sin(x),
  187. zeros, torch.sin(x), torch.cos(x)
  188. ], dim=1).reshape([batch_size, 3, 3])
  189. rot_y = torch.cat([
  190. torch.cos(y), zeros, torch.sin(y),
  191. zeros, ones, zeros,
  192. -torch.sin(y), zeros, torch.cos(y)
  193. ], dim=1).reshape([batch_size, 3, 3])
  194. rot_z = torch.cat([
  195. torch.cos(z), -torch.sin(z), zeros,
  196. torch.sin(z), torch.cos(z), zeros,
  197. zeros, zeros, ones
  198. ], dim=1).reshape([batch_size, 3, 3])
  199. rot = rot_z @ rot_y @ rot_x
  200. return rot.permute(0, 2, 1)
  201. def to_camera(self, face_shape):
  202. face_shape[..., -1] = self.camera_distance - face_shape[..., -1] # reverse the depth axis, add a fixed offset of length
  203. return face_shape
  204. def to_image(self, face_shape):
  205. """
  206. Return:
  207. face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
  208. Parameters:
  209. face_shape -- torch.tensor, size (B, N, 3)
  210. """
  211. # to image_plane
  212. face_proj = face_shape @ self.persc_proj
  213. face_proj = face_proj[..., :2] / face_proj[..., 2:]
  214. return face_proj
  215. def transform(self, face_shape, rot, trans):
  216. """
  217. Return:
  218. face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
  219. Parameters:
  220. face_shape -- torch.tensor, si≥ze (B, N, 3)
  221. rot -- torch.tensor, size (B, 3, 3)
  222. trans -- torch.tensor, size (B, 3)
  223. """
  224. return face_shape @ rot + trans.unsqueeze(1)
  225. def get_landmarks(self, face_proj):
  226. """
  227. Return:
  228. face_lms -- torch.tensor, size (B, 68, 2)
  229. Parameters:
  230. face_proj -- torch.tensor, size (B, N, 2)
  231. """
  232. return face_proj[:, self.keypoints]
  233. def split_coeff(self, coeffs):
  234. """
  235. Return:
  236. coeffs_dict -- a dict of torch.tensors
  237. Parameters:
  238. coeffs -- torch.tensor, size (B, 256)
  239. """
  240. id_coeffs = coeffs[:, :80]
  241. exp_coeffs = coeffs[:, 80: 144]
  242. tex_coeffs = coeffs[:, 144: 224]
  243. angles = coeffs[:, 224: 227]
  244. gammas = coeffs[:, 227: 254]
  245. translations = coeffs[:, 254:]
  246. return {
  247. 'id': id_coeffs,
  248. 'exp': exp_coeffs,
  249. 'tex': tex_coeffs,
  250. 'angle': angles,
  251. 'gamma': gammas,
  252. 'trans': translations
  253. }
  254. def compute_for_render(self, coeffs):
  255. """
  256. Return:
  257. face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
  258. face_color -- torch.tensor, size (B, N, 3), in RGB order
  259. landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
  260. Parameters:
  261. coeffs -- torch.tensor, size (B, 257)
  262. """
  263. coef_dict = self.split_coeff(coeffs)
  264. face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
  265. rotation = self.compute_rotation(coef_dict['angle'], device=self.device)
  266. face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
  267. face_vertex = self.to_camera(face_shape_transformed)
  268. face_proj = self.to_image(face_vertex)
  269. landmark = self.get_landmarks(face_proj)
  270. face_texture = self.compute_texture(coef_dict['tex'])
  271. face_norm = self.compute_norm(face_shape)
  272. face_norm_roted = face_norm @ rotation
  273. face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
  274. return face_vertex, face_texture, face_color, landmark
  275. def compute_face_vertex(self, id, exp, angle, trans):
  276. """
  277. Return:
  278. face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
  279. face_color -- torch.tensor, size (B, N, 3), in RGB order
  280. landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
  281. Parameters:
  282. coeffs -- torch.tensor, size (B, 257)
  283. """
  284. if not self.initialized:
  285. self.to(id.device)
  286. face_shape = self.compute_shape(id, exp)
  287. rotation = self.compute_rotation(angle, device=self.device)
  288. face_shape_transformed = self.transform(face_shape, rotation, trans)
  289. face_vertex = self.to_camera(face_shape_transformed)
  290. return face_vertex
  291. def compute_for_landmark_fit(self, id, exp, angles, trans, ret=None):
  292. """
  293. Return:
  294. face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
  295. face_color -- torch.tensor, size (B, N, 3), in RGB order
  296. landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
  297. Parameters:
  298. coeffs -- torch.tensor, size (B, 257)
  299. """
  300. face_shape = self.compute_key_shape(id, exp)
  301. rotation = self.compute_rotation(angles, device=self.device)
  302. face_shape_transformed = self.transform(face_shape, rotation, trans)
  303. face_vertex = self.to_camera(face_shape_transformed)
  304. face_proj = self.to_image(face_vertex)
  305. landmark = face_proj
  306. return landmark
  307. def compute_for_landmark_fit_nerf(self, id, exp, angles, trans, ret=None):
  308. """
  309. Return:
  310. face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
  311. face_color -- torch.tensor, size (B, N, 3), in RGB order
  312. landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
  313. Parameters:
  314. coeffs -- torch.tensor, size (B, 257)
  315. """
  316. face_shape = self.compute_key_shape(id, exp)
  317. rotation = self.compute_rotation(angles, device=self.device)
  318. face_shape_transformed = self.transform(face_shape, rotation, trans)
  319. face_vertex = face_shape_transformed # no to_camera
  320. face_proj = self.to_image(face_vertex)
  321. landmark = face_proj
  322. return landmark
  323. # def compute_for_landmark_fit(self, id, exp, angles, trans, ret={}):
  324. # """
  325. # Return:
  326. # face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
  327. # face_color -- torch.tensor, size (B, N, 3), in RGB order
  328. # landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
  329. # Parameters:
  330. # coeffs -- torch.tensor, size (B, 257)
  331. # """
  332. # face_shape = self.compute_shape(id, exp)
  333. # rotation = self.compute_rotation(angles)
  334. # face_shape_transformed = self.transform(face_shape, rotation, trans)
  335. # face_vertex = self.to_camera(face_shape_transformed)
  336. # face_proj = self.to_image(face_vertex)
  337. # landmark = self.get_landmarks(face_proj)
  338. # return landmark
  339. def compute_for_render_fit(self, id, exp, angles, trans, tex, gamma):
  340. """
  341. Return:
  342. face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
  343. face_color -- torch.tensor, size (B, N, 3), in RGB order
  344. landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
  345. Parameters:
  346. coeffs -- torch.tensor, size (B, 257)
  347. """
  348. face_shape = self.compute_shape(id, exp)
  349. rotation = self.compute_rotation(angles, device=self.device)
  350. face_shape_transformed = self.transform(face_shape, rotation, trans)
  351. face_vertex = self.to_camera(face_shape_transformed)
  352. face_proj = self.to_image(face_vertex)
  353. landmark = self.get_landmarks(face_proj)
  354. face_texture = self.compute_texture(tex)
  355. face_norm = self.compute_norm(face_shape)
  356. face_norm_roted = face_norm @ rotation
  357. face_color = self.compute_color(face_texture, face_norm_roted, gamma)
  358. return face_color, face_vertex, landmark