secc_renderer.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. from einops import rearrange
  5. from deep_3drecon.util.mesh_renderer import MeshRenderer
  6. from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
  7. class SECC_Renderer(nn.Module):
  8. def __init__(self, rasterize_size=None, device="cuda"):
  9. super().__init__()
  10. self.face_model = ParametricFaceModel('deep_3drecon/BFM')
  11. self.fov = 2 * np.arctan(self.face_model.center / self.face_model.focal) * 180 / np.pi
  12. self.znear = 5.
  13. self.zfar = 15.
  14. if rasterize_size is None:
  15. rasterize_size = 2*self.face_model.center
  16. self.face_renderer = MeshRenderer(rasterize_fov=self.fov, znear=self.znear, zfar=self.zfar, rasterize_size=rasterize_size, use_opengl=False).cuda()
  17. face_feat = np.load("deep_3drecon/ncc_code.npy", allow_pickle=True)
  18. self.face_feat = torch.tensor(face_feat.T).unsqueeze(0).to(device=device)
  19. del_index_re = np.load('deep_3drecon/bfm_right_eye_faces.npy')
  20. del_index_re = del_index_re - 1
  21. del_index_le = np.load('deep_3drecon/bfm_left_eye_faces.npy')
  22. del_index_le = del_index_le - 1
  23. face_buf_list = []
  24. for i in range(self.face_model.face_buf.shape[0]):
  25. if i not in del_index_re and i not in del_index_le:
  26. face_buf_list.append(self.face_model.face_buf[i])
  27. face_buf_arr = np.array(face_buf_list)
  28. self.face_buf = torch.tensor(face_buf_arr).to(device=device)
  29. def forward(self, id, exp, euler, trans):
  30. """
  31. id, exp, euler, euler: [B, C] or [B, T, C]
  32. return:
  33. MASK: [B, 1, 512, 512], value[0. or 1.0], 1.0 denotes is face
  34. SECC MAP: [B, 3, 512, 512], value[0~1]
  35. if input is BTC format, return [B, C, T, H, W]
  36. """
  37. bs = id.shape[0]
  38. is_btc_flag = id.ndim == 3
  39. if is_btc_flag:
  40. t = id.shape[1]
  41. bs = bs * t
  42. id, exp, euler, trans = id.reshape([bs,-1]), exp.reshape([bs,-1]), euler.reshape([bs,-1]), trans.reshape([bs,-1])
  43. face_vertex = self.face_model.compute_face_vertex(id, exp, euler, trans)
  44. face_mask, _, secc_face = self.face_renderer(
  45. face_vertex, self.face_buf.unsqueeze(0).repeat([bs, 1, 1]), feat=self.face_feat.repeat([bs,1,1]))
  46. secc_face = (secc_face - 0.5) / 0.5 # scale to -1~1
  47. if is_btc_flag:
  48. bs = bs // t
  49. face_mask = rearrange(face_mask, "(n t) c h w -> n c t h w", n=bs, t=t)
  50. secc_face = rearrange(secc_face, "(n t) c h w -> n c t h w", n=bs, t=t)
  51. return face_mask, secc_face
  52. if __name__ == '__main__':
  53. import imageio
  54. renderer = SECC_Renderer(rasterize_size=512)
  55. ret = np.load("data/processed/videos/May/vid_coeff_fit.npy", allow_pickle=True).tolist()
  56. idx = 6
  57. id = torch.tensor(ret['id']).cuda()[idx:idx+1]
  58. exp = torch.tensor(ret['exp']).cuda()[idx:idx+1]
  59. angle = torch.tensor(ret['euler']).cuda()[idx:idx+1]
  60. trans = torch.tensor(ret['trans']).cuda()[idx:idx+1]
  61. mask, secc = renderer(id, exp, angle*0, trans*0) # [1, 1, 512, 512], [1, 3, 512, 512]
  62. out_mask = mask[0].permute(1,2,0)
  63. out_mask = (out_mask * 127.5 + 127.5).int().cpu().numpy()
  64. imageio.imwrite("out_mask.png", out_mask)
  65. out_img = secc[0].permute(1,2,0)
  66. out_img = (out_img * 127.5 + 127.5).int().cpu().numpy()
  67. imageio.imwrite("out_secc.png", out_img)