mesh_renderer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. """This script is the differentiable renderer for Deep3DFaceRecon_pytorch
  2. Attention, antialiasing step is missing in current version.
  3. """
  4. import torch
  5. import torch.nn.functional as F
  6. import kornia
  7. from kornia.geometry.camera import pixel2cam
  8. import numpy as np
  9. from typing import List
  10. from scipy.io import loadmat
  11. from torch import nn
  12. import traceback
  13. try:
  14. import pytorch3d.ops
  15. from pytorch3d.structures import Meshes
  16. from pytorch3d.renderer import (
  17. look_at_view_transform,
  18. FoVPerspectiveCameras,
  19. DirectionalLights,
  20. RasterizationSettings,
  21. MeshRenderer,
  22. MeshRasterizer,
  23. SoftPhongShader,
  24. TexturesUV,
  25. )
  26. except:
  27. traceback.print_exc()
  28. # def ndc_projection(x=0.1, n=1.0, f=50.0):
  29. # return np.array([[n/x, 0, 0, 0],
  30. # [ 0, n/-x, 0, 0],
  31. # [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
  32. # [ 0, 0, -1, 0]]).astype(np.float32)
  33. class MeshRenderer(nn.Module):
  34. def __init__(self,
  35. rasterize_fov,
  36. znear=0.1,
  37. zfar=10,
  38. rasterize_size=224,**args):
  39. super(MeshRenderer, self).__init__()
  40. # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
  41. # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
  42. # torch.diag(torch.tensor([1., -1, -1, 1])))
  43. self.rasterize_size = rasterize_size
  44. self.fov = rasterize_fov
  45. self.znear = znear
  46. self.zfar = zfar
  47. self.rasterizer = None
  48. def forward(self, vertex, tri, feat=None):
  49. """
  50. Return:
  51. mask -- torch.tensor, size (B, 1, H, W)
  52. depth -- torch.tensor, size (B, 1, H, W)
  53. features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
  54. Parameters:
  55. vertex -- torch.tensor, size (B, N, 3)
  56. tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
  57. feat(optional) -- torch.tensor, size (B, N ,C), features
  58. """
  59. device = vertex.device
  60. rsize = int(self.rasterize_size)
  61. # ndc_proj = self.ndc_proj.to(device)
  62. # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
  63. if vertex.shape[-1] == 3:
  64. vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
  65. vertex[..., 0] = -vertex[..., 0]
  66. # vertex_ndc = vertex @ ndc_proj.t()
  67. if self.rasterizer is None:
  68. self.rasterizer = MeshRasterizer()
  69. print("create rasterizer on device cuda:%d"%device.index)
  70. # ranges = None
  71. # if isinstance(tri, List) or len(tri.shape) == 3:
  72. # vum = vertex_ndc.shape[1]
  73. # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
  74. # fstartidx = torch.cumsum(fnum, dim=0) - fnum
  75. # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
  76. # for i in range(tri.shape[0]):
  77. # tri[i] = tri[i] + i*vum
  78. # vertex_ndc = torch.cat(vertex_ndc, dim=0)
  79. # tri = torch.cat(tri, dim=0)
  80. # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
  81. tri = tri.type(torch.int32).contiguous()
  82. # rasterize
  83. cameras = FoVPerspectiveCameras(
  84. device=device,
  85. fov=self.fov,
  86. znear=self.znear,
  87. zfar=self.zfar,
  88. )
  89. raster_settings = RasterizationSettings(
  90. image_size=rsize
  91. )
  92. # print(vertex.shape, tri.shape)
  93. if tri.ndim == 2:
  94. tri = tri.unsqueeze(0)
  95. mesh = Meshes(vertex.contiguous()[...,:3], tri)
  96. fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
  97. rast_out = fragments.pix_to_face.squeeze(-1)
  98. depth = fragments.zbuf
  99. # render depth
  100. depth = depth.permute(0, 3, 1, 2)
  101. mask = (rast_out > 0).float().unsqueeze(1)
  102. depth = mask * depth
  103. image = None
  104. if feat is not None:
  105. attributes = feat.reshape(-1,3)[mesh.faces_packed()]
  106. image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
  107. fragments.bary_coords,
  108. attributes)
  109. # print(image.shape)
  110. image = image.squeeze(-2).permute(0, 3, 1, 2)
  111. image = mask * image
  112. return mask, depth, image