radnerf_torso.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import random
  5. import modules.radnerfs.raymarching as raymarching
  6. from modules.radnerfs.encoders.encoding import get_encoder
  7. from modules.radnerfs.renderer import NeRFRenderer
  8. from modules.radnerfs.radnerf import RADNeRF
  9. from modules.radnerfs.cond_encoder import AudioNet, AudioAttNet, MLP
  10. from modules.radnerfs.utils import trunc_exp
  11. from modules.radnerfs.utils import custom_meshgrid, convert_poses
  12. from utils.commons.hparams import hparams
  13. class RADNeRFTorso(RADNeRF):
  14. def __init__(self, hparams):
  15. super().__init__(hparams)
  16. density_grid_torso = torch.zeros([self.grid_size ** 2]) # [H * H]
  17. self.register_buffer('density_grid_torso', density_grid_torso)
  18. self.mean_density_torso = 0
  19. self.density_thresh_torso = hparams['density_thresh_torso']
  20. self.torso_individual_embedding_num = hparams['individual_embedding_num']
  21. self.torso_individual_embedding_dim = hparams['torso_individual_embedding_dim']
  22. if self.torso_individual_embedding_dim > 0:
  23. self.torso_individual_codes = nn.Parameter(torch.randn(self.torso_individual_embedding_num, self.torso_individual_embedding_dim) * 0.1)
  24. self.torso_pose_embedder, self.pose_embedding_dim = get_encoder('frequency', input_dim=6, multires=4)
  25. self.torso_deform_pos_embedder, self.torso_deform_pos_dim = get_encoder('frequency', input_dim=2, multires=10) # input 2D position
  26. self.torso_embedder, self.torso_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048)
  27. deform_net_in_dim = self.torso_deform_pos_dim + self.pose_embedding_dim + self.torso_individual_embedding_dim
  28. canonicial_net_in_dim = self.torso_in_dim + self.torso_deform_pos_dim + self.pose_embedding_dim + self.torso_individual_embedding_dim
  29. if hparams['torso_head_aware']:
  30. head_aware_out_dim = 16
  31. self.head_color_weights_encoder = nn.Sequential(*[
  32. nn.Linear(3+1, 16, bias=True),
  33. nn.LeakyReLU(0.02, True),
  34. nn.Linear(16, 32, bias=True),
  35. nn.LeakyReLU(0.02, True),
  36. nn.Linear(32, head_aware_out_dim, bias=True),
  37. ])
  38. deform_net_in_dim += head_aware_out_dim
  39. canonicial_net_in_dim += head_aware_out_dim
  40. self.torso_deform_net = MLP(deform_net_in_dim, 2, 64, 3)
  41. self.torso_canonicial_net = MLP(canonicial_net_in_dim, 4, 32, 3)
  42. def forward_torso(self, x, poses, c=None, image=None, weights_sum=None):
  43. # x: [N, 2] in [-1, 1]
  44. # head poses: [1, 6]
  45. # c: [1, ind_dim], individual code
  46. # test: shrink x
  47. x = x * hparams['torso_shrink']
  48. # deformation-based
  49. enc_pose = self.torso_pose_embedder(poses)
  50. enc_x = self.torso_deform_pos_embedder(x)
  51. if c is not None:
  52. h = torch.cat([enc_x, enc_pose.repeat(x.shape[0], 1), c.repeat(x.shape[0], 1)], dim=-1)
  53. else:
  54. h = torch.cat([enc_x, enc_pose.repeat(x.shape[0], 1)], dim=-1)
  55. if hparams['torso_head_aware']:
  56. if image is None:
  57. image = torch.zeros([x.shape[0],3], dtype=h.dtype, device=h.device)
  58. weights_sum = torch.zeros([x.shape[0],1], dtype=h.dtype, device=h.device)
  59. head_color_weights_inp = torch.cat([image, weights_sum],dim=-1)
  60. head_color_weights_encoding = self.head_color_weights_encoder(head_color_weights_inp)
  61. h = torch.cat([h, head_color_weights_encoding],dim=-1)
  62. dx = self.torso_deform_net(h)
  63. x = (x + dx).clamp(-1, 1).float()
  64. x = self.torso_embedder(x, bound=1)
  65. h = torch.cat([x, h], dim=-1)
  66. h = self.torso_canonicial_net(h)
  67. alpha = torch.sigmoid(h[..., :1])
  68. color = torch.sigmoid(h[..., 1:])
  69. return alpha, color, dx
  70. def render(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
  71. # rays_o, rays_d: [B, N, 3], assumes B == 1
  72. # cond: [B, 29, 16]
  73. # bg_coords: [1, N, 2]
  74. # return: pred_rgb: [B, N, 3]
  75. ### run head nerf with no_grad to get the renderred head
  76. with torch.no_grad():
  77. prefix = rays_o.shape[:-1]
  78. rays_o = rays_o.contiguous().view(-1, 3)
  79. rays_d = rays_d.contiguous().view(-1, 3)
  80. bg_coords = bg_coords.contiguous().view(-1, 2)
  81. N = rays_o.shape[0] # N = B * N, in fact
  82. device = rays_o.device
  83. results = {}
  84. # pre-calculate near far
  85. nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
  86. nears = nears.detach()
  87. fars = fars.detach()
  88. # encode audio
  89. cond_feat = self.cal_cond_feat(cond) # [1, 64]
  90. if self.individual_embedding_dim > 0:
  91. if self.training:
  92. ind_code = self.individual_embeddings[index]
  93. # use a fixed ind code for the unknown test data.
  94. else:
  95. ind_code = self.individual_embeddings[0]
  96. else:
  97. ind_code = None
  98. if self.training:
  99. # setup counter
  100. counter = self.step_counter[self.local_step % 16]
  101. counter.zero_() # set to 0
  102. self.local_step += 1
  103. xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
  104. # xyzs, dirs, deltas, rays, points2rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
  105. sigmas, rgbs, ambient = self(xyzs, dirs, cond_feat, ind_code)
  106. sigmas = self.density_scale * sigmas
  107. #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
  108. weights_sum, ambient_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ambient.abs().sum(-1), deltas, rays)
  109. # for training only
  110. results['weights_sum'] = weights_sum
  111. results['ambient'] = ambient_sum
  112. else:
  113. dtype = torch.float32
  114. weights_sum = torch.zeros(N, dtype=dtype, device=device)
  115. depth = torch.zeros(N, dtype=dtype, device=device)
  116. image = torch.zeros(N, 3, dtype=dtype, device=device)
  117. n_alive = N
  118. rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
  119. rays_t = nears.clone() # [N]
  120. step = 0
  121. while step < max_steps:
  122. # count alive rays
  123. n_alive = rays_alive.shape[0]
  124. # exit loop
  125. if n_alive <= 0:
  126. break
  127. # decide compact_steps
  128. n_step = max(min(N // n_alive, 8), 1)
  129. xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
  130. sigmas, rgbs, ambient = self(xyzs, dirs, cond_feat, ind_code)
  131. sigmas = self.density_scale * sigmas
  132. raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
  133. rays_alive = rays_alive[rays_alive >= 0]
  134. step += n_step
  135. # background
  136. if bg_color is None:
  137. bg_color = 1
  138. ### Start Rendering Torso
  139. if self.torso_individual_embedding_dim > 0:
  140. if self.training:
  141. torso_individual_code = self.torso_individual_codes[index]
  142. # use a fixed ind code for the unknown test data.
  143. else:
  144. torso_individual_code = self.torso_individual_codes[0]
  145. else:
  146. torso_individual_code = None
  147. # 2D density grid for acceleration...
  148. density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso)
  149. occupancy = F.grid_sample(self.density_grid_torso.view(1, 1, self.grid_size, self.grid_size), bg_coords.view(1, -1, 1, 2), align_corners=True).view(-1)
  150. mask = occupancy > density_thresh_torso
  151. # masked query of torso
  152. torso_alpha = torch.zeros([N, 1], device=device)
  153. torso_color = torch.zeros([N, 3], device=device)
  154. if mask.any():
  155. if hparams['torso_head_aware']:
  156. if random.random() < 0.5:
  157. torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, torso_individual_code, image[mask], weights_sum.unsqueeze(-1)[mask])
  158. else:
  159. torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, torso_individual_code, None, None)
  160. else:
  161. torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, torso_individual_code)
  162. torso_alpha[mask] = torso_alpha_mask.float()
  163. torso_color[mask] = torso_color_mask.float()
  164. results['deform'] = deform
  165. # first mix torso with background
  166. bg_color = torso_color * torso_alpha + bg_color * (1 - torso_alpha)
  167. results['torso_alpha_map'] = torso_alpha
  168. results['torso_rgb_map'] = bg_color
  169. # then mix the head image with the torso_bg
  170. image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
  171. image = image.view(*prefix, 3)
  172. image = image.clamp(0, 1)
  173. depth = torch.clamp(depth - nears, min=0) / (fars - nears)
  174. depth = depth.view(*prefix)
  175. results['depth_map'] = depth
  176. results['rgb_map'] = image # head_image if train, else com_image
  177. return results
  178. @torch.no_grad()
  179. def update_extra_state(self, decay=0.95, S=128):
  180. # forbid updating head if is training torso...
  181. # only update torso density grid
  182. tmp_grid_torso = torch.zeros_like(self.density_grid_torso)
  183. # random pose, random ind_code
  184. rand_idx = random.randint(0, self.poses.shape[0] - 1)
  185. pose = convert_poses(self.poses[[rand_idx]]).to(self.density_bitfield.device)
  186. if self.torso_individual_embedding_dim > 0:
  187. ind_code = self.torso_individual_codes[[rand_idx]]
  188. else:
  189. ind_code = None
  190. X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
  191. Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
  192. half_grid_size = 1 / self.grid_size
  193. for xs in X:
  194. for ys in Y:
  195. xx, yy = custom_meshgrid(xs, ys)
  196. coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], dim=-1) # [N, 2], in [0, 128)
  197. indices = (coords[:, 1] * self.grid_size + coords[:, 0]).long() # NOTE: xy transposed!
  198. xys = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 2] in [-1, 1]
  199. xys = xys * (1 - half_grid_size)
  200. # add noise in [-hgs, hgs]
  201. xys += (torch.rand_like(xys) * 2 - 1) * half_grid_size
  202. # query density
  203. alphas, _, _ = self.forward_torso(xys, pose, ind_code) # [N, 1]
  204. # assign
  205. tmp_grid_torso[indices] = alphas.squeeze(1).float()
  206. # dilate
  207. tmp_grid_torso = tmp_grid_torso.view(1, 1, self.grid_size, self.grid_size)
  208. tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=5, stride=1, padding=2)
  209. tmp_grid_torso = tmp_grid_torso.view(-1)
  210. self.density_grid_torso = torch.maximum(self.density_grid_torso * decay, tmp_grid_torso)
  211. self.mean_density_torso = torch.mean(self.density_grid_torso).item()